summaryrefslogtreecommitdiff
path: root/chromium/net/socket
diff options
context:
space:
mode:
authorZeno Albisser <zeno.albisser@theqtcompany.com>2014-12-05 15:04:29 +0100
committerAndras Becsi <andras.becsi@theqtcompany.com>2014-12-09 10:49:28 +0100
commitaf6588f8d723931a298c995fa97259bb7f7deb55 (patch)
tree060ca707847ba1735f01af2372e0d5e494dc0366 /chromium/net/socket
parent2fff84d821cc7b1c785f6404e0f8091333283e74 (diff)
downloadqtwebengine-chromium-af6588f8d723931a298c995fa97259bb7f7deb55.tar.gz
BASELINE: Update chromium to 40.0.2214.28 and ninja to 1.5.3.
Change-Id: I759465284fd64d59ad120219cbe257f7402c4181 Reviewed-by: Andras Becsi <andras.becsi@theqtcompany.com>
Diffstat (limited to 'chromium/net/socket')
-rw-r--r--chromium/net/socket/client_socket_factory.cc22
-rw-r--r--chromium/net/socket/client_socket_pool.cc4
-rw-r--r--chromium/net/socket/client_socket_pool.h2
-rw-r--r--chromium/net/socket/client_socket_pool_base.h13
-rw-r--r--chromium/net/socket/client_socket_pool_base_unittest.cc168
-rw-r--r--chromium/net/socket/client_socket_pool_manager.cc60
-rw-r--r--chromium/net/socket/client_socket_pool_manager_impl.cc70
-rw-r--r--chromium/net/socket/client_socket_pool_manager_impl.h40
-rw-r--r--chromium/net/socket/deterministic_socket_data_unittest.cc14
-rw-r--r--chromium/net/socket/mock_client_socket_pool_manager.h24
-rw-r--r--chromium/net/socket/next_proto.cc22
-rw-r--r--chromium/net/socket/next_proto.h19
-rw-r--r--chromium/net/socket/nss_ssl_util.cc23
-rw-r--r--chromium/net/socket/nss_ssl_util.h6
-rw-r--r--chromium/net/socket/openssl_ssl_util.cc156
-rw-r--r--chromium/net/socket/openssl_ssl_util.h32
-rw-r--r--chromium/net/socket/server_socket.cc30
-rw-r--r--chromium/net/socket/server_socket.h20
-rw-r--r--chromium/net/socket/socket_libevent.cc482
-rw-r--r--chromium/net/socket/socket_libevent.h132
-rw-r--r--chromium/net/socket/socket_test_util.cc179
-rw-r--r--chromium/net/socket/socket_test_util.h436
-rw-r--r--chromium/net/socket/socks5_client_socket.h50
-rw-r--r--chromium/net/socket/socks5_client_socket_unittest.cc2
-rw-r--r--chromium/net/socket/socks_client_socket.h50
-rw-r--r--chromium/net/socket/socks_client_socket_pool.h76
-rw-r--r--chromium/net/socket/socks_client_socket_pool_unittest.cc6
-rw-r--r--chromium/net/socket/socks_client_socket_unittest.cc57
-rw-r--r--chromium/net/socket/ssl_client_socket.cc143
-rw-r--r--chromium/net/socket/ssl_client_socket.h88
-rw-r--r--chromium/net/socket/ssl_client_socket_nss.cc369
-rw-r--r--chromium/net/socket/ssl_client_socket_nss.h80
-rw-r--r--chromium/net/socket/ssl_client_socket_openssl.cc955
-rw-r--r--chromium/net/socket/ssl_client_socket_openssl.h174
-rw-r--r--chromium/net/socket/ssl_client_socket_openssl_unittest.cc34
-rw-r--r--chromium/net/socket/ssl_client_socket_pool.cc313
-rw-r--r--chromium/net/socket/ssl_client_socket_pool.h249
-rw-r--r--chromium/net/socket/ssl_client_socket_pool_unittest.cc506
-rw-r--r--chromium/net/socket/ssl_client_socket_unittest.cc819
-rw-r--r--chromium/net/socket/ssl_error_params.cc31
-rw-r--r--chromium/net/socket/ssl_error_params.h18
-rw-r--r--chromium/net/socket/ssl_server_socket.h2
-rw-r--r--chromium/net/socket/ssl_server_socket_nss.cc7
-rw-r--r--chromium/net/socket/ssl_server_socket_nss.h58
-rw-r--r--chromium/net/socket/ssl_server_socket_openssl.cc38
-rw-r--r--chromium/net/socket/ssl_server_socket_openssl.h58
-rw-r--r--chromium/net/socket/ssl_server_socket_unittest.cc227
-rw-r--r--chromium/net/socket/ssl_session_cache_openssl.cc34
-rw-r--r--chromium/net/socket/ssl_session_cache_openssl.h3
-rw-r--r--chromium/net/socket/ssl_session_cache_openssl_unittest.cc14
-rw-r--r--chromium/net/socket/ssl_socket.h2
-rw-r--r--chromium/net/socket/stream_listen_socket.cc5
-rw-r--r--chromium/net/socket/stream_listen_socket.h6
-rw-r--r--chromium/net/socket/stream_socket.h7
-rw-r--r--chromium/net/socket/tcp_client_socket.cc9
-rw-r--r--chromium/net/socket/tcp_client_socket.h45
-rw-r--r--chromium/net/socket/tcp_listen_socket.cc5
-rw-r--r--chromium/net/socket/tcp_listen_socket.h10
-rw-r--r--chromium/net/socket/tcp_listen_socket_unittest.cc4
-rw-r--r--chromium/net/socket/tcp_listen_socket_unittest.h13
-rw-r--r--chromium/net/socket/tcp_server_socket.h10
-rw-r--r--chromium/net/socket/tcp_server_socket_unittest.cc8
-rw-r--r--chromium/net/socket/tcp_socket.cc82
-rw-r--r--chromium/net/socket/tcp_socket.h18
-rw-r--r--chromium/net/socket/tcp_socket_libevent.cc955
-rw-r--r--chromium/net/socket/tcp_socket_libevent.h211
-rw-r--r--chromium/net/socket/tcp_socket_win.cc35
-rw-r--r--chromium/net/socket/tcp_socket_win.h9
-rw-r--r--chromium/net/socket/transport_client_socket_pool.cc347
-rw-r--r--chromium/net/socket/transport_client_socket_pool.h251
-rw-r--r--chromium/net/socket/transport_client_socket_pool_test_util.cc424
-rw-r--r--chromium/net/socket/transport_client_socket_pool_test_util.h127
-rw-r--r--chromium/net/socket/transport_client_socket_pool_unittest.cc645
-rw-r--r--chromium/net/socket/transport_client_socket_unittest.cc10
-rw-r--r--chromium/net/socket/unix_domain_client_socket_posix.cc171
-rw-r--r--chromium/net/socket/unix_domain_client_socket_posix.h87
-rw-r--r--chromium/net/socket/unix_domain_client_socket_posix_unittest.cc446
-rw-r--r--chromium/net/socket/unix_domain_listen_socket_posix.cc167
-rw-r--r--chromium/net/socket/unix_domain_listen_socket_posix.h122
-rw-r--r--chromium/net/socket/unix_domain_listen_socket_posix_unittest.cc (renamed from chromium/net/socket/unix_domain_socket_posix_unittest.cc)133
-rw-r--r--chromium/net/socket/unix_domain_server_socket_posix.cc186
-rw-r--r--chromium/net/socket/unix_domain_server_socket_posix.h91
-rw-r--r--chromium/net/socket/unix_domain_server_socket_posix_unittest.cc125
-rw-r--r--chromium/net/socket/unix_domain_socket_posix.cc196
-rw-r--r--chromium/net/socket/unix_domain_socket_posix.h126
-rw-r--r--chromium/net/socket/websocket_endpoint_lock_manager.cc132
-rw-r--r--chromium/net/socket/websocket_endpoint_lock_manager.h121
-rw-r--r--chromium/net/socket/websocket_endpoint_lock_manager_unittest.cc224
-rw-r--r--chromium/net/socket/websocket_transport_client_socket_pool.cc651
-rw-r--r--chromium/net/socket/websocket_transport_client_socket_pool.h246
-rw-r--r--chromium/net/socket/websocket_transport_client_socket_pool_unittest.cc1143
-rw-r--r--chromium/net/socket/websocket_transport_connect_sub_job.cc170
-rw-r--r--chromium/net/socket/websocket_transport_connect_sub_job.h90
93 files changed, 10318 insertions, 3962 deletions
diff --git a/chromium/net/socket/client_socket_factory.cc b/chromium/net/socket/client_socket_factory.cc
index 953914581fc..51aea715f4d 100644
--- a/chromium/net/socket/client_socket_factory.cc
+++ b/chromium/net/socket/client_socket_factory.cc
@@ -50,45 +50,45 @@ class DefaultClientSocketFactory : public ClientSocketFactory,
CertDatabase::GetInstance()->AddObserver(this);
}
- virtual ~DefaultClientSocketFactory() {
+ ~DefaultClientSocketFactory() override {
// Note: This code never runs, as the factory is defined as a Leaky
// singleton.
CertDatabase::GetInstance()->RemoveObserver(this);
}
- virtual void OnCertAdded(const X509Certificate* cert) OVERRIDE {
+ void OnCertAdded(const X509Certificate* cert) override {
ClearSSLSessionCache();
}
- virtual void OnCACertChanged(const X509Certificate* cert) OVERRIDE {
+ void OnCACertChanged(const X509Certificate* cert) override {
// Per wtc, we actually only need to flush when trust is reduced.
// Always flush now because OnCACertChanged does not tell us this.
// See comments in ClientSocketPoolManager::OnCACertChanged.
ClearSSLSessionCache();
}
- virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket(
+ scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket(
DatagramSocket::BindType bind_type,
const RandIntCallback& rand_int_cb,
NetLog* net_log,
- const NetLog::Source& source) OVERRIDE {
+ const NetLog::Source& source) override {
return scoped_ptr<DatagramClientSocket>(
new UDPClientSocket(bind_type, rand_int_cb, net_log, source));
}
- virtual scoped_ptr<StreamSocket> CreateTransportClientSocket(
+ scoped_ptr<StreamSocket> CreateTransportClientSocket(
const AddressList& addresses,
NetLog* net_log,
- const NetLog::Source& source) OVERRIDE {
+ const NetLog::Source& source) override {
return scoped_ptr<StreamSocket>(
new TCPClientSocket(addresses, net_log, source));
}
- virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
+ scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
scoped_ptr<ClientSocketHandle> transport_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
- const SSLClientSocketContext& context) OVERRIDE {
+ const SSLClientSocketContext& context) override {
// nss_thread_task_runner_ may be NULL if g_use_dedicated_nss_thread is
// false or if the dedicated NSS thread failed to start. If so, cause NSS
// functions to execute on the current task runner.
@@ -120,9 +120,7 @@ class DefaultClientSocketFactory : public ClientSocketFactory,
#endif
}
- virtual void ClearSSLSessionCache() OVERRIDE {
- SSLClientSocket::ClearSessionCache();
- }
+ void ClearSSLSessionCache() override { SSLClientSocket::ClearSessionCache(); }
private:
scoped_refptr<base::SequencedWorkerPool> worker_pool_;
diff --git a/chromium/net/socket/client_socket_pool.cc b/chromium/net/socket/client_socket_pool.cc
index 0eebd11b9fb..261e87fc7fe 100644
--- a/chromium/net/socket/client_socket_pool.cc
+++ b/chromium/net/socket/client_socket_pool.cc
@@ -12,10 +12,10 @@ namespace {
// alive.
// TODO(ziadh): Change this timeout after getting histogram data on how long it
// should be.
-int g_unused_idle_socket_timeout_s = 10;
+int64 g_unused_idle_socket_timeout_s = 10;
// The maximum duration, in seconds, to keep used idle persistent sockets alive.
-int g_used_idle_socket_timeout_s = 300; // 5 minutes
+int64 g_used_idle_socket_timeout_s = 300; // 5 minutes
} // namespace
diff --git a/chromium/net/socket/client_socket_pool.h b/chromium/net/socket/client_socket_pool.h
index 715cddb94d4..2a2be36c8cd 100644
--- a/chromium/net/socket/client_socket_pool.h
+++ b/chromium/net/socket/client_socket_pool.h
@@ -182,7 +182,7 @@ class NET_EXPORT ClientSocketPool : public LowerLayeredPool {
protected:
ClientSocketPool();
- virtual ~ClientSocketPool();
+ ~ClientSocketPool() override;
// Return the connection timeout for this pool.
virtual base::TimeDelta ConnectionTimeout() const = 0;
diff --git a/chromium/net/socket/client_socket_pool_base.h b/chromium/net/socket/client_socket_pool_base.h
index 8079cd4e5c7..ec4e33cc0ed 100644
--- a/chromium/net/socket/client_socket_pool_base.h
+++ b/chromium/net/socket/client_socket_pool_base.h
@@ -219,7 +219,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper
base::TimeDelta used_idle_socket_timeout,
ConnectJobFactory* connect_job_factory);
- virtual ~ClientSocketPoolBaseHelper();
+ ~ClientSocketPoolBaseHelper() override;
// Adds a lower layered pool to |this|, and adds |this| as a higher layered
// pool on top of |lower_pool|.
@@ -327,10 +327,10 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper
void EnableConnectBackupJobs();
// ConnectJob::Delegate methods:
- virtual void OnConnectJobComplete(int result, ConnectJob* job) OVERRIDE;
+ void OnConnectJobComplete(int result, ConnectJob* job) override;
// NetworkChangeNotifier::IPAddressObserver methods:
- virtual void OnIPAddressChanged() OVERRIDE;
+ void OnIPAddressChanged() override;
private:
friend class base::RefCounted<ClientSocketPoolBaseHelper>;
@@ -741,10 +741,7 @@ class ClientSocketPoolBase {
internal::ClientSocketPoolBaseHelper::NORMAL,
params->ignore_limits(),
params, net_log));
- return helper_.RequestSocket(
- group_name,
- request.template PassAs<
- const internal::ClientSocketPoolBaseHelper::Request>());
+ return helper_.RequestSocket(group_name, request.Pass());
}
// RequestSockets bundles up the parameters into a Request and then forwards
@@ -856,7 +853,7 @@ class ClientSocketPoolBase {
virtual scoped_ptr<ConnectJob> NewConnectJob(
const std::string& group_name,
const internal::ClientSocketPoolBaseHelper::Request& request,
- ConnectJob::Delegate* delegate) const OVERRIDE {
+ ConnectJob::Delegate* delegate) const override {
const Request& casted_request = static_cast<const Request&>(request);
return connect_job_factory_->NewConnectJob(
group_name, casted_request, delegate);
diff --git a/chromium/net/socket/client_socket_pool_base_unittest.cc b/chromium/net/socket/client_socket_pool_base_unittest.cc
index 5a672b44a4e..c4a28459a1e 100644
--- a/chromium/net/socket/client_socket_pool_base_unittest.cc
+++ b/chromium/net/socket/client_socket_pool_base_unittest.cc
@@ -128,9 +128,9 @@ class MockClientSocket : public StreamSocket {
}
// Socket implementation.
- virtual int Read(
- IOBuffer* /* buf */, int len,
- const CompletionCallback& /* callback */) OVERRIDE {
+ int Read(IOBuffer* /* buf */,
+ int len,
+ const CompletionCallback& /* callback */) override {
if (has_unread_data_ && len > 0) {
has_unread_data_ = false;
was_used_to_convey_data_ = true;
@@ -139,54 +139,44 @@ class MockClientSocket : public StreamSocket {
return ERR_UNEXPECTED;
}
- virtual int Write(
- IOBuffer* /* buf */, int len,
- const CompletionCallback& /* callback */) OVERRIDE {
+ int Write(IOBuffer* /* buf */,
+ int len,
+ const CompletionCallback& /* callback */) override {
was_used_to_convey_data_ = true;
return len;
}
- virtual int SetReceiveBufferSize(int32 size) OVERRIDE { return OK; }
- virtual int SetSendBufferSize(int32 size) OVERRIDE { return OK; }
+ int SetReceiveBufferSize(int32 size) override { return OK; }
+ int SetSendBufferSize(int32 size) override { return OK; }
// StreamSocket implementation.
- virtual int Connect(const CompletionCallback& callback) OVERRIDE {
+ int Connect(const CompletionCallback& callback) override {
connected_ = true;
return OK;
}
- virtual void Disconnect() OVERRIDE { connected_ = false; }
- virtual bool IsConnected() const OVERRIDE { return connected_; }
- virtual bool IsConnectedAndIdle() const OVERRIDE {
+ void Disconnect() override { connected_ = false; }
+ bool IsConnected() const override { return connected_; }
+ bool IsConnectedAndIdle() const override {
return connected_ && !has_unread_data_;
}
- virtual int GetPeerAddress(IPEndPoint* /* address */) const OVERRIDE {
+ int GetPeerAddress(IPEndPoint* /* address */) const override {
return ERR_UNEXPECTED;
}
- virtual int GetLocalAddress(IPEndPoint* /* address */) const OVERRIDE {
+ int GetLocalAddress(IPEndPoint* /* address */) const override {
return ERR_UNEXPECTED;
}
- virtual const BoundNetLog& NetLog() const OVERRIDE {
- return net_log_;
- }
+ const BoundNetLog& NetLog() const override { return net_log_; }
- virtual void SetSubresourceSpeculation() OVERRIDE {}
- virtual void SetOmniboxSpeculation() OVERRIDE {}
- virtual bool WasEverUsed() const OVERRIDE {
- return was_used_to_convey_data_;
- }
- virtual bool UsingTCPFastOpen() const OVERRIDE { return false; }
- virtual bool WasNpnNegotiated() const OVERRIDE {
- return false;
- }
- virtual NextProto GetNegotiatedProtocol() const OVERRIDE {
- return kProtoUnknown;
- }
- virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE {
- return false;
- }
+ void SetSubresourceSpeculation() override {}
+ void SetOmniboxSpeculation() override {}
+ bool WasEverUsed() const override { return was_used_to_convey_data_; }
+ bool UsingTCPFastOpen() const override { return false; }
+ bool WasNpnNegotiated() const override { return false; }
+ NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
+ bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
private:
bool connected_;
@@ -203,35 +193,33 @@ class MockClientSocketFactory : public ClientSocketFactory {
public:
MockClientSocketFactory() : allocation_count_(0) {}
- virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket(
+ scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket(
DatagramSocket::BindType bind_type,
const RandIntCallback& rand_int_cb,
NetLog* net_log,
- const NetLog::Source& source) OVERRIDE {
+ const NetLog::Source& source) override {
NOTREACHED();
return scoped_ptr<DatagramClientSocket>();
}
- virtual scoped_ptr<StreamSocket> CreateTransportClientSocket(
+ scoped_ptr<StreamSocket> CreateTransportClientSocket(
const AddressList& addresses,
NetLog* /* net_log */,
- const NetLog::Source& /*source*/) OVERRIDE {
+ const NetLog::Source& /*source*/) override {
allocation_count_++;
return scoped_ptr<StreamSocket>();
}
- virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
+ scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
scoped_ptr<ClientSocketHandle> transport_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
- const SSLClientSocketContext& context) OVERRIDE {
+ const SSLClientSocketContext& context) override {
NOTIMPLEMENTED();
return scoped_ptr<SSLClientSocket>();
}
- virtual void ClearSSLSessionCache() OVERRIDE {
- NOTIMPLEMENTED();
- }
+ void ClearSSLSessionCache() override { NOTIMPLEMENTED(); }
void WaitForSignal(TestConnectJob* job) { waiting_jobs_.push_back(job); }
@@ -291,9 +279,9 @@ class TestConnectJob : public ConnectJob {
// From ConnectJob:
- virtual LoadState GetLoadState() const OVERRIDE { return load_state_; }
+ LoadState GetLoadState() const override { return load_state_; }
- virtual void GetAdditionalErrorState(ClientSocketHandle* handle) OVERRIDE {
+ void GetAdditionalErrorState(ClientSocketHandle* handle) override {
if (store_additional_error_state_) {
// Set all of the additional error state fields in some way.
handle->set_is_ssl_error(true);
@@ -306,7 +294,7 @@ class TestConnectJob : public ConnectJob {
private:
// From ConnectJob:
- virtual int ConnectInternal() OVERRIDE {
+ int ConnectInternal() override {
AddressList ignored;
client_socket_factory_->CreateTransportClientSocket(
ignored, NULL, net::NetLog::Source());
@@ -439,7 +427,7 @@ class TestConnectJobFactory
net_log_(net_log) {
}
- virtual ~TestConnectJobFactory() {}
+ ~TestConnectJobFactory() override {}
void set_job_type(TestConnectJob::JobType job_type) { job_type_ = job_type; }
@@ -454,10 +442,10 @@ class TestConnectJobFactory
// ConnectJobFactory implementation.
- virtual scoped_ptr<ConnectJob> NewConnectJob(
+ scoped_ptr<ConnectJob> NewConnectJob(
const std::string& group_name,
const TestClientSocketPoolBase::Request& request,
- ConnectJob::Delegate* delegate) const OVERRIDE {
+ ConnectJob::Delegate* delegate) const override {
EXPECT_TRUE(!job_types_ || !job_types_->empty());
TestConnectJob::JobType job_type = job_type_;
if (job_types_ && !job_types_->empty()) {
@@ -473,7 +461,7 @@ class TestConnectJobFactory
net_log_));
}
- virtual base::TimeDelta ConnectionTimeout() const OVERRIDE {
+ base::TimeDelta ConnectionTimeout() const override {
return timeout_duration_;
}
@@ -502,92 +490,78 @@ class TestClientSocketPool : public ClientSocketPool {
unused_idle_socket_timeout, used_idle_socket_timeout,
connect_job_factory) {}
- virtual ~TestClientSocketPool() {}
+ ~TestClientSocketPool() override {}
- virtual int RequestSocket(
- const std::string& group_name,
- const void* params,
- net::RequestPriority priority,
- ClientSocketHandle* handle,
- const CompletionCallback& callback,
- const BoundNetLog& net_log) OVERRIDE {
+ int RequestSocket(const std::string& group_name,
+ const void* params,
+ net::RequestPriority priority,
+ ClientSocketHandle* handle,
+ const CompletionCallback& callback,
+ const BoundNetLog& net_log) override {
const scoped_refptr<TestSocketParams>* casted_socket_params =
static_cast<const scoped_refptr<TestSocketParams>*>(params);
return base_.RequestSocket(group_name, *casted_socket_params, priority,
handle, callback, net_log);
}
- virtual void RequestSockets(const std::string& group_name,
- const void* params,
- int num_sockets,
- const BoundNetLog& net_log) OVERRIDE {
+ void RequestSockets(const std::string& group_name,
+ const void* params,
+ int num_sockets,
+ const BoundNetLog& net_log) override {
const scoped_refptr<TestSocketParams>* casted_params =
static_cast<const scoped_refptr<TestSocketParams>*>(params);
base_.RequestSockets(group_name, *casted_params, num_sockets, net_log);
}
- virtual void CancelRequest(
- const std::string& group_name,
- ClientSocketHandle* handle) OVERRIDE {
+ void CancelRequest(const std::string& group_name,
+ ClientSocketHandle* handle) override {
base_.CancelRequest(group_name, handle);
}
- virtual void ReleaseSocket(
- const std::string& group_name,
- scoped_ptr<StreamSocket> socket,
- int id) OVERRIDE {
+ void ReleaseSocket(const std::string& group_name,
+ scoped_ptr<StreamSocket> socket,
+ int id) override {
base_.ReleaseSocket(group_name, socket.Pass(), id);
}
- virtual void FlushWithError(int error) OVERRIDE {
- base_.FlushWithError(error);
- }
+ void FlushWithError(int error) override { base_.FlushWithError(error); }
- virtual bool IsStalled() const OVERRIDE {
- return base_.IsStalled();
- }
+ bool IsStalled() const override { return base_.IsStalled(); }
- virtual void CloseIdleSockets() OVERRIDE {
- base_.CloseIdleSockets();
- }
+ void CloseIdleSockets() override { base_.CloseIdleSockets(); }
- virtual int IdleSocketCount() const OVERRIDE {
- return base_.idle_socket_count();
- }
+ int IdleSocketCount() const override { return base_.idle_socket_count(); }
- virtual int IdleSocketCountInGroup(
- const std::string& group_name) const OVERRIDE {
+ int IdleSocketCountInGroup(const std::string& group_name) const override {
return base_.IdleSocketCountInGroup(group_name);
}
- virtual LoadState GetLoadState(
- const std::string& group_name,
- const ClientSocketHandle* handle) const OVERRIDE {
+ LoadState GetLoadState(const std::string& group_name,
+ const ClientSocketHandle* handle) const override {
return base_.GetLoadState(group_name, handle);
}
- virtual void AddHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE {
+ void AddHigherLayeredPool(HigherLayeredPool* higher_pool) override {
base_.AddHigherLayeredPool(higher_pool);
}
- virtual void RemoveHigherLayeredPool(
- HigherLayeredPool* higher_pool) OVERRIDE {
+ void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool) override {
base_.RemoveHigherLayeredPool(higher_pool);
}
- virtual base::DictionaryValue* GetInfoAsValue(
+ base::DictionaryValue* GetInfoAsValue(
const std::string& name,
const std::string& type,
- bool include_nested_pools) const OVERRIDE {
+ bool include_nested_pools) const override {
return base_.GetInfoAsValue(name, type);
}
- virtual base::TimeDelta ConnectionTimeout() const OVERRIDE {
+ base::TimeDelta ConnectionTimeout() const override {
return base_.ConnectionTimeout();
}
- virtual ClientSocketPoolHistograms* histograms() const OVERRIDE {
+ ClientSocketPoolHistograms* histograms() const override {
return base_.histograms();
}
@@ -651,9 +625,9 @@ class TestConnectJobDelegate : public ConnectJob::Delegate {
public:
TestConnectJobDelegate()
: have_result_(false), waiting_for_result_(false), result_(OK) {}
- virtual ~TestConnectJobDelegate() {}
+ ~TestConnectJobDelegate() override {}
- virtual void OnConnectJobComplete(int result, ConnectJob* job) OVERRIDE {
+ void OnConnectJobComplete(int result, ConnectJob* job) override {
result_ = result;
scoped_ptr<ConnectJob> owned_job(job);
scoped_ptr<StreamSocket> socket = owned_job->PassSocket();
@@ -693,7 +667,7 @@ class ClientSocketPoolBaseTest : public testing::Test {
internal::ClientSocketPoolBaseHelper::cleanup_timer_enabled();
}
- virtual ~ClientSocketPoolBaseTest() {
+ ~ClientSocketPoolBaseTest() override {
internal::ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(
connect_backup_jobs_enabled_);
internal::ClientSocketPoolBaseHelper::set_cleanup_timer_enabled(
@@ -1473,7 +1447,7 @@ class RequestSocketCallback : public TestCompletionCallbackBase {
base::Unretained(this))) {
}
- virtual ~RequestSocketCallback() {}
+ ~RequestSocketCallback() override {}
const CompletionCallback& callback() const { return callback_; }
@@ -2591,7 +2565,7 @@ class TestReleasingSocketRequest : public TestCompletionCallbackBase {
base::Unretained(this))) {
}
- virtual ~TestReleasingSocketRequest() {}
+ ~TestReleasingSocketRequest() override {}
ClientSocketHandle* handle() { return &handle_; }
@@ -2716,7 +2690,7 @@ class ConnectWithinCallback : public TestCompletionCallbackBase {
base::Unretained(this))) {
}
- virtual ~ConnectWithinCallback() {}
+ ~ConnectWithinCallback() override {}
int WaitForNestedResult() {
return nested_callback_.WaitForResult();
diff --git a/chromium/net/socket/client_socket_pool_manager.cc b/chromium/net/socket/client_socket_pool_manager.cc
index f81b4bc1db5..b99612718e4 100644
--- a/chromium/net/socket/client_socket_pool_manager.cc
+++ b/chromium/net/socket/client_socket_pool_manager.cc
@@ -84,7 +84,6 @@ int InitSocketPoolHelper(const GURL& request_url,
HttpNetworkSession::SocketPoolType socket_pool_type,
const OnHostResolutionCallback& resolution_callback,
const CompletionCallback& callback) {
- scoped_refptr<TransportSocketParams> tcp_params;
scoped_refptr<HttpProxySocketParams> http_proxy_params;
scoped_refptr<SOCKSSocketParams> socks_params;
scoped_ptr<HostPortPair> proxy_host_port;
@@ -132,7 +131,7 @@ int InitSocketPoolHelper(const GURL& request_url,
// should be the same for all connections, whereas version_max may
// change for version fallbacks.
std::string prefix = "ssl/";
- if (ssl_config_for_origin.version_max != net::kDefaultSSLVersionMax) {
+ if (ssl_config_for_origin.version_max != kDefaultSSLVersionMax) {
switch (ssl_config_for_origin.version_max) {
case SSL_PROTOCOL_VERSION_TLS1_2:
prefix = "ssl(max:3.3)/";
@@ -155,19 +154,16 @@ int InitSocketPoolHelper(const GURL& request_url,
}
bool ignore_limits = (request_load_flags & LOAD_IGNORE_LIMITS) != 0;
- if (proxy_info.is_direct()) {
- tcp_params = new TransportSocketParams(origin_host_port,
- disable_resolver_cache,
- ignore_limits,
- resolution_callback);
- } else {
+ if (!proxy_info.is_direct()) {
ProxyServer proxy_server = proxy_info.proxy_server();
proxy_host_port.reset(new HostPortPair(proxy_server.host_port_pair()));
scoped_refptr<TransportSocketParams> proxy_tcp_params(
- new TransportSocketParams(*proxy_host_port,
- disable_resolver_cache,
- ignore_limits,
- resolution_callback));
+ new TransportSocketParams(
+ *proxy_host_port,
+ disable_resolver_cache,
+ ignore_limits,
+ resolution_callback,
+ TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT));
if (proxy_info.is_http() || proxy_info.is_https()) {
std::string user_agent;
@@ -175,6 +171,18 @@ int InitSocketPoolHelper(const GURL& request_url,
&user_agent);
scoped_refptr<SSLSocketParams> ssl_params;
if (proxy_info.is_https()) {
+ // Combine connect and write for SSL sockets in TCP FastOpen
+ // field trial.
+ TransportSocketParams::CombineConnectAndWritePolicy
+ combine_connect_and_write =
+ session->params().enable_tcp_fast_open_for_ssl ?
+ TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DESIRED :
+ TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT;
+ proxy_tcp_params = new TransportSocketParams(*proxy_host_port,
+ disable_resolver_cache,
+ ignore_limits,
+ resolution_callback,
+ combine_connect_and_write);
// Set ssl_params, and unset proxy_tcp_params
ssl_params = new SSLSocketParams(proxy_tcp_params,
NULL,
@@ -197,7 +205,8 @@ int InitSocketPoolHelper(const GURL& request_url,
session->http_auth_cache(),
session->http_auth_handler_factory(),
session->spdy_session_pool(),
- force_tunnel || using_ssl);
+ force_tunnel || using_ssl,
+ session->params().proxy_delegate);
} else {
DCHECK(proxy_info.is_socks());
char socks_version;
@@ -220,8 +229,23 @@ int InitSocketPoolHelper(const GURL& request_url,
// Deal with SSL - which layers on top of any given proxy.
if (using_ssl) {
+ scoped_refptr<TransportSocketParams> ssl_tcp_params;
+ if (proxy_info.is_direct()) {
+ // Setup TCP params if non-proxied SSL connection.
+ // Combine connect and write for SSL sockets in TCP FastOpen field trial.
+ TransportSocketParams::CombineConnectAndWritePolicy
+ combine_connect_and_write =
+ session->params().enable_tcp_fast_open_for_ssl ?
+ TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DESIRED :
+ TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT;
+ ssl_tcp_params = new TransportSocketParams(origin_host_port,
+ disable_resolver_cache,
+ ignore_limits,
+ resolution_callback,
+ combine_connect_and_write);
+ }
scoped_refptr<SSLSocketParams> ssl_params =
- new SSLSocketParams(tcp_params,
+ new SSLSocketParams(ssl_tcp_params,
socks_params,
http_proxy_params,
origin_host_port,
@@ -280,7 +304,13 @@ int InitSocketPoolHelper(const GURL& request_url,
}
DCHECK(proxy_info.is_direct());
-
+ scoped_refptr<TransportSocketParams> tcp_params =
+ new TransportSocketParams(
+ origin_host_port,
+ disable_resolver_cache,
+ ignore_limits,
+ resolution_callback,
+ TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT);
TransportClientSocketPool* pool =
session->GetTransportSocketPool(socket_pool_type);
if (num_preconnect_streams) {
diff --git a/chromium/net/socket/client_socket_pool_manager_impl.cc b/chromium/net/socket/client_socket_pool_manager_impl.cc
index 991278d7341..5ed31fc9a25 100644
--- a/chromium/net/socket/client_socket_pool_manager_impl.cc
+++ b/chromium/net/socket/client_socket_pool_manager_impl.cc
@@ -11,6 +11,7 @@
#include "net/socket/socks_client_socket_pool.h"
#include "net/socket/ssl_client_socket_pool.h"
#include "net/socket/transport_client_socket_pool.h"
+#include "net/socket/websocket_transport_client_socket_pool.h"
#include "net/ssl/ssl_config_service.h"
namespace net {
@@ -38,54 +39,68 @@ ClientSocketPoolManagerImpl::ClientSocketPoolManagerImpl(
ClientSocketFactory* socket_factory,
HostResolver* host_resolver,
CertVerifier* cert_verifier,
- ServerBoundCertService* server_bound_cert_service,
+ ChannelIDService* channel_id_service,
TransportSecurityState* transport_security_state,
CTVerifier* cert_transparency_verifier,
const std::string& ssl_session_cache_shard,
ProxyService* proxy_service,
SSLConfigService* ssl_config_service,
+ bool enable_ssl_connect_job_waiting,
+ ProxyDelegate* proxy_delegate,
HttpNetworkSession::SocketPoolType pool_type)
: net_log_(net_log),
socket_factory_(socket_factory),
host_resolver_(host_resolver),
cert_verifier_(cert_verifier),
- server_bound_cert_service_(server_bound_cert_service),
+ channel_id_service_(channel_id_service),
transport_security_state_(transport_security_state),
cert_transparency_verifier_(cert_transparency_verifier),
ssl_session_cache_shard_(ssl_session_cache_shard),
proxy_service_(proxy_service),
ssl_config_service_(ssl_config_service),
+ enable_ssl_connect_job_waiting_(enable_ssl_connect_job_waiting),
pool_type_(pool_type),
transport_pool_histograms_("TCP"),
- transport_socket_pool_(new TransportClientSocketPool(
- max_sockets_per_pool(pool_type), max_sockets_per_group(pool_type),
- &transport_pool_histograms_,
- host_resolver,
- socket_factory_,
- net_log)),
+ transport_socket_pool_(
+ pool_type == HttpNetworkSession::WEBSOCKET_SOCKET_POOL
+ ? new WebSocketTransportClientSocketPool(
+ max_sockets_per_pool(pool_type),
+ max_sockets_per_group(pool_type),
+ &transport_pool_histograms_,
+ host_resolver,
+ socket_factory_,
+ net_log)
+ : new TransportClientSocketPool(max_sockets_per_pool(pool_type),
+ max_sockets_per_group(pool_type),
+ &transport_pool_histograms_,
+ host_resolver,
+ socket_factory_,
+ net_log)),
ssl_pool_histograms_("SSL2"),
- ssl_socket_pool_(new SSLClientSocketPool(
- max_sockets_per_pool(pool_type), max_sockets_per_group(pool_type),
- &ssl_pool_histograms_,
- host_resolver,
- cert_verifier,
- server_bound_cert_service,
- transport_security_state,
- cert_transparency_verifier,
- ssl_session_cache_shard,
- socket_factory,
- transport_socket_pool_.get(),
- NULL /* no socks proxy */,
- NULL /* no http proxy */,
- ssl_config_service,
- net_log)),
+ ssl_socket_pool_(new SSLClientSocketPool(max_sockets_per_pool(pool_type),
+ max_sockets_per_group(pool_type),
+ &ssl_pool_histograms_,
+ host_resolver,
+ cert_verifier,
+ channel_id_service,
+ transport_security_state,
+ cert_transparency_verifier,
+ ssl_session_cache_shard,
+ socket_factory,
+ transport_socket_pool_.get(),
+ NULL /* no socks proxy */,
+ NULL /* no http proxy */,
+ ssl_config_service,
+ enable_ssl_connect_job_waiting,
+ net_log)),
transport_for_socks_pool_histograms_("TCPforSOCKS"),
socks_pool_histograms_("SOCK"),
transport_for_http_proxy_pool_histograms_("TCPforHTTPProxy"),
transport_for_https_proxy_pool_histograms_("TCPforHTTPSProxy"),
ssl_for_https_proxy_pool_histograms_("SSLforHTTPSProxy"),
http_proxy_pool_histograms_("HTTPProxy"),
- ssl_socket_pool_for_proxies_histograms_("SSLForProxies") {
+ ssl_socket_pool_for_proxies_histograms_("SSLForProxies"),
+ proxy_delegate_(proxy_delegate) {
CertDatabase::GetInstance()->AddObserver(this);
}
@@ -287,7 +302,7 @@ ClientSocketPoolManagerImpl::GetSocketPoolForHTTPProxy(
&ssl_for_https_proxy_pool_histograms_,
host_resolver_,
cert_verifier_,
- server_bound_cert_service_,
+ channel_id_service_,
transport_security_state_,
cert_transparency_verifier_,
ssl_session_cache_shard_,
@@ -296,6 +311,7 @@ ClientSocketPoolManagerImpl::GetSocketPoolForHTTPProxy(
NULL /* no socks proxy */,
NULL /* no http proxy */,
ssl_config_service_.get(),
+ enable_ssl_connect_job_waiting_,
net_log_)));
DCHECK(tcp_https_ret.second);
@@ -310,6 +326,7 @@ ClientSocketPoolManagerImpl::GetSocketPoolForHTTPProxy(
host_resolver_,
tcp_http_ret.first->second,
ssl_https_ret.first->second,
+ proxy_delegate_,
net_log_)));
return ret.first->second;
@@ -328,7 +345,7 @@ SSLClientSocketPool* ClientSocketPoolManagerImpl::GetSocketPoolForSSLWithProxy(
&ssl_pool_histograms_,
host_resolver_,
cert_verifier_,
- server_bound_cert_service_,
+ channel_id_service_,
transport_security_state_,
cert_transparency_verifier_,
ssl_session_cache_shard_,
@@ -337,6 +354,7 @@ SSLClientSocketPool* ClientSocketPoolManagerImpl::GetSocketPoolForSSLWithProxy(
GetSocketPoolForSOCKSProxy(proxy_server),
GetSocketPoolForHTTPProxy(proxy_server),
ssl_config_service_.get(),
+ enable_ssl_connect_job_waiting_,
net_log_);
std::pair<SSLSocketPoolMap::iterator, bool> ret =
diff --git a/chromium/net/socket/client_socket_pool_manager_impl.h b/chromium/net/socket/client_socket_pool_manager_impl.h
index 06d4d244a36..f9f8d3b8f15 100644
--- a/chromium/net/socket/client_socket_pool_manager_impl.h
+++ b/chromium/net/socket/client_socket_pool_manager_impl.h
@@ -21,13 +21,14 @@
namespace net {
class CertVerifier;
+class ChannelIDService;
class ClientSocketFactory;
class ClientSocketPoolHistograms;
class CTVerifier;
class HttpProxyClientSocketPool;
class HostResolver;
class NetLog;
-class ServerBoundCertService;
+class ProxyDelegate;
class ProxyService;
class SOCKSClientSocketPool;
class SSLClientSocketPool;
@@ -61,38 +62,40 @@ class ClientSocketPoolManagerImpl : public base::NonThreadSafe,
ClientSocketFactory* socket_factory,
HostResolver* host_resolver,
CertVerifier* cert_verifier,
- ServerBoundCertService* server_bound_cert_service,
+ ChannelIDService* channel_id_service,
TransportSecurityState* transport_security_state,
CTVerifier* cert_transparency_verifier,
const std::string& ssl_session_cache_shard,
ProxyService* proxy_service,
SSLConfigService* ssl_config_service,
+ bool enable_ssl_connect_job_waiting,
+ ProxyDelegate* proxy_delegate,
HttpNetworkSession::SocketPoolType pool_type);
- virtual ~ClientSocketPoolManagerImpl();
+ ~ClientSocketPoolManagerImpl() override;
- virtual void FlushSocketPoolsWithError(int error) OVERRIDE;
- virtual void CloseIdleSockets() OVERRIDE;
+ void FlushSocketPoolsWithError(int error) override;
+ void CloseIdleSockets() override;
- virtual TransportClientSocketPool* GetTransportSocketPool() OVERRIDE;
+ TransportClientSocketPool* GetTransportSocketPool() override;
- virtual SSLClientSocketPool* GetSSLSocketPool() OVERRIDE;
+ SSLClientSocketPool* GetSSLSocketPool() override;
- virtual SOCKSClientSocketPool* GetSocketPoolForSOCKSProxy(
- const HostPortPair& socks_proxy) OVERRIDE;
+ SOCKSClientSocketPool* GetSocketPoolForSOCKSProxy(
+ const HostPortPair& socks_proxy) override;
- virtual HttpProxyClientSocketPool* GetSocketPoolForHTTPProxy(
- const HostPortPair& http_proxy) OVERRIDE;
+ HttpProxyClientSocketPool* GetSocketPoolForHTTPProxy(
+ const HostPortPair& http_proxy) override;
- virtual SSLClientSocketPool* GetSocketPoolForSSLWithProxy(
- const HostPortPair& proxy_server) OVERRIDE;
+ SSLClientSocketPool* GetSocketPoolForSSLWithProxy(
+ const HostPortPair& proxy_server) override;
// Creates a Value summary of the state of the socket pools. The caller is
// responsible for deleting the returned value.
- virtual base::Value* SocketPoolInfoToValue() const OVERRIDE;
+ base::Value* SocketPoolInfoToValue() const override;
// CertDatabase::Observer methods:
- virtual void OnCertAdded(const X509Certificate* cert) OVERRIDE;
- virtual void OnCACertChanged(const X509Certificate* cert) OVERRIDE;
+ void OnCertAdded(const X509Certificate* cert) override;
+ void OnCACertChanged(const X509Certificate* cert) override;
private:
typedef internal::OwnedPoolMap<HostPortPair, TransportClientSocketPool*>
@@ -108,12 +111,13 @@ class ClientSocketPoolManagerImpl : public base::NonThreadSafe,
ClientSocketFactory* const socket_factory_;
HostResolver* const host_resolver_;
CertVerifier* const cert_verifier_;
- ServerBoundCertService* const server_bound_cert_service_;
+ ChannelIDService* const channel_id_service_;
TransportSecurityState* const transport_security_state_;
CTVerifier* const cert_transparency_verifier_;
const std::string ssl_session_cache_shard_;
ProxyService* const proxy_service_;
const scoped_refptr<SSLConfigService> ssl_config_service_;
+ bool enable_ssl_connect_job_waiting_;
const HttpNetworkSession::SocketPoolType pool_type_;
// Note: this ordering is important.
@@ -145,6 +149,8 @@ class ClientSocketPoolManagerImpl : public base::NonThreadSafe,
ClientSocketPoolHistograms ssl_socket_pool_for_proxies_histograms_;
SSLSocketPoolMap ssl_socket_pools_for_proxies_;
+ const ProxyDelegate* proxy_delegate_;
+
DISALLOW_COPY_AND_ASSIGN(ClientSocketPoolManagerImpl);
};
diff --git a/chromium/net/socket/deterministic_socket_data_unittest.cc b/chromium/net/socket/deterministic_socket_data_unittest.cc
index c51427e25a7..bdeba2bef55 100644
--- a/chromium/net/socket/deterministic_socket_data_unittest.cc
+++ b/chromium/net/socket/deterministic_socket_data_unittest.cc
@@ -16,7 +16,7 @@ namespace {
static const char kMsg1[] = "\0hello!\xff";
static const int kLen1 = arraysize(kMsg1);
-static const char kMsg2[] = "\012345678\0";
+static const char kMsg2[] = "\0a2345678\0";
static const int kLen2 = arraysize(kMsg2);
static const char kMsg3[] = "bye!";
static const int kLen3 = arraysize(kMsg3);
@@ -29,7 +29,7 @@ class DeterministicSocketDataTest : public PlatformTest {
public:
DeterministicSocketDataTest();
- virtual void TearDown();
+ void TearDown() override;
void ReentrantReadCallback(int len, int rv);
void ReentrantWriteCallback(const char* data, int len, int rv);
@@ -71,10 +71,12 @@ DeterministicSocketDataTest::DeterministicSocketDataTest()
read_buf_(NULL),
connect_data_(SYNCHRONOUS, OK),
endpoint_("www.google.com", 443),
- tcp_params_(new TransportSocketParams(endpoint_,
- false,
- false,
- OnHostResolutionCallback())),
+ tcp_params_(new TransportSocketParams(
+ endpoint_,
+ false,
+ false,
+ OnHostResolutionCallback(),
+ TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)),
histograms_(std::string()),
socket_pool_(10, 10, &histograms_, &socket_factory_) {}
diff --git a/chromium/net/socket/mock_client_socket_pool_manager.h b/chromium/net/socket/mock_client_socket_pool_manager.h
index c2c3792a4f6..76d7704316a 100644
--- a/chromium/net/socket/mock_client_socket_pool_manager.h
+++ b/chromium/net/socket/mock_client_socket_pool_manager.h
@@ -14,7 +14,7 @@ namespace net {
class MockClientSocketPoolManager : public ClientSocketPoolManager {
public:
MockClientSocketPoolManager();
- virtual ~MockClientSocketPoolManager();
+ ~MockClientSocketPoolManager() override;
// Sets "override" socket pools that get used instead.
void SetTransportSocketPool(TransportClientSocketPool* pool);
@@ -27,17 +27,17 @@ class MockClientSocketPoolManager : public ClientSocketPoolManager {
SSLClientSocketPool* pool);
// ClientSocketPoolManager methods:
- virtual void FlushSocketPoolsWithError(int error) OVERRIDE;
- virtual void CloseIdleSockets() OVERRIDE;
- virtual TransportClientSocketPool* GetTransportSocketPool() OVERRIDE;
- virtual SSLClientSocketPool* GetSSLSocketPool() OVERRIDE;
- virtual SOCKSClientSocketPool* GetSocketPoolForSOCKSProxy(
- const HostPortPair& socks_proxy) OVERRIDE;
- virtual HttpProxyClientSocketPool* GetSocketPoolForHTTPProxy(
- const HostPortPair& http_proxy) OVERRIDE;
- virtual SSLClientSocketPool* GetSocketPoolForSSLWithProxy(
- const HostPortPair& proxy_server) OVERRIDE;
- virtual base::Value* SocketPoolInfoToValue() const OVERRIDE;
+ void FlushSocketPoolsWithError(int error) override;
+ void CloseIdleSockets() override;
+ TransportClientSocketPool* GetTransportSocketPool() override;
+ SSLClientSocketPool* GetSSLSocketPool() override;
+ SOCKSClientSocketPool* GetSocketPoolForSOCKSProxy(
+ const HostPortPair& socks_proxy) override;
+ HttpProxyClientSocketPool* GetSocketPoolForHTTPProxy(
+ const HostPortPair& http_proxy) override;
+ SSLClientSocketPool* GetSocketPoolForSSLWithProxy(
+ const HostPortPair& proxy_server) override;
+ base::Value* SocketPoolInfoToValue() const override;
private:
typedef internal::OwnedPoolMap<HostPortPair, TransportClientSocketPool*>
diff --git a/chromium/net/socket/next_proto.cc b/chromium/net/socket/next_proto.cc
index c9172365c38..1dcfb5d58f3 100644
--- a/chromium/net/socket/next_proto.cc
+++ b/chromium/net/socket/next_proto.cc
@@ -15,7 +15,6 @@ NextProtoVector NextProtosHttpOnly() {
NextProtoVector NextProtosDefaults() {
NextProtoVector next_protos;
next_protos.push_back(kProtoHTTP11);
- next_protos.push_back(kProtoSPDY3);
next_protos.push_back(kProtoSPDY31);
return next_protos;
}
@@ -27,35 +26,15 @@ NextProtoVector NextProtosWithSpdyAndQuic(bool spdy_enabled,
if (quic_enabled)
next_protos.push_back(kProtoQUIC1SPDY3);
if (spdy_enabled) {
- next_protos.push_back(kProtoSPDY3);
next_protos.push_back(kProtoSPDY31);
}
return next_protos;
}
-NextProtoVector NextProtosSpdy3() {
- NextProtoVector next_protos;
- next_protos.push_back(kProtoHTTP11);
- next_protos.push_back(kProtoQUIC1SPDY3);
- next_protos.push_back(kProtoSPDY3);
- return next_protos;
-}
-
NextProtoVector NextProtosSpdy31() {
NextProtoVector next_protos;
next_protos.push_back(kProtoHTTP11);
next_protos.push_back(kProtoQUIC1SPDY3);
- next_protos.push_back(kProtoSPDY3);
- next_protos.push_back(kProtoSPDY31);
- return next_protos;
-}
-
-NextProtoVector NextProtosSpdy31WithSpdy2() {
- NextProtoVector next_protos;
- next_protos.push_back(kProtoHTTP11);
- next_protos.push_back(kProtoQUIC1SPDY3);
- next_protos.push_back(kProtoDeprecatedSPDY2);
- next_protos.push_back(kProtoSPDY3);
next_protos.push_back(kProtoSPDY31);
return next_protos;
}
@@ -64,7 +43,6 @@ NextProtoVector NextProtosSpdy4Http2() {
NextProtoVector next_protos;
next_protos.push_back(kProtoHTTP11);
next_protos.push_back(kProtoQUIC1SPDY3);
- next_protos.push_back(kProtoSPDY3);
next_protos.push_back(kProtoSPDY31);
next_protos.push_back(kProtoSPDY4);
return next_protos;
diff --git a/chromium/net/socket/next_proto.h b/chromium/net/socket/next_proto.h
index 19ff55e0bc0..4df6e9b9cd5 100644
--- a/chromium/net/socket/next_proto.h
+++ b/chromium/net/socket/next_proto.h
@@ -14,20 +14,23 @@ namespace net {
// Next Protocol Negotiation (NPN), if successful, results in agreement on an
// application-level string that specifies the application level protocol to
// use over the TLS connection. NextProto enumerates the application level
-// protocols that we recognise.
+// protocols that we recognize. Do not change or reuse values, because they
+// are used to collect statistics on UMA. Also, values must be in [0,499),
+// because of the way TLS protocol negotiation extension information is added to
+// UMA histogram.
enum NextProto {
kProtoUnknown = 0,
- kProtoHTTP11,
+ kProtoHTTP11 = 1,
kProtoMinimumVersion = kProtoHTTP11,
- kProtoDeprecatedSPDY2,
+ kProtoDeprecatedSPDY2 = 100,
kProtoSPDYMinimumVersion = kProtoDeprecatedSPDY2,
- kProtoSPDY3,
- kProtoSPDY31,
- kProtoSPDY4, // SPDY4 is HTTP/2.
+ kProtoSPDY3 = 101,
+ kProtoSPDY31 = 102,
+ kProtoSPDY4 = 103, // SPDY4 is HTTP/2.
kProtoSPDYMaximumVersion = kProtoSPDY4,
- kProtoQUIC1SPDY3,
+ kProtoQUIC1SPDY3 = 200,
kProtoMaximumVersion = kProtoQUIC1SPDY3,
};
@@ -47,9 +50,7 @@ NET_EXPORT NextProtoVector NextProtosWithSpdyAndQuic(bool spdy_enabled,
bool quic_enabled);
// All of these also enable QUIC.
-NET_EXPORT NextProtoVector NextProtosSpdy3();
NET_EXPORT NextProtoVector NextProtosSpdy31();
-NET_EXPORT NextProtoVector NextProtosSpdy31WithSpdy2();
NET_EXPORT NextProtoVector NextProtosSpdy4Http2();
} // namespace net
diff --git a/chromium/net/socket/nss_ssl_util.cc b/chromium/net/socket/nss_ssl_util.cc
index 7b068545a55..a238a25d2d4 100644
--- a/chromium/net/socket/nss_ssl_util.cc
+++ b/chromium/net/socket/nss_ssl_util.cc
@@ -29,6 +29,8 @@
#include "base/win/windows_version.h"
#endif
+namespace net {
+
namespace {
// CiphersRemove takes a zero-terminated array of cipher suite ids in
@@ -77,9 +79,15 @@ size_t CiphersCopy(const uint16* in, uint16* out) {
}
}
-} // anonymous namespace
-
-namespace net {
+base::Value* NetLogSSLErrorCallback(int net_error,
+ int ssl_lib_error,
+ NetLog::LogLevel /* log_level */) {
+ base::DictionaryValue* dict = new base::DictionaryValue();
+ dict->SetInteger("net_error", net_error);
+ if (ssl_lib_error)
+ dict->SetInteger("ssl_lib_error", ssl_lib_error);
+ return dict;
+}
class NSSSSLInitSingleton {
public:
@@ -201,9 +209,11 @@ class NSSSSLInitSingleton {
PRFileDesc* model_fd_;
};
-static base::LazyInstance<NSSSSLInitSingleton>::Leaky g_nss_ssl_init_singleton =
+base::LazyInstance<NSSSSLInitSingleton>::Leaky g_nss_ssl_init_singleton =
LAZY_INSTANCE_INITIALIZER;
+} // anonymous namespace
+
// Initialize the NSS SSL library if it isn't already initialized. This must
// be called before any other NSS SSL functions. This function is
// thread-safe, and the NSS SSL library will only ever be initialized once.
@@ -399,4 +409,9 @@ void LogFailedNSSFunction(const BoundNetLog& net_log,
function, param, PR_GetError()));
}
+NetLog::ParametersCallback CreateNetLogSSLErrorCallback(int net_error,
+ int ssl_lib_error) {
+ return base::Bind(&NetLogSSLErrorCallback, net_error, ssl_lib_error);
+}
+
} // namespace net
diff --git a/chromium/net/socket/nss_ssl_util.h b/chromium/net/socket/nss_ssl_util.h
index 3aed7bf6b4a..7b046ffd282 100644
--- a/chromium/net/socket/nss_ssl_util.h
+++ b/chromium/net/socket/nss_ssl_util.h
@@ -12,6 +12,7 @@
#include <prio.h>
#include "net/base/net_export.h"
+#include "net/base/net_log.h"
namespace net {
@@ -35,6 +36,11 @@ PRFileDesc* GetNSSModelSocket();
// Map NSS error code to network error code.
int MapNSSError(PRErrorCode err);
+// Creates a NetLog callback for an SSL error.
+NetLog::ParametersCallback CreateNetLogSSLErrorCallback(int net_error,
+ int ssl_lib_error);
+
+
} // namespace net
#endif // NET_SOCKET_NSS_SSL_UTIL_H_
diff --git a/chromium/net/socket/openssl_ssl_util.cc b/chromium/net/socket/openssl_ssl_util.cc
deleted file mode 100644
index 36b8e6ca0e3..00000000000
--- a/chromium/net/socket/openssl_ssl_util.cc
+++ /dev/null
@@ -1,156 +0,0 @@
-// Copyright 2014 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#include "net/socket/openssl_ssl_util.h"
-
-#include <openssl/err.h>
-#include <openssl/ssl.h>
-
-#include "base/logging.h"
-#include "crypto/openssl_util.h"
-#include "net/base/net_errors.h"
-
-namespace net {
-
-SslSetClearMask::SslSetClearMask()
- : set_mask(0),
- clear_mask(0) {
-}
-
-void SslSetClearMask::ConfigureFlag(long flag, bool state) {
- (state ? set_mask : clear_mask) |= flag;
- // Make sure we haven't got any intersection in the set & clear options.
- DCHECK_EQ(0, set_mask & clear_mask) << flag << ":" << state;
-}
-
-namespace {
-
-int MapOpenSSLErrorSSL() {
- // Walk down the error stack to find the SSLerr generated reason.
- unsigned long error_code;
- do {
- error_code = ERR_get_error();
- if (error_code == 0)
- return ERR_SSL_PROTOCOL_ERROR;
- } while (ERR_GET_LIB(error_code) != ERR_LIB_SSL);
-
- DVLOG(1) << "OpenSSL SSL error, reason: " << ERR_GET_REASON(error_code)
- << ", name: " << ERR_error_string(error_code, NULL);
- switch (ERR_GET_REASON(error_code)) {
- case SSL_R_READ_TIMEOUT_EXPIRED:
- return ERR_TIMED_OUT;
- case SSL_R_BAD_RESPONSE_ARGUMENT:
- return ERR_INVALID_ARGUMENT;
- case SSL_R_UNKNOWN_CERTIFICATE_TYPE:
- case SSL_R_UNKNOWN_CIPHER_TYPE:
- case SSL_R_UNKNOWN_KEY_EXCHANGE_TYPE:
- case SSL_R_UNKNOWN_PKEY_TYPE:
- case SSL_R_UNKNOWN_REMOTE_ERROR_TYPE:
- case SSL_R_UNKNOWN_SSL_VERSION:
- return ERR_NOT_IMPLEMENTED;
- case SSL_R_UNSUPPORTED_SSL_VERSION:
- case SSL_R_NO_CIPHER_MATCH:
- case SSL_R_NO_SHARED_CIPHER:
- case SSL_R_TLSV1_ALERT_INSUFFICIENT_SECURITY:
- case SSL_R_TLSV1_ALERT_PROTOCOL_VERSION:
- case SSL_R_UNSUPPORTED_PROTOCOL:
- return ERR_SSL_VERSION_OR_CIPHER_MISMATCH;
- case SSL_R_SSLV3_ALERT_BAD_CERTIFICATE:
- case SSL_R_SSLV3_ALERT_UNSUPPORTED_CERTIFICATE:
- case SSL_R_SSLV3_ALERT_CERTIFICATE_REVOKED:
- case SSL_R_SSLV3_ALERT_CERTIFICATE_EXPIRED:
- case SSL_R_SSLV3_ALERT_CERTIFICATE_UNKNOWN:
- case SSL_R_TLSV1_ALERT_ACCESS_DENIED:
- case SSL_R_TLSV1_ALERT_UNKNOWN_CA:
- return ERR_BAD_SSL_CLIENT_AUTH_CERT;
- case SSL_R_BAD_DECOMPRESSION:
- case SSL_R_SSLV3_ALERT_DECOMPRESSION_FAILURE:
- return ERR_SSL_DECOMPRESSION_FAILURE_ALERT;
- case SSL_R_SSLV3_ALERT_BAD_RECORD_MAC:
- return ERR_SSL_BAD_RECORD_MAC_ALERT;
- case SSL_R_TLSV1_ALERT_DECRYPT_ERROR:
- return ERR_SSL_DECRYPT_ERROR_ALERT;
- case SSL_R_TLSV1_UNRECOGNIZED_NAME:
- return ERR_SSL_UNRECOGNIZED_NAME_ALERT;
- case SSL_R_UNSAFE_LEGACY_RENEGOTIATION_DISABLED:
- return ERR_SSL_UNSAFE_NEGOTIATION;
- case SSL_R_WRONG_NUMBER_OF_KEY_BITS:
- return ERR_SSL_WEAK_SERVER_EPHEMERAL_DH_KEY;
- // SSL_R_UNKNOWN_PROTOCOL is reported if premature application data is
- // received (see http://crbug.com/42538), and also if all the protocol
- // versions supported by the server were disabled in this socket instance.
- // Mapped to ERR_SSL_PROTOCOL_ERROR for compatibility with other SSL sockets
- // in the former scenario.
- case SSL_R_UNKNOWN_PROTOCOL:
- case SSL_R_SSL_HANDSHAKE_FAILURE:
- case SSL_R_DECRYPTION_FAILED:
- case SSL_R_DECRYPTION_FAILED_OR_BAD_RECORD_MAC:
- case SSL_R_DH_PUBLIC_VALUE_LENGTH_IS_WRONG:
- case SSL_R_DIGEST_CHECK_FAILED:
- case SSL_R_DUPLICATE_COMPRESSION_ID:
- case SSL_R_ECGROUP_TOO_LARGE_FOR_CIPHER:
- case SSL_R_ENCRYPTED_LENGTH_TOO_LONG:
- case SSL_R_ERROR_IN_RECEIVED_CIPHER_LIST:
- case SSL_R_EXCESSIVE_MESSAGE_SIZE:
- case SSL_R_EXTRA_DATA_IN_MESSAGE:
- case SSL_R_GOT_A_FIN_BEFORE_A_CCS:
- case SSL_R_ILLEGAL_PADDING:
- case SSL_R_INVALID_CHALLENGE_LENGTH:
- case SSL_R_INVALID_COMMAND:
- case SSL_R_INVALID_PURPOSE:
- case SSL_R_INVALID_STATUS_RESPONSE:
- case SSL_R_INVALID_TICKET_KEYS_LENGTH:
- case SSL_R_KEY_ARG_TOO_LONG:
- case SSL_R_READ_WRONG_PACKET_TYPE:
- // SSL_do_handshake reports this error when the server responds to a
- // ClientHello with a fatal close_notify alert.
- case SSL_AD_REASON_OFFSET + SSL_AD_CLOSE_NOTIFY:
- case SSL_R_SSLV3_ALERT_UNEXPECTED_MESSAGE:
- // TODO(joth): SSL_R_SSLV3_ALERT_HANDSHAKE_FAILURE may be returned from the
- // server after receiving ClientHello if there's no common supported cipher.
- // Ideally we'd map that specific case to ERR_SSL_VERSION_OR_CIPHER_MISMATCH
- // to match the NSS implementation. See also http://goo.gl/oMtZW
- case SSL_R_SSLV3_ALERT_HANDSHAKE_FAILURE:
- case SSL_R_SSLV3_ALERT_NO_CERTIFICATE:
- case SSL_R_SSLV3_ALERT_ILLEGAL_PARAMETER:
- case SSL_R_TLSV1_ALERT_DECODE_ERROR:
- case SSL_R_TLSV1_ALERT_DECRYPTION_FAILED:
- case SSL_R_TLSV1_ALERT_EXPORT_RESTRICTION:
- case SSL_R_TLSV1_ALERT_INTERNAL_ERROR:
- case SSL_R_TLSV1_ALERT_NO_RENEGOTIATION:
- case SSL_R_TLSV1_ALERT_RECORD_OVERFLOW:
- case SSL_R_TLSV1_ALERT_USER_CANCELLED:
- return ERR_SSL_PROTOCOL_ERROR;
- case SSL_R_CERTIFICATE_VERIFY_FAILED:
- // The only way that the certificate verify callback can fail is if
- // the leaf certificate changed during a renegotiation.
- return ERR_SSL_SERVER_CERT_CHANGED;
- default:
- LOG(WARNING) << "Unmapped error reason: " << ERR_GET_REASON(error_code);
- return ERR_FAILED;
- }
-}
-
-} // namespace
-
-int MapOpenSSLError(int err, const crypto::OpenSSLErrStackTracer& tracer) {
- switch (err) {
- case SSL_ERROR_WANT_READ:
- case SSL_ERROR_WANT_WRITE:
- return ERR_IO_PENDING;
- case SSL_ERROR_SYSCALL:
- LOG(ERROR) << "OpenSSL SYSCALL error, earliest error code in "
- "error queue: " << ERR_peek_error() << ", errno: "
- << errno;
- return ERR_SSL_PROTOCOL_ERROR;
- case SSL_ERROR_SSL:
- return MapOpenSSLErrorSSL();
- default:
- // TODO(joth): Implement full mapping.
- LOG(WARNING) << "Unknown OpenSSL error " << err;
- return ERR_SSL_PROTOCOL_ERROR;
- }
-}
-
-} // namespace net
diff --git a/chromium/net/socket/openssl_ssl_util.h b/chromium/net/socket/openssl_ssl_util.h
deleted file mode 100644
index e459a445ef3..00000000000
--- a/chromium/net/socket/openssl_ssl_util.h
+++ /dev/null
@@ -1,32 +0,0 @@
-// Copyright 2014 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#ifndef NET_SOCKET_OPENSSL_SSL_UTIL_H_
-#define NET_SOCKET_OPENSSL_SSL_UTIL_H_
-
-namespace crypto {
-class OpenSSLErrStackTracer;
-}
-
-namespace net {
-
-// Utility to construct the appropriate set & clear masks for use the OpenSSL
-// options and mode configuration functions. (SSL_set_options etc)
-struct SslSetClearMask {
- SslSetClearMask();
- void ConfigureFlag(long flag, bool state);
-
- long set_mask;
- long clear_mask;
-};
-
-// Converts an OpenSSL error code into a net error code, walking the OpenSSL
-// error stack if needed. Note that |tracer| is not currently used in the
-// implementation, but is passed in anyway as this ensures the caller will clear
-// any residual codes left on the error stack.
-int MapOpenSSLError(int err, const crypto::OpenSSLErrStackTracer& tracer);
-
-} // namespace net
-
-#endif // NET_SOCKET_OPENSSL_SSL_UTIL_H_
diff --git a/chromium/net/socket/server_socket.cc b/chromium/net/socket/server_socket.cc
new file mode 100644
index 00000000000..da89b4645f8
--- /dev/null
+++ b/chromium/net/socket/server_socket.cc
@@ -0,0 +1,30 @@
+// Copyright 2014 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/server_socket.h"
+
+#include "net/base/ip_endpoint.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_util.h"
+
+namespace net {
+
+ServerSocket::ServerSocket() {
+}
+
+ServerSocket::~ServerSocket() {
+}
+
+int ServerSocket::ListenWithAddressAndPort(const std::string& address_string,
+ int port,
+ int backlog) {
+ IPAddressNumber address_number;
+ if (!ParseIPLiteralToNumber(address_string, &address_number)) {
+ return ERR_ADDRESS_INVALID;
+ }
+
+ return Listen(IPEndPoint(address_number, port), backlog);
+}
+
+} // namespace net
diff --git a/chromium/net/socket/server_socket.h b/chromium/net/socket/server_socket.h
index 11151eea153..4b9ca8e39cf 100644
--- a/chromium/net/socket/server_socket.h
+++ b/chromium/net/socket/server_socket.h
@@ -5,6 +5,8 @@
#ifndef NET_SOCKET_SERVER_SOCKET_H_
#define NET_SOCKET_SERVER_SOCKET_H_
+#include <string>
+
#include "base/memory/scoped_ptr.h"
#include "net/base/completion_callback.h"
#include "net/base/net_export.h"
@@ -16,17 +18,25 @@ class StreamSocket;
class NET_EXPORT ServerSocket {
public:
- ServerSocket() { }
- virtual ~ServerSocket() { }
+ ServerSocket();
+ virtual ~ServerSocket();
- // Bind the socket and start listening. Destroy the socket to stop
+ // Binds the socket and starts listening. Destroys the socket to stop
// listening.
- virtual int Listen(const net::IPEndPoint& address, int backlog) = 0;
+ virtual int Listen(const IPEndPoint& address, int backlog) = 0;
+
+ // Binds the socket with address and port, and starts listening. It expects
+ // a valid IPv4 or IPv6 address. Otherwise, it returns ERR_ADDRESS_INVALID.
+ // Subclasses may override this function if |address_string| is in a different
+ // format, for example, unix domain socket path.
+ virtual int ListenWithAddressAndPort(const std::string& address_string,
+ int port,
+ int backlog);
// Gets current address the socket is bound to.
virtual int GetLocalAddress(IPEndPoint* address) const = 0;
- // Accept connection. Callback is called when new connection is
+ // Accepts connection. Callback is called when new connection is
// accepted.
virtual int Accept(scoped_ptr<StreamSocket>* socket,
const CompletionCallback& callback) = 0;
diff --git a/chromium/net/socket/socket_libevent.cc b/chromium/net/socket/socket_libevent.cc
new file mode 100644
index 00000000000..5f16c929ab2
--- /dev/null
+++ b/chromium/net/socket/socket_libevent.cc
@@ -0,0 +1,482 @@
+// Copyright 2014 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/socket_libevent.h"
+
+#include <errno.h>
+#include <netinet/in.h>
+#include <sys/socket.h>
+
+#include "base/callback_helpers.h"
+#include "base/logging.h"
+#include "base/posix/eintr_wrapper.h"
+#include "net/base/io_buffer.h"
+#include "net/base/ip_endpoint.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_util.h"
+
+namespace net {
+
+namespace {
+
+int MapAcceptError(int os_error) {
+ switch (os_error) {
+ // If the client aborts the connection before the server calls accept,
+ // POSIX specifies accept should fail with ECONNABORTED. The server can
+ // ignore the error and just call accept again, so we map the error to
+ // ERR_IO_PENDING. See UNIX Network Programming, Vol. 1, 3rd Ed., Sec.
+ // 5.11, "Connection Abort before accept Returns".
+ case ECONNABORTED:
+ return ERR_IO_PENDING;
+ default:
+ return MapSystemError(os_error);
+ }
+}
+
+int MapConnectError(int os_error) {
+ switch (os_error) {
+ case EINPROGRESS:
+ return ERR_IO_PENDING;
+ case EACCES:
+ return ERR_NETWORK_ACCESS_DENIED;
+ case ETIMEDOUT:
+ return ERR_CONNECTION_TIMED_OUT;
+ default: {
+ int net_error = MapSystemError(os_error);
+ if (net_error == ERR_FAILED)
+ return ERR_CONNECTION_FAILED; // More specific than ERR_FAILED.
+ return net_error;
+ }
+ }
+}
+
+} // namespace
+
+SocketLibevent::SocketLibevent()
+ : socket_fd_(kInvalidSocket),
+ read_buf_len_(0),
+ write_buf_len_(0),
+ waiting_connect_(false) {
+}
+
+SocketLibevent::~SocketLibevent() {
+ Close();
+}
+
+int SocketLibevent::Open(int address_family) {
+ DCHECK(thread_checker_.CalledOnValidThread());
+ DCHECK_EQ(kInvalidSocket, socket_fd_);
+ DCHECK(address_family == AF_INET ||
+ address_family == AF_INET6 ||
+ address_family == AF_UNIX);
+
+ socket_fd_ = CreatePlatformSocket(
+ address_family,
+ SOCK_STREAM,
+ address_family == AF_UNIX ? 0 : IPPROTO_TCP);
+ if (socket_fd_ < 0) {
+ PLOG(ERROR) << "CreatePlatformSocket() returned an error, errno=" << errno;
+ return MapSystemError(errno);
+ }
+
+ if (SetNonBlocking(socket_fd_)) {
+ int rv = MapSystemError(errno);
+ Close();
+ return rv;
+ }
+
+ return OK;
+}
+
+int SocketLibevent::AdoptConnectedSocket(SocketDescriptor socket,
+ const SockaddrStorage& address) {
+ DCHECK(thread_checker_.CalledOnValidThread());
+ DCHECK_EQ(kInvalidSocket, socket_fd_);
+
+ socket_fd_ = socket;
+
+ if (SetNonBlocking(socket_fd_)) {
+ int rv = MapSystemError(errno);
+ Close();
+ return rv;
+ }
+
+ SetPeerAddress(address);
+ return OK;
+}
+
+SocketDescriptor SocketLibevent::ReleaseConnectedSocket() {
+ StopWatchingAndCleanUp();
+ SocketDescriptor socket_fd = socket_fd_;
+ socket_fd_ = kInvalidSocket;
+ return socket_fd;
+}
+
+int SocketLibevent::Bind(const SockaddrStorage& address) {
+ DCHECK(thread_checker_.CalledOnValidThread());
+ DCHECK_NE(kInvalidSocket, socket_fd_);
+
+ int rv = bind(socket_fd_, address.addr, address.addr_len);
+ if (rv < 0) {
+ PLOG(ERROR) << "bind() returned an error, errno=" << errno;
+ return MapSystemError(errno);
+ }
+
+ return OK;
+}
+
+int SocketLibevent::Listen(int backlog) {
+ DCHECK(thread_checker_.CalledOnValidThread());
+ DCHECK_NE(kInvalidSocket, socket_fd_);
+ DCHECK_LT(0, backlog);
+
+ int rv = listen(socket_fd_, backlog);
+ if (rv < 0) {
+ PLOG(ERROR) << "listen() returned an error, errno=" << errno;
+ return MapSystemError(errno);
+ }
+
+ return OK;
+}
+
+int SocketLibevent::Accept(scoped_ptr<SocketLibevent>* socket,
+ const CompletionCallback& callback) {
+ DCHECK(thread_checker_.CalledOnValidThread());
+ DCHECK_NE(kInvalidSocket, socket_fd_);
+ DCHECK(accept_callback_.is_null());
+ DCHECK(socket);
+ DCHECK(!callback.is_null());
+
+ int rv = DoAccept(socket);
+ if (rv != ERR_IO_PENDING)
+ return rv;
+
+ if (!base::MessageLoopForIO::current()->WatchFileDescriptor(
+ socket_fd_, true, base::MessageLoopForIO::WATCH_READ,
+ &accept_socket_watcher_, this)) {
+ PLOG(ERROR) << "WatchFileDescriptor failed on accept, errno " << errno;
+ return MapSystemError(errno);
+ }
+
+ accept_socket_ = socket;
+ accept_callback_ = callback;
+ return ERR_IO_PENDING;
+}
+
+int SocketLibevent::Connect(const SockaddrStorage& address,
+ const CompletionCallback& callback) {
+ DCHECK(thread_checker_.CalledOnValidThread());
+ DCHECK_NE(kInvalidSocket, socket_fd_);
+ DCHECK(!waiting_connect_);
+ DCHECK(!callback.is_null());
+
+ SetPeerAddress(address);
+
+ int rv = DoConnect();
+ if (rv != ERR_IO_PENDING)
+ return rv;
+
+ if (!base::MessageLoopForIO::current()->WatchFileDescriptor(
+ socket_fd_, true, base::MessageLoopForIO::WATCH_WRITE,
+ &write_socket_watcher_, this)) {
+ PLOG(ERROR) << "WatchFileDescriptor failed on connect, errno " << errno;
+ return MapSystemError(errno);
+ }
+
+ write_callback_ = callback;
+ waiting_connect_ = true;
+ return ERR_IO_PENDING;
+}
+
+bool SocketLibevent::IsConnected() const {
+ DCHECK(thread_checker_.CalledOnValidThread());
+
+ if (socket_fd_ == kInvalidSocket || waiting_connect_)
+ return false;
+
+ // Checks if connection is alive.
+ char c;
+ int rv = HANDLE_EINTR(recv(socket_fd_, &c, 1, MSG_PEEK));
+ if (rv == 0)
+ return false;
+ if (rv == -1 && errno != EAGAIN && errno != EWOULDBLOCK)
+ return false;
+
+ return true;
+}
+
+bool SocketLibevent::IsConnectedAndIdle() const {
+ DCHECK(thread_checker_.CalledOnValidThread());
+
+ if (socket_fd_ == kInvalidSocket || waiting_connect_)
+ return false;
+
+ // Check if connection is alive and we haven't received any data
+ // unexpectedly.
+ char c;
+ int rv = HANDLE_EINTR(recv(socket_fd_, &c, 1, MSG_PEEK));
+ if (rv >= 0)
+ return false;
+ if (errno != EAGAIN && errno != EWOULDBLOCK)
+ return false;
+
+ return true;
+}
+
+int SocketLibevent::Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) {
+ DCHECK(thread_checker_.CalledOnValidThread());
+ DCHECK_NE(kInvalidSocket, socket_fd_);
+ DCHECK(!waiting_connect_);
+ CHECK(read_callback_.is_null());
+ // Synchronous operation not supported
+ DCHECK(!callback.is_null());
+ DCHECK_LT(0, buf_len);
+
+ int rv = DoRead(buf, buf_len);
+ if (rv != ERR_IO_PENDING)
+ return rv;
+
+ if (!base::MessageLoopForIO::current()->WatchFileDescriptor(
+ socket_fd_, true, base::MessageLoopForIO::WATCH_READ,
+ &read_socket_watcher_, this)) {
+ PLOG(ERROR) << "WatchFileDescriptor failed on read, errno " << errno;
+ return MapSystemError(errno);
+ }
+
+ read_buf_ = buf;
+ read_buf_len_ = buf_len;
+ read_callback_ = callback;
+ return ERR_IO_PENDING;
+}
+
+int SocketLibevent::Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) {
+ DCHECK(thread_checker_.CalledOnValidThread());
+ DCHECK_NE(kInvalidSocket, socket_fd_);
+ DCHECK(!waiting_connect_);
+ CHECK(write_callback_.is_null());
+ // Synchronous operation not supported
+ DCHECK(!callback.is_null());
+ DCHECK_LT(0, buf_len);
+
+ int rv = DoWrite(buf, buf_len);
+ if (rv == ERR_IO_PENDING)
+ rv = WaitForWrite(buf, buf_len, callback);
+ return rv;
+}
+
+int SocketLibevent::WaitForWrite(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) {
+ DCHECK(thread_checker_.CalledOnValidThread());
+ DCHECK_NE(kInvalidSocket, socket_fd_);
+ DCHECK(write_callback_.is_null());
+ // Synchronous operation not supported
+ DCHECK(!callback.is_null());
+ DCHECK_LT(0, buf_len);
+
+ if (!base::MessageLoopForIO::current()->WatchFileDescriptor(
+ socket_fd_, true, base::MessageLoopForIO::WATCH_WRITE,
+ &write_socket_watcher_, this)) {
+ PLOG(ERROR) << "WatchFileDescriptor failed on write, errno " << errno;
+ return MapSystemError(errno);
+ }
+
+ write_buf_ = buf;
+ write_buf_len_ = buf_len;
+ write_callback_ = callback;
+ return ERR_IO_PENDING;
+}
+
+int SocketLibevent::GetLocalAddress(SockaddrStorage* address) const {
+ DCHECK(thread_checker_.CalledOnValidThread());
+ DCHECK(address);
+
+ if (getsockname(socket_fd_, address->addr, &address->addr_len) < 0)
+ return MapSystemError(errno);
+ return OK;
+}
+
+int SocketLibevent::GetPeerAddress(SockaddrStorage* address) const {
+ DCHECK(thread_checker_.CalledOnValidThread());
+ DCHECK(address);
+
+ if (!HasPeerAddress())
+ return ERR_SOCKET_NOT_CONNECTED;
+
+ *address = *peer_address_;
+ return OK;
+}
+
+void SocketLibevent::SetPeerAddress(const SockaddrStorage& address) {
+ DCHECK(thread_checker_.CalledOnValidThread());
+ // |peer_address_| will be non-NULL if Connect() has been called. Unless
+ // Close() is called to reset the internal state, a second call to Connect()
+ // is not allowed.
+ // Please note that we don't allow a second Connect() even if the previous
+ // Connect() has failed. Connecting the same |socket_| again after a
+ // connection attempt failed results in unspecified behavior according to
+ // POSIX.
+ DCHECK(!peer_address_);
+ peer_address_.reset(new SockaddrStorage(address));
+}
+
+bool SocketLibevent::HasPeerAddress() const {
+ DCHECK(thread_checker_.CalledOnValidThread());
+ return peer_address_ != NULL;
+}
+
+void SocketLibevent::Close() {
+ DCHECK(thread_checker_.CalledOnValidThread());
+
+ StopWatchingAndCleanUp();
+
+ if (socket_fd_ != kInvalidSocket) {
+ if (IGNORE_EINTR(close(socket_fd_)) < 0)
+ PLOG(ERROR) << "close() returned an error, errno=" << errno;
+ socket_fd_ = kInvalidSocket;
+ }
+}
+
+void SocketLibevent::OnFileCanReadWithoutBlocking(int fd) {
+ DCHECK(!accept_callback_.is_null() || !read_callback_.is_null());
+ if (!accept_callback_.is_null()) {
+ AcceptCompleted();
+ } else { // !read_callback_.is_null()
+ ReadCompleted();
+ }
+}
+
+void SocketLibevent::OnFileCanWriteWithoutBlocking(int fd) {
+ DCHECK(!write_callback_.is_null());
+ if (waiting_connect_) {
+ ConnectCompleted();
+ } else {
+ WriteCompleted();
+ }
+}
+
+int SocketLibevent::DoAccept(scoped_ptr<SocketLibevent>* socket) {
+ SockaddrStorage new_peer_address;
+ int new_socket = HANDLE_EINTR(accept(socket_fd_,
+ new_peer_address.addr,
+ &new_peer_address.addr_len));
+ if (new_socket < 0)
+ return MapAcceptError(errno);
+
+ scoped_ptr<SocketLibevent> accepted_socket(new SocketLibevent);
+ int rv = accepted_socket->AdoptConnectedSocket(new_socket, new_peer_address);
+ if (rv != OK)
+ return rv;
+
+ *socket = accepted_socket.Pass();
+ return OK;
+}
+
+void SocketLibevent::AcceptCompleted() {
+ DCHECK(accept_socket_);
+ int rv = DoAccept(accept_socket_);
+ if (rv == ERR_IO_PENDING)
+ return;
+
+ bool ok = accept_socket_watcher_.StopWatchingFileDescriptor();
+ DCHECK(ok);
+ accept_socket_ = NULL;
+ base::ResetAndReturn(&accept_callback_).Run(rv);
+}
+
+int SocketLibevent::DoConnect() {
+ int rv = HANDLE_EINTR(connect(socket_fd_,
+ peer_address_->addr,
+ peer_address_->addr_len));
+ DCHECK_GE(0, rv);
+ return rv == 0 ? OK : MapConnectError(errno);
+}
+
+void SocketLibevent::ConnectCompleted() {
+ // Get the error that connect() completed with.
+ int os_error = 0;
+ socklen_t len = sizeof(os_error);
+ if (getsockopt(socket_fd_, SOL_SOCKET, SO_ERROR, &os_error, &len) == 0) {
+ // TCPSocketLibevent expects errno to be set.
+ errno = os_error;
+ }
+
+ int rv = MapConnectError(errno);
+ if (rv == ERR_IO_PENDING)
+ return;
+
+ bool ok = write_socket_watcher_.StopWatchingFileDescriptor();
+ DCHECK(ok);
+ waiting_connect_ = false;
+ base::ResetAndReturn(&write_callback_).Run(rv);
+}
+
+int SocketLibevent::DoRead(IOBuffer* buf, int buf_len) {
+ int rv = HANDLE_EINTR(read(socket_fd_, buf->data(), buf_len));
+ return rv >= 0 ? rv : MapSystemError(errno);
+}
+
+void SocketLibevent::ReadCompleted() {
+ int rv = DoRead(read_buf_.get(), read_buf_len_);
+ if (rv == ERR_IO_PENDING)
+ return;
+
+ bool ok = read_socket_watcher_.StopWatchingFileDescriptor();
+ DCHECK(ok);
+ read_buf_ = NULL;
+ read_buf_len_ = 0;
+ base::ResetAndReturn(&read_callback_).Run(rv);
+}
+
+int SocketLibevent::DoWrite(IOBuffer* buf, int buf_len) {
+ int rv = HANDLE_EINTR(write(socket_fd_, buf->data(), buf_len));
+ return rv >= 0 ? rv : MapSystemError(errno);
+}
+
+void SocketLibevent::WriteCompleted() {
+ int rv = DoWrite(write_buf_.get(), write_buf_len_);
+ if (rv == ERR_IO_PENDING)
+ return;
+
+ bool ok = write_socket_watcher_.StopWatchingFileDescriptor();
+ DCHECK(ok);
+ write_buf_ = NULL;
+ write_buf_len_ = 0;
+ base::ResetAndReturn(&write_callback_).Run(rv);
+}
+
+void SocketLibevent::StopWatchingAndCleanUp() {
+ bool ok = accept_socket_watcher_.StopWatchingFileDescriptor();
+ DCHECK(ok);
+ ok = read_socket_watcher_.StopWatchingFileDescriptor();
+ DCHECK(ok);
+ ok = write_socket_watcher_.StopWatchingFileDescriptor();
+ DCHECK(ok);
+
+ if (!accept_callback_.is_null()) {
+ accept_socket_ = NULL;
+ accept_callback_.Reset();
+ }
+
+ if (!read_callback_.is_null()) {
+ read_buf_ = NULL;
+ read_buf_len_ = 0;
+ read_callback_.Reset();
+ }
+
+ if (!write_callback_.is_null()) {
+ write_buf_ = NULL;
+ write_buf_len_ = 0;
+ write_callback_.Reset();
+ }
+
+ waiting_connect_ = false;
+ peer_address_.reset();
+}
+
+} // namespace net
diff --git a/chromium/net/socket/socket_libevent.h b/chromium/net/socket/socket_libevent.h
new file mode 100644
index 00000000000..00a0ca6fa7c
--- /dev/null
+++ b/chromium/net/socket/socket_libevent.h
@@ -0,0 +1,132 @@
+// Copyright 2014 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_SOCKET_LIBEVENT_H_
+#define NET_SOCKET_SOCKET_LIBEVENT_H_
+
+#include "base/basictypes.h"
+#include "base/compiler_specific.h"
+#include "base/macros.h"
+#include "base/memory/ref_counted.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/message_loop/message_loop.h"
+#include "base/threading/thread_checker.h"
+#include "net/base/completion_callback.h"
+#include "net/base/net_util.h"
+#include "net/socket/socket_descriptor.h"
+
+namespace net {
+
+class IOBuffer;
+class IPEndPoint;
+
+// Socket class to provide asynchronous read/write operations on top of the
+// posix socket api. It supports AF_INET, AF_INET6, and AF_UNIX addresses.
+class NET_EXPORT_PRIVATE SocketLibevent
+ : public base::MessageLoopForIO::Watcher {
+ public:
+ SocketLibevent();
+ ~SocketLibevent() override;
+
+ // Opens a socket and returns net::OK if |address_family| is AF_INET, AF_INET6
+ // or AF_UNIX. Otherwise, it does DCHECK() and returns a net error.
+ int Open(int address_family);
+ // Takes ownership of |socket|.
+ int AdoptConnectedSocket(SocketDescriptor socket,
+ const SockaddrStorage& peer_address);
+ // Releases ownership of |socket_fd_| to caller.
+ SocketDescriptor ReleaseConnectedSocket();
+
+ int Bind(const SockaddrStorage& address);
+
+ int Listen(int backlog);
+ int Accept(scoped_ptr<SocketLibevent>* socket,
+ const CompletionCallback& callback);
+
+ // Connects socket. On non-ERR_IO_PENDING error, sets errno and returns a net
+ // error code. On ERR_IO_PENDING, |callback| is called with a net error code,
+ // not errno, though errno is set if connect event happens with error.
+ // TODO(byungchul): Need more robust way to pass system errno.
+ int Connect(const SockaddrStorage& address,
+ const CompletionCallback& callback);
+ bool IsConnected() const;
+ bool IsConnectedAndIdle() const;
+
+ // Multiple outstanding requests of the same type are not supported.
+ // Full duplex mode (reading and writing at the same time) is supported.
+ // On error which is not ERR_IO_PENDING, sets errno and returns a net error
+ // code. On ERR_IO_PENDING, |callback| is called with a net error code, not
+ // errno, though errno is set if read or write events happen with error.
+ // TODO(byungchul): Need more robust way to pass system errno.
+ int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback);
+ int Write(IOBuffer* buf, int buf_len, const CompletionCallback& callback);
+
+ // Waits for next write event. This is called by TCPsocketLibevent for TCP
+ // fastopen after sending first data. Returns ERR_IO_PENDING if it starts
+ // waiting for write event successfully. Otherwise, returns a net error code.
+ // It must not be called after Write() because Write() calls it internally.
+ int WaitForWrite(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback);
+
+ int GetLocalAddress(SockaddrStorage* address) const;
+ int GetPeerAddress(SockaddrStorage* address) const;
+ void SetPeerAddress(const SockaddrStorage& address);
+ // Returns true if peer address has been set regardless of socket state.
+ bool HasPeerAddress() const;
+
+ void Close();
+
+ SocketDescriptor socket_fd() const { return socket_fd_; }
+
+ private:
+ // base::MessageLoopForIO::Watcher methods.
+ void OnFileCanReadWithoutBlocking(int fd) override;
+ void OnFileCanWriteWithoutBlocking(int fd) override;
+
+ int DoAccept(scoped_ptr<SocketLibevent>* socket);
+ void AcceptCompleted();
+
+ int DoConnect();
+ void ConnectCompleted();
+
+ int DoRead(IOBuffer* buf, int buf_len);
+ void ReadCompleted();
+
+ int DoWrite(IOBuffer* buf, int buf_len);
+ void WriteCompleted();
+
+ void StopWatchingAndCleanUp();
+
+ SocketDescriptor socket_fd_;
+
+ base::MessageLoopForIO::FileDescriptorWatcher accept_socket_watcher_;
+ scoped_ptr<SocketLibevent>* accept_socket_;
+ CompletionCallback accept_callback_;
+
+ base::MessageLoopForIO::FileDescriptorWatcher read_socket_watcher_;
+ scoped_refptr<IOBuffer> read_buf_;
+ int read_buf_len_;
+ // External callback; called when read is complete.
+ CompletionCallback read_callback_;
+
+ base::MessageLoopForIO::FileDescriptorWatcher write_socket_watcher_;
+ scoped_refptr<IOBuffer> write_buf_;
+ int write_buf_len_;
+ // External callback; called when write or connect is complete.
+ CompletionCallback write_callback_;
+
+ // A connect operation is pending. In this case, |write_callback_| needs to be
+ // called when connect is complete.
+ bool waiting_connect_;
+
+ scoped_ptr<SockaddrStorage> peer_address_;
+
+ base::ThreadChecker thread_checker_;
+
+ DISALLOW_COPY_AND_ASSIGN(SocketLibevent);
+};
+
+} // namespace net
+
+#endif // NET_SOCKET_SOCKET_LIBEVENT_H_
diff --git a/chromium/net/socket/socket_test_util.cc b/chromium/net/socket/socket_test_util.cc
index f993801adeb..5ae4eeeedb8 100644
--- a/chromium/net/socket/socket_test_util.cc
+++ b/chromium/net/socket/socket_test_util.cc
@@ -10,6 +10,7 @@
#include "base/basictypes.h"
#include "base/bind.h"
#include "base/bind_helpers.h"
+#include "base/callback_helpers.h"
#include "base/compiler_specific.h"
#include "base/message_loop/message_loop.h"
#include "base/run_loop.h"
@@ -277,7 +278,9 @@ SSLSocketDataProvider::SSLSocketDataProvider(IoMode mode, int result)
client_cert_sent(false),
cert_request_info(NULL),
channel_id_sent(false),
- connection_status(0) {
+ connection_status(0),
+ should_pause_on_connect(false),
+ is_in_session_cache(false) {
SSLConnectionStatusSetVersion(SSL_CONNECTION_VERSION_TLS1_2,
&connection_status);
// Set to TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305
@@ -680,7 +683,7 @@ MockClientSocketFactory::CreateDatagramClientSocket(
data_provider->set_socket(socket.get());
if (bind_type == DatagramSocket::RANDOM_BIND)
socket->set_source_port(rand_int_cb.Run(1025, 65535));
- return socket.PassAs<DatagramClientSocket>();
+ return socket.Pass();
}
scoped_ptr<StreamSocket> MockClientSocketFactory::CreateTransportClientSocket(
@@ -691,7 +694,7 @@ scoped_ptr<StreamSocket> MockClientSocketFactory::CreateTransportClientSocket(
scoped_ptr<MockTCPClientSocket> socket(
new MockTCPClientSocket(addresses, net_log, data_provider));
data_provider->set_socket(socket.get());
- return socket.PassAs<StreamSocket>();
+ return socket.Pass();
}
scoped_ptr<SSLClientSocket> MockClientSocketFactory::CreateSSLClientSocket(
@@ -699,10 +702,13 @@ scoped_ptr<SSLClientSocket> MockClientSocketFactory::CreateSSLClientSocket(
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
const SSLClientSocketContext& context) {
- return scoped_ptr<SSLClientSocket>(
+ scoped_ptr<MockSSLClientSocket> socket(
new MockSSLClientSocket(transport_socket.Pass(),
- host_and_port, ssl_config,
+ host_and_port,
+ ssl_config,
mock_ssl_data_.GetNext()));
+ ssl_client_sockets_.push_back(socket.get());
+ return socket.Pass();
}
void MockClientSocketFactory::ClearSSLSessionCache() {
@@ -758,6 +764,20 @@ const BoundNetLog& MockClientSocket::NetLog() const {
return net_log_;
}
+std::string MockClientSocket::GetSessionCacheKey() const {
+ NOTIMPLEMENTED();
+ return std::string();
+}
+
+bool MockClientSocket::InSessionCache() const {
+ NOTIMPLEMENTED();
+ return false;
+}
+
+void MockClientSocket::SetHandshakeCompletionCallback(const base::Closure& cb) {
+ NOTIMPLEMENTED();
+}
+
void MockClientSocket::GetSSLCertRequestInfo(
SSLCertRequestInfo* cert_request_info) {
}
@@ -776,15 +796,14 @@ int MockClientSocket::GetTLSUniqueChannelBinding(std::string* out) {
return OK;
}
-ServerBoundCertService* MockClientSocket::GetServerBoundCertService() const {
+ChannelIDService* MockClientSocket::GetChannelIDService() const {
NOTREACHED();
return NULL;
}
SSLClientSocket::NextProtoStatus
-MockClientSocket::GetNextProto(std::string* proto, std::string* server_protos) {
+MockClientSocket::GetNextProto(std::string* proto) {
proto->clear();
- server_protos->clear();
return SSLClientSocket::kNextProtoUnsupported;
}
@@ -838,7 +857,7 @@ int MockTCPClientSocket::Read(IOBuffer* buf, int buf_len,
return ERR_UNEXPECTED;
// If the buffer is already in use, a read is already in progress!
- DCHECK(pending_buf_ == NULL);
+ DCHECK(pending_buf_.get() == NULL);
// Store our async IO data.
pending_buf_ = buf;
@@ -946,7 +965,7 @@ bool MockTCPClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
void MockTCPClientSocket::OnReadComplete(const MockRead& data) {
// There must be a read pending.
- DCHECK(pending_buf_);
+ DCHECK(pending_buf_.get());
// You can't complete a read with another ERR_IO_PENDING status code.
DCHECK_NE(ERR_IO_PENDING, data.result);
// Since we've been waiting for data, need_read_data_ should be true.
@@ -970,7 +989,7 @@ void MockTCPClientSocket::OnConnectComplete(const MockConnect& data) {
}
int MockTCPClientSocket::CompleteRead() {
- DCHECK(pending_buf_);
+ DCHECK(pending_buf_.get());
DCHECK(pending_buf_len_ > 0);
was_used_to_convey_data_ = true;
@@ -1298,31 +1317,25 @@ void DeterministicMockTCPClientSocket::OnReadComplete(const MockRead& data) {}
void DeterministicMockTCPClientSocket::OnConnectComplete(
const MockConnect& data) {}
-// static
-void MockSSLClientSocket::ConnectCallback(
- MockSSLClientSocket* ssl_client_socket,
- const CompletionCallback& callback,
- int rv) {
- if (rv == OK)
- ssl_client_socket->connected_ = true;
- callback.Run(rv);
-}
-
MockSSLClientSocket::MockSSLClientSocket(
scoped_ptr<ClientSocketHandle> transport_socket,
const HostPortPair& host_port_pair,
const SSLConfig& ssl_config,
SSLSocketDataProvider* data)
: MockClientSocket(
- // Have to use the right BoundNetLog for LoadTimingInfo regression
- // tests.
- transport_socket->socket()->NetLog()),
+ // Have to use the right BoundNetLog for LoadTimingInfo regression
+ // tests.
+ transport_socket->socket()->NetLog()),
transport_(transport_socket.Pass()),
+ host_port_pair_(host_port_pair),
data_(data),
is_npn_state_set_(false),
new_npn_value_(false),
is_protocol_negotiated_set_(false),
- protocol_negotiated_(kProtoUnknown) {
+ protocol_negotiated_(kProtoUnknown),
+ next_connect_state_(STATE_NONE),
+ reached_connect_(false),
+ weak_factory_(this) {
DCHECK(data_);
peer_addr_ = data->connect.peer_addr;
}
@@ -1342,28 +1355,23 @@ int MockSSLClientSocket::Write(IOBuffer* buf, int buf_len,
}
int MockSSLClientSocket::Connect(const CompletionCallback& callback) {
- int rv = transport_->socket()->Connect(
- base::Bind(&ConnectCallback, base::Unretained(this), callback));
- if (rv == OK) {
- if (data_->connect.result == OK)
- connected_ = true;
- if (data_->connect.mode == ASYNC) {
- RunCallbackAsync(callback, data_->connect.result);
- return ERR_IO_PENDING;
- }
- return data_->connect.result;
- }
+ next_connect_state_ = STATE_SSL_CONNECT;
+ reached_connect_ = true;
+ int rv = DoConnectLoop(OK);
+ if (rv == ERR_IO_PENDING)
+ connect_callback_ = callback;
return rv;
}
void MockSSLClientSocket::Disconnect() {
+ weak_factory_.InvalidateWeakPtrs();
MockClientSocket::Disconnect();
if (transport_->socket() != NULL)
transport_->socket()->Disconnect();
}
bool MockSSLClientSocket::IsConnected() const {
- return transport_->socket()->IsConnected();
+ return transport_->socket()->IsConnected() && connected_;
}
bool MockSSLClientSocket::WasEverUsed() const {
@@ -1387,6 +1395,21 @@ bool MockSSLClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
return true;
}
+std::string MockSSLClientSocket::GetSessionCacheKey() const {
+ // For the purposes of these tests, |host_and_port| will serve as the
+ // cache key.
+ return host_port_pair_.ToString();
+}
+
+bool MockSSLClientSocket::InSessionCache() const {
+ return data_->is_in_session_cache;
+}
+
+void MockSSLClientSocket::SetHandshakeCompletionCallback(
+ const base::Closure& cb) {
+ handshake_completion_callback_ = cb;
+}
+
void MockSSLClientSocket::GetSSLCertRequestInfo(
SSLCertRequestInfo* cert_request_info) {
DCHECK(cert_request_info);
@@ -1400,9 +1423,8 @@ void MockSSLClientSocket::GetSSLCertRequestInfo(
}
SSLClientSocket::NextProtoStatus MockSSLClientSocket::GetNextProto(
- std::string* proto, std::string* server_protos) {
+ std::string* proto) {
*proto = data_->next_proto;
- *server_protos = data_->server_protos;
return data_->next_proto_status;
}
@@ -1437,8 +1459,8 @@ void MockSSLClientSocket::set_channel_id_sent(bool channel_id_sent) {
data_->channel_id_sent = channel_id_sent;
}
-ServerBoundCertService* MockSSLClientSocket::GetServerBoundCertService() const {
- return data_->server_bound_cert_service;
+ChannelIDService* MockSSLClientSocket::GetChannelIDService() const {
+ return data_->channel_id_service;
}
void MockSSLClientSocket::OnReadComplete(const MockRead& data) {
@@ -1449,6 +1471,69 @@ void MockSSLClientSocket::OnConnectComplete(const MockConnect& data) {
NOTIMPLEMENTED();
}
+void MockSSLClientSocket::RestartPausedConnect() {
+ DCHECK(data_->should_pause_on_connect);
+ DCHECK_EQ(next_connect_state_, STATE_SSL_CONNECT_COMPLETE);
+ OnIOComplete(data_->connect.result);
+}
+
+void MockSSLClientSocket::OnIOComplete(int result) {
+ int rv = DoConnectLoop(result);
+ if (rv != ERR_IO_PENDING)
+ base::ResetAndReturn(&connect_callback_).Run(rv);
+}
+
+int MockSSLClientSocket::DoConnectLoop(int result) {
+ DCHECK_NE(next_connect_state_, STATE_NONE);
+
+ int rv = result;
+ do {
+ ConnectState state = next_connect_state_;
+ next_connect_state_ = STATE_NONE;
+ switch (state) {
+ case STATE_SSL_CONNECT:
+ rv = DoSSLConnect();
+ break;
+ case STATE_SSL_CONNECT_COMPLETE:
+ rv = DoSSLConnectComplete(rv);
+ break;
+ default:
+ NOTREACHED() << "bad state";
+ rv = ERR_UNEXPECTED;
+ break;
+ }
+ } while (rv != ERR_IO_PENDING && next_connect_state_ != STATE_NONE);
+
+ return rv;
+}
+
+int MockSSLClientSocket::DoSSLConnect() {
+ next_connect_state_ = STATE_SSL_CONNECT_COMPLETE;
+
+ if (data_->should_pause_on_connect)
+ return ERR_IO_PENDING;
+
+ if (data_->connect.mode == ASYNC) {
+ base::MessageLoop::current()->PostTask(
+ FROM_HERE,
+ base::Bind(&MockSSLClientSocket::OnIOComplete,
+ weak_factory_.GetWeakPtr(),
+ data_->connect.result));
+ return ERR_IO_PENDING;
+ }
+
+ return data_->connect.result;
+}
+
+int MockSSLClientSocket::DoSSLConnectComplete(int result) {
+ if (result == OK)
+ connected_ = true;
+
+ if (!handshake_completion_callback_.is_null())
+ base::ResetAndReturn(&handshake_completion_callback_).Run();
+ return result;
+}
+
MockUDPClientSocket::MockUDPClientSocket(SocketDataProvider* data,
net::NetLog* net_log)
: connected_(false),
@@ -1475,7 +1560,7 @@ int MockUDPClientSocket::Read(IOBuffer* buf,
return ERR_UNEXPECTED;
// If the buffer is already in use, a read is already in progress!
- DCHECK(pending_buf_ == NULL);
+ DCHECK(pending_buf_.get() == NULL);
// Store our async IO data.
pending_buf_ = buf;
@@ -1552,7 +1637,7 @@ int MockUDPClientSocket::Connect(const IPEndPoint& address) {
void MockUDPClientSocket::OnReadComplete(const MockRead& data) {
// There must be a read pending.
- DCHECK(pending_buf_);
+ DCHECK(pending_buf_.get());
// You can't complete a read with another ERR_IO_PENDING status code.
DCHECK_NE(ERR_IO_PENDING, data.result);
// Since we've been waiting for data, need_read_data_ should be true.
@@ -1575,7 +1660,7 @@ void MockUDPClientSocket::OnConnectComplete(const MockConnect& data) {
}
int MockUDPClientSocket::CompleteRead() {
- DCHECK(pending_buf_);
+ DCHECK(pending_buf_.get());
DCHECK(pending_buf_len_ > 0);
// Save the pending async IO data and reset our |pending_| state.
@@ -1833,7 +1918,7 @@ DeterministicMockClientSocketFactory::CreateDatagramClientSocket(
udp_client_sockets().push_back(socket.get());
if (bind_type == DatagramSocket::RANDOM_BIND)
socket->set_source_port(rand_int_cb.Run(1025, 65535));
- return socket.PassAs<DatagramClientSocket>();
+ return socket.Pass();
}
scoped_ptr<StreamSocket>
@@ -1846,7 +1931,7 @@ DeterministicMockClientSocketFactory::CreateTransportClientSocket(
new DeterministicMockTCPClientSocket(net_log, data_provider));
data_provider->set_delegate(socket->AsWeakPtr());
tcp_client_sockets().push_back(socket.get());
- return socket.PassAs<StreamSocket>();
+ return socket.Pass();
}
scoped_ptr<SSLClientSocket>
@@ -1860,7 +1945,7 @@ DeterministicMockClientSocketFactory::CreateSSLClientSocket(
host_and_port, ssl_config,
mock_ssl_data_.GetNext()));
ssl_client_sockets_.push_back(socket.get());
- return socket.PassAs<SSLClientSocket>();
+ return socket.Pass();
}
void DeterministicMockClientSocketFactory::ClearSSLSessionCache() {
diff --git a/chromium/net/socket/socket_test_util.h b/chromium/net/socket/socket_test_util.h
index 2918aad2dc5..7bccdaed727 100644
--- a/chromium/net/socket/socket_test_util.h
+++ b/chromium/net/socket/socket_test_util.h
@@ -47,8 +47,8 @@ enum {
};
class AsyncSocket;
+class ChannelIDService;
class MockClientSocket;
-class ServerBoundCertService;
class SSLClientSocket;
class StreamSocket;
@@ -243,7 +243,7 @@ class StaticSocketDataProvider : public SocketDataProvider {
size_t reads_count,
MockWrite* writes,
size_t writes_count);
- virtual ~StaticSocketDataProvider();
+ ~StaticSocketDataProvider() override;
// These functions get access to the next available read and write data.
const MockRead& PeekRead() const;
@@ -262,9 +262,9 @@ class StaticSocketDataProvider : public SocketDataProvider {
virtual void CompleteRead() {}
// SocketDataProvider implementation.
- virtual MockRead GetNextRead() OVERRIDE;
- virtual MockWriteResult OnWrite(const std::string& data) OVERRIDE;
- virtual void Reset() OVERRIDE;
+ MockRead GetNextRead() override;
+ MockWriteResult OnWrite(const std::string& data) override;
+ void Reset() override;
private:
MockRead* reads_;
@@ -284,7 +284,7 @@ class StaticSocketDataProvider : public SocketDataProvider {
class DynamicSocketDataProvider : public SocketDataProvider {
public:
DynamicSocketDataProvider();
- virtual ~DynamicSocketDataProvider();
+ ~DynamicSocketDataProvider() override;
int short_read_limit() const { return short_read_limit_; }
void set_short_read_limit(int limit) { short_read_limit_ = limit; }
@@ -292,9 +292,9 @@ class DynamicSocketDataProvider : public SocketDataProvider {
void allow_unconsumed_reads(bool allow) { allow_unconsumed_reads_ = allow; }
// SocketDataProvider implementation.
- virtual MockRead GetNextRead() OVERRIDE;
+ MockRead GetNextRead() override;
virtual MockWriteResult OnWrite(const std::string& data) = 0;
- virtual void Reset() OVERRIDE;
+ void Reset() override;
protected:
// The next time there is a read from this socket, it will return |data|.
@@ -326,15 +326,20 @@ struct SSLSocketDataProvider {
MockConnect connect;
SSLClientSocket::NextProtoStatus next_proto_status;
std::string next_proto;
- std::string server_protos;
bool was_npn_negotiated;
NextProto protocol_negotiated;
bool client_cert_sent;
SSLCertRequestInfo* cert_request_info;
scoped_refptr<X509Certificate> cert;
bool channel_id_sent;
- ServerBoundCertService* server_bound_cert_service;
+ ChannelIDService* channel_id_service;
int connection_status;
+ // Indicates that the socket should pause in the Connect method.
+ bool should_pause_on_connect;
+ // Whether or not the Socket should behave like there is a pre-existing
+ // session to resume. Whether or not such a session is reported as
+ // resumed is controlled by |connection_status|.
+ bool is_in_session_cache;
};
// A DataProvider where the client must write a request before the reads (e.g.
@@ -366,15 +371,15 @@ class DelayedSocketData : public StaticSocketDataProvider {
size_t reads_count,
MockWrite* writes,
size_t writes_count);
- virtual ~DelayedSocketData();
+ ~DelayedSocketData() override;
void ForceNextRead();
// StaticSocketDataProvider:
- virtual MockRead GetNextRead() OVERRIDE;
- virtual MockWriteResult OnWrite(const std::string& data) OVERRIDE;
- virtual void Reset() OVERRIDE;
- virtual void CompleteRead() OVERRIDE;
+ MockRead GetNextRead() override;
+ MockWriteResult OnWrite(const std::string& data) override;
+ void Reset() override;
+ void CompleteRead() override;
private:
int write_delay_;
@@ -407,7 +412,7 @@ class OrderedSocketData : public StaticSocketDataProvider {
size_t reads_count,
MockWrite* writes,
size_t writes_count);
- virtual ~OrderedSocketData();
+ ~OrderedSocketData() override;
// |connect| the result for the connect phase.
// |reads| the list of MockRead completions.
@@ -425,10 +430,10 @@ class OrderedSocketData : public StaticSocketDataProvider {
void EndLoop();
// StaticSocketDataProvider:
- virtual MockRead GetNextRead() OVERRIDE;
- virtual MockWriteResult OnWrite(const std::string& data) OVERRIDE;
- virtual void Reset() OVERRIDE;
- virtual void CompleteRead() OVERRIDE;
+ MockRead GetNextRead() override;
+ MockWriteResult OnWrite(const std::string& data) override;
+ void Reset() override;
+ void CompleteRead() override;
private:
int sequence_number_;
@@ -531,7 +536,7 @@ class DeterministicSocketData : public StaticSocketDataProvider {
size_t reads_count,
MockWrite* writes,
size_t writes_count);
- virtual ~DeterministicSocketData();
+ ~DeterministicSocketData() override;
// Consume all the data up to the give stop point (via SetStop()).
void Run();
@@ -555,14 +560,14 @@ class DeterministicSocketData : public StaticSocketDataProvider {
// When the socket calls Read(), that calls GetNextRead(), and expects either
// ERR_IO_PENDING or data.
- virtual MockRead GetNextRead() OVERRIDE;
+ MockRead GetNextRead() override;
// When the socket calls Write(), it always completes synchronously. OnWrite()
// checks to make sure the written data matches the expected data. The
// callback will not be invoked until its sequence number is reached.
- virtual MockWriteResult OnWrite(const std::string& data) OVERRIDE;
- virtual void Reset() OVERRIDE;
- virtual void CompleteRead() OVERRIDE {}
+ MockWriteResult OnWrite(const std::string& data) override;
+ void Reset() override;
+ void CompleteRead() override {}
private:
// Invoke the read and write callbacks, if the timing is appropriate.
@@ -628,7 +633,7 @@ class MockSSLClientSocket;
class MockClientSocketFactory : public ClientSocketFactory {
public:
MockClientSocketFactory();
- virtual ~MockClientSocketFactory();
+ ~MockClientSocketFactory() override;
void AddSocketDataProvider(SocketDataProvider* socket);
void AddSSLSocketDataProvider(SSLSocketDataProvider* socket);
@@ -638,26 +643,33 @@ class MockClientSocketFactory : public ClientSocketFactory {
return mock_data_;
}
+ // Note: this method is unsafe; the elements of the returned vector
+ // are not necessarily valid.
+ const std::vector<MockSSLClientSocket*>& ssl_client_sockets() const {
+ return ssl_client_sockets_;
+ }
+
// ClientSocketFactory
- virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket(
+ scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket(
DatagramSocket::BindType bind_type,
const RandIntCallback& rand_int_cb,
NetLog* net_log,
- const NetLog::Source& source) OVERRIDE;
- virtual scoped_ptr<StreamSocket> CreateTransportClientSocket(
+ const NetLog::Source& source) override;
+ scoped_ptr<StreamSocket> CreateTransportClientSocket(
const AddressList& addresses,
NetLog* net_log,
- const NetLog::Source& source) OVERRIDE;
- virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
+ const NetLog::Source& source) override;
+ scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
scoped_ptr<ClientSocketHandle> transport_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
- const SSLClientSocketContext& context) OVERRIDE;
- virtual void ClearSSLSessionCache() OVERRIDE;
+ const SSLClientSocketContext& context) override;
+ void ClearSSLSessionCache() override;
private:
SocketDataProviderArray<SocketDataProvider> mock_data_;
SocketDataProviderArray<SSLSocketDataProvider> mock_ssl_data_;
+ std::vector<MockSSLClientSocket*> ssl_client_sockets_;
};
class MockClientSocket : public SSLClientSocket {
@@ -676,41 +688,42 @@ class MockClientSocket : public SSLClientSocket {
virtual int Write(IOBuffer* buf,
int buf_len,
const CompletionCallback& callback) = 0;
- virtual int SetReceiveBufferSize(int32 size) OVERRIDE;
- virtual int SetSendBufferSize(int32 size) OVERRIDE;
+ int SetReceiveBufferSize(int32 size) override;
+ int SetSendBufferSize(int32 size) override;
// StreamSocket implementation.
virtual int Connect(const CompletionCallback& callback) = 0;
- virtual void Disconnect() OVERRIDE;
- virtual bool IsConnected() const OVERRIDE;
- virtual bool IsConnectedAndIdle() const OVERRIDE;
- virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE;
- virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE;
- virtual const BoundNetLog& NetLog() const OVERRIDE;
- virtual void SetSubresourceSpeculation() OVERRIDE {}
- virtual void SetOmniboxSpeculation() OVERRIDE {}
+ void Disconnect() override;
+ bool IsConnected() const override;
+ bool IsConnectedAndIdle() const override;
+ int GetPeerAddress(IPEndPoint* address) const override;
+ int GetLocalAddress(IPEndPoint* address) const override;
+ const BoundNetLog& NetLog() const override;
+ void SetSubresourceSpeculation() override {}
+ void SetOmniboxSpeculation() override {}
// SSLClientSocket implementation.
- virtual void GetSSLCertRequestInfo(SSLCertRequestInfo* cert_request_info)
- OVERRIDE;
- virtual int ExportKeyingMaterial(const base::StringPiece& label,
- bool has_context,
- const base::StringPiece& context,
- unsigned char* out,
- unsigned int outlen) OVERRIDE;
- virtual int GetTLSUniqueChannelBinding(std::string* out) OVERRIDE;
- virtual NextProtoStatus GetNextProto(std::string* proto,
- std::string* server_protos) OVERRIDE;
- virtual ServerBoundCertService* GetServerBoundCertService() const OVERRIDE;
+ std::string GetSessionCacheKey() const override;
+ bool InSessionCache() const override;
+ void SetHandshakeCompletionCallback(const base::Closure& cb) override;
+ void GetSSLCertRequestInfo(SSLCertRequestInfo* cert_request_info) override;
+ int ExportKeyingMaterial(const base::StringPiece& label,
+ bool has_context,
+ const base::StringPiece& context,
+ unsigned char* out,
+ unsigned int outlen) override;
+ int GetTLSUniqueChannelBinding(std::string* out) override;
+ NextProtoStatus GetNextProto(std::string* proto) override;
+ ChannelIDService* GetChannelIDService() const override;
protected:
- virtual ~MockClientSocket();
+ ~MockClientSocket() override;
void RunCallbackAsync(const CompletionCallback& callback, int result);
void RunCallback(const CompletionCallback& callback, int result);
// SSLClientSocket implementation.
- virtual scoped_refptr<X509Certificate> GetUnverifiedServerCertificateChain()
- const OVERRIDE;
+ scoped_refptr<X509Certificate> GetUnverifiedServerCertificateChain()
+ const override;
// True if Connect completed successfully and Disconnect hasn't been called.
bool connected_;
@@ -720,6 +733,7 @@ class MockClientSocket : public SSLClientSocket {
BoundNetLog net_log_;
+ private:
base::WeakPtrFactory<MockClientSocket> weak_factory_;
DISALLOW_COPY_AND_ASSIGN(MockClientSocket);
@@ -730,32 +744,32 @@ class MockTCPClientSocket : public MockClientSocket, public AsyncSocket {
MockTCPClientSocket(const AddressList& addresses,
net::NetLog* net_log,
SocketDataProvider* socket);
- virtual ~MockTCPClientSocket();
+ ~MockTCPClientSocket() override;
const AddressList& addresses() const { return addresses_; }
// Socket implementation.
- virtual int Read(IOBuffer* buf,
- int buf_len,
- const CompletionCallback& callback) OVERRIDE;
- virtual int Write(IOBuffer* buf,
- int buf_len,
- const CompletionCallback& callback) OVERRIDE;
+ int Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override;
+ int Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override;
// StreamSocket implementation.
- virtual int Connect(const CompletionCallback& callback) OVERRIDE;
- virtual void Disconnect() OVERRIDE;
- virtual bool IsConnected() const OVERRIDE;
- virtual bool IsConnectedAndIdle() const OVERRIDE;
- virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE;
- virtual bool WasEverUsed() const OVERRIDE;
- virtual bool UsingTCPFastOpen() const OVERRIDE;
- virtual bool WasNpnNegotiated() const OVERRIDE;
- virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE;
+ int Connect(const CompletionCallback& callback) override;
+ void Disconnect() override;
+ bool IsConnected() const override;
+ bool IsConnectedAndIdle() const override;
+ int GetPeerAddress(IPEndPoint* address) const override;
+ bool WasEverUsed() const override;
+ bool UsingTCPFastOpen() const override;
+ bool WasNpnNegotiated() const override;
+ bool GetSSLInfo(SSLInfo* ssl_info) override;
// AsyncSocket:
- virtual void OnReadComplete(const MockRead& data) OVERRIDE;
- virtual void OnConnectComplete(const MockConnect& data) OVERRIDE;
+ void OnReadComplete(const MockRead& data) override;
+ void OnConnectComplete(const MockConnect& data) override;
private:
int CompleteRead();
@@ -836,36 +850,36 @@ class DeterministicMockUDPClientSocket
public:
DeterministicMockUDPClientSocket(net::NetLog* net_log,
DeterministicSocketData* data);
- virtual ~DeterministicMockUDPClientSocket();
+ ~DeterministicMockUDPClientSocket() override;
// DeterministicSocketData::Delegate:
- virtual bool WritePending() const OVERRIDE;
- virtual bool ReadPending() const OVERRIDE;
- virtual void CompleteWrite() OVERRIDE;
- virtual int CompleteRead() OVERRIDE;
+ bool WritePending() const override;
+ bool ReadPending() const override;
+ void CompleteWrite() override;
+ int CompleteRead() override;
// Socket implementation.
- virtual int Read(IOBuffer* buf,
- int buf_len,
- const CompletionCallback& callback) OVERRIDE;
- virtual int Write(IOBuffer* buf,
- int buf_len,
- const CompletionCallback& callback) OVERRIDE;
- virtual int SetReceiveBufferSize(int32 size) OVERRIDE;
- virtual int SetSendBufferSize(int32 size) OVERRIDE;
+ int Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override;
+ int Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override;
+ int SetReceiveBufferSize(int32 size) override;
+ int SetSendBufferSize(int32 size) override;
// DatagramSocket implementation.
- virtual void Close() OVERRIDE;
- virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE;
- virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE;
- virtual const BoundNetLog& NetLog() const OVERRIDE;
+ void Close() override;
+ int GetPeerAddress(IPEndPoint* address) const override;
+ int GetLocalAddress(IPEndPoint* address) const override;
+ const BoundNetLog& NetLog() const override;
// DatagramClientSocket implementation.
- virtual int Connect(const IPEndPoint& address) OVERRIDE;
+ int Connect(const IPEndPoint& address) override;
// AsyncSocket implementation.
- virtual void OnReadComplete(const MockRead& data) OVERRIDE;
- virtual void OnConnectComplete(const MockConnect& data) OVERRIDE;
+ void OnReadComplete(const MockRead& data) override;
+ void OnConnectComplete(const MockConnect& data) override;
void set_source_port(int port) { source_port_ = port; }
@@ -887,35 +901,35 @@ class DeterministicMockTCPClientSocket
public:
DeterministicMockTCPClientSocket(net::NetLog* net_log,
DeterministicSocketData* data);
- virtual ~DeterministicMockTCPClientSocket();
+ ~DeterministicMockTCPClientSocket() override;
// DeterministicSocketData::Delegate:
- virtual bool WritePending() const OVERRIDE;
- virtual bool ReadPending() const OVERRIDE;
- virtual void CompleteWrite() OVERRIDE;
- virtual int CompleteRead() OVERRIDE;
+ bool WritePending() const override;
+ bool ReadPending() const override;
+ void CompleteWrite() override;
+ int CompleteRead() override;
// Socket:
- virtual int Write(IOBuffer* buf,
- int buf_len,
- const CompletionCallback& callback) OVERRIDE;
- virtual int Read(IOBuffer* buf,
- int buf_len,
- const CompletionCallback& callback) OVERRIDE;
+ int Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override;
+ int Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override;
// StreamSocket:
- virtual int Connect(const CompletionCallback& callback) OVERRIDE;
- virtual void Disconnect() OVERRIDE;
- virtual bool IsConnected() const OVERRIDE;
- virtual bool IsConnectedAndIdle() const OVERRIDE;
- virtual bool WasEverUsed() const OVERRIDE;
- virtual bool UsingTCPFastOpen() const OVERRIDE;
- virtual bool WasNpnNegotiated() const OVERRIDE;
- virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE;
+ int Connect(const CompletionCallback& callback) override;
+ void Disconnect() override;
+ bool IsConnected() const override;
+ bool IsConnectedAndIdle() const override;
+ bool WasEverUsed() const override;
+ bool UsingTCPFastOpen() const override;
+ bool WasNpnNegotiated() const override;
+ bool GetSSLInfo(SSLInfo* ssl_info) override;
// AsyncSocket:
- virtual void OnReadComplete(const MockRead& data) OVERRIDE;
- virtual void OnConnectComplete(const MockConnect& data) OVERRIDE;
+ void OnReadComplete(const MockRead& data) override;
+ void OnConnectComplete(const MockConnect& data) override;
private:
DeterministicSocketHelper helper_;
@@ -929,85 +943,113 @@ class MockSSLClientSocket : public MockClientSocket, public AsyncSocket {
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
SSLSocketDataProvider* socket);
- virtual ~MockSSLClientSocket();
+ ~MockSSLClientSocket() override;
// Socket implementation.
- virtual int Read(IOBuffer* buf,
- int buf_len,
- const CompletionCallback& callback) OVERRIDE;
- virtual int Write(IOBuffer* buf,
- int buf_len,
- const CompletionCallback& callback) OVERRIDE;
+ int Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override;
+ int Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override;
// StreamSocket implementation.
- virtual int Connect(const CompletionCallback& callback) OVERRIDE;
- virtual void Disconnect() OVERRIDE;
- virtual bool IsConnected() const OVERRIDE;
- virtual bool WasEverUsed() const OVERRIDE;
- virtual bool UsingTCPFastOpen() const OVERRIDE;
- virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE;
- virtual bool WasNpnNegotiated() const OVERRIDE;
- virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE;
+ int Connect(const CompletionCallback& callback) override;
+ void Disconnect() override;
+ bool IsConnected() const override;
+ bool WasEverUsed() const override;
+ bool UsingTCPFastOpen() const override;
+ int GetPeerAddress(IPEndPoint* address) const override;
+ bool WasNpnNegotiated() const override;
+ bool GetSSLInfo(SSLInfo* ssl_info) override;
// SSLClientSocket implementation.
- virtual void GetSSLCertRequestInfo(SSLCertRequestInfo* cert_request_info)
- OVERRIDE;
- virtual NextProtoStatus GetNextProto(std::string* proto,
- std::string* server_protos) OVERRIDE;
- virtual bool set_was_npn_negotiated(bool negotiated) OVERRIDE;
- virtual void set_protocol_negotiated(NextProto protocol_negotiated) OVERRIDE;
- virtual NextProto GetNegotiatedProtocol() const OVERRIDE;
+ std::string GetSessionCacheKey() const override;
+ bool InSessionCache() const override;
+ void SetHandshakeCompletionCallback(const base::Closure& cb) override;
+ void GetSSLCertRequestInfo(SSLCertRequestInfo* cert_request_info) override;
+ NextProtoStatus GetNextProto(std::string* proto) override;
+ bool set_was_npn_negotiated(bool negotiated) override;
+ void set_protocol_negotiated(NextProto protocol_negotiated) override;
+ NextProto GetNegotiatedProtocol() const override;
// This MockSocket does not implement the manual async IO feature.
- virtual void OnReadComplete(const MockRead& data) OVERRIDE;
- virtual void OnConnectComplete(const MockConnect& data) OVERRIDE;
+ void OnReadComplete(const MockRead& data) override;
+ void OnConnectComplete(const MockConnect& data) override;
+
+ bool WasChannelIDSent() const override;
+ void set_channel_id_sent(bool channel_id_sent) override;
+ ChannelIDService* GetChannelIDService() const override;
- virtual bool WasChannelIDSent() const OVERRIDE;
- virtual void set_channel_id_sent(bool channel_id_sent) OVERRIDE;
- virtual ServerBoundCertService* GetServerBoundCertService() const OVERRIDE;
+ bool reached_connect() const { return reached_connect_; }
+
+ // Resumes the connection of a socket that was paused for testing.
+ // |connect_callback_| should be set before invoking this method.
+ void RestartPausedConnect();
private:
- static void ConnectCallback(MockSSLClientSocket* ssl_client_socket,
- const CompletionCallback& callback,
- int rv);
+ enum ConnectState {
+ STATE_NONE,
+ STATE_SSL_CONNECT,
+ STATE_SSL_CONNECT_COMPLETE,
+ };
+
+ void OnIOComplete(int result);
+
+ // Runs the state transistion loop.
+ int DoConnectLoop(int result);
+
+ int DoSSLConnect();
+ int DoSSLConnectComplete(int result);
scoped_ptr<ClientSocketHandle> transport_;
+ HostPortPair host_port_pair_;
SSLSocketDataProvider* data_;
bool is_npn_state_set_;
bool new_npn_value_;
bool is_protocol_negotiated_set_;
NextProto protocol_negotiated_;
+ CompletionCallback connect_callback_;
+ // Indicates what state of Connect the socket should enter.
+ ConnectState next_connect_state_;
+ // True if the Connect method has been called on the socket.
+ bool reached_connect_;
+
+ base::Closure handshake_completion_callback_;
+
+ base::WeakPtrFactory<MockSSLClientSocket> weak_factory_;
+
DISALLOW_COPY_AND_ASSIGN(MockSSLClientSocket);
};
class MockUDPClientSocket : public DatagramClientSocket, public AsyncSocket {
public:
MockUDPClientSocket(SocketDataProvider* data, net::NetLog* net_log);
- virtual ~MockUDPClientSocket();
+ ~MockUDPClientSocket() override;
// Socket implementation.
- virtual int Read(IOBuffer* buf,
- int buf_len,
- const CompletionCallback& callback) OVERRIDE;
- virtual int Write(IOBuffer* buf,
- int buf_len,
- const CompletionCallback& callback) OVERRIDE;
- virtual int SetReceiveBufferSize(int32 size) OVERRIDE;
- virtual int SetSendBufferSize(int32 size) OVERRIDE;
+ int Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override;
+ int Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override;
+ int SetReceiveBufferSize(int32 size) override;
+ int SetSendBufferSize(int32 size) override;
// DatagramSocket implementation.
- virtual void Close() OVERRIDE;
- virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE;
- virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE;
- virtual const BoundNetLog& NetLog() const OVERRIDE;
+ void Close() override;
+ int GetPeerAddress(IPEndPoint* address) const override;
+ int GetLocalAddress(IPEndPoint* address) const override;
+ const BoundNetLog& NetLog() const override;
// DatagramClientSocket implementation.
- virtual int Connect(const IPEndPoint& address) OVERRIDE;
+ int Connect(const IPEndPoint& address) override;
// AsyncSocket implementation.
- virtual void OnReadComplete(const MockRead& data) OVERRIDE;
- virtual void OnConnectComplete(const MockConnect& data) OVERRIDE;
+ void OnReadComplete(const MockRead& data) override;
+ void OnConnectComplete(const MockConnect& data) override;
void set_source_port(int port) { source_port_ = port;}
@@ -1043,7 +1085,7 @@ class TestSocketRequest : public TestCompletionCallbackBase {
public:
TestSocketRequest(std::vector<TestSocketRequest*>* request_order,
size_t* completion_count);
- virtual ~TestSocketRequest();
+ ~TestSocketRequest() override;
ClientSocketHandle* handle() { return &handle_; }
@@ -1163,7 +1205,7 @@ class MockTransportClientSocketPool : public TransportClientSocketPool {
ClientSocketPoolHistograms* histograms,
ClientSocketFactory* socket_factory);
- virtual ~MockTransportClientSocketPool();
+ ~MockTransportClientSocketPool() override;
RequestPriority last_request_priority() const {
return last_request_priority_;
@@ -1172,18 +1214,18 @@ class MockTransportClientSocketPool : public TransportClientSocketPool {
int cancel_count() const { return cancel_count_; }
// TransportClientSocketPool implementation.
- virtual int RequestSocket(const std::string& group_name,
- const void* socket_params,
- RequestPriority priority,
- ClientSocketHandle* handle,
- const CompletionCallback& callback,
- const BoundNetLog& net_log) OVERRIDE;
-
- virtual void CancelRequest(const std::string& group_name,
- ClientSocketHandle* handle) OVERRIDE;
- virtual void ReleaseSocket(const std::string& group_name,
- scoped_ptr<StreamSocket> socket,
- int id) OVERRIDE;
+ int RequestSocket(const std::string& group_name,
+ const void* socket_params,
+ RequestPriority priority,
+ ClientSocketHandle* handle,
+ const CompletionCallback& callback,
+ const BoundNetLog& net_log) override;
+
+ void CancelRequest(const std::string& group_name,
+ ClientSocketHandle* handle) override;
+ void ReleaseSocket(const std::string& group_name,
+ scoped_ptr<StreamSocket> socket,
+ int id) override;
private:
ClientSocketFactory* client_socket_factory_;
@@ -1198,7 +1240,7 @@ class MockTransportClientSocketPool : public TransportClientSocketPool {
class DeterministicMockClientSocketFactory : public ClientSocketFactory {
public:
DeterministicMockClientSocketFactory();
- virtual ~DeterministicMockClientSocketFactory();
+ ~DeterministicMockClientSocketFactory() override;
void AddSocketDataProvider(DeterministicSocketData* socket);
void AddSSLSocketDataProvider(SSLSocketDataProvider* socket);
@@ -1219,21 +1261,21 @@ class DeterministicMockClientSocketFactory : public ClientSocketFactory {
}
// ClientSocketFactory
- virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket(
+ scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket(
DatagramSocket::BindType bind_type,
const RandIntCallback& rand_int_cb,
NetLog* net_log,
- const NetLog::Source& source) OVERRIDE;
- virtual scoped_ptr<StreamSocket> CreateTransportClientSocket(
+ const NetLog::Source& source) override;
+ scoped_ptr<StreamSocket> CreateTransportClientSocket(
const AddressList& addresses,
NetLog* net_log,
- const NetLog::Source& source) OVERRIDE;
- virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
+ const NetLog::Source& source) override;
+ scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
scoped_ptr<ClientSocketHandle> transport_socket,
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
- const SSLClientSocketContext& context) OVERRIDE;
- virtual void ClearSSLSessionCache() OVERRIDE;
+ const SSLClientSocketContext& context) override;
+ void ClearSSLSessionCache() override;
private:
SocketDataProviderArray<DeterministicSocketData> mock_data_;
@@ -1254,21 +1296,21 @@ class MockSOCKSClientSocketPool : public SOCKSClientSocketPool {
ClientSocketPoolHistograms* histograms,
TransportClientSocketPool* transport_pool);
- virtual ~MockSOCKSClientSocketPool();
+ ~MockSOCKSClientSocketPool() override;
// SOCKSClientSocketPool implementation.
- virtual int RequestSocket(const std::string& group_name,
- const void* socket_params,
- RequestPriority priority,
- ClientSocketHandle* handle,
- const CompletionCallback& callback,
- const BoundNetLog& net_log) OVERRIDE;
-
- virtual void CancelRequest(const std::string& group_name,
- ClientSocketHandle* handle) OVERRIDE;
- virtual void ReleaseSocket(const std::string& group_name,
- scoped_ptr<StreamSocket> socket,
- int id) OVERRIDE;
+ int RequestSocket(const std::string& group_name,
+ const void* socket_params,
+ RequestPriority priority,
+ ClientSocketHandle* handle,
+ const CompletionCallback& callback,
+ const BoundNetLog& net_log) override;
+
+ void CancelRequest(const std::string& group_name,
+ ClientSocketHandle* handle) override;
+ void ReleaseSocket(const std::string& group_name,
+ scoped_ptr<StreamSocket> socket,
+ int id) override;
private:
TransportClientSocketPool* const transport_pool_;
diff --git a/chromium/net/socket/socks5_client_socket.h b/chromium/net/socket/socks5_client_socket.h
index 8da0b4da5ce..a405212b56b 100644
--- a/chromium/net/socket/socks5_client_socket.h
+++ b/chromium/net/socket/socks5_client_socket.h
@@ -38,37 +38,37 @@ class NET_EXPORT_PRIVATE SOCKS5ClientSocket : public StreamSocket {
const HostResolver::RequestInfo& req_info);
// On destruction Disconnect() is called.
- virtual ~SOCKS5ClientSocket();
+ ~SOCKS5ClientSocket() override;
// StreamSocket implementation.
// Does the SOCKS handshake and completes the protocol.
- virtual int Connect(const CompletionCallback& callback) OVERRIDE;
- virtual void Disconnect() OVERRIDE;
- virtual bool IsConnected() const OVERRIDE;
- virtual bool IsConnectedAndIdle() const OVERRIDE;
- virtual const BoundNetLog& NetLog() const OVERRIDE;
- virtual void SetSubresourceSpeculation() OVERRIDE;
- virtual void SetOmniboxSpeculation() OVERRIDE;
- virtual bool WasEverUsed() const OVERRIDE;
- virtual bool UsingTCPFastOpen() const OVERRIDE;
- virtual bool WasNpnNegotiated() const OVERRIDE;
- virtual NextProto GetNegotiatedProtocol() const OVERRIDE;
- virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE;
+ int Connect(const CompletionCallback& callback) override;
+ void Disconnect() override;
+ bool IsConnected() const override;
+ bool IsConnectedAndIdle() const override;
+ const BoundNetLog& NetLog() const override;
+ void SetSubresourceSpeculation() override;
+ void SetOmniboxSpeculation() override;
+ bool WasEverUsed() const override;
+ bool UsingTCPFastOpen() const override;
+ bool WasNpnNegotiated() const override;
+ NextProto GetNegotiatedProtocol() const override;
+ bool GetSSLInfo(SSLInfo* ssl_info) override;
// Socket implementation.
- virtual int Read(IOBuffer* buf,
- int buf_len,
- const CompletionCallback& callback) OVERRIDE;
- virtual int Write(IOBuffer* buf,
- int buf_len,
- const CompletionCallback& callback) OVERRIDE;
-
- virtual int SetReceiveBufferSize(int32 size) OVERRIDE;
- virtual int SetSendBufferSize(int32 size) OVERRIDE;
-
- virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE;
- virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE;
+ int Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override;
+ int Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override;
+
+ int SetReceiveBufferSize(int32 size) override;
+ int SetSendBufferSize(int32 size) override;
+
+ int GetPeerAddress(IPEndPoint* address) const override;
+ int GetLocalAddress(IPEndPoint* address) const override;
private:
enum State {
diff --git a/chromium/net/socket/socks5_client_socket_unittest.cc b/chromium/net/socket/socks5_client_socket_unittest.cc
index 78f2ac433c3..c474a0b4198 100644
--- a/chromium/net/socket/socks5_client_socket_unittest.cc
+++ b/chromium/net/socket/socks5_client_socket_unittest.cc
@@ -40,7 +40,7 @@ class SOCKS5ClientSocketTest : public PlatformTest {
int port,
NetLog* net_log);
- virtual void SetUp();
+ void SetUp() override;
protected:
const uint16 kNwPort;
diff --git a/chromium/net/socket/socks_client_socket.h b/chromium/net/socket/socks_client_socket.h
index 26da332b3ea..e792881cc7f 100644
--- a/chromium/net/socket/socks_client_socket.h
+++ b/chromium/net/socket/socks_client_socket.h
@@ -35,37 +35,37 @@ class NET_EXPORT_PRIVATE SOCKSClientSocket : public StreamSocket {
HostResolver* host_resolver);
// On destruction Disconnect() is called.
- virtual ~SOCKSClientSocket();
+ ~SOCKSClientSocket() override;
// StreamSocket implementation.
// Does the SOCKS handshake and completes the protocol.
- virtual int Connect(const CompletionCallback& callback) OVERRIDE;
- virtual void Disconnect() OVERRIDE;
- virtual bool IsConnected() const OVERRIDE;
- virtual bool IsConnectedAndIdle() const OVERRIDE;
- virtual const BoundNetLog& NetLog() const OVERRIDE;
- virtual void SetSubresourceSpeculation() OVERRIDE;
- virtual void SetOmniboxSpeculation() OVERRIDE;
- virtual bool WasEverUsed() const OVERRIDE;
- virtual bool UsingTCPFastOpen() const OVERRIDE;
- virtual bool WasNpnNegotiated() const OVERRIDE;
- virtual NextProto GetNegotiatedProtocol() const OVERRIDE;
- virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE;
+ int Connect(const CompletionCallback& callback) override;
+ void Disconnect() override;
+ bool IsConnected() const override;
+ bool IsConnectedAndIdle() const override;
+ const BoundNetLog& NetLog() const override;
+ void SetSubresourceSpeculation() override;
+ void SetOmniboxSpeculation() override;
+ bool WasEverUsed() const override;
+ bool UsingTCPFastOpen() const override;
+ bool WasNpnNegotiated() const override;
+ NextProto GetNegotiatedProtocol() const override;
+ bool GetSSLInfo(SSLInfo* ssl_info) override;
// Socket implementation.
- virtual int Read(IOBuffer* buf,
- int buf_len,
- const CompletionCallback& callback) OVERRIDE;
- virtual int Write(IOBuffer* buf,
- int buf_len,
- const CompletionCallback& callback) OVERRIDE;
-
- virtual int SetReceiveBufferSize(int32 size) OVERRIDE;
- virtual int SetSendBufferSize(int32 size) OVERRIDE;
-
- virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE;
- virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE;
+ int Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override;
+ int Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override;
+
+ int SetReceiveBufferSize(int32 size) override;
+ int SetSendBufferSize(int32 size) override;
+
+ int GetPeerAddress(IPEndPoint* address) const override;
+ int GetLocalAddress(IPEndPoint* address) const override;
private:
FRIEND_TEST_ALL_PREFIXES(SOCKSClientSocketTest, CompleteHandshake);
diff --git a/chromium/net/socket/socks_client_socket_pool.h b/chromium/net/socket/socks_client_socket_pool.h
index c6d5c8d0883..35f7146f967 100644
--- a/chromium/net/socket/socks_client_socket_pool.h
+++ b/chromium/net/socket/socks_client_socket_pool.h
@@ -63,10 +63,10 @@ class SOCKSConnectJob : public ConnectJob {
HostResolver* host_resolver,
Delegate* delegate,
NetLog* net_log);
- virtual ~SOCKSConnectJob();
+ ~SOCKSConnectJob() override;
// ConnectJob methods.
- virtual LoadState GetLoadState() const OVERRIDE;
+ LoadState GetLoadState() const override;
private:
enum State {
@@ -90,7 +90,7 @@ class SOCKSConnectJob : public ConnectJob {
// Begins the transport connection and the SOCKS handshake. Returns OK on
// success and ERR_IO_PENDING if it cannot immediately service the request.
// Otherwise, it returns a net error code.
- virtual int ConnectInternal() OVERRIDE;
+ int ConnectInternal() override;
scoped_refptr<SOCKSSocketParams> socks_params_;
TransportClientSocketPool* const transport_pool_;
@@ -117,59 +117,57 @@ class NET_EXPORT_PRIVATE SOCKSClientSocketPool
TransportClientSocketPool* transport_pool,
NetLog* net_log);
- virtual ~SOCKSClientSocketPool();
+ ~SOCKSClientSocketPool() override;
// ClientSocketPool implementation.
- virtual int RequestSocket(const std::string& group_name,
- const void* connect_params,
- RequestPriority priority,
- ClientSocketHandle* handle,
- const CompletionCallback& callback,
- const BoundNetLog& net_log) OVERRIDE;
+ int RequestSocket(const std::string& group_name,
+ const void* connect_params,
+ RequestPriority priority,
+ ClientSocketHandle* handle,
+ const CompletionCallback& callback,
+ const BoundNetLog& net_log) override;
- virtual void RequestSockets(const std::string& group_name,
- const void* params,
- int num_sockets,
- const BoundNetLog& net_log) OVERRIDE;
+ void RequestSockets(const std::string& group_name,
+ const void* params,
+ int num_sockets,
+ const BoundNetLog& net_log) override;
- virtual void CancelRequest(const std::string& group_name,
- ClientSocketHandle* handle) OVERRIDE;
+ void CancelRequest(const std::string& group_name,
+ ClientSocketHandle* handle) override;
- virtual void ReleaseSocket(const std::string& group_name,
- scoped_ptr<StreamSocket> socket,
- int id) OVERRIDE;
+ void ReleaseSocket(const std::string& group_name,
+ scoped_ptr<StreamSocket> socket,
+ int id) override;
- virtual void FlushWithError(int error) OVERRIDE;
+ void FlushWithError(int error) override;
- virtual void CloseIdleSockets() OVERRIDE;
+ void CloseIdleSockets() override;
- virtual int IdleSocketCount() const OVERRIDE;
+ int IdleSocketCount() const override;
- virtual int IdleSocketCountInGroup(
- const std::string& group_name) const OVERRIDE;
+ int IdleSocketCountInGroup(const std::string& group_name) const override;
- virtual LoadState GetLoadState(
- const std::string& group_name,
- const ClientSocketHandle* handle) const OVERRIDE;
+ LoadState GetLoadState(const std::string& group_name,
+ const ClientSocketHandle* handle) const override;
- virtual base::DictionaryValue* GetInfoAsValue(
+ base::DictionaryValue* GetInfoAsValue(
const std::string& name,
const std::string& type,
- bool include_nested_pools) const OVERRIDE;
+ bool include_nested_pools) const override;
- virtual base::TimeDelta ConnectionTimeout() const OVERRIDE;
+ base::TimeDelta ConnectionTimeout() const override;
- virtual ClientSocketPoolHistograms* histograms() const OVERRIDE;
+ ClientSocketPoolHistograms* histograms() const override;
// LowerLayeredPool implementation.
- virtual bool IsStalled() const OVERRIDE;
+ bool IsStalled() const override;
- virtual void AddHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE;
+ void AddHigherLayeredPool(HigherLayeredPool* higher_pool) override;
- virtual void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE;
+ void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool) override;
// HigherLayeredPool implementation.
- virtual bool CloseOneIdleConnection() OVERRIDE;
+ bool CloseOneIdleConnection() override;
private:
typedef ClientSocketPoolBase<SOCKSSocketParams> PoolBase;
@@ -183,15 +181,15 @@ class NET_EXPORT_PRIVATE SOCKSClientSocketPool
host_resolver_(host_resolver),
net_log_(net_log) {}
- virtual ~SOCKSConnectJobFactory() {}
+ ~SOCKSConnectJobFactory() override {}
// ClientSocketPoolBase::ConnectJobFactory methods.
- virtual scoped_ptr<ConnectJob> NewConnectJob(
+ scoped_ptr<ConnectJob> NewConnectJob(
const std::string& group_name,
const PoolBase::Request& request,
- ConnectJob::Delegate* delegate) const OVERRIDE;
+ ConnectJob::Delegate* delegate) const override;
- virtual base::TimeDelta ConnectionTimeout() const OVERRIDE;
+ base::TimeDelta ConnectionTimeout() const override;
private:
TransportClientSocketPool* const transport_pool_;
diff --git a/chromium/net/socket/socks_client_socket_pool_unittest.cc b/chromium/net/socket/socks_client_socket_pool_unittest.cc
index b2b8655ee22..391d31beddb 100644
--- a/chromium/net/socket/socks_client_socket_pool_unittest.cc
+++ b/chromium/net/socket/socks_client_socket_pool_unittest.cc
@@ -44,8 +44,8 @@ void TestLoadTimingInfo(const ClientSocketHandle& handle) {
scoped_refptr<TransportSocketParams> CreateProxyHostParams() {
return new TransportSocketParams(
- HostPortPair("proxy", 80), false, false,
- OnHostResolutionCallback());
+ HostPortPair("proxy", 80), false, false, OnHostResolutionCallback(),
+ TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT);
}
scoped_refptr<SOCKSSocketParams> CreateSOCKSv4Params() {
@@ -103,7 +103,7 @@ class SOCKSClientSocketPoolTest : public testing::Test {
NULL) {
}
- virtual ~SOCKSClientSocketPoolTest() {}
+ ~SOCKSClientSocketPoolTest() override {}
int StartRequestV5(const std::string& group_name, RequestPriority priority) {
return test_base_.StartRequestUsingPool(
diff --git a/chromium/net/socket/socks_client_socket_unittest.cc b/chromium/net/socket/socks_client_socket_unittest.cc
index f361244feff..fbb84f8f50a 100644
--- a/chromium/net/socket/socks_client_socket_unittest.cc
+++ b/chromium/net/socket/socks_client_socket_unittest.cc
@@ -10,6 +10,7 @@
#include "net/base/net_log_unittest.h"
#include "net/base/test_completion_callback.h"
#include "net/base/winsock_init.h"
+#include "net/dns/host_resolver.h"
#include "net/dns/mock_host_resolver.h"
#include "net/socket/client_socket_factory.h"
#include "net/socket/socket_test_util.h"
@@ -34,7 +35,7 @@ class SOCKSClientSocketTest : public PlatformTest {
HostResolver* host_resolver,
const std::string& hostname, int port,
NetLog* net_log);
- virtual void SetUp();
+ void SetUp() override;
protected:
scoped_ptr<SOCKSClientSocket> user_sock_;
@@ -95,12 +96,12 @@ class HangingHostResolverWithCancel : public HostResolver {
public:
HangingHostResolverWithCancel() : outstanding_request_(NULL) {}
- virtual int Resolve(const RequestInfo& info,
- RequestPriority priority,
- AddressList* addresses,
- const CompletionCallback& callback,
- RequestHandle* out_req,
- const BoundNetLog& net_log) OVERRIDE {
+ int Resolve(const RequestInfo& info,
+ RequestPriority priority,
+ AddressList* addresses,
+ const CompletionCallback& callback,
+ RequestHandle* out_req,
+ const BoundNetLog& net_log) override {
DCHECK(addresses);
DCHECK_EQ(false, callback.is_null());
EXPECT_FALSE(HasOutstandingRequest());
@@ -109,14 +110,14 @@ class HangingHostResolverWithCancel : public HostResolver {
return ERR_IO_PENDING;
}
- virtual int ResolveFromCache(const RequestInfo& info,
- AddressList* addresses,
- const BoundNetLog& net_log) OVERRIDE {
+ int ResolveFromCache(const RequestInfo& info,
+ AddressList* addresses,
+ const BoundNetLog& net_log) override {
NOTIMPLEMENTED();
return ERR_UNEXPECTED;
}
- virtual void CancelRequest(RequestHandle req) OVERRIDE {
+ void CancelRequest(RequestHandle req) override {
EXPECT_TRUE(HasOutstandingRequest());
EXPECT_EQ(outstanding_request_, req);
outstanding_request_ = NULL;
@@ -213,7 +214,7 @@ TEST_F(SOCKSClientSocketTest, HandshakeFailures) {
//---------------------------------------
- for (size_t i = 0; i < ARRAYSIZE_UNSAFE(tests); ++i) {
+ for (size_t i = 0; i < arraysize(tests); ++i) {
MockWrite data_writes[] = {
MockWrite(SYNCHRONOUS, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) };
MockRead data_reads[] = {
@@ -414,4 +415,36 @@ TEST_F(SOCKSClientSocketTest, DisconnectWhileHostResolveInProgress) {
EXPECT_FALSE(user_sock_->IsConnectedAndIdle());
}
+// Tries to connect to an IPv6 IP. Should fail, as SOCKS4 does not support
+// IPv6.
+TEST_F(SOCKSClientSocketTest, NoIPv6) {
+ const char kHostName[] = "::1";
+
+ user_sock_ = BuildMockSocket(NULL, 0,
+ NULL, 0,
+ host_resolver_.get(),
+ kHostName, 80,
+ NULL);
+
+ EXPECT_EQ(ERR_NAME_NOT_RESOLVED,
+ callback_.GetResult(user_sock_->Connect(callback_.callback())));
+}
+
+// Same as above, but with a real resolver, to protect against regressions.
+TEST_F(SOCKSClientSocketTest, NoIPv6RealResolver) {
+ const char kHostName[] = "::1";
+
+ scoped_ptr<HostResolver> host_resolver(
+ HostResolver::CreateSystemResolver(HostResolver::Options(), NULL));
+
+ user_sock_ = BuildMockSocket(NULL, 0,
+ NULL, 0,
+ host_resolver.get(),
+ kHostName, 80,
+ NULL);
+
+ EXPECT_EQ(ERR_NAME_NOT_RESOLVED,
+ callback_.GetResult(user_sock_->Connect(callback_.callback())));
+}
+
} // namespace net
diff --git a/chromium/net/socket/ssl_client_socket.cc b/chromium/net/socket/ssl_client_socket.cc
index 1b2fe144304..3184e04e3f7 100644
--- a/chromium/net/socket/ssl_client_socket.cc
+++ b/chromium/net/socket/ssl_client_socket.cc
@@ -5,10 +5,14 @@
#include "net/socket/ssl_client_socket.h"
#include "base/metrics/histogram.h"
+#include "base/metrics/sparse_histogram.h"
#include "base/strings/string_util.h"
#include "crypto/ec_private_key.h"
-#include "net/ssl/server_bound_cert_service.h"
+#include "net/base/connection_type_histograms.h"
+#include "net/base/host_port_pair.h"
+#include "net/ssl/channel_id_service.h"
#include "net/ssl/ssl_config_service.h"
+#include "net/ssl/ssl_connection_status_flags.h"
namespace net {
@@ -18,7 +22,8 @@ SSLClientSocket::SSLClientSocket()
protocol_negotiated_(kProtoUnknown),
channel_id_sent_(false),
signed_cert_timestamps_received_(false),
- stapled_ocsp_response_received_(false) {
+ stapled_ocsp_response_received_(false),
+ negotiation_extension_(kExtensionUnknown) {
}
// static
@@ -32,8 +37,8 @@ NextProto SSLClientSocket::NextProtoFromString(
return kProtoSPDY3;
} else if (proto_string == "spdy/3.1") {
return kProtoSPDY31;
- } else if (proto_string == "h2-12") {
- // This is the HTTP/2 draft 12 identifier. For internal
+ } else if (proto_string == "h2-14") {
+ // This is the HTTP/2 draft 14 identifier. For internal
// consistency, HTTP/2 is named SPDY4 within Chromium.
return kProtoSPDY4;
} else if (proto_string == "quic/1+spdy/3") {
@@ -55,9 +60,9 @@ const char* SSLClientSocket::NextProtoToString(NextProto next_proto) {
case kProtoSPDY31:
return "spdy/3.1";
case kProtoSPDY4:
- // This is the HTTP/2 draft 12 identifier. For internal
+ // This is the HTTP/2 draft 14 identifier. For internal
// consistency, HTTP/2 is named SPDY4 within Chromium.
- return "h2-12";
+ return "h2-14";
case kProtoQUIC1SPDY3:
return "quic/1+spdy/3";
case kProtoUnknown:
@@ -80,21 +85,6 @@ const char* SSLClientSocket::NextProtoStatusToString(
return NULL;
}
-// static
-std::string SSLClientSocket::ServerProtosToString(
- const std::string& server_protos) {
- const char* protos = server_protos.c_str();
- size_t protos_len = server_protos.length();
- std::vector<std::string> server_protos_with_commas;
- for (size_t i = 0; i < protos_len; ) {
- const size_t len = protos[i];
- std::string proto_str(&protos[i + 1], len);
- server_protos_with_commas.push_back(proto_str);
- i += len + 1;
- }
- return JoinString(server_protos_with_commas, ',');
-}
-
bool SSLClientSocket::WasNpnNegotiated() const {
return was_npn_negotiated_;
}
@@ -138,6 +128,11 @@ void SSLClientSocket::set_protocol_negotiated(NextProto protocol_negotiated) {
protocol_negotiated_ = protocol_negotiated;
}
+void SSLClientSocket::set_negotiation_extension(
+ SSLNegotiationExtension negotiation_extension) {
+ negotiation_extension_ = negotiation_extension;
+}
+
bool SSLClientSocket::WasChannelIDSent() const {
return channel_id_sent_;
}
@@ -158,7 +153,7 @@ void SSLClientSocket::set_stapled_ocsp_response_received(
// static
void SSLClientSocket::RecordChannelIDSupport(
- ServerBoundCertService* server_bound_cert_service,
+ ChannelIDService* channel_id_service,
bool negotiated_channel_id,
bool channel_id_enabled,
bool supports_ecc) {
@@ -169,40 +164,62 @@ void SSLClientSocket::RecordChannelIDSupport(
CLIENT_AND_SERVER = 2,
CLIENT_NO_ECC = 3,
CLIENT_BAD_SYSTEM_TIME = 4,
- CLIENT_NO_SERVER_BOUND_CERT_SERVICE = 5,
- DOMAIN_BOUND_CERT_USAGE_MAX
+ CLIENT_NO_CHANNEL_ID_SERVICE = 5,
+ CHANNEL_ID_USAGE_MAX
} supported = DISABLED;
if (negotiated_channel_id) {
supported = CLIENT_AND_SERVER;
} else if (channel_id_enabled) {
- if (!server_bound_cert_service)
- supported = CLIENT_NO_SERVER_BOUND_CERT_SERVICE;
+ if (!channel_id_service)
+ supported = CLIENT_NO_CHANNEL_ID_SERVICE;
else if (!supports_ecc)
supported = CLIENT_NO_ECC;
- else if (!server_bound_cert_service->IsSystemTimeValid())
+ else if (!channel_id_service->IsSystemTimeValid())
supported = CLIENT_BAD_SYSTEM_TIME;
else
supported = CLIENT_ONLY;
}
UMA_HISTOGRAM_ENUMERATION("DomainBoundCerts.Support", supported,
- DOMAIN_BOUND_CERT_USAGE_MAX);
+ CHANNEL_ID_USAGE_MAX);
+}
+
+// static
+void SSLClientSocket::RecordConnectionTypeMetrics(int ssl_version) {
+ UpdateConnectionTypeHistograms(CONNECTION_SSL);
+ switch (ssl_version) {
+ case SSL_CONNECTION_VERSION_SSL2:
+ UpdateConnectionTypeHistograms(CONNECTION_SSL_SSL2);
+ break;
+ case SSL_CONNECTION_VERSION_SSL3:
+ UpdateConnectionTypeHistograms(CONNECTION_SSL_SSL3);
+ break;
+ case SSL_CONNECTION_VERSION_TLS1:
+ UpdateConnectionTypeHistograms(CONNECTION_SSL_TLS1);
+ break;
+ case SSL_CONNECTION_VERSION_TLS1_1:
+ UpdateConnectionTypeHistograms(CONNECTION_SSL_TLS1_1);
+ break;
+ case SSL_CONNECTION_VERSION_TLS1_2:
+ UpdateConnectionTypeHistograms(CONNECTION_SSL_TLS1_2);
+ break;
+ }
}
// static
bool SSLClientSocket::IsChannelIDEnabled(
const SSLConfig& ssl_config,
- ServerBoundCertService* server_bound_cert_service) {
+ ChannelIDService* channel_id_service) {
if (!ssl_config.channel_id_enabled)
return false;
- if (!server_bound_cert_service) {
- DVLOG(1) << "NULL server_bound_cert_service_, not enabling channel ID.";
+ if (!channel_id_service) {
+ DVLOG(1) << "NULL channel_id_service_, not enabling channel ID.";
return false;
}
if (!crypto::ECPrivateKey::IsSupported()) {
DVLOG(1) << "Elliptic Curve not supported, not enabling channel ID.";
return false;
}
- if (!server_bound_cert_service->IsSystemTimeValid()) {
+ if (!channel_id_service->IsSystemTimeValid()) {
DVLOG(1) << "System time is not within the supported range for certificate "
"generation, not enabling channel ID.";
return false;
@@ -210,4 +227,66 @@ bool SSLClientSocket::IsChannelIDEnabled(
return true;
}
+// static
+std::vector<uint8_t> SSLClientSocket::SerializeNextProtos(
+ const std::vector<std::string>& next_protos) {
+ // Do a first pass to determine the total length.
+ size_t wire_length = 0;
+ for (std::vector<std::string>::const_iterator i = next_protos.begin();
+ i != next_protos.end(); ++i) {
+ if (i->size() > 255) {
+ LOG(WARNING) << "Ignoring overlong NPN/ALPN protocol: " << *i;
+ continue;
+ }
+ if (i->size() == 0) {
+ LOG(WARNING) << "Ignoring empty NPN/ALPN protocol";
+ continue;
+ }
+ wire_length += i->size();
+ wire_length++;
+ }
+
+ // Allocate memory for the result and fill it in.
+ std::vector<uint8_t> wire_protos;
+ wire_protos.reserve(wire_length);
+ for (std::vector<std::string>::const_iterator i = next_protos.begin();
+ i != next_protos.end(); i++) {
+ if (i->size() == 0 || i->size() > 255)
+ continue;
+ wire_protos.push_back(i->size());
+ wire_protos.resize(wire_protos.size() + i->size());
+ memcpy(&wire_protos[wire_protos.size() - i->size()],
+ i->data(), i->size());
+ }
+ DCHECK_EQ(wire_protos.size(), wire_length);
+
+ return wire_protos;
+}
+
+void SSLClientSocket::RecordNegotiationExtension() {
+ if (negotiation_extension_ == kExtensionUnknown)
+ return;
+ std::string proto;
+ SSLClientSocket::NextProtoStatus status = GetNextProto(&proto);
+ if (status == kNextProtoUnsupported)
+ return;
+ // Convert protocol into numerical value for histogram.
+ NextProto protocol_negotiated = SSLClientSocket::NextProtoFromString(proto);
+ base::HistogramBase::Sample sample =
+ static_cast<base::HistogramBase::Sample>(protocol_negotiated);
+ // In addition to the protocol negotiated, we want to record which TLS
+ // extension was used, and in case of NPN, whether there was overlap between
+ // server and client list of supported protocols.
+ if (negotiation_extension_ == kExtensionNPN) {
+ if (status == kNextProtoNoOverlap) {
+ sample += 1000;
+ } else {
+ sample += 500;
+ }
+ } else {
+ DCHECK_EQ(kExtensionALPN, negotiation_extension_);
+ }
+ UMA_HISTOGRAM_SPARSE_SLOWLY("Net.SSLProtocolNegotiation", sample);
+}
+
} // namespace net
diff --git a/chromium/net/socket/ssl_client_socket.h b/chromium/net/socket/ssl_client_socket.h
index a43e58cc26b..7adfa8c626a 100644
--- a/chromium/net/socket/ssl_client_socket.h
+++ b/chromium/net/socket/ssl_client_socket.h
@@ -17,7 +17,9 @@
namespace net {
class CertVerifier;
+class ChannelIDService;
class CTVerifier;
+class HostPortPair;
class ServerBoundCertService;
class SSLCertRequestInfo;
struct SSLConfig;
@@ -30,23 +32,23 @@ class X509Certificate;
struct SSLClientSocketContext {
SSLClientSocketContext()
: cert_verifier(NULL),
- server_bound_cert_service(NULL),
+ channel_id_service(NULL),
transport_security_state(NULL),
cert_transparency_verifier(NULL) {}
SSLClientSocketContext(CertVerifier* cert_verifier_arg,
- ServerBoundCertService* server_bound_cert_service_arg,
+ ChannelIDService* channel_id_service_arg,
TransportSecurityState* transport_security_state_arg,
CTVerifier* cert_transparency_verifier_arg,
const std::string& ssl_session_cache_shard_arg)
: cert_verifier(cert_verifier_arg),
- server_bound_cert_service(server_bound_cert_service_arg),
+ channel_id_service(channel_id_service_arg),
transport_security_state(transport_security_state_arg),
cert_transparency_verifier(cert_transparency_verifier_arg),
ssl_session_cache_shard(ssl_session_cache_shard_arg) {}
CertVerifier* cert_verifier;
- ServerBoundCertService* server_bound_cert_service;
+ ChannelIDService* channel_id_service;
TransportSecurityState* transport_security_state;
CTVerifier* cert_transparency_verifier;
// ssl_session_cache_shard is an opaque string that identifies a shard of the
@@ -77,9 +79,49 @@ class NET_EXPORT SSLClientSocket : public SSLSocket {
// the first protocol in our list.
};
+ // TLS extension used to negotiate protocol.
+ enum SSLNegotiationExtension {
+ kExtensionUnknown,
+ kExtensionALPN,
+ kExtensionNPN,
+ };
+
// StreamSocket:
- virtual bool WasNpnNegotiated() const OVERRIDE;
- virtual NextProto GetNegotiatedProtocol() const OVERRIDE;
+ bool WasNpnNegotiated() const override;
+ NextProto GetNegotiatedProtocol() const override;
+
+ // Computes a unique key string for the SSL session cache.
+ virtual std::string GetSessionCacheKey() const = 0;
+
+ // Returns true if there is a cache entry in the SSL session cache
+ // for the cache key of the SSL socket.
+ //
+ // The cache key consists of a host and port concatenated with a session
+ // cache shard. These two strings are passed to the constructor of most
+ // subclasses of SSLClientSocket.
+ virtual bool InSessionCache() const = 0;
+
+ // Sets |callback| to be run when the handshake has fully completed.
+ // For example, in the case of False Start, Connect() will return
+ // early, before the peer's TLS Finished message has been verified,
+ // in order to allow the caller to call Write() and send application
+ // data with the client's Finished message.
+ // In such situations, |callback| will be invoked sometime after
+ // Connect() - either during a Write() or Read() call, and before
+ // invoking the Read() or Write() callback.
+ // Otherwise, during a traditional TLS connection (i.e. no False
+ // Start), this will be called right before the Connect() callback
+ // is called.
+ //
+ // Note that it's not valid to mutate this socket during such
+ // callbacks, including deleting the socket.
+ //
+ // TODO(mshelley): Provide additional details about whether or not
+ // the handshake actually succeeded or not. This can be inferred
+ // from the result to Connect()/Read()/Write(), but may be useful
+ // to inform here as well.
+ virtual void SetHandshakeCompletionCallback(
+ const base::Closure& callback) = 0;
// Gets the SSL CertificateRequest info of the socket after Connect failed
// with ERR_SSL_CLIENT_AUTH_CERT_NEEDED.
@@ -93,9 +135,7 @@ class NET_EXPORT SSLClientSocket : public SSLSocket {
// kNextProtoNegotiated: *proto is set to the negotiated protocol.
// kNextProtoNoOverlap: *proto is set to the first protocol in the
// supported list.
- // *server_protos is set to the server advertised protocols.
- virtual NextProtoStatus GetNextProto(std::string* proto,
- std::string* server_protos) = 0;
+ virtual NextProtoStatus GetNextProto(std::string* proto) = 0;
static NextProto NextProtoFromString(const std::string& proto_string);
@@ -103,10 +143,6 @@ class NET_EXPORT SSLClientSocket : public SSLSocket {
static const char* NextProtoStatusToString(const NextProtoStatus status);
- // Can be used with the second argument(|server_protos|) of |GetNextProto| to
- // construct a comma separated string of server advertised protocols.
- static std::string ServerProtosToString(const std::string& server_protos);
-
static bool IgnoreCertError(int error, int load_flags);
// ClearSessionCache clears the SSL session cache, used to resume SSL
@@ -121,9 +157,11 @@ class NET_EXPORT SSLClientSocket : public SSLSocket {
virtual void set_protocol_negotiated(NextProto protocol_negotiated);
- // Returns the ServerBoundCertService used by this socket, or NULL if
- // server bound certificates are not supported.
- virtual ServerBoundCertService* GetServerBoundCertService() const = 0;
+ void set_negotiation_extension(SSLNegotiationExtension negotiation_extension);
+
+ // Returns the ChannelIDService used by this socket, or NULL if
+ // channel ids are not supported.
+ virtual ChannelIDService* GetChannelIDService() const = 0;
// Returns true if a channel ID was sent on this connection.
// This may be useful for protocols, like SPDY, which allow the same
@@ -133,6 +171,10 @@ class NET_EXPORT SSLClientSocket : public SSLSocket {
// Public for ssl_client_socket_openssl_unittest.cc.
virtual bool WasChannelIDSent() const;
+ // Record which TLS extension was used to negotiate protocol and protocol
+ // chosen in a UMA histogram.
+ void RecordNegotiationExtension();
+
protected:
virtual void set_channel_id_sent(bool channel_id_sent);
@@ -145,15 +187,23 @@ class NET_EXPORT SSLClientSocket : public SSLSocket {
// Records histograms for channel id support during full handshakes - resumed
// handshakes are ignored.
static void RecordChannelIDSupport(
- ServerBoundCertService* server_bound_cert_service,
+ ChannelIDService* channel_id_service,
bool negotiated_channel_id,
bool channel_id_enabled,
bool supports_ecc);
+ // Records ConnectionType histograms for a successful SSL connection.
+ static void RecordConnectionTypeMetrics(int ssl_version);
+
// Returns whether TLS channel ID is enabled.
static bool IsChannelIDEnabled(
const SSLConfig& ssl_config,
- ServerBoundCertService* server_bound_cert_service);
+ ChannelIDService* channel_id_service);
+
+ // Serializes |next_protos| in the wire format for ALPN: protocols are listed
+ // in order, each prefixed by a one-byte length.
+ static std::vector<uint8_t> SerializeNextProtos(
+ const std::vector<std::string>& next_protos);
// For unit testing only.
// Returns the unverified certificate chain as presented by server.
@@ -185,6 +235,8 @@ class NET_EXPORT SSLClientSocket : public SSLSocket {
bool signed_cert_timestamps_received_;
// True if a stapled OCSP response was received.
bool stapled_ocsp_response_received_;
+ // Protocol negotiation extension used.
+ SSLNegotiationExtension negotiation_extension_;
};
} // namespace net
diff --git a/chromium/net/socket/ssl_client_socket_nss.cc b/chromium/net/socket/ssl_client_socket_nss.cc
index 9f40a78371e..08cf2c55f51 100644
--- a/chromium/net/socket/ssl_client_socket_nss.cc
+++ b/chromium/net/socket/ssl_client_socket_nss.cc
@@ -71,6 +71,7 @@
#include "base/logging.h"
#include "base/memory/singleton.h"
#include "base/metrics/histogram.h"
+#include "base/profiler/scoped_tracker.h"
#include "base/single_thread_task_runner.h"
#include "base/stl_util.h"
#include "base/strings/string_number_conversions.h"
@@ -85,7 +86,6 @@
#include "crypto/rsa_private_key.h"
#include "crypto/scoped_nss_types.h"
#include "net/base/address_list.h"
-#include "net/base/connection_type_histograms.h"
#include "net/base/dns_util.h"
#include "net/base/io_buffer.h"
#include "net/base/net_errors.h"
@@ -93,7 +93,7 @@
#include "net/cert/asn1_util.h"
#include "net/cert/cert_status_flags.h"
#include "net/cert/cert_verifier.h"
-#include "net/cert/ct_objects_extractor.h"
+#include "net/cert/ct_ev_whitelist.h"
#include "net/cert/ct_verifier.h"
#include "net/cert/ct_verify_result.h"
#include "net/cert/scoped_nss_types.h"
@@ -105,7 +105,6 @@
#include "net/ocsp/nss_ocsp.h"
#include "net/socket/client_socket_handle.h"
#include "net/socket/nss_ssl_util.h"
-#include "net/socket/ssl_error_params.h"
#include "net/ssl/ssl_cert_request_info.h"
#include "net/ssl/ssl_connection_status_flags.h"
#include "net/ssl/ssl_info.h"
@@ -408,7 +407,7 @@ struct HandshakeState {
void Reset() {
next_proto_status = SSLClientSocket::kNextProtoUnsupported;
next_proto.clear();
- server_protos.clear();
+ negotiation_extension_ = SSLClientSocket::kExtensionUnknown;
channel_id_sent = false;
server_cert_chain.Reset(NULL);
server_cert = NULL;
@@ -422,8 +421,9 @@ struct HandshakeState {
// negotiated protocol stored in |next_proto|.
SSLClientSocket::NextProtoStatus next_proto_status;
std::string next_proto;
- // If the server supports NPN, the protocols supported by the server.
- std::string server_protos;
+
+ // TLS extension used for protocol negotiation.
+ SSLClientSocket::SSLNegotiationExtension negotiation_extension_;
// True if a channel ID was sent.
bool channel_id_sent;
@@ -474,18 +474,6 @@ int MapNSSClientError(PRErrorCode err) {
}
}
-// Map NSS error code from the first SSL handshake to network error code.
-int MapNSSClientHandshakeError(PRErrorCode err) {
- switch (err) {
- // If the server closed on us, it is a protocol error.
- // Some TLS-intolerant servers do this when we request TLS.
- case PR_END_OF_FILE_ERROR:
- return ERR_SSL_PROTOCOL_ERROR;
- default:
- return MapNSSClientError(err);
- }
-}
-
} // namespace
// SSLClientSocketNSS::Core provides a thread-safe, ref-counted core that is
@@ -524,7 +512,7 @@ int MapNSSClientHandshakeError(PRErrorCode err) {
// 2) NSS Task Runner: Prepare data to go from NSS to an IO function:
// (BufferRecv, BufferSend)
// 3) Network Task Runner: Perform IO on that data (DoBufferRecv,
-// DoBufferSend, DoGetDomainBoundCert, OnGetDomainBoundCertComplete)
+// DoBufferSend, DoGetChannelID, OnGetChannelIDComplete)
// 4) Both Task Runners: Callback for asynchronous completion or to marshal
// data from the network task runner back to NSS (BufferRecvComplete,
// BufferSendComplete, OnHandshakeIOComplete)
@@ -592,7 +580,7 @@ class SSLClientSocketNSS::Core : public base::RefCountedThreadSafe<Core> {
// that their lifetimes match that of the owning SSLClientSocketNSS.
//
// The caller retains ownership of |transport|, |net_log|, and
- // |server_bound_cert_service|, and they will not be accessed once Detach()
+ // |channel_id_service|, and they will not be accessed once Detach()
// has been called.
Core(base::SequencedTaskRunner* network_task_runner,
base::SequencedTaskRunner* nss_task_runner,
@@ -600,7 +588,7 @@ class SSLClientSocketNSS::Core : public base::RefCountedThreadSafe<Core> {
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
BoundNetLog* net_log,
- ServerBoundCertService* server_bound_cert_service);
+ ChannelIDService* channel_id_service);
// Called on the network task runner.
// Transfers ownership of |socket|, an NSS SSL socket, and |buffers|, the
@@ -720,7 +708,7 @@ class SSLClientSocketNSS::Core : public base::RefCountedThreadSafe<Core> {
// Handles an NSS error generated while handshaking or performing IO.
// Returns a network error code mapped from the original NSS error.
- int HandleNSSError(PRErrorCode error, bool handshake_error);
+ int HandleNSSError(PRErrorCode error);
int DoHandshakeLoop(int last_io_result);
int DoReadLoop(int result);
@@ -754,7 +742,7 @@ class SSLClientSocketNSS::Core : public base::RefCountedThreadSafe<Core> {
// key into a SECKEYPublicKey and SECKEYPrivateKey. Returns OK upon success
// and an error code otherwise.
// Requires |domain_bound_private_key_| and |domain_bound_cert_| to have been
- // set by a call to ServerBoundCertService->GetDomainBoundCert. The caller
+ // set by a call to ChannelIDService->GetChannelID. The caller
// takes ownership of the |*cert| and |*key|.
int ImportChannelIDKeys(SECKEYPublicKey** public_key, SECKEYPrivateKey** key);
@@ -775,15 +763,17 @@ class SSLClientSocketNSS::Core : public base::RefCountedThreadSafe<Core> {
// UpdateNextProto gets any application-layer protocol that may have been
// negotiated by the TLS connection.
void UpdateNextProto();
+ // Record TLS extension used for protocol negotiation (NPN or ALPN).
+ void UpdateExtensionUsed();
////////////////////////////////////////////////////////////////////////////
// Methods that are ONLY called on the network task runner:
////////////////////////////////////////////////////////////////////////////
int DoBufferRecv(IOBuffer* buffer, int len);
int DoBufferSend(IOBuffer* buffer, int len);
- int DoGetDomainBoundCert(const std::string& host);
+ int DoGetChannelID(const std::string& host);
- void OnGetDomainBoundCertComplete(int result);
+ void OnGetChannelIDComplete(int result);
void OnHandshakeStateUpdated(const HandshakeState& state);
void OnNSSBufferUpdated(int amount_in_read_buffer);
void DidNSSRead(int result);
@@ -832,8 +822,8 @@ class SSLClientSocketNSS::Core : public base::RefCountedThreadSafe<Core> {
HandshakeState network_handshake_state_;
// The service for retrieving Channel ID keys. May be NULL.
- ServerBoundCertService* server_bound_cert_service_;
- ServerBoundCertService::RequestHandle domain_bound_cert_request_handle_;
+ ChannelIDService* channel_id_service_;
+ ChannelIDService::RequestHandle domain_bound_cert_request_handle_;
// The information about NSS task runner.
int unhandled_buffer_size_;
@@ -914,7 +904,7 @@ class SSLClientSocketNSS::Core : public base::RefCountedThreadSafe<Core> {
// for the network task runner from the NSS task runner.
base::WeakPtr<BoundNetLog> weak_net_log_;
- // Written on the network task runner by the |server_bound_cert_service_|,
+ // Written on the network task runner by the |channel_id_service_|,
// prior to invoking OnHandshakeIOComplete.
// Read on the NSS task runner when once OnHandshakeIOComplete is invoked
// on the NSS task runner.
@@ -931,11 +921,11 @@ SSLClientSocketNSS::Core::Core(
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
BoundNetLog* net_log,
- ServerBoundCertService* server_bound_cert_service)
+ ChannelIDService* channel_id_service)
: detached_(false),
transport_(transport),
weak_net_log_factory_(net_log),
- server_bound_cert_service_(server_bound_cert_service),
+ channel_id_service_(channel_id_service),
unhandled_buffer_size_(0),
nss_waiting_read_(false),
nss_waiting_write_(false),
@@ -969,6 +959,7 @@ SSLClientSocketNSS::Core::~Core() {
PR_Close(nss_fd_);
nss_fd_ = NULL;
}
+ nss_bufs_ = NULL;
}
bool SSLClientSocketNSS::Core::Init(PRFileDesc* socket,
@@ -983,30 +974,11 @@ bool SSLClientSocketNSS::Core::Init(PRFileDesc* socket,
SECStatus rv = SECSuccess;
if (!ssl_config_.next_protos.empty()) {
- size_t wire_length = 0;
- for (std::vector<std::string>::const_iterator
- i = ssl_config_.next_protos.begin();
- i != ssl_config_.next_protos.end(); ++i) {
- if (i->size() > 255) {
- LOG(WARNING) << "Ignoring overlong NPN/ALPN protocol: " << *i;
- continue;
- }
- wire_length += i->size();
- wire_length++;
- }
- scoped_ptr<uint8[]> wire_protos(new uint8[wire_length]);
- uint8* dst = wire_protos.get();
- for (std::vector<std::string>::const_iterator
- i = ssl_config_.next_protos.begin();
- i != ssl_config_.next_protos.end(); i++) {
- if (i->size() > 255)
- continue;
- *dst++ = i->size();
- memcpy(dst, i->data(), i->size());
- dst += i->size();
- }
- DCHECK_EQ(dst, wire_protos.get() + wire_length);
- rv = SSL_SetNextProtoNego(nss_fd_, wire_protos.get(), wire_length);
+ std::vector<uint8_t> wire_protos =
+ SerializeNextProtos(ssl_config_.next_protos);
+ rv = SSL_SetNextProtoNego(
+ nss_fd_, wire_protos.empty() ? NULL : &wire_protos[0],
+ wire_protos.size());
if (rv != SECSuccess)
LogFailedNSSFunction(*weak_net_log_, "SSL_SetNextProtoNego", "");
rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_ALPN, PR_TRUE);
@@ -1037,7 +1009,7 @@ bool SSLClientSocketNSS::Core::Init(PRFileDesc* socket,
return false;
}
- if (IsChannelIDEnabled(ssl_config_, server_bound_cert_service_)) {
+ if (IsChannelIDEnabled(ssl_config_, channel_id_service_)) {
rv = SSL_SetClientChannelIDCallback(
nss_fd_, SSLClientSocketNSS::Core::ClientChannelIDHandler, this);
if (rv != SECSuccess) {
@@ -1658,6 +1630,11 @@ void SSLClientSocketNSS::Core::HandshakeCallback(
}
void SSLClientSocketNSS::Core::HandshakeSucceeded() {
+ // TODO(vadimt): Remove ScopedTracker below once crbug.com/424386 is fixed.
+ tracked_objects::ScopedTracker tracking_profile(
+ FROM_HERE_WITH_EXPLICIT_FUNCTION(
+ "424386 SSLClientSocketNSS::Core::HandshakeSucceeded"));
+
DCHECK(OnNSSTaskRunner());
PRBool last_handshake_resumed;
@@ -1674,6 +1651,7 @@ void SSLClientSocketNSS::Core::HandshakeSucceeded() {
UpdateStapledOCSPResponse();
UpdateConnectionStatus();
UpdateNextProto();
+ UpdateExtensionUsed();
// Update the network task runners view of the handshake state whenever
// a handshake has completed.
@@ -1682,12 +1660,15 @@ void SSLClientSocketNSS::Core::HandshakeSucceeded() {
nss_handshake_state_));
}
-int SSLClientSocketNSS::Core::HandleNSSError(PRErrorCode nss_error,
- bool handshake_error) {
+int SSLClientSocketNSS::Core::HandleNSSError(PRErrorCode nss_error) {
+ // TODO(vadimt): Remove ScopedTracker below once crbug.com/424386 is fixed.
+ tracked_objects::ScopedTracker tracking_profile(
+ FROM_HERE_WITH_EXPLICIT_FUNCTION(
+ "424386 SSLClientSocketNSS::Core::HandleNSSError"));
+
DCHECK(OnNSSTaskRunner());
- int net_error = handshake_error ? MapNSSClientHandshakeError(nss_error) :
- MapNSSClientError(nss_error);
+ int net_error = MapNSSClientError(nss_error);
#if defined(OS_WIN)
// On Windows, a handle to the HCRYPTPROV is cached in the X509Certificate
@@ -1717,6 +1698,11 @@ int SSLClientSocketNSS::Core::HandleNSSError(PRErrorCode nss_error,
}
int SSLClientSocketNSS::Core::DoHandshakeLoop(int last_io_result) {
+ // TODO(vadimt): Remove ScopedTracker below once crbug.com/424386 is fixed.
+ tracked_objects::ScopedTracker tracking_profile(
+ FROM_HERE_WITH_EXPLICIT_FUNCTION(
+ "424386 SSLClientSocketNSS::Core::DoHandshakeLoop"));
+
DCHECK(OnNSSTaskRunner());
int rv = last_io_result;
@@ -1753,6 +1739,11 @@ int SSLClientSocketNSS::Core::DoHandshakeLoop(int last_io_result) {
}
int SSLClientSocketNSS::Core::DoReadLoop(int result) {
+ // TODO(vadimt): Remove ScopedTracker below once crbug.com/424386 is fixed.
+ tracked_objects::ScopedTracker tracking_profile(
+ FROM_HERE_WITH_EXPLICIT_FUNCTION(
+ "424386 SSLClientSocketNSS::Core::DoReadLoop"));
+
DCHECK(OnNSSTaskRunner());
DCHECK(false_started_ || handshake_callback_called_);
DCHECK_EQ(STATE_NONE, next_handshake_state_);
@@ -1812,11 +1803,21 @@ int SSLClientSocketNSS::Core::DoWriteLoop(int result) {
}
int SSLClientSocketNSS::Core::DoHandshake() {
+ // TODO(vadimt): Remove ScopedTracker below once crbug.com/424386 is fixed.
+ tracked_objects::ScopedTracker tracking_profile(
+ FROM_HERE_WITH_EXPLICIT_FUNCTION(
+ "424386 SSLClientSocketNSS::Core::DoHandshake"));
+
DCHECK(OnNSSTaskRunner());
- int net_error = net::OK;
+ int net_error = OK;
SECStatus rv = SSL_ForceHandshake(nss_fd_);
+ // TODO(vadimt): Remove ScopedTracker below once crbug.com/424386 is fixed.
+ tracked_objects::ScopedTracker tracking_profile1(
+ FROM_HERE_WITH_EXPLICIT_FUNCTION(
+ "424386 SSLClientSocketNSS::Core::DoHandshake 1"));
+
// Note: this function may be called multiple times during the handshake, so
// even though channel id and client auth are separate else cases, they can
// both be used during a single SSL handshake.
@@ -1845,24 +1846,7 @@ int SSLClientSocketNSS::Core::DoHandshake() {
}
} else {
PRErrorCode prerr = PR_GetError();
- net_error = HandleNSSError(prerr, true);
-
- // Some network devices that inspect application-layer packets seem to
- // inject TCP reset packets to break the connections when they see
- // TLS 1.1 in ClientHello or ServerHello. See http://crbug.com/130293.
- //
- // Only allow ERR_CONNECTION_RESET to trigger a fallback from TLS 1.1 or
- // 1.2. We don't lose much in this fallback because the explicit IV for CBC
- // mode in TLS 1.1 is approximated by record splitting in TLS 1.0. The
- // fallback will be more painful for TLS 1.2 when we have GCM support.
- //
- // ERR_CONNECTION_RESET is a common network error, so we don't want it
- // to trigger a version fallback in general, especially the TLS 1.0 ->
- // SSL 3.0 fallback, which would drop TLS extensions.
- if (prerr == PR_CONNECT_RESET_ERROR &&
- ssl_config_.version_max >= SSL_PROTOCOL_VERSION_TLS1_1) {
- net_error = ERR_SSL_PROTOCOL_ERROR;
- }
+ net_error = HandleNSSError(prerr);
// If not done, stay in this state
if (net_error == ERR_IO_PENDING) {
@@ -1880,6 +1864,11 @@ int SSLClientSocketNSS::Core::DoHandshake() {
}
int SSLClientSocketNSS::Core::DoGetDBCertComplete(int result) {
+ // TODO(vadimt): Remove ScopedTracker below once crbug.com/424386 is fixed.
+ tracked_objects::ScopedTracker tracking_profile(
+ FROM_HERE_WITH_EXPLICIT_FUNCTION(
+ "424386 SSLClientSocketNSS::Core::DoGetDBCertComplete"));
+
SECStatus rv;
PostOrRunCallback(
FROM_HERE,
@@ -1989,7 +1978,7 @@ int SSLClientSocketNSS::Core::DoPayloadRead() {
// If *next_result == 0, then that indicates EOF, and no special error
// handling is needed.
pending_read_nss_error_ = PR_GetError();
- *next_result = HandleNSSError(pending_read_nss_error_, false);
+ *next_result = HandleNSSError(pending_read_nss_error_);
if (rv > 0 && *next_result == ERR_IO_PENDING) {
// If at least some data was read from PR_Read(), do not treat
// insufficient data as an error to return in the next call to
@@ -2051,7 +2040,7 @@ int SSLClientSocketNSS::Core::DoPayloadWrite() {
if (prerr == PR_WOULD_BLOCK_ERROR)
return ERR_IO_PENDING;
- rv = HandleNSSError(prerr, false);
+ rv = HandleNSSError(prerr);
PostOrRunCallback(
FROM_HERE,
base::Bind(&AddLogEventWithCallback, weak_net_log_,
@@ -2064,6 +2053,11 @@ int SSLClientSocketNSS::Core::DoPayloadWrite() {
// transport socket. Return true if some I/O performed, false
// otherwise (error or ERR_IO_PENDING).
bool SSLClientSocketNSS::Core::DoTransportIO() {
+ // TODO(vadimt): Remove ScopedTracker below once crbug.com/424386 is fixed.
+ tracked_objects::ScopedTracker tracking_profile(
+ FROM_HERE_WITH_EXPLICIT_FUNCTION(
+ "424386 SSLClientSocketNSS::Core::DoTransportIO"));
+
DCHECK(OnNSSTaskRunner());
bool network_moved = false;
@@ -2240,6 +2234,11 @@ void SSLClientSocketNSS::Core::OnSendComplete(int result) {
// callback. For Read() and Write(), that's what we want. But for Connect(),
// the caller expects OK (i.e. 0) for success.
void SSLClientSocketNSS::Core::DoConnectCallback(int rv) {
+ // TODO(vadimt): Remove ScopedTracker below once crbug.com/424386 is fixed.
+ tracked_objects::ScopedTracker tracking_profile(
+ FROM_HERE_WITH_EXPLICIT_FUNCTION(
+ "424386 SSLClientSocketNSS::Core::DoConnectCallback"));
+
DCHECK(OnNSSTaskRunner());
DCHECK_NE(rv, ERR_IO_PENDING);
DCHECK(!user_connect_callback_.is_null());
@@ -2251,6 +2250,11 @@ void SSLClientSocketNSS::Core::DoConnectCallback(int rv) {
}
void SSLClientSocketNSS::Core::DoReadCallback(int rv) {
+ // TODO(vadimt): Remove ScopedTracker below once crbug.com/424386 is fixed.
+ tracked_objects::ScopedTracker tracking_profile(
+ FROM_HERE_WITH_EXPLICIT_FUNCTION(
+ "424386 SSLClientSocketNSS::Core::DoReadCallback"));
+
DCHECK(OnNSSTaskRunner());
DCHECK_NE(ERR_IO_PENDING, rv);
DCHECK(!user_read_callback_.is_null());
@@ -2266,6 +2270,10 @@ void SSLClientSocketNSS::Core::DoReadCallback(int rv) {
PostOrRunCallback(
FROM_HERE,
base::Bind(&Core::DidNSSRead, this, rv));
+ // TODO(vadimt): Remove ScopedTracker below once crbug.com/418183 is fixed.
+ tracked_objects::ScopedTracker tracking_profile1(
+ FROM_HERE_WITH_EXPLICIT_FUNCTION(
+ "SSLClientSocketNSS::Core::DoReadCallback"));
PostOrRunCallback(
FROM_HERE,
base::Bind(base::ResetAndReturn(&user_read_callback_), rv));
@@ -2314,12 +2322,12 @@ SECStatus SSLClientSocketNSS::Core::ClientChannelIDHandler(
std::string host = core->host_and_port_.host();
int error = ERR_UNEXPECTED;
if (core->OnNetworkTaskRunner()) {
- error = core->DoGetDomainBoundCert(host);
+ error = core->DoGetChannelID(host);
} else {
bool posted = core->network_task_runner_->PostTask(
FROM_HERE,
base::Bind(
- IgnoreResult(&Core::DoGetDomainBoundCert),
+ IgnoreResult(&Core::DoGetChannelID),
core, host));
error = posted ? ERR_IO_PENDING : ERR_ABORTED;
}
@@ -2367,7 +2375,7 @@ int SSLClientSocketNSS::Core::ImportChannelIDKeys(SECKEYPublicKey** public_key,
// Set the private key.
if (!crypto::ECPrivateKey::ImportFromEncryptedPrivateKeyInfo(
slot.get(),
- ServerBoundCertService::kEPKIPassword,
+ ChannelIDService::kEPKIPassword,
reinterpret_cast<const unsigned char*>(
domain_bound_private_key_.data()),
domain_bound_private_key_.size(),
@@ -2459,6 +2467,10 @@ void SSLClientSocketNSS::Core::UpdateStapledOCSPResponse() {
}
void SSLClientSocketNSS::Core::UpdateConnectionStatus() {
+ // Note: This function may be called multiple times for a single connection
+ // if renegotiations occur.
+ nss_handshake_state_.ssl_connection_status = 0;
+
SSLChannelInfo channel_info;
SECStatus ok = SSL_GetChannelInfo(nss_fd_,
&channel_info, sizeof(channel_info));
@@ -2475,8 +2487,6 @@ void SSLClientSocketNSS::Core::UpdateConnectionStatus() {
SSL_CONNECTION_COMPRESSION_MASK) <<
SSL_CONNECTION_COMPRESSION_SHIFT;
- // NSS 3.14.x doesn't have a version macro for TLS 1.2 (because NSS didn't
- // support it yet), so use 0x0303 directly.
int version = SSL_CONNECTION_VERSION_UNKNOWN;
if (channel_info.protocolVersion < SSL_LIBRARY_VERSION_3_0) {
// All versions less than SSL_LIBRARY_VERSION_3_0 are treated as SSL
@@ -2484,11 +2494,11 @@ void SSLClientSocketNSS::Core::UpdateConnectionStatus() {
version = SSL_CONNECTION_VERSION_SSL2;
} else if (channel_info.protocolVersion == SSL_LIBRARY_VERSION_3_0) {
version = SSL_CONNECTION_VERSION_SSL3;
- } else if (channel_info.protocolVersion == SSL_LIBRARY_VERSION_3_1_TLS) {
+ } else if (channel_info.protocolVersion == SSL_LIBRARY_VERSION_TLS_1_0) {
version = SSL_CONNECTION_VERSION_TLS1;
} else if (channel_info.protocolVersion == SSL_LIBRARY_VERSION_TLS_1_1) {
version = SSL_CONNECTION_VERSION_TLS1_1;
- } else if (channel_info.protocolVersion == 0x0303) {
+ } else if (channel_info.protocolVersion == SSL_LIBRARY_VERSION_TLS_1_2) {
version = SSL_CONNECTION_VERSION_TLS1_2;
}
nss_handshake_state_.ssl_connection_status |=
@@ -2508,26 +2518,6 @@ void SSLClientSocketNSS::Core::UpdateConnectionStatus() {
VLOG(1) << "The server " << host_and_port_.ToString()
<< " does not support the TLS renegotiation_info extension.";
}
- UMA_HISTOGRAM_ENUMERATION("Net.RenegotiationExtensionSupported",
- peer_supports_renego_ext, 2);
-
- // We would like to eliminate fallback to SSLv3 for non-buggy servers
- // because of security concerns. For example, Google offers forward
- // secrecy with ECDHE but that requires TLS 1.0. An attacker can block
- // TLSv1 connections and force us to downgrade to SSLv3 and remove forward
- // secrecy.
- //
- // Yngve from Opera has suggested using the renegotiation extension as an
- // indicator that SSLv3 fallback was mistaken:
- // tools.ietf.org/html/draft-pettersen-tls-version-rollback-removal-00 .
- //
- // As a first step, measure how often clients perform version fallback
- // while the server advertises support secure renegotiation.
- if (ssl_config_.version_fallback &&
- channel_info.protocolVersion == SSL_LIBRARY_VERSION_3_0) {
- UMA_HISTOGRAM_BOOLEAN("Net.SSLv3FallbackToRenegoPatchedServer",
- peer_supports_renego_ext == PR_TRUE);
- }
}
if (ssl_config_.version_fallback) {
@@ -2564,6 +2554,23 @@ void SSLClientSocketNSS::Core::UpdateNextProto() {
}
}
+void SSLClientSocketNSS::Core::UpdateExtensionUsed() {
+ PRBool negotiated_extension;
+ SECStatus rv = SSL_HandshakeNegotiatedExtension(nss_fd_,
+ ssl_app_layer_protocol_xtn,
+ &negotiated_extension);
+ if (rv == SECSuccess && negotiated_extension) {
+ nss_handshake_state_.negotiation_extension_ = kExtensionALPN;
+ } else {
+ rv = SSL_HandshakeNegotiatedExtension(nss_fd_,
+ ssl_next_proto_nego_xtn,
+ &negotiated_extension);
+ if (rv == SECSuccess && negotiated_extension) {
+ nss_handshake_state_.negotiation_extension_ = kExtensionNPN;
+ }
+ }
+}
+
void SSLClientSocketNSS::Core::RecordChannelIDSupportOnNSSTaskRunner() {
DCHECK(OnNSSTaskRunner());
if (nss_handshake_state_.resumed_handshake)
@@ -2587,7 +2594,7 @@ void SSLClientSocketNSS::Core::RecordChannelIDSupportOnNetworkTaskRunner(
bool supports_ecc) const {
DCHECK(OnNetworkTaskRunner());
- RecordChannelIDSupport(server_bound_cert_service_,
+ RecordChannelIDSupport(channel_id_service_,
negotiated_channel_id,
channel_id_enabled,
supports_ecc);
@@ -2637,19 +2644,19 @@ int SSLClientSocketNSS::Core::DoBufferSend(IOBuffer* send_buffer, int len) {
return rv;
}
-int SSLClientSocketNSS::Core::DoGetDomainBoundCert(const std::string& host) {
+int SSLClientSocketNSS::Core::DoGetChannelID(const std::string& host) {
DCHECK(OnNetworkTaskRunner());
if (detached_)
- return ERR_FAILED;
+ return ERR_ABORTED;
weak_net_log_->BeginEvent(NetLog::TYPE_SSL_GET_DOMAIN_BOUND_CERT);
- int rv = server_bound_cert_service_->GetOrCreateDomainBoundCert(
+ int rv = channel_id_service_->GetOrCreateChannelID(
host,
&domain_bound_private_key_,
&domain_bound_cert_,
- base::Bind(&Core::OnGetDomainBoundCertComplete, base::Unretained(this)),
+ base::Bind(&Core::OnGetChannelIDComplete, base::Unretained(this)),
&domain_bound_cert_request_handle_);
if (rv != ERR_IO_PENDING && !OnNSSTaskRunner()) {
@@ -2696,6 +2703,11 @@ void SSLClientSocketNSS::Core::DidNSSWrite(int result) {
}
void SSLClientSocketNSS::Core::BufferSendComplete(int result) {
+ // TODO(vadimt): Remove ScopedTracker below once crbug.com/418183 is fixed.
+ tracked_objects::ScopedTracker tracking_profile(
+ FROM_HERE_WITH_EXPLICIT_FUNCTION(
+ "418183 DidCompleteReadWrite => Core::BufferSendComplete"));
+
if (!OnNSSTaskRunner()) {
if (detached_)
return;
@@ -2729,7 +2741,7 @@ void SSLClientSocketNSS::Core::OnHandshakeIOComplete(int result) {
DoConnectCallback(rv);
}
-void SSLClientSocketNSS::Core::OnGetDomainBoundCertComplete(int result) {
+void SSLClientSocketNSS::Core::OnGetChannelIDComplete(int result) {
DVLOG(1) << __FUNCTION__ << " " << result;
DCHECK(OnNetworkTaskRunner());
@@ -2739,6 +2751,11 @@ void SSLClientSocketNSS::Core::OnGetDomainBoundCertComplete(int result) {
void SSLClientSocketNSS::Core::BufferRecvComplete(
IOBuffer* read_buffer,
int result) {
+ // TODO(vadimt): Remove ScopedTracker below once crbug.com/418183 is fixed.
+ tracked_objects::ScopedTracker tracking_profile(
+ FROM_HERE_WITH_EXPLICIT_FUNCTION(
+ "418183 DidCompleteReadWrite => SSLClientSocketNSS::Core::..."));
+
DCHECK(read_buffer);
if (!OnNSSTaskRunner()) {
@@ -2814,7 +2831,7 @@ SSLClientSocketNSS::SSLClientSocketNSS(
ssl_config_(ssl_config),
cert_verifier_(context.cert_verifier),
cert_transparency_verifier_(context.cert_transparency_verifier),
- server_bound_cert_service_(context.server_bound_cert_service),
+ channel_id_service_(context.channel_id_service),
ssl_session_cache_shard_(context.ssl_session_cache_shard),
completed_handshake_(false),
next_handshake_state_(STATE_NONE),
@@ -2886,6 +2903,21 @@ bool SSLClientSocketNSS::GetSSLInfo(SSLInfo* ssl_info) {
return true;
}
+std::string SSLClientSocketNSS::GetSessionCacheKey() const {
+ NOTIMPLEMENTED();
+ return std::string();
+}
+
+bool SSLClientSocketNSS::InSessionCache() const {
+ // For now, always return true so that SSLConnectJobs are never held back.
+ return true;
+}
+
+void SSLClientSocketNSS::SetHandshakeCompletionCallback(
+ const base::Closure& callback) {
+ NOTIMPLEMENTED();
+}
+
void SSLClientSocketNSS::GetSSLCertRequestInfo(
SSLCertRequestInfo* cert_request_info) {
EnterFunction("");
@@ -2932,10 +2964,8 @@ int SSLClientSocketNSS::GetTLSUniqueChannelBinding(std::string* out) {
}
SSLClientSocket::NextProtoStatus
-SSLClientSocketNSS::GetNextProto(std::string* proto,
- std::string* server_protos) {
+SSLClientSocketNSS::GetNextProto(std::string* proto) {
*proto = core_->state().next_proto;
- *server_protos = core_->state().server_protos;
return core_->state().next_proto_status;
}
@@ -3125,7 +3155,7 @@ void SSLClientSocketNSS::InitCore() {
host_and_port_,
ssl_config_,
&net_log_,
- server_bound_cert_service_);
+ channel_id_service_);
}
int SSLClientSocketNSS::InitializeSSLOptions() {
@@ -3374,6 +3404,11 @@ int SSLClientSocketNSS::DoHandshakeComplete(int result) {
EnterFunction(result);
if (result == OK) {
+ if (ssl_config_.version_fallback &&
+ ssl_config_.version_max < ssl_config_.version_fallback_min) {
+ return ERR_SSL_FALLBACK_BEYOND_MINIMUM_VERSION;
+ }
+
// SSL handshake is completed. Let's verify the certificate.
GotoState(STATE_VERIFY_CERT);
// Done!
@@ -3383,6 +3418,7 @@ int SSLClientSocketNSS::DoHandshakeComplete(int result) {
!core_->state().sct_list_from_tls_extension.empty());
set_stapled_ocsp_response_received(
!core_->state().stapled_ocsp_response.empty());
+ set_negotiation_extension(core_->state().negotiation_extension_);
LeaveFunction(result);
return result;
@@ -3468,56 +3504,38 @@ int SSLClientSocketNSS::DoVerifyCertComplete(int result) {
// TODO(hclam): Skip logging if server cert was expected to be bad because
// |server_cert_verify_result_| doesn't contain all the information about
// the cert.
- if (result == OK)
- LogConnectionTypeMetrics();
-
-#if defined(OFFICIAL_BUILD) && !defined(OS_ANDROID) && !defined(OS_IOS)
- // Take care of any mandates for public key pinning.
- //
- // Pinning is only enabled for official builds to make sure that others don't
- // end up with pins that cannot be easily updated.
- //
- // TODO(agl): We might have an issue here where a request for foo.example.com
- // merges into a SPDY connection to www.example.com, and gets a different
- // certificate.
+ if (result == OK) {
+ int ssl_version =
+ SSLConnectionStatusToVersion(core_->state().ssl_connection_status);
+ RecordConnectionTypeMetrics(ssl_version);
+ }
- // Perform pin validation if, and only if, all these conditions obtain:
- //
- // * a TransportSecurityState object is available;
- // * the server's certificate chain is valid (or suffers from only a minor
- // error);
- // * the server's certificate chain chains up to a known root (i.e. not a
- // user-installed trust anchor); and
- // * the build is recent (very old builds should fail open so that users
- // have some chance to recover).
- //
const CertStatus cert_status = server_cert_verify_result_.cert_status;
if (transport_security_state_ &&
(result == OK ||
(IsCertificateError(result) && IsCertStatusMinorError(cert_status))) &&
- server_cert_verify_result_.is_issued_by_known_root &&
- TransportSecurityState::IsBuildTimely()) {
- bool sni_available =
- ssl_config_.version_max >= SSL_PROTOCOL_VERSION_TLS1 ||
- ssl_config_.version_fallback;
- const std::string& host = host_and_port_.host();
-
- if (transport_security_state_->HasPublicKeyPins(host, sni_available)) {
- if (!transport_security_state_->CheckPublicKeyPins(
- host,
- sni_available,
- server_cert_verify_result_.public_key_hashes,
- &pinning_failure_log_)) {
- LOG(ERROR) << pinning_failure_log_;
- result = ERR_SSL_PINNED_KEY_NOT_IN_CERT_CHAIN;
- UMA_HISTOGRAM_BOOLEAN("Net.PublicKeyPinSuccess", false);
- TransportSecurityState::ReportUMAOnPinFailure(host);
- } else {
- UMA_HISTOGRAM_BOOLEAN("Net.PublicKeyPinSuccess", true);
- }
+ !transport_security_state_->CheckPublicKeyPins(
+ host_and_port_.host(),
+ server_cert_verify_result_.is_issued_by_known_root,
+ server_cert_verify_result_.public_key_hashes,
+ &pinning_failure_log_)) {
+ result = ERR_SSL_PINNED_KEY_NOT_IN_CERT_CHAIN;
+ }
+
+ scoped_refptr<ct::EVCertsWhitelist> ev_whitelist =
+ SSLConfigService::GetEVCertsWhitelist();
+ if (server_cert_verify_result_.cert_status & CERT_STATUS_IS_EV) {
+ if (ev_whitelist.get() && ev_whitelist->IsValid()) {
+ const SHA256HashValue fingerprint(
+ X509Certificate::CalculateFingerprint256(
+ server_cert_verify_result_.verified_cert->os_cert_handle()));
+
+ UMA_HISTOGRAM_BOOLEAN(
+ "Net.SSL_EVCertificateInWhitelist",
+ ev_whitelist->ContainsCertificateHash(
+ std::string(reinterpret_cast<const char*>(fingerprint.data), 8)));
}
}
-#endif
if (result == OK) {
// Only check Certificate Transparency if there were no other errors with
@@ -3543,7 +3561,7 @@ void SSLClientSocketNSS::VerifyCT() {
// gets all the data it needs for SCT verification and does not do any
// external communication.
int result = cert_transparency_verifier_->Verify(
- server_cert_verify_result_.verified_cert,
+ server_cert_verify_result_.verified_cert.get(),
core_->state().stapled_ocsp_response,
core_->state().sct_list_from_tls_extension,
&ct_verify_result_,
@@ -3558,29 +3576,6 @@ void SSLClientSocketNSS::VerifyCT() {
<< ct_verify_result_.unknown_logs_scts.size();
}
-void SSLClientSocketNSS::LogConnectionTypeMetrics() const {
- UpdateConnectionTypeHistograms(CONNECTION_SSL);
- int ssl_version = SSLConnectionStatusToVersion(
- core_->state().ssl_connection_status);
- switch (ssl_version) {
- case SSL_CONNECTION_VERSION_SSL2:
- UpdateConnectionTypeHistograms(CONNECTION_SSL_SSL2);
- break;
- case SSL_CONNECTION_VERSION_SSL3:
- UpdateConnectionTypeHistograms(CONNECTION_SSL_SSL3);
- break;
- case SSL_CONNECTION_VERSION_TLS1:
- UpdateConnectionTypeHistograms(CONNECTION_SSL_TLS1);
- break;
- case SSL_CONNECTION_VERSION_TLS1_1:
- UpdateConnectionTypeHistograms(CONNECTION_SSL_TLS1_1);
- break;
- case SSL_CONNECTION_VERSION_TLS1_2:
- UpdateConnectionTypeHistograms(CONNECTION_SSL_TLS1_2);
- break;
- };
-}
-
void SSLClientSocketNSS::EnsureThreadIdAssigned() const {
base::AutoLock auto_lock(lock_);
if (valid_thread_id_ != base::kInvalidThreadId)
@@ -3621,8 +3616,8 @@ SSLClientSocketNSS::GetUnverifiedServerCertificateChain() const {
return core_->state().server_cert.get();
}
-ServerBoundCertService* SSLClientSocketNSS::GetServerBoundCertService() const {
- return server_bound_cert_service_;
+ChannelIDService* SSLClientSocketNSS::GetChannelIDService() const {
+ return channel_id_service_;
}
} // namespace net
diff --git a/chromium/net/socket/ssl_client_socket_nss.h b/chromium/net/socket/ssl_client_socket_nss.h
index e8cce574b64..71f09c0b82b 100644
--- a/chromium/net/socket/ssl_client_socket_nss.h
+++ b/chromium/net/socket/ssl_client_socket_nss.h
@@ -17,7 +17,6 @@
#include "base/synchronization/lock.h"
#include "base/threading/platform_thread.h"
#include "base/time/time.h"
-#include "base/timer/timer.h"
#include "net/base/completion_callback.h"
#include "net/base/host_port_pair.h"
#include "net/base/net_export.h"
@@ -27,7 +26,7 @@
#include "net/cert/ct_verify_result.h"
#include "net/cert/x509_certificate.h"
#include "net/socket/ssl_client_socket.h"
-#include "net/ssl/server_bound_cert_service.h"
+#include "net/ssl/channel_id_service.h"
#include "net/ssl/ssl_config_service.h"
namespace base {
@@ -38,9 +37,9 @@ namespace net {
class BoundNetLog;
class CertVerifier;
+class ChannelIDService;
class CTVerifier;
class ClientSocketHandle;
-class ServerBoundCertService;
class SingleRequestCertVerifier;
class TransportSecurityState;
class X509Certificate;
@@ -65,51 +64,52 @@ class SSLClientSocketNSS : public SSLClientSocket {
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
const SSLClientSocketContext& context);
- virtual ~SSLClientSocketNSS();
+ ~SSLClientSocketNSS() override;
// SSLClientSocket implementation.
- virtual void GetSSLCertRequestInfo(
- SSLCertRequestInfo* cert_request_info) OVERRIDE;
- virtual NextProtoStatus GetNextProto(std::string* proto,
- std::string* server_protos) OVERRIDE;
+ std::string GetSessionCacheKey() const override;
+ bool InSessionCache() const override;
+ void SetHandshakeCompletionCallback(const base::Closure& callback) override;
+ void GetSSLCertRequestInfo(SSLCertRequestInfo* cert_request_info) override;
+ NextProtoStatus GetNextProto(std::string* proto) override;
// SSLSocket implementation.
- virtual int ExportKeyingMaterial(const base::StringPiece& label,
- bool has_context,
- const base::StringPiece& context,
- unsigned char* out,
- unsigned int outlen) OVERRIDE;
- virtual int GetTLSUniqueChannelBinding(std::string* out) OVERRIDE;
+ int ExportKeyingMaterial(const base::StringPiece& label,
+ bool has_context,
+ const base::StringPiece& context,
+ unsigned char* out,
+ unsigned int outlen) override;
+ int GetTLSUniqueChannelBinding(std::string* out) override;
// StreamSocket implementation.
- virtual int Connect(const CompletionCallback& callback) OVERRIDE;
- virtual void Disconnect() OVERRIDE;
- virtual bool IsConnected() const OVERRIDE;
- virtual bool IsConnectedAndIdle() const OVERRIDE;
- virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE;
- virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE;
- virtual const BoundNetLog& NetLog() const OVERRIDE;
- virtual void SetSubresourceSpeculation() OVERRIDE;
- virtual void SetOmniboxSpeculation() OVERRIDE;
- virtual bool WasEverUsed() const OVERRIDE;
- virtual bool UsingTCPFastOpen() const OVERRIDE;
- virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE;
+ int Connect(const CompletionCallback& callback) override;
+ void Disconnect() override;
+ bool IsConnected() const override;
+ bool IsConnectedAndIdle() const override;
+ int GetPeerAddress(IPEndPoint* address) const override;
+ int GetLocalAddress(IPEndPoint* address) const override;
+ const BoundNetLog& NetLog() const override;
+ void SetSubresourceSpeculation() override;
+ void SetOmniboxSpeculation() override;
+ bool WasEverUsed() const override;
+ bool UsingTCPFastOpen() const override;
+ bool GetSSLInfo(SSLInfo* ssl_info) override;
// Socket implementation.
- virtual int Read(IOBuffer* buf,
- int buf_len,
- const CompletionCallback& callback) OVERRIDE;
- virtual int Write(IOBuffer* buf,
- int buf_len,
- const CompletionCallback& callback) OVERRIDE;
- virtual int SetReceiveBufferSize(int32 size) OVERRIDE;
- virtual int SetSendBufferSize(int32 size) OVERRIDE;
- virtual ServerBoundCertService* GetServerBoundCertService() const OVERRIDE;
+ int Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override;
+ int Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override;
+ int SetReceiveBufferSize(int32 size) override;
+ int SetSendBufferSize(int32 size) override;
+ ChannelIDService* GetChannelIDService() const override;
protected:
// SSLClientSocket implementation.
- virtual scoped_refptr<X509Certificate> GetUnverifiedServerCertificateChain()
- const OVERRIDE;
+ scoped_refptr<X509Certificate> GetUnverifiedServerCertificateChain()
+ const override;
private:
// Helper class to handle marshalling any NSS interaction to and from the
@@ -144,8 +144,6 @@ class SSLClientSocketNSS : public SSLClientSocket {
void VerifyCT();
- void LogConnectionTypeMetrics() const;
-
// The following methods are for debugging bug 65948. Will remove this code
// after fixing bug 65948.
void EnsureThreadIdAssigned() const;
@@ -178,7 +176,7 @@ class SSLClientSocketNSS : public SSLClientSocket {
CTVerifier* cert_transparency_verifier_;
// The service for retrieving Channel ID keys. May be NULL.
- ServerBoundCertService* server_bound_cert_service_;
+ ChannelIDService* channel_id_service_;
// ssl_session_cache_shard_ is an opaque string that partitions the SSL
// session cache. i.e. sessions created with one value will not attempt to
@@ -202,7 +200,7 @@ class SSLClientSocketNSS : public SSLClientSocket {
TransportSecurityState* transport_security_state_;
// pinning_failure_log contains a message produced by
- // TransportSecurityState::DomainState::CheckPublicKeyPins in the event of a
+ // TransportSecurityState::CheckPublicKeyPins in the event of a
// pinning failure. It is a (somewhat) human-readable string.
std::string pinning_failure_log_;
diff --git a/chromium/net/socket/ssl_client_socket_openssl.cc b/chromium/net/socket/ssl_client_socket_openssl.cc
index 4ff8d438e96..9fdfe38ccdd 100644
--- a/chromium/net/socket/ssl_client_socket_openssl.cc
+++ b/chromium/net/socket/ssl_client_socket_openssl.cc
@@ -7,29 +7,44 @@
#include "net/socket/ssl_client_socket_openssl.h"
+#include <errno.h>
+#include <openssl/bio.h>
#include <openssl/err.h>
-#include <openssl/opensslv.h>
#include <openssl/ssl.h>
#include "base/bind.h"
#include "base/callback_helpers.h"
+#include "base/environment.h"
#include "base/memory/singleton.h"
#include "base/metrics/histogram.h"
+#include "base/strings/string_piece.h"
#include "base/synchronization/lock.h"
#include "crypto/ec_private_key.h"
#include "crypto/openssl_util.h"
+#include "crypto/scoped_openssl_types.h"
#include "net/base/net_errors.h"
#include "net/cert/cert_verifier.h"
+#include "net/cert/ct_ev_whitelist.h"
+#include "net/cert/ct_verifier.h"
#include "net/cert/single_request_cert_verifier.h"
#include "net/cert/x509_certificate_net_log_param.h"
-#include "net/socket/openssl_ssl_util.h"
-#include "net/socket/ssl_error_params.h"
+#include "net/cert/x509_util_openssl.h"
+#include "net/http/transport_security_state.h"
#include "net/socket/ssl_session_cache_openssl.h"
-#include "net/ssl/openssl_client_key_store.h"
#include "net/ssl/ssl_cert_request_info.h"
#include "net/ssl/ssl_connection_status_flags.h"
#include "net/ssl/ssl_info.h"
+#if defined(OS_WIN)
+#include "base/win/windows_version.h"
+#endif
+
+#if defined(USE_OPENSSL_CERTS)
+#include "net/ssl/openssl_client_key_store.h"
+#else
+#include "net/ssl/openssl_platform_key.h"
+#endif
+
namespace net {
namespace {
@@ -51,6 +66,14 @@ const int kNoPendingReadResult = 1;
// the server supports NPN, choosing "http/1.1" is the best answer.
const char kDefaultSupportedNPNProtocol[] = "http/1.1";
+void FreeX509Stack(STACK_OF(X509)* ptr) {
+ sk_X509_pop_free(ptr, X509_free);
+}
+
+typedef crypto::ScopedOpenSSL<X509, X509_free>::Type ScopedX509;
+typedef crypto::ScopedOpenSSL<STACK_OF(X509), FreeX509Stack>::Type
+ ScopedX509Stack;
+
#if OPENSSL_VERSION_NUMBER < 0x1000103fL
// This method doesn't seem to have made it into the OpenSSL headers.
unsigned long SSL_CIPHER_get_id(const SSL_CIPHER* cipher) { return cipher->id; }
@@ -78,22 +101,43 @@ int GetNetSSLVersion(SSL* ssl) {
return SSL_CONNECTION_VERSION_SSL3;
case TLS1_VERSION:
return SSL_CONNECTION_VERSION_TLS1;
- case 0x0302:
+ case TLS1_1_VERSION:
return SSL_CONNECTION_VERSION_TLS1_1;
- case 0x0303:
+ case TLS1_2_VERSION:
return SSL_CONNECTION_VERSION_TLS1_2;
default:
return SSL_CONNECTION_VERSION_UNKNOWN;
}
}
-// Compute a unique key string for the SSL session cache. |socket| is an
-// input socket object. Return a string.
-std::string GetSocketSessionCacheKey(const SSLClientSocketOpenSSL& socket) {
- std::string result = socket.host_and_port().ToString();
- result.append("/");
- result.append(socket.ssl_session_cache_shard());
- return result;
+ScopedX509 OSCertHandleToOpenSSL(
+ X509Certificate::OSCertHandle os_handle) {
+#if defined(USE_OPENSSL_CERTS)
+ return ScopedX509(X509Certificate::DupOSCertHandle(os_handle));
+#else // !defined(USE_OPENSSL_CERTS)
+ std::string der_encoded;
+ if (!X509Certificate::GetDEREncoded(os_handle, &der_encoded))
+ return ScopedX509();
+ const uint8_t* bytes = reinterpret_cast<const uint8_t*>(der_encoded.data());
+ return ScopedX509(d2i_X509(NULL, &bytes, der_encoded.size()));
+#endif // defined(USE_OPENSSL_CERTS)
+}
+
+ScopedX509Stack OSCertHandlesToOpenSSL(
+ const X509Certificate::OSCertHandles& os_handles) {
+ ScopedX509Stack stack(sk_X509_new_null());
+ for (size_t i = 0; i < os_handles.size(); i++) {
+ ScopedX509 x509 = OSCertHandleToOpenSSL(os_handles[i]);
+ if (!x509)
+ return ScopedX509Stack();
+ sk_X509_push(stack.get(), x509.release());
+ }
+ return stack.Pass();
+}
+
+int LogErrorCallback(const char* str, size_t len, void* context) {
+ LOG(ERROR) << base::StringPiece(str, len);
+ return 1;
}
} // namespace
@@ -126,34 +170,42 @@ class SSLClientSocketOpenSSL::SSLContext {
ssl_ctx_.reset(SSL_CTX_new(SSLv23_client_method()));
session_cache_.Reset(ssl_ctx_.get(), kDefaultSessionCacheConfig);
SSL_CTX_set_cert_verify_callback(ssl_ctx_.get(), CertVerifyCallback, NULL);
- SSL_CTX_set_client_cert_cb(ssl_ctx_.get(), ClientCertCallback);
- SSL_CTX_set_channel_id_cb(ssl_ctx_.get(), ChannelIDCallback);
+ SSL_CTX_set_cert_cb(ssl_ctx_.get(), ClientCertRequestCallback, NULL);
SSL_CTX_set_verify(ssl_ctx_.get(), SSL_VERIFY_PEER, NULL);
// TODO(kristianm): Only select this if ssl_config_.next_proto is not empty.
// It would be better if the callback were not a global setting,
// but that is an OpenSSL issue.
SSL_CTX_set_next_proto_select_cb(ssl_ctx_.get(), SelectNextProtoCallback,
NULL);
+ ssl_ctx_->tlsext_channel_id_enabled_new = 1;
+
+ scoped_ptr<base::Environment> env(base::Environment::Create());
+ std::string ssl_keylog_file;
+ if (env->GetVar("SSLKEYLOGFILE", &ssl_keylog_file) &&
+ !ssl_keylog_file.empty()) {
+ crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE);
+ BIO* bio = BIO_new_file(ssl_keylog_file.c_str(), "a");
+ if (!bio) {
+ LOG(ERROR) << "Failed to open " << ssl_keylog_file;
+ ERR_print_errors_cb(&LogErrorCallback, NULL);
+ } else {
+ SSL_CTX_set_keylog_bio(ssl_ctx_.get(), bio);
+ }
+ }
}
static std::string GetSessionCacheKey(const SSL* ssl) {
SSLClientSocketOpenSSL* socket = GetInstance()->GetClientSocketFromSSL(ssl);
DCHECK(socket);
- return GetSocketSessionCacheKey(*socket);
+ return socket->GetSessionCacheKey();
}
static SSLSessionCacheOpenSSL::Config kDefaultSessionCacheConfig;
- static int ClientCertCallback(SSL* ssl, X509** x509, EVP_PKEY** pkey) {
- SSLClientSocketOpenSSL* socket = GetInstance()->GetClientSocketFromSSL(ssl);
- CHECK(socket);
- return socket->ClientCertRequestCallback(ssl, x509, pkey);
- }
-
- static void ChannelIDCallback(SSL* ssl, EVP_PKEY** pkey) {
+ static int ClientCertRequestCallback(SSL* ssl, void* arg) {
SSLClientSocketOpenSSL* socket = GetInstance()->GetClientSocketFromSSL(ssl);
- CHECK(socket);
- socket->ChannelIDRequestCallback(ssl, pkey);
+ DCHECK(socket);
+ return socket->ClientCertRequestCallback(ssl);
}
static int CertVerifyCallback(X509_STORE_CTX *store_ctx, void *arg) {
@@ -177,7 +229,7 @@ class SSLClientSocketOpenSSL::SSLContext {
// SSLClientSocketOpenSSL object from an SSL instance.
int ssl_socket_data_index_;
- crypto::ScopedOpenSSL<SSL_CTX, SSL_CTX_free> ssl_ctx_;
+ crypto::ScopedOpenSSL<SSL_CTX, SSL_CTX_free>::Type ssl_ctx_;
// |session_cache_| must be destroyed before |ssl_ctx_|.
SSLSessionCacheOpenSSL session_cache_;
};
@@ -200,7 +252,7 @@ class SSLClientSocketOpenSSL::PeerCertificateChain {
void Reset(STACK_OF(X509)* chain);
// Note that when USE_OPENSSL is defined, OSCertHandle is X509*
- const scoped_refptr<X509Certificate>& AsOSChain() const { return os_chain_; }
+ scoped_refptr<X509Certificate> AsOSChain() const;
size_t size() const {
if (!openssl_chain_.get())
@@ -208,23 +260,17 @@ class SSLClientSocketOpenSSL::PeerCertificateChain {
return sk_X509_num(openssl_chain_.get());
}
- X509* operator[](size_t index) const {
+ bool empty() const {
+ return size() == 0;
+ }
+
+ X509* Get(size_t index) const {
DCHECK_LT(index, size());
return sk_X509_value(openssl_chain_.get(), index);
}
- bool IsValid() { return os_chain_.get() && openssl_chain_.get(); }
-
private:
- static void FreeX509Stack(STACK_OF(X509)* cert_chain) {
- sk_X509_pop_free(cert_chain, X509_free);
- }
-
- friend class crypto::ScopedOpenSSL<STACK_OF(X509), FreeX509Stack>;
-
- crypto::ScopedOpenSSL<STACK_OF(X509), FreeX509Stack> openssl_chain_;
-
- scoped_refptr<X509Certificate> os_chain_;
+ ScopedX509Stack openssl_chain_;
};
SSLClientSocketOpenSSL::PeerCertificateChain&
@@ -233,85 +279,41 @@ SSLClientSocketOpenSSL::PeerCertificateChain::operator=(
if (this == &other)
return *this;
- // os_chain_ is reference counted by scoped_refptr;
- os_chain_ = other.os_chain_;
-
- // Must increase the reference count manually for sk_X509_dup
- openssl_chain_.reset(sk_X509_dup(other.openssl_chain_.get()));
- for (int i = 0; i < sk_X509_num(openssl_chain_.get()); ++i) {
- X509* x = sk_X509_value(openssl_chain_.get(), i);
- CRYPTO_add(&x->references, 1, CRYPTO_LOCK_X509);
- }
+ openssl_chain_.reset(X509_chain_up_ref(other.openssl_chain_.get()));
return *this;
}
-#if defined(USE_OPENSSL_CERTS)
-// When OSCertHandle is typedef'ed to X509, this implementation does a short cut
-// to avoid converting back and forth between der and X509 struct.
void SSLClientSocketOpenSSL::PeerCertificateChain::Reset(
STACK_OF(X509)* chain) {
- openssl_chain_.reset(NULL);
- os_chain_ = NULL;
-
- if (!chain)
- return;
+ openssl_chain_.reset(chain ? X509_chain_up_ref(chain) : NULL);
+}
+scoped_refptr<X509Certificate>
+SSLClientSocketOpenSSL::PeerCertificateChain::AsOSChain() const {
+#if defined(USE_OPENSSL_CERTS)
+ // When OSCertHandle is typedef'ed to X509, this implementation does a short
+ // cut to avoid converting back and forth between DER and the X509 struct.
X509Certificate::OSCertHandles intermediates;
- for (int i = 1; i < sk_X509_num(chain); ++i)
- intermediates.push_back(sk_X509_value(chain, i));
-
- os_chain_ =
- X509Certificate::CreateFromHandle(sk_X509_value(chain, 0), intermediates);
-
- // sk_X509_dup does not increase reference count on the certs in the stack.
- openssl_chain_.reset(sk_X509_dup(chain));
-
- std::vector<base::StringPiece> der_chain;
- for (int i = 0; i < sk_X509_num(openssl_chain_.get()); ++i) {
- X509* x = sk_X509_value(openssl_chain_.get(), i);
- // Increase the reference count for the certs in openssl_chain_.
- CRYPTO_add(&x->references, 1, CRYPTO_LOCK_X509);
+ for (size_t i = 1; i < sk_X509_num(openssl_chain_.get()); ++i) {
+ intermediates.push_back(sk_X509_value(openssl_chain_.get(), i));
}
-}
-#else // !defined(USE_OPENSSL_CERTS)
-void SSLClientSocketOpenSSL::PeerCertificateChain::Reset(
- STACK_OF(X509)* chain) {
- openssl_chain_.reset(NULL);
- os_chain_ = NULL;
-
- if (!chain)
- return;
-
- // sk_X509_dup does not increase reference count on the certs in the stack.
- openssl_chain_.reset(sk_X509_dup(chain));
+ return make_scoped_refptr(X509Certificate::CreateFromHandle(
+ sk_X509_value(openssl_chain_.get(), 0), intermediates));
+#else
+ // DER-encode the chain and convert to a platform certificate handle.
std::vector<base::StringPiece> der_chain;
- for (int i = 0; i < sk_X509_num(openssl_chain_.get()); ++i) {
+ for (size_t i = 0; i < sk_X509_num(openssl_chain_.get()); ++i) {
X509* x = sk_X509_value(openssl_chain_.get(), i);
-
- // Increase the reference count for the certs in openssl_chain_.
- CRYPTO_add(&x->references, 1, CRYPTO_LOCK_X509);
-
- unsigned char* cert_data = NULL;
- int cert_data_length = i2d_X509(x, &cert_data);
- if (cert_data_length && cert_data)
- der_chain.push_back(base::StringPiece(reinterpret_cast<char*>(cert_data),
- cert_data_length));
+ base::StringPiece der;
+ if (!x509_util::GetDER(x, &der))
+ return NULL;
+ der_chain.push_back(der);
}
- os_chain_ = X509Certificate::CreateFromDERCertChain(der_chain);
-
- for (size_t i = 0; i < der_chain.size(); ++i) {
- OPENSSL_free(const_cast<char*>(der_chain[i].data()));
- }
-
- if (der_chain.size() !=
- static_cast<size_t>(sk_X509_num(openssl_chain_.get()))) {
- openssl_chain_.reset(NULL);
- os_chain_ = NULL;
- }
+ return make_scoped_refptr(X509Certificate::CreateFromDERCertChain(der_chain));
+#endif
}
-#endif // defined(USE_OPENSSL_CERTS)
// static
SSLSessionCacheOpenSSL::Config
@@ -327,9 +329,6 @@ void SSLClientSocket::ClearSessionCache() {
SSLClientSocketOpenSSL::SSLContext* context =
SSLClientSocketOpenSSL::SSLContext::GetInstance();
context->session_cache()->Flush();
-#if defined(USE_OPENSSL_CERTS)
- OpenSSLClientKeyStore::GetInstance()->Flush();
-#endif
}
SSLClientSocketOpenSSL::SSLClientSocketOpenSSL(
@@ -339,16 +338,17 @@ SSLClientSocketOpenSSL::SSLClientSocketOpenSSL(
const SSLClientSocketContext& context)
: transport_send_busy_(false),
transport_recv_busy_(false),
- transport_recv_eof_(false),
- weak_factory_(this),
pending_read_error_(kNoPendingReadResult),
+ pending_read_ssl_error_(SSL_ERROR_NONE),
+ transport_read_error_(OK),
transport_write_error_(OK),
server_cert_chain_(new PeerCertificateChain(NULL)),
- completed_handshake_(false),
+ completed_connect_(false),
was_ever_used_(false),
client_auth_cert_needed_(false),
cert_verifier_(context.cert_verifier),
- server_bound_cert_service_(context.server_bound_cert_service),
+ cert_transparency_verifier_(context.cert_transparency_verifier),
+ channel_id_service_(context.channel_id_service),
ssl_(NULL),
transport_bio_(NULL),
transport_(transport_socket.Pass()),
@@ -358,14 +358,36 @@ SSLClientSocketOpenSSL::SSLClientSocketOpenSSL(
trying_cached_session_(false),
next_handshake_state_(STATE_NONE),
npn_status_(kNextProtoUnsupported),
- channel_id_request_return_value_(ERR_UNEXPECTED),
channel_id_xtn_negotiated_(false),
- net_log_(transport_->socket()->NetLog()) {}
+ handshake_succeeded_(false),
+ marked_session_as_good_(false),
+ transport_security_state_(context.transport_security_state),
+ net_log_(transport_->socket()->NetLog()),
+ weak_factory_(this) {
+}
SSLClientSocketOpenSSL::~SSLClientSocketOpenSSL() {
Disconnect();
}
+std::string SSLClientSocketOpenSSL::GetSessionCacheKey() const {
+ std::string result = host_and_port_.ToString();
+ result.append("/");
+ result.append(ssl_session_cache_shard_);
+ return result;
+}
+
+bool SSLClientSocketOpenSSL::InSessionCache() const {
+ SSLContext* context = SSLContext::GetInstance();
+ std::string cache_key = GetSessionCacheKey();
+ return context->session_cache()->SSLSessionIsInCache(cache_key);
+}
+
+void SSLClientSocketOpenSSL::SetHandshakeCompletionCallback(
+ const base::Closure& callback) {
+ handshake_completion_callback_ = callback;
+}
+
void SSLClientSocketOpenSSL::GetSSLCertRequestInfo(
SSLCertRequestInfo* cert_request_info) {
cert_request_info->host_and_port = host_and_port_;
@@ -374,15 +396,14 @@ void SSLClientSocketOpenSSL::GetSSLCertRequestInfo(
}
SSLClientSocket::NextProtoStatus SSLClientSocketOpenSSL::GetNextProto(
- std::string* proto, std::string* server_protos) {
+ std::string* proto) {
*proto = npn_proto_;
- *server_protos = server_protos_;
return npn_status_;
}
-ServerBoundCertService*
-SSLClientSocketOpenSSL::GetServerBoundCertService() const {
- return server_bound_cert_service_;
+ChannelIDService*
+SSLClientSocketOpenSSL::GetChannelIDService() const {
+ return channel_id_service_;
}
int SSLClientSocketOpenSSL::ExportKeyingMaterial(
@@ -412,6 +433,10 @@ int SSLClientSocketOpenSSL::GetTLSUniqueChannelBinding(std::string* out) {
}
int SSLClientSocketOpenSSL::Connect(const CompletionCallback& callback) {
+ // It is an error to create an SSLClientSocket whose context has no
+ // TransportSecurityState.
+ DCHECK(transport_security_state_);
+
net_log_.BeginEvent(NetLog::TYPE_SSL_CONNECT);
// Set up new ssl object.
@@ -430,12 +455,18 @@ int SSLClientSocketOpenSSL::Connect(const CompletionCallback& callback) {
user_connect_callback_ = callback;
} else {
net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, rv);
+ if (rv < OK)
+ OnHandshakeCompletion();
}
return rv > OK ? OK : rv;
}
void SSLClientSocketOpenSSL::Disconnect() {
+ // If a handshake was pending (Connect() had been called), notify interested
+ // parties that it's been aborted now. If the handshake had already
+ // completed, this is a no-op.
+ OnHandshakeCompletion();
if (ssl_) {
// Calling SSL_shutdown prevents the session from being marked as
// unresumable.
@@ -456,7 +487,6 @@ void SSLClientSocketOpenSSL::Disconnect() {
transport_send_busy_ = false;
send_buffer_ = NULL;
transport_recv_busy_ = false;
- transport_recv_eof_ = false;
recv_buffer_ = NULL;
user_connect_callback_.Reset();
@@ -468,19 +498,31 @@ void SSLClientSocketOpenSSL::Disconnect() {
user_write_buf_len_ = 0;
pending_read_error_ = kNoPendingReadResult;
+ pending_read_ssl_error_ = SSL_ERROR_NONE;
+ pending_read_error_info_ = OpenSSLErrorInfo();
+
+ transport_read_error_ = OK;
transport_write_error_ = OK;
server_cert_verify_result_.Reset();
- completed_handshake_ = false;
+ completed_connect_ = false;
cert_authorities_.clear();
cert_key_types_.clear();
client_auth_cert_needed_ = false;
+
+ start_cert_verification_time_ = base::TimeTicks();
+
+ npn_status_ = kNextProtoUnsupported;
+ npn_proto_.clear();
+
+ channel_id_xtn_negotiated_ = false;
+ channel_id_request_handle_.Cancel();
}
bool SSLClientSocketOpenSSL::IsConnected() const {
// If the handshake has not yet completed.
- if (!completed_handshake_)
+ if (!completed_connect_)
return false;
// If an asynchronous operation is still pending.
if (user_read_buf_.get() || user_write_buf_.get())
@@ -491,15 +533,15 @@ bool SSLClientSocketOpenSSL::IsConnected() const {
bool SSLClientSocketOpenSSL::IsConnectedAndIdle() const {
// If the handshake has not yet completed.
- if (!completed_handshake_)
+ if (!completed_connect_)
return false;
// If an asynchronous operation is still pending.
if (user_read_buf_.get() || user_write_buf_.get())
return false;
// If there is data waiting to be sent, or data read from the network that
// has not yet been consumed.
- if (BIO_ctrl_pending(transport_bio_) > 0 ||
- BIO_ctrl_wpending(transport_bio_) > 0) {
+ if (BIO_pending(transport_bio_) > 0 ||
+ BIO_wpending(transport_bio_) > 0) {
return false;
}
@@ -548,7 +590,7 @@ bool SSLClientSocketOpenSSL::UsingTCPFastOpen() const {
bool SSLClientSocketOpenSSL::GetSSLInfo(SSLInfo* ssl_info) {
ssl_info->Reset();
- if (!server_cert_.get())
+ if (server_cert_chain_->empty())
return false;
ssl_info->cert = server_cert_verify_result_.verified_cert;
@@ -560,27 +602,20 @@ bool SSLClientSocketOpenSSL::GetSSLInfo(SSLInfo* ssl_info) {
ssl_info->client_cert_sent =
ssl_config_.send_client_cert && ssl_config_.client_cert.get();
ssl_info->channel_id_sent = WasChannelIDSent();
+ ssl_info->pinning_failure_log = pinning_failure_log_;
- RecordChannelIDSupport(server_bound_cert_service_,
- channel_id_xtn_negotiated_,
- ssl_config_.channel_id_enabled,
- crypto::ECPrivateKey::IsSupported());
+ AddSCTInfoToSSLInfo(ssl_info);
const SSL_CIPHER* cipher = SSL_get_current_cipher(ssl_);
CHECK(cipher);
ssl_info->security_bits = SSL_CIPHER_get_bits(cipher, NULL);
- const COMP_METHOD* compression = SSL_get_current_compression(ssl_);
ssl_info->connection_status = EncodeSSLConnectionStatus(
- SSL_CIPHER_get_id(cipher),
- compression ? compression->type : 0,
+ SSL_CIPHER_get_id(cipher), 0 /* no compression */,
GetNetSSLVersion(ssl_));
- bool peer_supports_renego_ext = !!SSL_get_secure_renegotiation_support(ssl_);
- if (!peer_supports_renego_ext)
+ if (!SSL_get_secure_renegotiation_support(ssl_))
ssl_info->connection_status |= SSL_CONNECTION_NO_RENEGOTIATION_EXTENSION;
- UMA_HISTOGRAM_ENUMERATION("Net.RenegotiationExtensionSupported",
- implicit_cast<int>(peer_supports_renego_ext), 2);
if (ssl_config_.version_fallback)
ssl_info->connection_status |= SSL_CONNECTION_VERSION_FALLBACK;
@@ -601,7 +636,7 @@ int SSLClientSocketOpenSSL::Read(IOBuffer* buf,
user_read_buf_ = buf;
user_read_buf_len_ = buf_len;
- int rv = DoReadLoop(OK);
+ int rv = DoReadLoop();
if (rv == ERR_IO_PENDING) {
user_read_callback_ = callback;
@@ -610,6 +645,11 @@ int SSLClientSocketOpenSSL::Read(IOBuffer* buf,
was_ever_used_ = true;
user_read_buf_ = NULL;
user_read_buf_len_ = 0;
+ if (rv <= 0) {
+ // Failure of a read attempt may indicate a failed false start
+ // connection.
+ OnHandshakeCompletion();
+ }
}
return rv;
@@ -621,7 +661,7 @@ int SSLClientSocketOpenSSL::Write(IOBuffer* buf,
user_write_buf_ = buf;
user_write_buf_len_ = buf_len;
- int rv = DoWriteLoop(OK);
+ int rv = DoWriteLoop();
if (rv == ERR_IO_PENDING) {
user_write_callback_ = callback;
@@ -630,6 +670,11 @@ int SSLClientSocketOpenSSL::Write(IOBuffer* buf,
was_ever_used_ = true;
user_write_buf_ = NULL;
user_write_buf_len_ = 0;
+ if (rv < 0) {
+ // Failure of a write attempt may indicate a failed false start
+ // connection.
+ OnHandshakeCompletion();
+ }
}
return rv;
@@ -657,8 +702,11 @@ int SSLClientSocketOpenSSL::Init() {
if (!SSL_set_tlsext_host_name(ssl_, host_and_port_.host().c_str()))
return ERR_UNEXPECTED;
+ // Set an OpenSSL callback to monitor this SSL*'s connection.
+ SSL_set_info_callback(ssl_, &InfoCallback);
+
trying_cached_session_ = context->session_cache()->SetSSLSessionWithKey(
- ssl_, GetSocketSessionCacheKey(*this));
+ ssl_, GetSessionCacheKey());
BIO* ssl_bio = NULL;
// 0 => use default buffer sizes.
@@ -667,6 +715,10 @@ int SSLClientSocketOpenSSL::Init() {
DCHECK(ssl_bio);
DCHECK(transport_bio_);
+ // Install a callback on OpenSSL's end to plumb transport errors through.
+ BIO_set_callback(ssl_bio, BIOCallback);
+ BIO_set_callback_arg(ssl_bio, reinterpret_cast<char*>(this));
+
SSL_set_bio(ssl_, ssl_bio, ssl_bio);
// OpenSSL defaults some options to on, others to off. To avoid ambiguity,
@@ -699,6 +751,7 @@ int SSLClientSocketOpenSSL::Init() {
SslSetClearMask mode;
mode.ConfigureFlag(SSL_MODE_RELEASE_BUFFERS, true);
+ mode.ConfigureFlag(SSL_MODE_CBC_RECORD_SPLITTING, true);
mode.ConfigureFlag(SSL_MODE_HANDSHAKE_CUTTHROUGH,
ssl_config_.false_start_enabled);
@@ -719,7 +772,7 @@ int SSLClientSocketOpenSSL::Init() {
"!aECDH:!AESGCM+AES256");
// Walk through all the installed ciphers, seeing if any need to be
// appended to the cipher removal |command|.
- for (int i = 0; i < sk_SSL_CIPHER_num(ciphers); ++i) {
+ for (size_t i = 0; i < sk_SSL_CIPHER_num(ciphers); ++i) {
const SSL_CIPHER* cipher = sk_SSL_CIPHER_value(ciphers, i);
const uint16 id = SSL_CIPHER_get_id(cipher);
// Remove any ciphers with a strength of less than 80 bits. Note the NSS
@@ -740,6 +793,15 @@ int SSLClientSocketOpenSSL::Init() {
command.append(name);
}
}
+
+ // Disable ECDSA cipher suites on platforms that do not support ECDSA
+ // signed certificates, as servers may use the presence of such
+ // ciphersuites as a hint to send an ECDSA certificate.
+#if defined(OS_WIN)
+ if (base::win::GetVersion() < base::win::VERSION_VISTA)
+ command.append(":!ECDSA");
+#endif
+
int rv = SSL_set_cipher_list(ssl_, command.c_str());
// If this fails (rv = 0) it means there are no ciphers enabled on this SSL.
// This will almost certainly result in the socket failing to complete the
@@ -747,11 +809,29 @@ int SSLClientSocketOpenSSL::Init() {
LOG_IF(WARNING, rv != 1) << "SSL_set_cipher_list('" << command << "') "
"returned " << rv;
+ if (ssl_config_.version_fallback)
+ SSL_enable_fallback_scsv(ssl_);
+
// TLS channel ids.
- if (IsChannelIDEnabled(ssl_config_, server_bound_cert_service_)) {
+ if (IsChannelIDEnabled(ssl_config_, channel_id_service_)) {
SSL_enable_tls_channel_id(ssl_);
}
+ if (!ssl_config_.next_protos.empty()) {
+ std::vector<uint8_t> wire_protos =
+ SerializeNextProtos(ssl_config_.next_protos);
+ SSL_set_alpn_protos(ssl_, wire_protos.empty() ? NULL : &wire_protos[0],
+ wire_protos.size());
+ }
+
+ if (ssl_config_.signed_cert_timestamps_enabled) {
+ SSL_enable_signed_cert_timestamps(ssl_);
+ SSL_enable_ocsp_stapling(ssl_);
+ }
+
+ // TODO(davidben): Enable OCSP stapling on platforms which support it and pass
+ // into the certificate verifier. https://crbug.com/398677
+
return OK;
}
@@ -762,6 +842,11 @@ void SSLClientSocketOpenSSL::DoReadCallback(int rv) {
was_ever_used_ = true;
user_read_buf_ = NULL;
user_read_buf_len_ = 0;
+ if (rv <= 0) {
+ // Failure of a read attempt may indicate a failed false start
+ // connection.
+ OnHandshakeCompletion();
+ }
base::ResetAndReturn(&user_read_callback_).Run(rv);
}
@@ -772,9 +857,19 @@ void SSLClientSocketOpenSSL::DoWriteCallback(int rv) {
was_ever_used_ = true;
user_write_buf_ = NULL;
user_write_buf_len_ = 0;
+ if (rv < 0) {
+ // Failure of a write attempt may indicate a failed false start
+ // connection.
+ OnHandshakeCompletion();
+ }
base::ResetAndReturn(&user_write_callback_).Run(rv);
}
+void SSLClientSocketOpenSSL::OnHandshakeCompletion() {
+ if (!handshake_completion_callback_.is_null())
+ base::ResetAndReturn(&handshake_completion_callback_).Run();
+}
+
bool SSLClientSocketOpenSSL::DoTransportIO() {
bool network_moved = false;
int rv;
@@ -785,7 +880,7 @@ bool SSLClientSocketOpenSSL::DoTransportIO() {
if (rv != ERR_IO_PENDING && rv != 0)
network_moved = true;
} while (rv > 0);
- if (!transport_recv_eof_ && BufferRecv() != ERR_IO_PENDING)
+ if (transport_read_error_ == OK && BufferRecv() != ERR_IO_PENDING)
network_moved = true;
return network_moved;
}
@@ -815,25 +910,56 @@ int SSLClientSocketOpenSSL::DoHandshake() {
DVLOG(2) << "Result of session reuse for " << host_and_port_.ToString()
<< " is: " << (SSL_session_reused(ssl_) ? "Success" : "Fail");
}
- // SSL handshake is completed. Let's verify the certificate.
- const bool got_cert = !!UpdateServerCert();
- DCHECK(got_cert);
- net_log_.AddEvent(
- NetLog::TYPE_SSL_CERTIFICATES_RECEIVED,
- base::Bind(&NetLogX509CertificateCallback,
- base::Unretained(server_cert_.get())));
+
+ if (ssl_config_.version_fallback &&
+ ssl_config_.version_max < ssl_config_.version_fallback_min) {
+ return ERR_SSL_FALLBACK_BEYOND_MINIMUM_VERSION;
+ }
+
+ // SSL handshake is completed. If NPN wasn't negotiated, see if ALPN was.
+ if (npn_status_ == kNextProtoUnsupported) {
+ const uint8_t* alpn_proto = NULL;
+ unsigned alpn_len = 0;
+ SSL_get0_alpn_selected(ssl_, &alpn_proto, &alpn_len);
+ if (alpn_len > 0) {
+ npn_proto_.assign(reinterpret_cast<const char*>(alpn_proto), alpn_len);
+ npn_status_ = kNextProtoNegotiated;
+ set_negotiation_extension(kExtensionALPN);
+ }
+ }
+
+ RecordChannelIDSupport(channel_id_service_,
+ channel_id_xtn_negotiated_,
+ ssl_config_.channel_id_enabled,
+ crypto::ECPrivateKey::IsSupported());
+
+ uint8_t* ocsp_response;
+ size_t ocsp_response_len;
+ SSL_get0_ocsp_response(ssl_, &ocsp_response, &ocsp_response_len);
+ set_stapled_ocsp_response_received(ocsp_response_len != 0);
+
+ uint8_t* sct_list;
+ size_t sct_list_len;
+ SSL_get0_signed_cert_timestamp_list(ssl_, &sct_list, &sct_list_len);
+ set_signed_cert_timestamps_received(sct_list_len != 0);
+
+ // Verify the certificate.
+ UpdateServerCert();
GotoState(STATE_VERIFY_CERT);
} else {
int ssl_error = SSL_get_error(ssl_, rv);
if (ssl_error == SSL_ERROR_WANT_CHANNEL_ID_LOOKUP) {
- // The server supports TLS channel id and the lookup is asynchronous.
- // Retrieve the error from the call to |server_bound_cert_service_|.
- net_error = channel_id_request_return_value_;
- } else {
- net_error = MapOpenSSLError(ssl_error, err_tracer);
+ // The server supports channel ID. Stop to look one up before returning to
+ // the handshake.
+ channel_id_xtn_negotiated_ = true;
+ GotoState(STATE_CHANNEL_ID_LOOKUP);
+ return OK;
}
+ OpenSSLErrorInfo error_info;
+ net_error = MapOpenSSLErrorWithDetails(ssl_error, err_tracer, &error_info);
+
// If not done, stay in this state
if (net_error == ERR_IO_PENDING) {
GotoState(STATE_HANDSHAKE);
@@ -843,18 +969,78 @@ int SSLClientSocketOpenSSL::DoHandshake() {
<< ", net_error " << net_error;
net_log_.AddEvent(
NetLog::TYPE_SSL_HANDSHAKE_ERROR,
- CreateNetLogSSLErrorCallback(net_error, ssl_error));
+ CreateNetLogOpenSSLErrorCallback(net_error, ssl_error, error_info));
}
}
return net_error;
}
+int SSLClientSocketOpenSSL::DoChannelIDLookup() {
+ GotoState(STATE_CHANNEL_ID_LOOKUP_COMPLETE);
+ return channel_id_service_->GetOrCreateChannelID(
+ host_and_port_.host(),
+ &channel_id_private_key_,
+ &channel_id_cert_,
+ base::Bind(&SSLClientSocketOpenSSL::OnHandshakeIOComplete,
+ base::Unretained(this)),
+ &channel_id_request_handle_);
+}
+
+int SSLClientSocketOpenSSL::DoChannelIDLookupComplete(int result) {
+ if (result < 0)
+ return result;
+
+ DCHECK_LT(0u, channel_id_private_key_.size());
+ // Decode key.
+ std::vector<uint8> encrypted_private_key_info;
+ std::vector<uint8> subject_public_key_info;
+ encrypted_private_key_info.assign(
+ channel_id_private_key_.data(),
+ channel_id_private_key_.data() + channel_id_private_key_.size());
+ subject_public_key_info.assign(
+ channel_id_cert_.data(),
+ channel_id_cert_.data() + channel_id_cert_.size());
+ scoped_ptr<crypto::ECPrivateKey> ec_private_key(
+ crypto::ECPrivateKey::CreateFromEncryptedPrivateKeyInfo(
+ ChannelIDService::kEPKIPassword,
+ encrypted_private_key_info,
+ subject_public_key_info));
+ if (!ec_private_key) {
+ LOG(ERROR) << "Failed to import Channel ID.";
+ return ERR_CHANNEL_ID_IMPORT_FAILED;
+ }
+
+ // Hand the key to OpenSSL. Check for error in case OpenSSL rejects the key
+ // type.
+ crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE);
+ int rv = SSL_set1_tls_channel_id(ssl_, ec_private_key->key());
+ if (!rv) {
+ LOG(ERROR) << "Failed to set Channel ID.";
+ int err = SSL_get_error(ssl_, rv);
+ return MapOpenSSLError(err, err_tracer);
+ }
+
+ // Return to the handshake.
+ set_channel_id_sent(true);
+ GotoState(STATE_HANDSHAKE);
+ return OK;
+}
+
int SSLClientSocketOpenSSL::DoVerifyCert(int result) {
- DCHECK(server_cert_.get());
+ DCHECK(!server_cert_chain_->empty());
+ DCHECK(start_cert_verification_time_.is_null());
+
GotoState(STATE_VERIFY_CERT_COMPLETE);
+ // If the certificate is bad and has been previously accepted, use
+ // the previous status and bypass the error.
+ base::StringPiece der_cert;
+ if (!x509_util::GetDER(server_cert_chain_->Get(0), &der_cert)) {
+ NOTREACHED();
+ return ERR_CERT_INVALID;
+ }
CertStatus cert_status;
- if (ssl_config_.IsAllowedBadCert(server_cert_.get(), &cert_status)) {
+ if (ssl_config_.IsAllowedBadCert(der_cert, &cert_status)) {
VLOG(1) << "Received an expected bad cert with status: " << cert_status;
server_cert_verify_result_.Reset();
server_cert_verify_result_.cert_status = cert_status;
@@ -862,6 +1048,17 @@ int SSLClientSocketOpenSSL::DoVerifyCert(int result) {
return OK;
}
+ // When running in a sandbox, it may not be possible to create an
+ // X509Certificate*, as that may depend on OS functionality blocked
+ // in the sandbox.
+ if (!server_cert_.get()) {
+ server_cert_verify_result_.Reset();
+ server_cert_verify_result_.cert_status = CERT_STATUS_INVALID;
+ return ERR_CERT_INVALID;
+ }
+
+ start_cert_verification_time_ = base::TimeTicks::Now();
+
int flags = 0;
if (ssl_config_.rev_checking_enabled)
flags |= CertVerifier::VERIFY_REV_CHECKING_ENABLED;
@@ -876,7 +1073,9 @@ int SSLClientSocketOpenSSL::DoVerifyCert(int result) {
server_cert_.get(),
host_and_port_.host(),
flags,
- NULL /* no CRL set */,
+ // TODO(davidben): Route the CRLSet through SSLConfig so
+ // SSLClientSocket doesn't depend on SSLConfigService.
+ SSLConfigService::GetCRLSet().get(),
&server_cert_verify_result_,
base::Bind(&SSLClientSocketOpenSSL::OnHandshakeIOComplete,
base::Unretained(this)),
@@ -886,22 +1085,71 @@ int SSLClientSocketOpenSSL::DoVerifyCert(int result) {
int SSLClientSocketOpenSSL::DoVerifyCertComplete(int result) {
verifier_.reset();
+ if (!start_cert_verification_time_.is_null()) {
+ base::TimeDelta verify_time =
+ base::TimeTicks::Now() - start_cert_verification_time_;
+ if (result == OK) {
+ UMA_HISTOGRAM_TIMES("Net.SSLCertVerificationTime", verify_time);
+ } else {
+ UMA_HISTOGRAM_TIMES("Net.SSLCertVerificationTimeError", verify_time);
+ }
+ }
+
+ if (result == OK)
+ RecordConnectionTypeMetrics(GetNetSSLVersion(ssl_));
+
+ const CertStatus cert_status = server_cert_verify_result_.cert_status;
+ if (transport_security_state_ &&
+ (result == OK ||
+ (IsCertificateError(result) && IsCertStatusMinorError(cert_status))) &&
+ !transport_security_state_->CheckPublicKeyPins(
+ host_and_port_.host(),
+ server_cert_verify_result_.is_issued_by_known_root,
+ server_cert_verify_result_.public_key_hashes,
+ &pinning_failure_log_)) {
+ result = ERR_SSL_PINNED_KEY_NOT_IN_CERT_CHAIN;
+ }
+
+ scoped_refptr<ct::EVCertsWhitelist> ev_whitelist =
+ SSLConfigService::GetEVCertsWhitelist();
+ if (server_cert_verify_result_.cert_status & CERT_STATUS_IS_EV) {
+ if (ev_whitelist.get() && ev_whitelist->IsValid()) {
+ const SHA256HashValue fingerprint(
+ X509Certificate::CalculateFingerprint256(
+ server_cert_verify_result_.verified_cert->os_cert_handle()));
+
+ UMA_HISTOGRAM_BOOLEAN(
+ "Net.SSL_EVCertificateInWhitelist",
+ ev_whitelist->ContainsCertificateHash(
+ std::string(reinterpret_cast<const char*>(fingerprint.data), 8)));
+ }
+ }
+
if (result == OK) {
+ // Only check Certificate Transparency if there were no other errors with
+ // the connection.
+ VerifyCT();
+
// TODO(joth): Work out if we need to remember the intermediate CA certs
// when the server sends them to us, and do so here.
SSLContext::GetInstance()->session_cache()->MarkSSLSessionAsGood(ssl_);
+ marked_session_as_good_ = true;
+ CheckIfHandshakeFinished();
} else {
DVLOG(1) << "DoVerifyCertComplete error " << ErrorToString(result)
<< " (" << result << ")";
}
- completed_handshake_ = true;
+ completed_connect_ = true;
+
// Exit DoHandshakeLoop and return the result to the caller to Connect.
DCHECK_EQ(STATE_NONE, next_handshake_state_);
return result;
}
void SSLClientSocketOpenSSL::DoConnectCallback(int rv) {
+ if (rv < OK)
+ OnHandshakeCompletion();
if (!user_connect_callback_.is_null()) {
CompletionCallback c = user_connect_callback_;
user_connect_callback_.Reset();
@@ -909,14 +1157,50 @@ void SSLClientSocketOpenSSL::DoConnectCallback(int rv) {
}
}
-X509Certificate* SSLClientSocketOpenSSL::UpdateServerCert() {
+void SSLClientSocketOpenSSL::UpdateServerCert() {
server_cert_chain_->Reset(SSL_get_peer_cert_chain(ssl_));
server_cert_ = server_cert_chain_->AsOSChain();
- if (!server_cert_chain_->IsValid())
- DVLOG(1) << "UpdateServerCert received invalid certificate chain from peer";
+ if (server_cert_.get()) {
+ net_log_.AddEvent(
+ NetLog::TYPE_SSL_CERTIFICATES_RECEIVED,
+ base::Bind(&NetLogX509CertificateCallback,
+ base::Unretained(server_cert_.get())));
+ }
+}
+
+void SSLClientSocketOpenSSL::VerifyCT() {
+ if (!cert_transparency_verifier_)
+ return;
+
+ uint8_t* ocsp_response_raw;
+ size_t ocsp_response_len;
+ SSL_get0_ocsp_response(ssl_, &ocsp_response_raw, &ocsp_response_len);
+ std::string ocsp_response;
+ if (ocsp_response_len > 0) {
+ ocsp_response.assign(reinterpret_cast<const char*>(ocsp_response_raw),
+ ocsp_response_len);
+ }
- return server_cert_.get();
+ uint8_t* sct_list_raw;
+ size_t sct_list_len;
+ SSL_get0_signed_cert_timestamp_list(ssl_, &sct_list_raw, &sct_list_len);
+ std::string sct_list;
+ if (sct_list_len > 0)
+ sct_list.assign(reinterpret_cast<const char*>(sct_list_raw), sct_list_len);
+
+ // Note that this is a completely synchronous operation: The CT Log Verifier
+ // gets all the data it needs for SCT verification and does not do any
+ // external communication.
+ int result = cert_transparency_verifier_->Verify(
+ server_cert_verify_result_.verified_cert.get(),
+ ocsp_response, sct_list, &ct_verify_result_, net_log_);
+
+ VLOG(1) << "CT Verification complete: result " << result
+ << " Invalid scts: " << ct_verify_result_.invalid_scts.size()
+ << " Verified scts: " << ct_verify_result_.verified_scts.size()
+ << " scts from unknown logs: "
+ << ct_verify_result_.unknown_logs_scts.size();
}
void SSLClientSocketOpenSSL::OnHandshakeIOComplete(int result) {
@@ -974,7 +1258,7 @@ void SSLClientSocketOpenSSL::OnRecvComplete(int result) {
if (!user_read_buf_.get())
return;
- int rv = DoReadLoop(result);
+ int rv = DoReadLoop();
if (rv != ERR_IO_PENDING)
DoReadCallback(rv);
}
@@ -993,8 +1277,15 @@ int SSLClientSocketOpenSSL::DoHandshakeLoop(int last_io_result) {
case STATE_HANDSHAKE:
rv = DoHandshake();
break;
+ case STATE_CHANNEL_ID_LOOKUP:
+ DCHECK_EQ(OK, rv);
+ rv = DoChannelIDLookup();
+ break;
+ case STATE_CHANNEL_ID_LOOKUP_COMPLETE:
+ rv = DoChannelIDLookupComplete(rv);
+ break;
case STATE_VERIFY_CERT:
- DCHECK(rv == OK);
+ DCHECK_EQ(OK, rv);
rv = DoVerifyCert(rv);
break;
case STATE_VERIFY_CERT_COMPLETE:
@@ -1015,13 +1306,11 @@ int SSLClientSocketOpenSSL::DoHandshakeLoop(int last_io_result) {
rv = OK; // This causes us to stay in the loop.
}
} while (rv != ERR_IO_PENDING && next_handshake_state_ != STATE_NONE);
+
return rv;
}
-int SSLClientSocketOpenSSL::DoReadLoop(int result) {
- if (result < 0)
- return result;
-
+int SSLClientSocketOpenSSL::DoReadLoop() {
bool network_moved;
int rv;
do {
@@ -1032,10 +1321,7 @@ int SSLClientSocketOpenSSL::DoReadLoop(int result) {
return rv;
}
-int SSLClientSocketOpenSSL::DoWriteLoop(int result) {
- if (result < 0)
- return result;
-
+int SSLClientSocketOpenSSL::DoWriteLoop() {
bool network_moved;
int rv;
do {
@@ -1056,7 +1342,14 @@ int SSLClientSocketOpenSSL::DoPayloadRead() {
if (rv == 0) {
net_log_.AddByteTransferEvent(NetLog::TYPE_SSL_SOCKET_BYTES_RECEIVED,
rv, user_read_buf_->data());
+ } else {
+ net_log_.AddEvent(
+ NetLog::TYPE_SSL_READ_ERROR,
+ CreateNetLogOpenSSLErrorCallback(rv, pending_read_ssl_error_,
+ pending_read_error_info_));
}
+ pending_read_ssl_error_ = SSL_ERROR_NONE;
+ pending_read_error_info_ = OpenSSLErrorInfo();
return rv;
}
@@ -1091,8 +1384,19 @@ int SSLClientSocketOpenSSL::DoPayloadRead() {
if (client_auth_cert_needed_) {
*next_result = ERR_SSL_CLIENT_AUTH_CERT_NEEDED;
} else if (*next_result < 0) {
- int err = SSL_get_error(ssl_, *next_result);
- *next_result = MapOpenSSLError(err, err_tracer);
+ pending_read_ssl_error_ = SSL_get_error(ssl_, *next_result);
+ *next_result = MapOpenSSLErrorWithDetails(pending_read_ssl_error_,
+ err_tracer,
+ &pending_read_error_info_);
+
+ // Many servers do not reliably send a close_notify alert when shutting
+ // down a connection, and instead terminate the TCP connection. This is
+ // reported as ERR_CONNECTION_CLOSED. Because of this, map the unclean
+ // shutdown to a graceful EOF, instead of treating it as an error as it
+ // should be.
+ if (*next_result == ERR_CONNECTION_CLOSED)
+ *next_result = 0;
+
if (rv > 0 && *next_result == ERR_IO_PENDING) {
// If at least some data was read from SSL_read(), do not treat
// insufficient data as an error to return in the next call to
@@ -1109,6 +1413,13 @@ int SSLClientSocketOpenSSL::DoPayloadRead() {
if (rv >= 0) {
net_log_.AddByteTransferEvent(NetLog::TYPE_SSL_SOCKET_BYTES_RECEIVED, rv,
user_read_buf_->data());
+ } else if (rv != ERR_IO_PENDING) {
+ net_log_.AddEvent(
+ NetLog::TYPE_SSL_READ_ERROR,
+ CreateNetLogOpenSSLErrorCallback(rv, pending_read_ssl_error_,
+ pending_read_error_info_));
+ pending_read_ssl_error_ = SSL_ERROR_NONE;
+ pending_read_error_info_ = OpenSSLErrorInfo();
}
return rv;
}
@@ -1116,15 +1427,23 @@ int SSLClientSocketOpenSSL::DoPayloadRead() {
int SSLClientSocketOpenSSL::DoPayloadWrite() {
crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE);
int rv = SSL_write(ssl_, user_write_buf_->data(), user_write_buf_len_);
-
if (rv >= 0) {
net_log_.AddByteTransferEvent(NetLog::TYPE_SSL_SOCKET_BYTES_SENT, rv,
user_write_buf_->data());
return rv;
}
- int err = SSL_get_error(ssl_, rv);
- return MapOpenSSLError(err, err_tracer);
+ int ssl_error = SSL_get_error(ssl_, rv);
+ OpenSSLErrorInfo error_info;
+ int net_error = MapOpenSSLErrorWithDetails(ssl_error, err_tracer,
+ &error_info);
+
+ if (net_error != ERR_IO_PENDING) {
+ net_log_.AddEvent(
+ NetLog::TYPE_SSL_WRITE_ERROR,
+ CreateNetLogOpenSSLErrorCallback(net_error, ssl_error, error_info));
+ }
+ return net_error;
}
int SSLClientSocketOpenSSL::BufferSend(void) {
@@ -1133,7 +1452,7 @@ int SSLClientSocketOpenSSL::BufferSend(void) {
if (!send_buffer_.get()) {
// Get a fresh send buffer out of the send BIO.
- size_t max_read = BIO_ctrl_pending(transport_bio_);
+ size_t max_read = BIO_pending(transport_bio_);
if (!max_read)
return 0; // Nothing pending in the OpenSSL write BIO.
send_buffer_ = new DrainableIOBuffer(new IOBuffer(max_read), max_read);
@@ -1209,20 +1528,9 @@ void SSLClientSocketOpenSSL::BufferRecvComplete(int result) {
void SSLClientSocketOpenSSL::TransportWriteComplete(int result) {
DCHECK(ERR_IO_PENDING != result);
if (result < 0) {
- // Got a socket write error; close the BIO to indicate this upward.
- //
- // TODO(davidben): The value of |result| gets lost. Feed the error back into
- // the BIO so it gets (re-)detected in OnSendComplete. Perhaps with
- // BIO_set_callback.
- DVLOG(1) << "TransportWriteComplete error " << result;
- (void)BIO_shutdown_wr(SSL_get_wbio(ssl_));
-
- // Match the fix for http://crbug.com/249848 in NSS by erroring future reads
- // from the socket after a write error.
- //
- // TODO(davidben): Avoid having read and write ends interact this way.
+ // Record the error. Save it to be reported in a future read or write on
+ // transport_bio_'s peer.
transport_write_error_ = result;
- (void)BIO_shutdown_wr(transport_bio_);
send_buffer_ = NULL;
} else {
DCHECK(send_buffer_.get());
@@ -1235,19 +1543,15 @@ void SSLClientSocketOpenSSL::TransportWriteComplete(int result) {
int SSLClientSocketOpenSSL::TransportReadComplete(int result) {
DCHECK(ERR_IO_PENDING != result);
- if (result <= 0) {
+ // If an EOF, canonicalize to ERR_CONNECTION_CLOSED here so MapOpenSSLError
+ // does not report success.
+ if (result == 0)
+ result = ERR_CONNECTION_CLOSED;
+ if (result < 0) {
DVLOG(1) << "TransportReadComplete result " << result;
- // Received 0 (end of file) or an error. Either way, bubble it up to the
- // SSL layer via the BIO. TODO(joth): consider stashing the error code, to
- // relay up to the SSL socket client (i.e. via DoReadCallback).
- if (result == 0)
- transport_recv_eof_ = true;
- (void)BIO_shutdown_wr(transport_bio_);
- } else if (transport_write_error_ < 0) {
- // Mirror transport write errors as read failures; transport_bio_ has been
- // shut down by TransportWriteComplete, so the BIO_write will fail, failing
- // the CHECK. http://crbug.com/335557.
- result = transport_write_error_;
+ // Received an error. Save it to be reported in a future read on
+ // transport_bio_'s peer.
+ transport_read_error_ = result;
} else {
DCHECK(recv_buffer_.get());
int ret = BIO_write(transport_bio_, recv_buffer_->data(), result);
@@ -1259,19 +1563,23 @@ int SSLClientSocketOpenSSL::TransportReadComplete(int result) {
return result;
}
-int SSLClientSocketOpenSSL::ClientCertRequestCallback(SSL* ssl,
- X509** x509,
- EVP_PKEY** pkey) {
+int SSLClientSocketOpenSSL::ClientCertRequestCallback(SSL* ssl) {
DVLOG(3) << "OpenSSL ClientCertRequestCallback called";
DCHECK(ssl == ssl_);
- DCHECK(*x509 == NULL);
- DCHECK(*pkey == NULL);
+
+ // Clear any currently configured certificates.
+ SSL_certs_clear(ssl_);
+
+#if defined(OS_IOS)
+ // TODO(droger): Support client auth on iOS. See http://crbug.com/145954).
+ LOG(WARNING) << "Client auth is not supported";
+#else // !defined(OS_IOS)
if (!ssl_config_.send_client_cert) {
// First pass: we know that a client certificate is needed, but we do not
// have one at hand.
client_auth_cert_needed_ = true;
STACK_OF(X509_NAME) *authorities = SSL_get_client_CA_list(ssl);
- for (int i = 0; i < sk_X509_NAME_num(authorities); i++) {
+ for (size_t i = 0; i < sk_X509_NAME_num(authorities); i++) {
X509_NAME *ca_name = (X509_NAME *)sk_X509_NAME_value(authorities, i);
unsigned char* str = NULL;
int length = i2d_X509_NAME(ca_name, &str);
@@ -1282,9 +1590,8 @@ int SSLClientSocketOpenSSL::ClientCertRequestCallback(SSL* ssl,
}
const unsigned char* client_cert_types;
- size_t num_client_cert_types;
- SSL_get_client_certificate_types(ssl, &client_cert_types,
- &num_client_cert_types);
+ size_t num_client_cert_types =
+ SSL_get0_certificate_types(ssl, &client_cert_types);
for (size_t i = 0; i < num_client_cert_types; i++) {
cert_key_types_.push_back(
static_cast<SSLClientCertType>(client_cert_types[i]));
@@ -1295,90 +1602,81 @@ int SSLClientSocketOpenSSL::ClientCertRequestCallback(SSL* ssl,
// Second pass: a client certificate should have been selected.
if (ssl_config_.client_cert.get()) {
-#if defined(USE_OPENSSL_CERTS)
- // A note about ownership: FetchClientCertPrivateKey() increments
- // the reference count of the EVP_PKEY. Ownership of this reference
- // is passed directly to OpenSSL, which will release the reference
- // using EVP_PKEY_free() when the SSL object is destroyed.
- OpenSSLClientKeyStore::ScopedEVP_PKEY privkey;
- if (OpenSSLClientKeyStore::GetInstance()->FetchClientCertPrivateKey(
- ssl_config_.client_cert.get(), &privkey)) {
- // TODO(joth): (copied from NSS) We should wait for server certificate
- // verification before sending our credentials. See http://crbug.com/13934
- *x509 = X509Certificate::DupOSCertHandle(
- ssl_config_.client_cert->os_cert_handle());
- *pkey = privkey.release();
- return 1;
+ ScopedX509 leaf_x509 =
+ OSCertHandleToOpenSSL(ssl_config_.client_cert->os_cert_handle());
+ if (!leaf_x509) {
+ LOG(WARNING) << "Failed to import certificate";
+ OpenSSLPutNetError(FROM_HERE, ERR_SSL_CLIENT_AUTH_CERT_BAD_FORMAT);
+ return -1;
}
- LOG(WARNING) << "Client cert found without private key";
-#else // !defined(USE_OPENSSL_CERTS)
- // OS handling of client certificates is not yet implemented.
- NOTIMPLEMENTED();
-#endif // defined(USE_OPENSSL_CERTS)
- }
- // Send no client certificate.
- return 0;
-}
+ ScopedX509Stack chain = OSCertHandlesToOpenSSL(
+ ssl_config_.client_cert->GetIntermediateCertificates());
+ if (!chain) {
+ LOG(WARNING) << "Failed to import intermediate certificates";
+ OpenSSLPutNetError(FROM_HERE, ERR_SSL_CLIENT_AUTH_CERT_BAD_FORMAT);
+ return -1;
+ }
-void SSLClientSocketOpenSSL::ChannelIDRequestCallback(SSL* ssl,
- EVP_PKEY** pkey) {
- DVLOG(3) << "OpenSSL ChannelIDRequestCallback called";
- DCHECK_EQ(ssl, ssl_);
- DCHECK(!*pkey);
+ // TODO(davidben): With Linux client auth support, this should be
+ // conditioned on OS_ANDROID and then, with https://crbug.com/394131,
+ // removed altogether. OpenSSLClientKeyStore is mostly an artifact of the
+ // net/ client auth API lacking a private key handle.
+#if defined(USE_OPENSSL_CERTS)
+ crypto::ScopedEVP_PKEY privkey =
+ OpenSSLClientKeyStore::GetInstance()->FetchClientCertPrivateKey(
+ ssl_config_.client_cert.get());
+#else // !defined(USE_OPENSSL_CERTS)
+ crypto::ScopedEVP_PKEY privkey =
+ FetchClientCertPrivateKey(ssl_config_.client_cert.get());
+#endif // defined(USE_OPENSSL_CERTS)
+ if (!privkey) {
+ // Could not find the private key. Fail the handshake and surface an
+ // appropriate error to the caller.
+ LOG(WARNING) << "Client cert found without private key";
+ OpenSSLPutNetError(FROM_HERE, ERR_SSL_CLIENT_AUTH_CERT_NO_PRIVATE_KEY);
+ return -1;
+ }
- channel_id_xtn_negotiated_ = true;
- if (!channel_id_private_key_.size()) {
- channel_id_request_return_value_ =
- server_bound_cert_service_->GetOrCreateDomainBoundCert(
- host_and_port_.host(),
- &channel_id_private_key_,
- &channel_id_cert_,
- base::Bind(&SSLClientSocketOpenSSL::OnHandshakeIOComplete,
- base::Unretained(this)),
- &channel_id_request_handle_);
- if (channel_id_request_return_value_ != OK)
- return;
+ if (!SSL_use_certificate(ssl_, leaf_x509.get()) ||
+ !SSL_use_PrivateKey(ssl_, privkey.get()) ||
+ !SSL_set1_chain(ssl_, chain.get())) {
+ LOG(WARNING) << "Failed to set client certificate";
+ return -1;
+ }
+ return 1;
}
+#endif // defined(OS_IOS)
- // Decode key.
- std::vector<uint8> encrypted_private_key_info;
- std::vector<uint8> subject_public_key_info;
- encrypted_private_key_info.assign(
- channel_id_private_key_.data(),
- channel_id_private_key_.data() + channel_id_private_key_.size());
- subject_public_key_info.assign(
- channel_id_cert_.data(),
- channel_id_cert_.data() + channel_id_cert_.size());
- scoped_ptr<crypto::ECPrivateKey> ec_private_key(
- crypto::ECPrivateKey::CreateFromEncryptedPrivateKeyInfo(
- ServerBoundCertService::kEPKIPassword,
- encrypted_private_key_info,
- subject_public_key_info));
- if (!ec_private_key)
- return;
- set_channel_id_sent(true);
- *pkey = EVP_PKEY_dup(ec_private_key->key());
+ // Send no client certificate.
+ return 1;
}
int SSLClientSocketOpenSSL::CertVerifyCallback(X509_STORE_CTX* store_ctx) {
- if (!completed_handshake_) {
+ if (!completed_connect_) {
// If the first handshake hasn't completed then we accept any certificates
// because we verify after the handshake.
return 1;
}
- CHECK(server_cert_.get());
-
- PeerCertificateChain chain(store_ctx->untrusted);
- if (chain.IsValid() && server_cert_->Equals(chain.AsOSChain()))
- return 1;
-
- if (!chain.IsValid())
+ // Disallow the server certificate to change in a renegotiation.
+ if (server_cert_chain_->empty()) {
LOG(ERROR) << "Received invalid certificate chain between handshakes";
- else
+ return 0;
+ }
+ base::StringPiece old_der, new_der;
+ if (store_ctx->cert == NULL ||
+ !x509_util::GetDER(server_cert_chain_->Get(0), &old_der) ||
+ !x509_util::GetDER(store_ctx->cert, &new_der)) {
+ LOG(ERROR) << "Failed to encode certificates";
+ return 0;
+ }
+ if (old_der != new_der) {
LOG(ERROR) << "Server certificate changed between handshakes";
- return 0;
+ return 0;
+ }
+
+ return 1;
}
// SelectNextProtoCallback is called by OpenSSL during the handshake. If the
@@ -1426,11 +1724,104 @@ int SSLClientSocketOpenSSL::SelectNextProtoCallback(unsigned char** out,
}
npn_proto_.assign(reinterpret_cast<const char*>(*out), *outlen);
- server_protos_.assign(reinterpret_cast<const char*>(in), inlen);
DVLOG(2) << "next protocol: '" << npn_proto_ << "' status: " << npn_status_;
+ set_negotiation_extension(kExtensionNPN);
return SSL_TLSEXT_ERR_OK;
}
+long SSLClientSocketOpenSSL::MaybeReplayTransportError(
+ BIO *bio,
+ int cmd,
+ const char *argp, int argi, long argl,
+ long retvalue) {
+ if (cmd == (BIO_CB_READ|BIO_CB_RETURN) && retvalue <= 0) {
+ // If there is no more data in the buffer, report any pending errors that
+ // were observed. Note that both the readbuf and the writebuf are checked
+ // for errors, since the application may have encountered a socket error
+ // while writing that would otherwise not be reported until the application
+ // attempted to write again - which it may never do. See
+ // https://crbug.com/249848.
+ if (transport_read_error_ != OK) {
+ OpenSSLPutNetError(FROM_HERE, transport_read_error_);
+ return -1;
+ }
+ if (transport_write_error_ != OK) {
+ OpenSSLPutNetError(FROM_HERE, transport_write_error_);
+ return -1;
+ }
+ } else if (cmd == BIO_CB_WRITE) {
+ // Because of the write buffer, this reports a failure from the previous
+ // write payload. If the current payload fails to write, the error will be
+ // reported in a future write or read to |bio|.
+ if (transport_write_error_ != OK) {
+ OpenSSLPutNetError(FROM_HERE, transport_write_error_);
+ return -1;
+ }
+ }
+ return retvalue;
+}
+
+// static
+long SSLClientSocketOpenSSL::BIOCallback(
+ BIO *bio,
+ int cmd,
+ const char *argp, int argi, long argl,
+ long retvalue) {
+ SSLClientSocketOpenSSL* socket = reinterpret_cast<SSLClientSocketOpenSSL*>(
+ BIO_get_callback_arg(bio));
+ CHECK(socket);
+ return socket->MaybeReplayTransportError(
+ bio, cmd, argp, argi, argl, retvalue);
+}
+
+// static
+void SSLClientSocketOpenSSL::InfoCallback(const SSL* ssl,
+ int type,
+ int /*val*/) {
+ if (type == SSL_CB_HANDSHAKE_DONE) {
+ SSLClientSocketOpenSSL* ssl_socket =
+ SSLContext::GetInstance()->GetClientSocketFromSSL(ssl);
+ ssl_socket->handshake_succeeded_ = true;
+ ssl_socket->CheckIfHandshakeFinished();
+ }
+}
+
+// Determines if both the handshake and certificate verification have completed
+// successfully, and calls the handshake completion callback if that is the
+// case.
+//
+// CheckIfHandshakeFinished is called twice per connection: once after
+// MarkSSLSessionAsGood, when the certificate has been verified, and
+// once via an OpenSSL callback when the handshake has completed. On the
+// second call, when the certificate has been verified and the handshake
+// has completed, the connection's handshake completion callback is run.
+void SSLClientSocketOpenSSL::CheckIfHandshakeFinished() {
+ if (handshake_succeeded_ && marked_session_as_good_)
+ OnHandshakeCompletion();
+}
+
+void SSLClientSocketOpenSSL::AddSCTInfoToSSLInfo(SSLInfo* ssl_info) const {
+ for (ct::SCTList::const_iterator iter =
+ ct_verify_result_.verified_scts.begin();
+ iter != ct_verify_result_.verified_scts.end(); ++iter) {
+ ssl_info->signed_certificate_timestamps.push_back(
+ SignedCertificateTimestampAndStatus(*iter, ct::SCT_STATUS_OK));
+ }
+ for (ct::SCTList::const_iterator iter =
+ ct_verify_result_.invalid_scts.begin();
+ iter != ct_verify_result_.invalid_scts.end(); ++iter) {
+ ssl_info->signed_certificate_timestamps.push_back(
+ SignedCertificateTimestampAndStatus(*iter, ct::SCT_STATUS_INVALID));
+ }
+ for (ct::SCTList::const_iterator iter =
+ ct_verify_result_.unknown_logs_scts.begin();
+ iter != ct_verify_result_.unknown_logs_scts.end(); ++iter) {
+ ssl_info->signed_certificate_timestamps.push_back(
+ SignedCertificateTimestampAndStatus(*iter,
+ ct::SCT_STATUS_LOG_UNKNOWN));
+ }
+}
+
scoped_refptr<X509Certificate>
SSLClientSocketOpenSSL::GetUnverifiedServerCertificateChain() const {
return server_cert_;
diff --git a/chromium/net/socket/ssl_client_socket_openssl.h b/chromium/net/socket/ssl_client_socket_openssl.h
index 5d70c0523fa..53d33c4c8c7 100644
--- a/chromium/net/socket/ssl_client_socket_openssl.h
+++ b/chromium/net/socket/ssl_client_socket_openssl.h
@@ -13,9 +13,11 @@
#include "net/base/completion_callback.h"
#include "net/base/io_buffer.h"
#include "net/cert/cert_verify_result.h"
+#include "net/cert/ct_verify_result.h"
#include "net/socket/client_socket_handle.h"
#include "net/socket/ssl_client_socket.h"
-#include "net/ssl/server_bound_cert_service.h"
+#include "net/ssl/channel_id_service.h"
+#include "net/ssl/openssl_ssl_util.h"
#include "net/ssl/ssl_client_cert_type.h"
#include "net/ssl/ssl_config_service.h"
@@ -34,6 +36,7 @@ typedef struct x509_store_ctx_st X509_STORE_CTX;
namespace net {
class CertVerifier;
+class CTVerifier;
class SingleRequestCertVerifier;
class SSLCertRequestInfo;
class SSLInfo;
@@ -49,7 +52,7 @@ class SSLClientSocketOpenSSL : public SSLClientSocket {
const HostPortPair& host_and_port,
const SSLConfig& ssl_config,
const SSLClientSocketContext& context);
- virtual ~SSLClientSocketOpenSSL();
+ ~SSLClientSocketOpenSSL() override;
const HostPortPair& host_and_port() const { return host_and_port_; }
const std::string& ssl_session_cache_shard() const {
@@ -57,46 +60,49 @@ class SSLClientSocketOpenSSL : public SSLClientSocket {
}
// SSLClientSocket implementation.
- virtual void GetSSLCertRequestInfo(
- SSLCertRequestInfo* cert_request_info) OVERRIDE;
- virtual NextProtoStatus GetNextProto(std::string* proto,
- std::string* server_protos) OVERRIDE;
- virtual ServerBoundCertService* GetServerBoundCertService() const OVERRIDE;
+ std::string GetSessionCacheKey() const override;
+ bool InSessionCache() const override;
+ void SetHandshakeCompletionCallback(const base::Closure& callback) override;
+ void GetSSLCertRequestInfo(SSLCertRequestInfo* cert_request_info) override;
+ NextProtoStatus GetNextProto(std::string* proto) override;
+ ChannelIDService* GetChannelIDService() const override;
// SSLSocket implementation.
- virtual int ExportKeyingMaterial(const base::StringPiece& label,
- bool has_context,
- const base::StringPiece& context,
- unsigned char* out,
- unsigned int outlen) OVERRIDE;
- virtual int GetTLSUniqueChannelBinding(std::string* out) OVERRIDE;
+ int ExportKeyingMaterial(const base::StringPiece& label,
+ bool has_context,
+ const base::StringPiece& context,
+ unsigned char* out,
+ unsigned int outlen) override;
+ int GetTLSUniqueChannelBinding(std::string* out) override;
// StreamSocket implementation.
- virtual int Connect(const CompletionCallback& callback) OVERRIDE;
- virtual void Disconnect() OVERRIDE;
- virtual bool IsConnected() const OVERRIDE;
- virtual bool IsConnectedAndIdle() const OVERRIDE;
- virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE;
- virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE;
- virtual const BoundNetLog& NetLog() const OVERRIDE;
- virtual void SetSubresourceSpeculation() OVERRIDE;
- virtual void SetOmniboxSpeculation() OVERRIDE;
- virtual bool WasEverUsed() const OVERRIDE;
- virtual bool UsingTCPFastOpen() const OVERRIDE;
- virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE;
+ int Connect(const CompletionCallback& callback) override;
+ void Disconnect() override;
+ bool IsConnected() const override;
+ bool IsConnectedAndIdle() const override;
+ int GetPeerAddress(IPEndPoint* address) const override;
+ int GetLocalAddress(IPEndPoint* address) const override;
+ const BoundNetLog& NetLog() const override;
+ void SetSubresourceSpeculation() override;
+ void SetOmniboxSpeculation() override;
+ bool WasEverUsed() const override;
+ bool UsingTCPFastOpen() const override;
+ bool GetSSLInfo(SSLInfo* ssl_info) override;
// Socket implementation.
- virtual int Read(IOBuffer* buf, int buf_len,
- const CompletionCallback& callback) OVERRIDE;
- virtual int Write(IOBuffer* buf, int buf_len,
- const CompletionCallback& callback) OVERRIDE;
- virtual int SetReceiveBufferSize(int32 size) OVERRIDE;
- virtual int SetSendBufferSize(int32 size) OVERRIDE;
+ int Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override;
+ int Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override;
+ int SetReceiveBufferSize(int32 size) override;
+ int SetSendBufferSize(int32 size) override;
protected:
// SSLClientSocket implementation.
- virtual scoped_refptr<X509Certificate> GetUnverifiedServerCertificateChain()
- const OVERRIDE;
+ scoped_refptr<X509Certificate> GetUnverifiedServerCertificateChain()
+ const override;
private:
class PeerCertificateChain;
@@ -108,20 +114,25 @@ class SSLClientSocketOpenSSL : public SSLClientSocket {
void DoReadCallback(int result);
void DoWriteCallback(int result);
+ void OnHandshakeCompletion();
+
bool DoTransportIO();
int DoHandshake();
+ int DoChannelIDLookup();
+ int DoChannelIDLookupComplete(int result);
int DoVerifyCert(int result);
int DoVerifyCertComplete(int result);
void DoConnectCallback(int result);
- X509Certificate* UpdateServerCert();
+ void UpdateServerCert();
+ void VerifyCT();
void OnHandshakeIOComplete(int result);
void OnSendComplete(int result);
void OnRecvComplete(int result);
int DoHandshakeLoop(int last_io_result);
- int DoReadLoop(int result);
- int DoWriteLoop(int result);
+ int DoReadLoop();
+ int DoWriteLoop();
int DoPayloadRead();
int DoPayloadWrite();
@@ -134,11 +145,7 @@ class SSLClientSocketOpenSSL : public SSLClientSocket {
// Callback from the SSL layer that indicates the remote server is requesting
// a certificate for this client.
- int ClientCertRequestCallback(SSL* ssl, X509** x509, EVP_PKEY** pkey);
-
- // Callback from the SSL layer that indicates the remote server supports TLS
- // Channel IDs.
- void ChannelIDRequestCallback(SSL* ssl, EVP_PKEY** pkey);
+ int ClientCertRequestCallback(SSL* ssl);
// CertVerifyCallback is called to verify the server's certificates. We do
// verification after the handshake so this function only enforces that the
@@ -149,9 +156,36 @@ class SSLClientSocketOpenSSL : public SSLClientSocket {
int SelectNextProtoCallback(unsigned char** out, unsigned char* outlen,
const unsigned char* in, unsigned int inlen);
+ // Called during an operation on |transport_bio_|'s peer. Checks saved
+ // transport error state and, if appropriate, returns an error through
+ // OpenSSL's error system.
+ long MaybeReplayTransportError(BIO *bio,
+ int cmd,
+ const char *argp, int argi, long argl,
+ long retvalue);
+
+ // Callback from the SSL layer when an operation is performed on
+ // |transport_bio_|'s peer.
+ static long BIOCallback(BIO *bio,
+ int cmd,
+ const char *argp, int argi, long argl,
+ long retvalue);
+
+ // Callback that is used to obtain information about the state of the SSL
+ // handshake.
+ static void InfoCallback(const SSL* ssl, int type, int val);
+
+ void CheckIfHandshakeFinished();
+
+ // Adds the SignedCertificateTimestamps from ct_verify_result_ to |ssl_info|.
+ // SCTs are held in three separate vectors in ct_verify_result, each
+ // vetor representing a particular verification state, this method associates
+ // each of the SCTs with the corresponding SCTVerifyStatus as it adds it to
+ // the |ssl_info|.signed_certificate_timestamps list.
+ void AddSCTInfoToSSLInfo(SSLInfo* ssl_info) const;
+
bool transport_send_busy_;
bool transport_recv_busy_;
- bool transport_recv_eof_;
scoped_refptr<DrainableIOBuffer> send_buffer_;
scoped_refptr<IOBuffer> recv_buffer_;
@@ -160,8 +194,6 @@ class SSLClientSocketOpenSSL : public SSLClientSocket {
CompletionCallback user_read_callback_;
CompletionCallback user_write_callback_;
- base::WeakPtrFactory<SSLClientSocketOpenSSL> weak_factory_;
-
// Used by Read function.
scoped_refptr<IOBuffer> user_read_buf_;
int user_read_buf_len_;
@@ -178,15 +210,27 @@ class SSLClientSocketOpenSSL : public SSLClientSocket {
// indicates an error.
int pending_read_error_;
+ // If there is a pending read result, the OpenSSL result code (output of
+ // SSL_get_error) associated with it.
+ int pending_read_ssl_error_;
+
+ // If there is a pending read result, the OpenSSLErrorInfo associated with it.
+ OpenSSLErrorInfo pending_read_error_info_;
+
+ // Used by TransportReadComplete() to signify an error reading from the
+ // transport socket. A value of OK indicates the socket is still
+ // readable. EOFs are mapped to ERR_CONNECTION_CLOSED.
+ int transport_read_error_;
+
// Used by TransportWriteComplete() and TransportReadComplete() to signify an
// error writing to the transport socket. A value of OK indicates no error.
int transport_write_error_;
- // Set when handshake finishes.
+ // Set when Connect finishes.
scoped_ptr<PeerCertificateChain> server_cert_chain_;
scoped_refptr<X509Certificate> server_cert_;
CertVerifyResult server_cert_verify_result_;
- bool completed_handshake_;
+ bool completed_connect_;
// Set when Read() or Write() successfully reads or writes data to or from the
// network.
@@ -204,9 +248,20 @@ class SSLClientSocketOpenSSL : public SSLClientSocket {
CertVerifier* const cert_verifier_;
scoped_ptr<SingleRequestCertVerifier> verifier_;
+ base::TimeTicks start_cert_verification_time_;
+
+ // Certificate Transparency: Verifier and result holder.
+ ct::CTVerifyResult ct_verify_result_;
+ CTVerifier* cert_transparency_verifier_;
// The service for retrieving Channel ID keys. May be NULL.
- ServerBoundCertService* server_bound_cert_service_;
+ ChannelIDService* channel_id_service_;
+
+ // Callback that is invoked when the connection finishes.
+ //
+ // Note: this callback will be run in Disconnect(). It will not alter
+ // any member variables of the SSLClientSocketOpenSSL.
+ base::Closure handshake_completion_callback_;
// OpenSSL stuff
SSL* ssl_;
@@ -226,23 +281,36 @@ class SSLClientSocketOpenSSL : public SSLClientSocket {
enum State {
STATE_NONE,
STATE_HANDSHAKE,
+ STATE_CHANNEL_ID_LOOKUP,
+ STATE_CHANNEL_ID_LOOKUP_COMPLETE,
STATE_VERIFY_CERT,
STATE_VERIFY_CERT_COMPLETE,
};
State next_handshake_state_;
NextProtoStatus npn_status_;
std::string npn_proto_;
- std::string server_protos_;
- // Written by the |server_bound_cert_service_|.
+ // Written by the |channel_id_service_|.
std::string channel_id_private_key_;
std::string channel_id_cert_;
- // The return value of the last call to |server_bound_cert_service_|.
- int channel_id_request_return_value_;
// True if channel ID extension was negotiated.
bool channel_id_xtn_negotiated_;
- // The request handle for |server_bound_cert_service_|.
- ServerBoundCertService::RequestHandle channel_id_request_handle_;
+ // True if InfoCallback has been run with result = SSL_CB_HANDSHAKE_DONE.
+ bool handshake_succeeded_;
+ // True if MarkSSLSessionAsGood has been called for this socket's
+ // SSL session.
+ bool marked_session_as_good_;
+ // The request handle for |channel_id_service_|.
+ ChannelIDService::RequestHandle channel_id_request_handle_;
+
+ TransportSecurityState* transport_security_state_;
+
+ // pinning_failure_log contains a message produced by
+ // TransportSecurityState::CheckPublicKeyPins in the event of a
+ // pinning failure. It is a (somewhat) human-readable string.
+ std::string pinning_failure_log_;
+
BoundNetLog net_log_;
+ base::WeakPtrFactory<SSLClientSocketOpenSSL> weak_factory_;
};
} // namespace net
diff --git a/chromium/net/socket/ssl_client_socket_openssl_unittest.cc b/chromium/net/socket/ssl_client_socket_openssl_unittest.cc
index d4e0685467e..8a6a8828810 100644
--- a/chromium/net/socket/ssl_client_socket_openssl_unittest.cc
+++ b/chromium/net/socket/ssl_client_socket_openssl_unittest.cc
@@ -13,12 +13,13 @@
#include <openssl/pem.h>
#include <openssl/rsa.h>
-#include "base/file_util.h"
#include "base/files/file_path.h"
+#include "base/files/file_util.h"
#include "base/memory/ref_counted.h"
#include "base/message_loop/message_loop_proxy.h"
#include "base/values.h"
#include "crypto/openssl_util.h"
+#include "crypto/scoped_openssl_types.h"
#include "net/base/address_list.h"
#include "net/base/io_buffer.h"
#include "net/base/net_errors.h"
@@ -48,16 +49,6 @@ namespace {
// These client auth tests are currently dependent on OpenSSL's struct X509.
#if defined(USE_OPENSSL_CERTS)
-typedef OpenSSLClientKeyStore::ScopedEVP_PKEY ScopedEVP_PKEY;
-
-// BIO_free is a macro, it can't be used as a template parameter.
-void BIO_free_func(BIO* bio) {
- BIO_free(bio);
-}
-
-typedef crypto::ScopedOpenSSL<BIO, BIO_free_func> ScopedBIO;
-typedef crypto::ScopedOpenSSL<RSA, RSA_free> ScopedRSA;
-typedef crypto::ScopedOpenSSL<BIGNUM, BN_free> ScopedBIGNUM;
const SSLConfig kDefaultSSLConfig;
@@ -67,17 +58,16 @@ const SSLConfig kDefaultSSLConfig;
// Returns true on success, false on failure.
bool LoadPrivateKeyOpenSSL(
const base::FilePath& filepath,
- OpenSSLClientKeyStore::ScopedEVP_PKEY* pkey) {
+ crypto::ScopedEVP_PKEY* pkey) {
std::string data;
if (!base::ReadFileToString(filepath, &data)) {
LOG(ERROR) << "Could not read private key file: "
<< filepath.value() << ": " << strerror(errno);
return false;
}
- ScopedBIO bio(
- BIO_new_mem_buf(
- const_cast<char*>(reinterpret_cast<const char*>(data.data())),
- static_cast<int>(data.size())));
+ crypto::ScopedBIO bio(BIO_new_mem_buf(
+ const_cast<char*>(reinterpret_cast<const char*>(data.data())),
+ static_cast<int>(data.size())));
if (!bio.get()) {
LOG(ERROR) << "Could not allocate BIO for buffer?";
return false;
@@ -95,13 +85,13 @@ bool LoadPrivateKeyOpenSSL(
class SSLClientSocketOpenSSLClientAuthTest : public PlatformTest {
public:
SSLClientSocketOpenSSLClientAuthTest()
- : socket_factory_(net::ClientSocketFactory::GetDefaultFactory()),
- cert_verifier_(new net::MockCertVerifier),
- transport_security_state_(new net::TransportSecurityState) {
- cert_verifier_->set_default_result(net::OK);
+ : socket_factory_(ClientSocketFactory::GetDefaultFactory()),
+ cert_verifier_(new MockCertVerifier),
+ transport_security_state_(new TransportSecurityState) {
+ cert_verifier_->set_default_result(OK);
context_.cert_verifier = cert_verifier_.get();
context_.transport_security_state = transport_security_state_.get();
- key_store_ = net::OpenSSLClientKeyStore::GetInstance();
+ key_store_ = OpenSSLClientKeyStore::GetInstance();
}
virtual ~SSLClientSocketOpenSSLClientAuthTest() {
@@ -260,7 +250,7 @@ TEST_F(SSLClientSocketOpenSSLClientAuthTest, SendGoodCert) {
// This is required to ensure that signing works with the client
// certificate's private key.
- OpenSSLClientKeyStore::ScopedEVP_PKEY client_private_key;
+ crypto::ScopedEVP_PKEY client_private_key;
ASSERT_TRUE(LoadPrivateKeyOpenSSL(certs_dir.AppendASCII("client_1.key"),
&client_private_key));
EXPECT_TRUE(RecordPrivateKey(ssl_config, client_private_key.get()));
diff --git a/chromium/net/socket/ssl_client_socket_pool.cc b/chromium/net/socket/ssl_client_socket_pool.cc
index 2c704658820..56df1d85d09 100644
--- a/chromium/net/socket/ssl_client_socket_pool.cc
+++ b/chromium/net/socket/ssl_client_socket_pool.cc
@@ -9,6 +9,7 @@
#include "base/metrics/field_trial.h"
#include "base/metrics/histogram.h"
#include "base/metrics/sparse_histogram.h"
+#include "base/stl_util.h"
#include "base/values.h"
#include "net/base/host_port_pair.h"
#include "net/base/net_errors.h"
@@ -45,15 +46,15 @@ SSLSocketParams::SSLSocketParams(
force_spdy_over_ssl_(force_spdy_over_ssl),
want_spdy_over_npn_(want_spdy_over_npn),
ignore_limits_(false) {
- if (direct_params_) {
- DCHECK(!socks_proxy_params_);
- DCHECK(!http_proxy_params_);
+ if (direct_params_.get()) {
+ DCHECK(!socks_proxy_params_.get());
+ DCHECK(!http_proxy_params_.get());
ignore_limits_ = direct_params_->ignore_limits();
- } else if (socks_proxy_params_) {
- DCHECK(!http_proxy_params_);
+ } else if (socks_proxy_params_.get()) {
+ DCHECK(!http_proxy_params_.get());
ignore_limits_ = socks_proxy_params_->ignore_limits();
} else {
- DCHECK(http_proxy_params_);
+ DCHECK(http_proxy_params_.get());
ignore_limits_ = http_proxy_params_->ignore_limits();
}
}
@@ -61,18 +62,18 @@ SSLSocketParams::SSLSocketParams(
SSLSocketParams::~SSLSocketParams() {}
SSLSocketParams::ConnectionType SSLSocketParams::GetConnectionType() const {
- if (direct_params_) {
- DCHECK(!socks_proxy_params_);
- DCHECK(!http_proxy_params_);
+ if (direct_params_.get()) {
+ DCHECK(!socks_proxy_params_.get());
+ DCHECK(!http_proxy_params_.get());
return DIRECT;
}
- if (socks_proxy_params_) {
- DCHECK(!http_proxy_params_);
+ if (socks_proxy_params_.get()) {
+ DCHECK(!http_proxy_params_.get());
return SOCKS_PROXY;
}
- DCHECK(http_proxy_params_);
+ DCHECK(http_proxy_params_.get());
return HTTP_PROXY;
}
@@ -94,6 +95,77 @@ SSLSocketParams::GetHttpProxyConnectionParams() const {
return http_proxy_params_;
}
+SSLConnectJobMessenger::SocketAndCallback::SocketAndCallback(
+ SSLClientSocket* ssl_socket,
+ const base::Closure& job_resumption_callback)
+ : socket(ssl_socket), callback(job_resumption_callback) {
+}
+
+SSLConnectJobMessenger::SocketAndCallback::~SocketAndCallback() {
+}
+
+SSLConnectJobMessenger::SSLConnectJobMessenger(
+ const base::Closure& messenger_finished_callback)
+ : messenger_finished_callback_(messenger_finished_callback),
+ weak_factory_(this) {
+}
+
+SSLConnectJobMessenger::~SSLConnectJobMessenger() {
+}
+
+void SSLConnectJobMessenger::RemovePendingSocket(SSLClientSocket* ssl_socket) {
+ // Sockets do not need to be removed from connecting_sockets_ because
+ // OnSSLHandshakeCompleted will do this.
+ for (SSLPendingSocketsAndCallbacks::iterator it =
+ pending_sockets_and_callbacks_.begin();
+ it != pending_sockets_and_callbacks_.end();
+ ++it) {
+ if (it->socket == ssl_socket) {
+ pending_sockets_and_callbacks_.erase(it);
+ return;
+ }
+ }
+}
+
+bool SSLConnectJobMessenger::CanProceed(SSLClientSocket* ssl_socket) {
+ // If there are no connecting sockets, allow the connection to proceed.
+ return connecting_sockets_.empty();
+}
+
+void SSLConnectJobMessenger::MonitorConnectionResult(
+ SSLClientSocket* ssl_socket) {
+ connecting_sockets_.push_back(ssl_socket);
+ ssl_socket->SetHandshakeCompletionCallback(
+ base::Bind(&SSLConnectJobMessenger::OnSSLHandshakeCompleted,
+ weak_factory_.GetWeakPtr()));
+}
+
+void SSLConnectJobMessenger::AddPendingSocket(SSLClientSocket* ssl_socket,
+ const base::Closure& callback) {
+ DCHECK(!connecting_sockets_.empty());
+ pending_sockets_and_callbacks_.push_back(
+ SocketAndCallback(ssl_socket, callback));
+}
+
+void SSLConnectJobMessenger::OnSSLHandshakeCompleted() {
+ connecting_sockets_.clear();
+ SSLPendingSocketsAndCallbacks temp_list;
+ temp_list.swap(pending_sockets_and_callbacks_);
+ base::Closure messenger_finished_callback = messenger_finished_callback_;
+ messenger_finished_callback.Run();
+ RunAllCallbacks(temp_list);
+}
+
+void SSLConnectJobMessenger::RunAllCallbacks(
+ const SSLPendingSocketsAndCallbacks& pending_sockets_and_callbacks) {
+ for (std::vector<SocketAndCallback>::const_iterator it =
+ pending_sockets_and_callbacks.begin();
+ it != pending_sockets_and_callbacks.end();
+ ++it) {
+ it->callback.Run();
+ }
+}
+
// Timeout for the SSL handshake portion of the connect.
static const int kSSLHandshakeTimeoutInSeconds = 30;
@@ -107,6 +179,7 @@ SSLConnectJob::SSLConnectJob(const std::string& group_name,
ClientSocketFactory* client_socket_factory,
HostResolver* host_resolver,
const SSLClientSocketContext& context,
+ const GetMessengerCallback& get_messenger_callback,
Delegate* delegate,
NetLog* net_log)
: ConnectJob(group_name,
@@ -121,16 +194,23 @@ SSLConnectJob::SSLConnectJob(const std::string& group_name,
client_socket_factory_(client_socket_factory),
host_resolver_(host_resolver),
context_(context.cert_verifier,
- context.server_bound_cert_service,
+ context.channel_id_service,
context.transport_security_state,
context.cert_transparency_verifier,
(params->privacy_mode() == PRIVACY_MODE_ENABLED
? "pm/" + context.ssl_session_cache_shard
: context.ssl_session_cache_shard)),
- callback_(base::Bind(&SSLConnectJob::OnIOComplete,
- base::Unretained(this))) {}
+ io_callback_(
+ base::Bind(&SSLConnectJob::OnIOComplete, base::Unretained(this))),
+ messenger_(NULL),
+ get_messenger_callback_(get_messenger_callback),
+ weak_factory_(this) {
+}
-SSLConnectJob::~SSLConnectJob() {}
+SSLConnectJob::~SSLConnectJob() {
+ if (ssl_socket_.get() && messenger_)
+ messenger_->RemovePendingSocket(ssl_socket_.get());
+}
LoadState SSLConnectJob::GetLoadState() const {
switch (next_state_) {
@@ -144,6 +224,8 @@ LoadState SSLConnectJob::GetLoadState() const {
case STATE_SOCKS_CONNECT_COMPLETE:
case STATE_TUNNEL_CONNECT:
return transport_socket_handle_->GetLoadState();
+ case STATE_CREATE_SSL_SOCKET:
+ case STATE_CHECK_FOR_RESUME:
case STATE_SSL_CONNECT:
case STATE_SSL_CONNECT_COMPLETE:
return LOAD_STATE_SSL_HANDSHAKE;
@@ -200,6 +282,12 @@ int SSLConnectJob::DoLoop(int result) {
case STATE_TUNNEL_CONNECT_COMPLETE:
rv = DoTunnelConnectComplete(rv);
break;
+ case STATE_CREATE_SSL_SOCKET:
+ rv = DoCreateSSLSocket();
+ break;
+ case STATE_CHECK_FOR_RESUME:
+ rv = DoCheckForResume();
+ break;
case STATE_SSL_CONNECT:
DCHECK_EQ(OK, rv);
rv = DoSSLConnect();
@@ -227,14 +315,14 @@ int SSLConnectJob::DoTransportConnect() {
return transport_socket_handle_->Init(group_name(),
direct_params,
priority(),
- callback_,
+ io_callback_,
transport_pool_,
net_log());
}
int SSLConnectJob::DoTransportConnectComplete(int result) {
if (result == OK)
- next_state_ = STATE_SSL_CONNECT;
+ next_state_ = STATE_CREATE_SSL_SOCKET;
return result;
}
@@ -248,14 +336,14 @@ int SSLConnectJob::DoSOCKSConnect() {
return transport_socket_handle_->Init(group_name(),
socks_proxy_params,
priority(),
- callback_,
+ io_callback_,
socks_pool_,
net_log());
}
int SSLConnectJob::DoSOCKSConnectComplete(int result) {
if (result == OK)
- next_state_ = STATE_SSL_CONNECT;
+ next_state_ = STATE_CREATE_SSL_SOCKET;
return result;
}
@@ -270,7 +358,7 @@ int SSLConnectJob::DoTunnelConnect() {
return transport_socket_handle_->Init(group_name(),
http_proxy_params,
priority(),
- callback_,
+ io_callback_,
http_proxy_pool_,
net_log());
}
@@ -290,13 +378,13 @@ int SSLConnectJob::DoTunnelConnectComplete(int result) {
}
if (result < 0)
return result;
-
- next_state_ = STATE_SSL_CONNECT;
+ next_state_ = STATE_CREATE_SSL_SOCKET;
return result;
}
-int SSLConnectJob::DoSSLConnect() {
- next_state_ = STATE_SSL_CONNECT_COMPLETE;
+int SSLConnectJob::DoCreateSSLSocket() {
+ next_state_ = STATE_CHECK_FOR_RESUME;
+
// Reset the timeout to just the time allowed for the SSL handshake.
ResetTimer(base::TimeDelta::FromSeconds(kSSLHandshakeTimeoutInSeconds));
@@ -314,14 +402,45 @@ int SSLConnectJob::DoSSLConnect() {
connect_timing_.dns_end = socket_connect_timing.dns_end;
}
- connect_timing_.ssl_start = base::TimeTicks::Now();
-
ssl_socket_ = client_socket_factory_->CreateSSLClientSocket(
transport_socket_handle_.Pass(),
params_->host_and_port(),
params_->ssl_config(),
context_);
- return ssl_socket_->Connect(callback_);
+
+ if (!ssl_socket_->InSessionCache())
+ messenger_ = get_messenger_callback_.Run(ssl_socket_->GetSessionCacheKey());
+
+ return OK;
+}
+
+int SSLConnectJob::DoCheckForResume() {
+ next_state_ = STATE_SSL_CONNECT;
+
+ if (!messenger_)
+ return OK;
+
+ if (messenger_->CanProceed(ssl_socket_.get())) {
+ messenger_->MonitorConnectionResult(ssl_socket_.get());
+ // The SSLConnectJob no longer needs access to the messenger after this
+ // point.
+ messenger_ = NULL;
+ return OK;
+ }
+
+ messenger_->AddPendingSocket(ssl_socket_.get(),
+ base::Bind(&SSLConnectJob::ResumeSSLConnection,
+ weak_factory_.GetWeakPtr()));
+
+ return ERR_IO_PENDING;
+}
+
+int SSLConnectJob::DoSSLConnect() {
+ next_state_ = STATE_SSL_CONNECT_COMPLETE;
+
+ connect_timing_.ssl_start = base::TimeTicks::Now();
+
+ return ssl_socket_->Connect(io_callback_);
}
int SSLConnectJob::DoSSLConnectComplete(int result) {
@@ -330,12 +449,13 @@ int SSLConnectJob::DoSSLConnectComplete(int result) {
SSLClientSocket::NextProtoStatus status =
SSLClientSocket::kNextProtoUnsupported;
std::string proto;
- std::string server_protos;
// GetNextProto will fail and and trigger a NOTREACHED if we pass in a socket
// that hasn't had SSL_ImportFD called on it. If we get a certificate error
// here, then we know that we called SSL_ImportFD.
- if (result == OK || IsCertificateError(result))
- status = ssl_socket_->GetNextProto(&proto, &server_protos);
+ if (result == OK || IsCertificateError(result)) {
+ status = ssl_socket_->GetNextProto(&proto);
+ ssl_socket_->RecordNegotiationExtension();
+ }
// If we want spdy over npn, make sure it succeeded.
if (status == SSLClientSocket::kNextProtoNegotiated) {
@@ -370,18 +490,6 @@ int SSLConnectJob::DoSSLConnectComplete(int result) {
base::TimeDelta::FromMinutes(1),
100);
}
-#if defined(SPDY_PROXY_AUTH_ORIGIN)
- bool using_data_reduction_proxy = params_->host_and_port().Equals(
- HostPortPair::FromURL(GURL(SPDY_PROXY_AUTH_ORIGIN)));
- if (using_data_reduction_proxy) {
- UMA_HISTOGRAM_CUSTOM_TIMES(
- "Net.SSL_Connection_Latency_DataReductionProxy",
- connect_duration,
- base::TimeDelta::FromMilliseconds(1),
- base::TimeDelta::FromMinutes(1),
- 100);
- }
-#endif
UMA_HISTOGRAM_CUSTOM_TIMES("Net.SSL_Connection_Latency_2",
connect_duration,
@@ -396,6 +504,11 @@ int SSLConnectJob::DoSSLConnectComplete(int result) {
SSLConnectionStatusToCipherSuite(
ssl_info.connection_status));
+ UMA_HISTOGRAM_BOOLEAN(
+ "Net.RenegotiationExtensionSupported",
+ (ssl_info.connection_status &
+ SSL_CONNECTION_NO_RENEGOTIATION_EXTENSION) == 0);
+
if (ssl_info.handshake_type == SSLInfo::HANDSHAKE_RESUME) {
UMA_HISTOGRAM_CUSTOM_TIMES("Net.SSL_Connection_Latency_Resume_Handshake",
connect_duration,
@@ -411,9 +524,9 @@ int SSLConnectJob::DoSSLConnectComplete(int result) {
}
const std::string& host = params_->host_and_port().host();
- bool is_google = host == "google.com" ||
- (host.size() > 11 &&
- host.rfind(".google.com") == host.size() - 11);
+ bool is_google =
+ host == "google.com" ||
+ (host.size() > 11 && host.rfind(".google.com") == host.size() - 11);
if (is_google) {
UMA_HISTOGRAM_CUSTOM_TIMES("Net.SSL_Connection_Latency_Google2",
connect_duration,
@@ -439,7 +552,7 @@ int SSLConnectJob::DoSSLConnectComplete(int result) {
}
if (result == OK || IsCertificateError(result)) {
- SetSocket(ssl_socket_.PassAs<StreamSocket>());
+ SetSocket(ssl_socket_.Pass());
} else if (result == ERR_SSL_CLIENT_AUTH_CERT_NEEDED) {
error_response_info_.cert_request_info = new SSLCertRequestInfo;
ssl_socket_->GetSSLCertRequestInfo(
@@ -449,6 +562,12 @@ int SSLConnectJob::DoSSLConnectComplete(int result) {
return result;
}
+void SSLConnectJob::ResumeSSLConnection() {
+ DCHECK_EQ(next_state_, STATE_SSL_CONNECT);
+ messenger_ = NULL;
+ OnIOComplete(OK);
+}
+
SSLConnectJob::State SSLConnectJob::GetInitialState(
SSLSocketParams::ConnectionType connection_type) {
switch (connection_type) {
@@ -475,6 +594,7 @@ SSLClientSocketPool::SSLConnectJobFactory::SSLConnectJobFactory(
ClientSocketFactory* client_socket_factory,
HostResolver* host_resolver,
const SSLClientSocketContext& context,
+ const SSLConnectJob::GetMessengerCallback& get_messenger_callback,
NetLog* net_log)
: transport_pool_(transport_pool),
socks_pool_(socks_pool),
@@ -482,6 +602,7 @@ SSLClientSocketPool::SSLConnectJobFactory::SSLConnectJobFactory(
client_socket_factory_(client_socket_factory),
host_resolver_(host_resolver),
context_(context),
+ get_messenger_callback_(get_messenger_callback),
net_log_(net_log) {
base::TimeDelta max_transport_timeout = base::TimeDelta();
base::TimeDelta pool_timeout;
@@ -501,13 +622,16 @@ SSLClientSocketPool::SSLConnectJobFactory::SSLConnectJobFactory(
base::TimeDelta::FromSeconds(kSSLHandshakeTimeoutInSeconds);
}
+SSLClientSocketPool::SSLConnectJobFactory::~SSLConnectJobFactory() {
+}
+
SSLClientSocketPool::SSLClientSocketPool(
int max_sockets,
int max_sockets_per_group,
ClientSocketPoolHistograms* histograms,
HostResolver* host_resolver,
CertVerifier* cert_verifier,
- ServerBoundCertService* server_bound_cert_service,
+ ChannelIDService* channel_id_service,
TransportSecurityState* transport_security_state,
CTVerifier* cert_transparency_verifier,
const std::string& ssl_session_cache_shard,
@@ -516,26 +640,34 @@ SSLClientSocketPool::SSLClientSocketPool(
SOCKSClientSocketPool* socks_pool,
HttpProxyClientSocketPool* http_proxy_pool,
SSLConfigService* ssl_config_service,
+ bool enable_ssl_connect_job_waiting,
NetLog* net_log)
: transport_pool_(transport_pool),
socks_pool_(socks_pool),
http_proxy_pool_(http_proxy_pool),
- base_(this, max_sockets, max_sockets_per_group, histograms,
+ base_(this,
+ max_sockets,
+ max_sockets_per_group,
+ histograms,
ClientSocketPool::unused_idle_socket_timeout(),
ClientSocketPool::used_idle_socket_timeout(),
- new SSLConnectJobFactory(transport_pool,
- socks_pool,
- http_proxy_pool,
- client_socket_factory,
- host_resolver,
- SSLClientSocketContext(
- cert_verifier,
- server_bound_cert_service,
- transport_security_state,
- cert_transparency_verifier,
- ssl_session_cache_shard),
- net_log)),
- ssl_config_service_(ssl_config_service) {
+ new SSLConnectJobFactory(
+ transport_pool,
+ socks_pool,
+ http_proxy_pool,
+ client_socket_factory,
+ host_resolver,
+ SSLClientSocketContext(cert_verifier,
+ channel_id_service,
+ transport_security_state,
+ cert_transparency_verifier,
+ ssl_session_cache_shard),
+ base::Bind(
+ &SSLClientSocketPool::GetOrCreateSSLConnectJobMessenger,
+ base::Unretained(this)),
+ net_log)),
+ ssl_config_service_(ssl_config_service),
+ enable_ssl_connect_job_waiting_(enable_ssl_connect_job_waiting) {
if (ssl_config_service_.get())
ssl_config_service_->AddObserver(this);
if (transport_pool_)
@@ -547,24 +679,33 @@ SSLClientSocketPool::SSLClientSocketPool(
}
SSLClientSocketPool::~SSLClientSocketPool() {
+ STLDeleteContainerPairSecondPointers(messenger_map_.begin(),
+ messenger_map_.end());
if (ssl_config_service_.get())
ssl_config_service_->RemoveObserver(this);
}
-scoped_ptr<ConnectJob>
-SSLClientSocketPool::SSLConnectJobFactory::NewConnectJob(
+scoped_ptr<ConnectJob> SSLClientSocketPool::SSLConnectJobFactory::NewConnectJob(
const std::string& group_name,
const PoolBase::Request& request,
ConnectJob::Delegate* delegate) const {
- return scoped_ptr<ConnectJob>(
- new SSLConnectJob(group_name, request.priority(), request.params(),
- ConnectionTimeout(), transport_pool_, socks_pool_,
- http_proxy_pool_, client_socket_factory_,
- host_resolver_, context_, delegate, net_log_));
-}
-
-base::TimeDelta
-SSLClientSocketPool::SSLConnectJobFactory::ConnectionTimeout() const {
+ return scoped_ptr<ConnectJob>(new SSLConnectJob(group_name,
+ request.priority(),
+ request.params(),
+ ConnectionTimeout(),
+ transport_pool_,
+ socks_pool_,
+ http_proxy_pool_,
+ client_socket_factory_,
+ host_resolver_,
+ context_,
+ get_messenger_callback_,
+ delegate,
+ net_log_));
+}
+
+base::TimeDelta SSLClientSocketPool::SSLConnectJobFactory::ConnectionTimeout()
+ const {
return timeout_;
}
@@ -679,6 +820,32 @@ bool SSLClientSocketPool::CloseOneIdleConnection() {
return base_.CloseOneIdleConnectionInHigherLayeredPool();
}
+SSLConnectJobMessenger* SSLClientSocketPool::GetOrCreateSSLConnectJobMessenger(
+ const std::string& cache_key) {
+ if (!enable_ssl_connect_job_waiting_)
+ return NULL;
+ MessengerMap::const_iterator it = messenger_map_.find(cache_key);
+ if (it == messenger_map_.end()) {
+ std::pair<MessengerMap::iterator, bool> iter =
+ messenger_map_.insert(MessengerMap::value_type(
+ cache_key,
+ new SSLConnectJobMessenger(
+ base::Bind(&SSLClientSocketPool::DeleteSSLConnectJobMessenger,
+ base::Unretained(this),
+ cache_key))));
+ it = iter.first;
+ }
+ return it->second;
+}
+
+void SSLClientSocketPool::DeleteSSLConnectJobMessenger(
+ const std::string& cache_key) {
+ MessengerMap::iterator it = messenger_map_.find(cache_key);
+ CHECK(it != messenger_map_.end());
+ delete it->second;
+ messenger_map_.erase(it);
+}
+
void SSLClientSocketPool::OnSSLConfigChanged() {
FlushWithError(ERR_NETWORK_CHANGED);
}
diff --git a/chromium/net/socket/ssl_client_socket_pool.h b/chromium/net/socket/ssl_client_socket_pool.h
index e03b76ade6a..c7f613e5149 100644
--- a/chromium/net/socket/ssl_client_socket_pool.h
+++ b/chromium/net/socket/ssl_client_socket_pool.h
@@ -5,7 +5,9 @@
#ifndef NET_SOCKET_SSL_CLIENT_SOCKET_POOL_H_
#define NET_SOCKET_SSL_CLIENT_SOCKET_POOL_H_
+#include <map>
#include <string>
+#include <vector>
#include "base/memory/ref_counted.h"
#include "base/memory/scoped_ptr.h"
@@ -94,29 +96,110 @@ class NET_EXPORT_PRIVATE SSLSocketParams
DISALLOW_COPY_AND_ASSIGN(SSLSocketParams);
};
+// SSLConnectJobMessenger handles communication between concurrent
+// SSLConnectJobs that share the same SSL session cache key.
+//
+// SSLConnectJobMessengers tell the session cache when a certain
+// connection should be monitored for success or failure, and
+// tell SSLConnectJobs when to pause or resume their connections.
+class SSLConnectJobMessenger {
+ public:
+ struct SocketAndCallback {
+ SocketAndCallback(SSLClientSocket* ssl_socket,
+ const base::Closure& job_resumption_callback);
+ ~SocketAndCallback();
+
+ SSLClientSocket* socket;
+ base::Closure callback;
+ };
+
+ typedef std::vector<SocketAndCallback> SSLPendingSocketsAndCallbacks;
+
+ // |messenger_finished_callback| is run when a connection monitored by the
+ // SSLConnectJobMessenger has completed and we are finished with the
+ // SSLConnectJobMessenger.
+ explicit SSLConnectJobMessenger(
+ const base::Closure& messenger_finished_callback);
+ ~SSLConnectJobMessenger();
+
+ // Removes |socket| from the set of sockets being monitored. This
+ // guarantees that |job_resumption_callback| will not be called for
+ // the socket.
+ void RemovePendingSocket(SSLClientSocket* ssl_socket);
+
+ // Returns true if |ssl_socket|'s Connect() method should be called.
+ bool CanProceed(SSLClientSocket* ssl_socket);
+
+ // Configures the SSLConnectJobMessenger to begin monitoring |ssl_socket|'s
+ // connection status. After a successful connection, or an error,
+ // the messenger will determine which sockets that have been added
+ // via AddPendingSocket() to allow to proceed.
+ void MonitorConnectionResult(SSLClientSocket* ssl_socket);
+
+ // Adds |socket| to the list of sockets waiting to Connect(). When
+ // the messenger has determined that it's an appropriate time for |socket|
+ // to connect, it will invoke |callback|.
+ //
+ // Note: It is an error to call AddPendingSocket() without having first
+ // called MonitorConnectionResult() and configuring a socket that WILL
+ // have Connect() called on it.
+ void AddPendingSocket(SSLClientSocket* ssl_socket,
+ const base::Closure& callback);
+
+ private:
+ // Processes pending callbacks when a socket completes its SSL handshake --
+ // either successfully or unsuccessfully.
+ void OnSSLHandshakeCompleted();
+
+ // Runs all callbacks stored in |pending_sockets_and_callbacks_|.
+ void RunAllCallbacks(
+ const SSLPendingSocketsAndCallbacks& pending_socket_and_callbacks);
+
+ SSLPendingSocketsAndCallbacks pending_sockets_and_callbacks_;
+ // Note: this field is a vector to allow for future design changes. Currently,
+ // this vector should only ever have one entry.
+ std::vector<SSLClientSocket*> connecting_sockets_;
+
+ base::Closure messenger_finished_callback_;
+
+ base::WeakPtrFactory<SSLConnectJobMessenger> weak_factory_;
+};
+
// SSLConnectJob handles the SSL handshake after setting up the underlying
// connection as specified in the params.
class SSLConnectJob : public ConnectJob {
public:
- SSLConnectJob(
- const std::string& group_name,
- RequestPriority priority,
- const scoped_refptr<SSLSocketParams>& params,
- const base::TimeDelta& timeout_duration,
- TransportClientSocketPool* transport_pool,
- SOCKSClientSocketPool* socks_pool,
- HttpProxyClientSocketPool* http_proxy_pool,
- ClientSocketFactory* client_socket_factory,
- HostResolver* host_resolver,
- const SSLClientSocketContext& context,
- Delegate* delegate,
- NetLog* net_log);
- virtual ~SSLConnectJob();
+ // Callback to allow the SSLConnectJob to obtain an SSLConnectJobMessenger to
+ // coordinate connecting. The SSLConnectJob will supply a unique identifer
+ // (ex: the SSL session cache key), with the expectation that the same
+ // Messenger will be returned for all such ConnectJobs.
+ //
+ // Note: It will only be called for situations where the SSL session cache
+ // does not already have a candidate session to resume.
+ typedef base::Callback<SSLConnectJobMessenger*(const std::string&)>
+ GetMessengerCallback;
+
+ // Note: the SSLConnectJob does not own |messenger| so it must outlive the
+ // job.
+ SSLConnectJob(const std::string& group_name,
+ RequestPriority priority,
+ const scoped_refptr<SSLSocketParams>& params,
+ const base::TimeDelta& timeout_duration,
+ TransportClientSocketPool* transport_pool,
+ SOCKSClientSocketPool* socks_pool,
+ HttpProxyClientSocketPool* http_proxy_pool,
+ ClientSocketFactory* client_socket_factory,
+ HostResolver* host_resolver,
+ const SSLClientSocketContext& context,
+ const GetMessengerCallback& get_messenger_callback,
+ Delegate* delegate,
+ NetLog* net_log);
+ ~SSLConnectJob() override;
// ConnectJob methods.
- virtual LoadState GetLoadState() const OVERRIDE;
+ LoadState GetLoadState() const override;
- virtual void GetAdditionalErrorState(ClientSocketHandle * handle) OVERRIDE;
+ void GetAdditionalErrorState(ClientSocketHandle* handle) override;
private:
enum State {
@@ -126,6 +209,8 @@ class SSLConnectJob : public ConnectJob {
STATE_SOCKS_CONNECT_COMPLETE,
STATE_TUNNEL_CONNECT,
STATE_TUNNEL_CONNECT_COMPLETE,
+ STATE_CREATE_SSL_SOCKET,
+ STATE_CHECK_FOR_RESUME,
STATE_SSL_CONNECT,
STATE_SSL_CONNECT_COMPLETE,
STATE_NONE,
@@ -142,9 +227,14 @@ class SSLConnectJob : public ConnectJob {
int DoSOCKSConnectComplete(int result);
int DoTunnelConnect();
int DoTunnelConnectComplete(int result);
+ int DoCreateSSLSocket();
+ int DoCheckForResume();
int DoSSLConnect();
int DoSSLConnectComplete(int result);
+ // Tells a waiting SSLConnectJob to resume its SSL connection.
+ void ResumeSSLConnection();
+
// Returns the initial state for the state machine based on the
// |connection_type|.
static State GetInitialState(SSLSocketParams::ConnectionType connection_type);
@@ -152,7 +242,7 @@ class SSLConnectJob : public ConnectJob {
// Starts the SSL connection process. Returns OK on success and
// ERR_IO_PENDING if it cannot immediately service the request.
// Otherwise, it returns a net error code.
- virtual int ConnectInternal() OVERRIDE;
+ int ConnectInternal() override;
scoped_refptr<SSLSocketParams> params_;
TransportClientSocketPool* const transport_pool_;
@@ -164,12 +254,17 @@ class SSLConnectJob : public ConnectJob {
const SSLClientSocketContext context_;
State next_state_;
- CompletionCallback callback_;
+ CompletionCallback io_callback_;
scoped_ptr<ClientSocketHandle> transport_socket_handle_;
scoped_ptr<SSLClientSocket> ssl_socket_;
+ SSLConnectJobMessenger* messenger_;
HttpResponseInfo error_response_info_;
+ GetMessengerCallback get_messenger_callback_;
+
+ base::WeakPtrFactory<SSLConnectJob> weak_factory_;
+
DISALLOW_COPY_AND_ASSIGN(SSLConnectJob);
};
@@ -182,85 +277,91 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool
// Only the pools that will be used are required. i.e. if you never
// try to create an SSL over SOCKS socket, |socks_pool| may be NULL.
- SSLClientSocketPool(
- int max_sockets,
- int max_sockets_per_group,
- ClientSocketPoolHistograms* histograms,
- HostResolver* host_resolver,
- CertVerifier* cert_verifier,
- ServerBoundCertService* server_bound_cert_service,
- TransportSecurityState* transport_security_state,
- CTVerifier* cert_transparency_verifier,
- const std::string& ssl_session_cache_shard,
- ClientSocketFactory* client_socket_factory,
- TransportClientSocketPool* transport_pool,
- SOCKSClientSocketPool* socks_pool,
- HttpProxyClientSocketPool* http_proxy_pool,
- SSLConfigService* ssl_config_service,
- NetLog* net_log);
-
- virtual ~SSLClientSocketPool();
+ SSLClientSocketPool(int max_sockets,
+ int max_sockets_per_group,
+ ClientSocketPoolHistograms* histograms,
+ HostResolver* host_resolver,
+ CertVerifier* cert_verifier,
+ ChannelIDService* channel_id_service,
+ TransportSecurityState* transport_security_state,
+ CTVerifier* cert_transparency_verifier,
+ const std::string& ssl_session_cache_shard,
+ ClientSocketFactory* client_socket_factory,
+ TransportClientSocketPool* transport_pool,
+ SOCKSClientSocketPool* socks_pool,
+ HttpProxyClientSocketPool* http_proxy_pool,
+ SSLConfigService* ssl_config_service,
+ bool enable_ssl_connect_job_waiting,
+ NetLog* net_log);
+
+ ~SSLClientSocketPool() override;
// ClientSocketPool implementation.
- virtual int RequestSocket(const std::string& group_name,
- const void* connect_params,
- RequestPriority priority,
- ClientSocketHandle* handle,
- const CompletionCallback& callback,
- const BoundNetLog& net_log) OVERRIDE;
+ int RequestSocket(const std::string& group_name,
+ const void* connect_params,
+ RequestPriority priority,
+ ClientSocketHandle* handle,
+ const CompletionCallback& callback,
+ const BoundNetLog& net_log) override;
- virtual void RequestSockets(const std::string& group_name,
- const void* params,
- int num_sockets,
- const BoundNetLog& net_log) OVERRIDE;
+ void RequestSockets(const std::string& group_name,
+ const void* params,
+ int num_sockets,
+ const BoundNetLog& net_log) override;
- virtual void CancelRequest(const std::string& group_name,
- ClientSocketHandle* handle) OVERRIDE;
+ void CancelRequest(const std::string& group_name,
+ ClientSocketHandle* handle) override;
- virtual void ReleaseSocket(const std::string& group_name,
- scoped_ptr<StreamSocket> socket,
- int id) OVERRIDE;
+ void ReleaseSocket(const std::string& group_name,
+ scoped_ptr<StreamSocket> socket,
+ int id) override;
- virtual void FlushWithError(int error) OVERRIDE;
+ void FlushWithError(int error) override;
- virtual void CloseIdleSockets() OVERRIDE;
+ void CloseIdleSockets() override;
- virtual int IdleSocketCount() const OVERRIDE;
+ int IdleSocketCount() const override;
- virtual int IdleSocketCountInGroup(
- const std::string& group_name) const OVERRIDE;
+ int IdleSocketCountInGroup(const std::string& group_name) const override;
- virtual LoadState GetLoadState(
- const std::string& group_name,
- const ClientSocketHandle* handle) const OVERRIDE;
+ LoadState GetLoadState(const std::string& group_name,
+ const ClientSocketHandle* handle) const override;
- virtual base::DictionaryValue* GetInfoAsValue(
+ base::DictionaryValue* GetInfoAsValue(
const std::string& name,
const std::string& type,
- bool include_nested_pools) const OVERRIDE;
+ bool include_nested_pools) const override;
- virtual base::TimeDelta ConnectionTimeout() const OVERRIDE;
+ base::TimeDelta ConnectionTimeout() const override;
- virtual ClientSocketPoolHistograms* histograms() const OVERRIDE;
+ ClientSocketPoolHistograms* histograms() const override;
// LowerLayeredPool implementation.
- virtual bool IsStalled() const OVERRIDE;
+ bool IsStalled() const override;
- virtual void AddHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE;
+ void AddHigherLayeredPool(HigherLayeredPool* higher_pool) override;
- virtual void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE;
+ void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool) override;
// HigherLayeredPool implementation.
- virtual bool CloseOneIdleConnection() OVERRIDE;
+ bool CloseOneIdleConnection() override;
+
+ // Gets the SSLConnectJobMessenger for the given ssl session |cache_key|. If
+ // none exits, it creates one and stores it in |messenger_map_|.
+ SSLConnectJobMessenger* GetOrCreateSSLConnectJobMessenger(
+ const std::string& cache_key);
+ void DeleteSSLConnectJobMessenger(const std::string& cache_key);
private:
typedef ClientSocketPoolBase<SSLSocketParams> PoolBase;
+ // Maps SSLConnectJob cache keys to SSLConnectJobMessenger objects.
+ typedef std::map<std::string, SSLConnectJobMessenger*> MessengerMap;
// SSLConfigService::Observer implementation.
// When the user changes the SSL config, we flush all idle sockets so they
// won't get re-used.
- virtual void OnSSLConfigChanged() OVERRIDE;
+ void OnSSLConfigChanged() override;
class SSLConnectJobFactory : public PoolBase::ConnectJobFactory {
public:
@@ -271,17 +372,18 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool
ClientSocketFactory* client_socket_factory,
HostResolver* host_resolver,
const SSLClientSocketContext& context,
+ const SSLConnectJob::GetMessengerCallback& get_messenger_callback,
NetLog* net_log);
- virtual ~SSLConnectJobFactory() {}
+ ~SSLConnectJobFactory() override;
// ClientSocketPoolBase::ConnectJobFactory methods.
- virtual scoped_ptr<ConnectJob> NewConnectJob(
+ scoped_ptr<ConnectJob> NewConnectJob(
const std::string& group_name,
const PoolBase::Request& request,
- ConnectJob::Delegate* delegate) const OVERRIDE;
+ ConnectJob::Delegate* delegate) const override;
- virtual base::TimeDelta ConnectionTimeout() const OVERRIDE;
+ base::TimeDelta ConnectionTimeout() const override;
private:
TransportClientSocketPool* const transport_pool_;
@@ -291,6 +393,7 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool
HostResolver* const host_resolver_;
const SSLClientSocketContext context_;
base::TimeDelta timeout_;
+ SSLConnectJob::GetMessengerCallback get_messenger_callback_;
NetLog* net_log_;
DISALLOW_COPY_AND_ASSIGN(SSLConnectJobFactory);
@@ -301,6 +404,8 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool
HttpProxyClientSocketPool* const http_proxy_pool_;
PoolBase base_;
const scoped_refptr<SSLConfigService> ssl_config_service_;
+ MessengerMap messenger_map_;
+ bool enable_ssl_connect_job_waiting_;
DISALLOW_COPY_AND_ASSIGN(SSLClientSocketPool);
};
diff --git a/chromium/net/socket/ssl_client_socket_pool_unittest.cc b/chromium/net/socket/ssl_client_socket_pool_unittest.cc
index 6ae07ed0bda..202cd8809e5 100644
--- a/chromium/net/socket/ssl_client_socket_pool_unittest.cc
+++ b/chromium/net/socket/ssl_client_socket_pool_unittest.cc
@@ -6,6 +6,7 @@
#include "base/callback.h"
#include "base/compiler_specific.h"
+#include "base/run_loop.h"
#include "base/strings/string_util.h"
#include "base/strings/utf_string_conversions.h"
#include "base/time/time.h"
@@ -21,6 +22,7 @@
#include "net/http/http_request_headers.h"
#include "net/http/http_response_headers.h"
#include "net/http/http_server_properties_impl.h"
+#include "net/http/transport_security_state.h"
#include "net/proxy/proxy_service.h"
#include "net/socket/client_socket_handle.h"
#include "net/socket/client_socket_pool_histograms.h"
@@ -78,26 +80,31 @@ class SSLClientSocketPoolTest
public ::testing::WithParamInterface<NextProto> {
protected:
SSLClientSocketPoolTest()
- : proxy_service_(ProxyService::CreateDirect()),
+ : transport_security_state_(new TransportSecurityState),
+ proxy_service_(ProxyService::CreateDirect()),
ssl_config_service_(new SSLConfigServiceDefaults),
http_auth_handler_factory_(
HttpAuthHandlerFactory::CreateDefault(&host_resolver_)),
session_(CreateNetworkSession()),
direct_transport_socket_params_(
- new TransportSocketParams(HostPortPair("host", 443),
- false,
- false,
- OnHostResolutionCallback())),
+ new TransportSocketParams(
+ HostPortPair("host", 443),
+ false,
+ false,
+ OnHostResolutionCallback(),
+ TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)),
transport_histograms_("MockTCP"),
transport_socket_pool_(kMaxSockets,
kMaxSocketsPerGroup,
&transport_histograms_,
&socket_factory_),
proxy_transport_socket_params_(
- new TransportSocketParams(HostPortPair("proxy", 443),
- false,
- false,
- OnHostResolutionCallback())),
+ new TransportSocketParams(
+ HostPortPair("proxy", 443),
+ false,
+ false,
+ OnHostResolutionCallback(),
+ TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)),
socks_socket_params_(
new SOCKSSocketParams(proxy_transport_socket_params_,
true,
@@ -116,7 +123,8 @@ class SSLClientSocketPoolTest
session_->http_auth_cache(),
session_->http_auth_handler_factory(),
session_->spdy_session_pool(),
- true)),
+ true,
+ NULL)),
http_proxy_histograms_("MockHttpProxy"),
http_proxy_socket_pool_(kMaxSockets,
kMaxSocketsPerGroup,
@@ -124,7 +132,9 @@ class SSLClientSocketPoolTest
&host_resolver_,
&transport_socket_pool_,
NULL,
- NULL) {
+ NULL,
+ NULL),
+ enable_ssl_connect_job_waiting_(false) {
scoped_refptr<SSLConfigService> ssl_config_service(
new SSLConfigServiceDefaults);
ssl_config_service->GetSSLConfig(&ssl_config_);
@@ -138,7 +148,7 @@ class SSLClientSocketPoolTest
ssl_histograms_.get(),
NULL /* host_resolver */,
NULL /* cert_verifier */,
- NULL /* server_bound_cert_service */,
+ NULL /* channel_id_service */,
NULL /* transport_security_state */,
NULL /* cert_transparency_verifier */,
std::string() /* ssl_session_cache_shard */,
@@ -147,6 +157,7 @@ class SSLClientSocketPoolTest
socks_pool ? &socks_socket_pool_ : NULL,
http_proxy_pool ? &http_proxy_socket_pool_ : NULL,
NULL,
+ enable_ssl_connect_job_waiting_,
NULL));
}
@@ -221,6 +232,8 @@ class SSLClientSocketPoolTest
SSLConfig ssl_config_;
scoped_ptr<ClientSocketPoolHistograms> ssl_histograms_;
scoped_ptr<SSLClientSocketPool> pool_;
+
+ bool enable_ssl_connect_job_waiting_;
};
INSTANTIATE_TEST_CASE_P(
@@ -229,6 +242,462 @@ INSTANTIATE_TEST_CASE_P(
testing::Values(kProtoDeprecatedSPDY2,
kProtoSPDY3, kProtoSPDY31, kProtoSPDY4));
+// Tests that the final socket will connect even if all sockets
+// prior to it fail.
+//
+// All sockets should wait for the first socket to attempt to
+// connect. Once it fails to connect, all other sockets should
+// attempt to connect. All should fail, except the final socket.
+TEST_P(SSLClientSocketPoolTest, AllSocketsFailButLast) {
+ // Although we request four sockets, the first three socket connect
+ // failures cause the socket pool to create three more sockets because
+ // there are pending requests.
+ StaticSocketDataProvider data1;
+ StaticSocketDataProvider data2;
+ StaticSocketDataProvider data3;
+ StaticSocketDataProvider data4;
+ StaticSocketDataProvider data5;
+ StaticSocketDataProvider data6;
+ StaticSocketDataProvider data7;
+ socket_factory_.AddSocketDataProvider(&data1);
+ socket_factory_.AddSocketDataProvider(&data2);
+ socket_factory_.AddSocketDataProvider(&data3);
+ socket_factory_.AddSocketDataProvider(&data4);
+ socket_factory_.AddSocketDataProvider(&data5);
+ socket_factory_.AddSocketDataProvider(&data6);
+ socket_factory_.AddSocketDataProvider(&data7);
+ SSLSocketDataProvider ssl(ASYNC, ERR_SSL_PROTOCOL_ERROR);
+ ssl.is_in_session_cache = false;
+ SSLSocketDataProvider ssl2(ASYNC, ERR_SSL_PROTOCOL_ERROR);
+ ssl2.is_in_session_cache = false;
+ SSLSocketDataProvider ssl3(ASYNC, ERR_SSL_PROTOCOL_ERROR);
+ ssl3.is_in_session_cache = false;
+ SSLSocketDataProvider ssl4(ASYNC, OK);
+ ssl4.is_in_session_cache = false;
+ SSLSocketDataProvider ssl5(ASYNC, OK);
+ ssl5.is_in_session_cache = false;
+ SSLSocketDataProvider ssl6(ASYNC, OK);
+ ssl6.is_in_session_cache = false;
+ SSLSocketDataProvider ssl7(ASYNC, OK);
+ ssl7.is_in_session_cache = false;
+
+ socket_factory_.AddSSLSocketDataProvider(&ssl);
+ socket_factory_.AddSSLSocketDataProvider(&ssl2);
+ socket_factory_.AddSSLSocketDataProvider(&ssl3);
+ socket_factory_.AddSSLSocketDataProvider(&ssl4);
+ socket_factory_.AddSSLSocketDataProvider(&ssl5);
+ socket_factory_.AddSSLSocketDataProvider(&ssl6);
+ socket_factory_.AddSSLSocketDataProvider(&ssl7);
+
+ enable_ssl_connect_job_waiting_ = true;
+ CreatePool(true, false, false);
+
+ scoped_refptr<SSLSocketParams> params1 =
+ SSLParams(ProxyServer::SCHEME_DIRECT, false);
+ scoped_refptr<SSLSocketParams> params2 =
+ SSLParams(ProxyServer::SCHEME_DIRECT, false);
+ scoped_refptr<SSLSocketParams> params3 =
+ SSLParams(ProxyServer::SCHEME_DIRECT, false);
+ scoped_refptr<SSLSocketParams> params4 =
+ SSLParams(ProxyServer::SCHEME_DIRECT, false);
+ ClientSocketHandle handle1;
+ ClientSocketHandle handle2;
+ ClientSocketHandle handle3;
+ ClientSocketHandle handle4;
+ TestCompletionCallback callback1;
+ TestCompletionCallback callback2;
+ TestCompletionCallback callback3;
+ TestCompletionCallback callback4;
+
+ handle1.Init(
+ "b", params1, MEDIUM, callback1.callback(), pool_.get(), BoundNetLog());
+ handle2.Init(
+ "b", params2, MEDIUM, callback2.callback(), pool_.get(), BoundNetLog());
+ handle3.Init(
+ "b", params3, MEDIUM, callback3.callback(), pool_.get(), BoundNetLog());
+ handle4.Init(
+ "b", params4, MEDIUM, callback4.callback(), pool_.get(), BoundNetLog());
+
+ base::RunLoop().RunUntilIdle();
+
+ // Only the last socket should have connected.
+ EXPECT_FALSE(handle1.socket());
+ EXPECT_FALSE(handle2.socket());
+ EXPECT_FALSE(handle3.socket());
+ EXPECT_TRUE(handle4.socket()->IsConnected());
+}
+
+// Tests that sockets will still connect in parallel if the
+// EnableSSLConnectJobWaiting flag is not enabled.
+TEST_P(SSLClientSocketPoolTest, SocketsConnectWithoutFlag) {
+ StaticSocketDataProvider data1;
+ StaticSocketDataProvider data2;
+ StaticSocketDataProvider data3;
+ socket_factory_.AddSocketDataProvider(&data1);
+ socket_factory_.AddSocketDataProvider(&data2);
+ socket_factory_.AddSocketDataProvider(&data3);
+
+ SSLSocketDataProvider ssl(ASYNC, OK);
+ ssl.is_in_session_cache = false;
+ ssl.should_pause_on_connect = true;
+ SSLSocketDataProvider ssl2(ASYNC, OK);
+ ssl2.is_in_session_cache = false;
+ ssl2.should_pause_on_connect = true;
+ SSLSocketDataProvider ssl3(ASYNC, OK);
+ ssl3.is_in_session_cache = false;
+ ssl3.should_pause_on_connect = true;
+ socket_factory_.AddSSLSocketDataProvider(&ssl);
+ socket_factory_.AddSSLSocketDataProvider(&ssl2);
+ socket_factory_.AddSSLSocketDataProvider(&ssl3);
+
+ CreatePool(true, false, false);
+
+ scoped_refptr<SSLSocketParams> params1 =
+ SSLParams(ProxyServer::SCHEME_DIRECT, false);
+ scoped_refptr<SSLSocketParams> params2 =
+ SSLParams(ProxyServer::SCHEME_DIRECT, false);
+ scoped_refptr<SSLSocketParams> params3 =
+ SSLParams(ProxyServer::SCHEME_DIRECT, false);
+ ClientSocketHandle handle1;
+ ClientSocketHandle handle2;
+ ClientSocketHandle handle3;
+ TestCompletionCallback callback1;
+ TestCompletionCallback callback2;
+ TestCompletionCallback callback3;
+
+ handle1.Init(
+ "b", params1, MEDIUM, callback1.callback(), pool_.get(), BoundNetLog());
+ handle2.Init(
+ "b", params2, MEDIUM, callback2.callback(), pool_.get(), BoundNetLog());
+ handle3.Init(
+ "b", params3, MEDIUM, callback3.callback(), pool_.get(), BoundNetLog());
+
+ base::RunLoop().RunUntilIdle();
+
+ std::vector<MockSSLClientSocket*> sockets =
+ socket_factory_.ssl_client_sockets();
+
+ // All sockets should have started their connections.
+ for (std::vector<MockSSLClientSocket*>::iterator it = sockets.begin();
+ it != sockets.end();
+ ++it) {
+ EXPECT_TRUE((*it)->reached_connect());
+ }
+
+ // Resume connecting all of the sockets.
+ for (std::vector<MockSSLClientSocket*>::iterator it = sockets.begin();
+ it != sockets.end();
+ ++it) {
+ (*it)->RestartPausedConnect();
+ }
+
+ callback1.WaitForResult();
+ callback2.WaitForResult();
+ callback3.WaitForResult();
+
+ EXPECT_TRUE(handle1.socket()->IsConnected());
+ EXPECT_TRUE(handle2.socket()->IsConnected());
+ EXPECT_TRUE(handle3.socket()->IsConnected());
+}
+
+// Tests that the pool deleting an SSLConnectJob will not cause a crash,
+// or prevent pending sockets from connecting.
+TEST_P(SSLClientSocketPoolTest, DeletedSSLConnectJob) {
+ StaticSocketDataProvider data1;
+ StaticSocketDataProvider data2;
+ StaticSocketDataProvider data3;
+ socket_factory_.AddSocketDataProvider(&data1);
+ socket_factory_.AddSocketDataProvider(&data2);
+ socket_factory_.AddSocketDataProvider(&data3);
+
+ SSLSocketDataProvider ssl(ASYNC, OK);
+ ssl.is_in_session_cache = false;
+ ssl.should_pause_on_connect = true;
+ SSLSocketDataProvider ssl2(ASYNC, OK);
+ ssl2.is_in_session_cache = false;
+ SSLSocketDataProvider ssl3(ASYNC, OK);
+ ssl3.is_in_session_cache = false;
+ socket_factory_.AddSSLSocketDataProvider(&ssl);
+ socket_factory_.AddSSLSocketDataProvider(&ssl2);
+ socket_factory_.AddSSLSocketDataProvider(&ssl3);
+
+ enable_ssl_connect_job_waiting_ = true;
+ CreatePool(true, false, false);
+
+ scoped_refptr<SSLSocketParams> params1 =
+ SSLParams(ProxyServer::SCHEME_DIRECT, false);
+ scoped_refptr<SSLSocketParams> params2 =
+ SSLParams(ProxyServer::SCHEME_DIRECT, false);
+ scoped_refptr<SSLSocketParams> params3 =
+ SSLParams(ProxyServer::SCHEME_DIRECT, false);
+ ClientSocketHandle handle1;
+ ClientSocketHandle handle2;
+ ClientSocketHandle handle3;
+ TestCompletionCallback callback1;
+ TestCompletionCallback callback2;
+ TestCompletionCallback callback3;
+
+ handle1.Init(
+ "b", params1, MEDIUM, callback1.callback(), pool_.get(), BoundNetLog());
+ handle2.Init(
+ "b", params2, MEDIUM, callback2.callback(), pool_.get(), BoundNetLog());
+ handle3.Init(
+ "b", params3, MEDIUM, callback3.callback(), pool_.get(), BoundNetLog());
+
+ // Allow the connections to proceed until the first socket has started
+ // connecting.
+ base::RunLoop().RunUntilIdle();
+
+ std::vector<MockSSLClientSocket*> sockets =
+ socket_factory_.ssl_client_sockets();
+
+ pool_->CancelRequest("b", &handle2);
+
+ sockets[0]->RestartPausedConnect();
+
+ callback1.WaitForResult();
+ callback3.WaitForResult();
+
+ EXPECT_TRUE(handle1.socket()->IsConnected());
+ EXPECT_FALSE(handle2.socket());
+ EXPECT_TRUE(handle3.socket()->IsConnected());
+}
+
+// Tests that all pending sockets still connect when the pool deletes a pending
+// SSLConnectJob which immediately followed a failed leading connection.
+TEST_P(SSLClientSocketPoolTest, DeletedSocketAfterFail) {
+ StaticSocketDataProvider data1;
+ StaticSocketDataProvider data2;
+ StaticSocketDataProvider data3;
+ StaticSocketDataProvider data4;
+ socket_factory_.AddSocketDataProvider(&data1);
+ socket_factory_.AddSocketDataProvider(&data2);
+ socket_factory_.AddSocketDataProvider(&data3);
+ socket_factory_.AddSocketDataProvider(&data4);
+
+ SSLSocketDataProvider ssl(ASYNC, ERR_SSL_PROTOCOL_ERROR);
+ ssl.is_in_session_cache = false;
+ ssl.should_pause_on_connect = true;
+ SSLSocketDataProvider ssl2(ASYNC, OK);
+ ssl2.is_in_session_cache = false;
+ SSLSocketDataProvider ssl3(ASYNC, OK);
+ ssl3.is_in_session_cache = false;
+ SSLSocketDataProvider ssl4(ASYNC, OK);
+ ssl4.is_in_session_cache = false;
+ socket_factory_.AddSSLSocketDataProvider(&ssl);
+ socket_factory_.AddSSLSocketDataProvider(&ssl2);
+ socket_factory_.AddSSLSocketDataProvider(&ssl3);
+ socket_factory_.AddSSLSocketDataProvider(&ssl4);
+
+ enable_ssl_connect_job_waiting_ = true;
+ CreatePool(true, false, false);
+
+ scoped_refptr<SSLSocketParams> params1 =
+ SSLParams(ProxyServer::SCHEME_DIRECT, false);
+ scoped_refptr<SSLSocketParams> params2 =
+ SSLParams(ProxyServer::SCHEME_DIRECT, false);
+ scoped_refptr<SSLSocketParams> params3 =
+ SSLParams(ProxyServer::SCHEME_DIRECT, false);
+ ClientSocketHandle handle1;
+ ClientSocketHandle handle2;
+ ClientSocketHandle handle3;
+ TestCompletionCallback callback1;
+ TestCompletionCallback callback2;
+ TestCompletionCallback callback3;
+
+ handle1.Init(
+ "b", params1, MEDIUM, callback1.callback(), pool_.get(), BoundNetLog());
+ handle2.Init(
+ "b", params2, MEDIUM, callback2.callback(), pool_.get(), BoundNetLog());
+ handle3.Init(
+ "b", params3, MEDIUM, callback3.callback(), pool_.get(), BoundNetLog());
+
+ // Allow the connections to proceed until the first socket has started
+ // connecting.
+ base::RunLoop().RunUntilIdle();
+
+ std::vector<MockSSLClientSocket*> sockets =
+ socket_factory_.ssl_client_sockets();
+
+ EXPECT_EQ(3u, sockets.size());
+ EXPECT_TRUE(sockets[0]->reached_connect());
+ EXPECT_FALSE(handle1.socket());
+
+ pool_->CancelRequest("b", &handle2);
+
+ sockets[0]->RestartPausedConnect();
+
+ callback1.WaitForResult();
+ callback3.WaitForResult();
+
+ EXPECT_FALSE(handle1.socket());
+ EXPECT_FALSE(handle2.socket());
+ EXPECT_TRUE(handle3.socket()->IsConnected());
+}
+
+// Make sure that sockets still connect after the leader socket's
+// connection fails.
+TEST_P(SSLClientSocketPoolTest, SimultaneousConnectJobsFail) {
+ StaticSocketDataProvider data1;
+ StaticSocketDataProvider data2;
+ StaticSocketDataProvider data3;
+ StaticSocketDataProvider data4;
+ StaticSocketDataProvider data5;
+ socket_factory_.AddSocketDataProvider(&data1);
+ socket_factory_.AddSocketDataProvider(&data2);
+ socket_factory_.AddSocketDataProvider(&data3);
+ socket_factory_.AddSocketDataProvider(&data4);
+ socket_factory_.AddSocketDataProvider(&data5);
+ SSLSocketDataProvider ssl(ASYNC, ERR_SSL_PROTOCOL_ERROR);
+ ssl.is_in_session_cache = false;
+ ssl.should_pause_on_connect = true;
+ SSLSocketDataProvider ssl2(ASYNC, OK);
+ ssl2.is_in_session_cache = false;
+ ssl2.should_pause_on_connect = true;
+ SSLSocketDataProvider ssl3(ASYNC, OK);
+ ssl3.is_in_session_cache = false;
+ SSLSocketDataProvider ssl4(ASYNC, OK);
+ ssl4.is_in_session_cache = false;
+ SSLSocketDataProvider ssl5(ASYNC, OK);
+ ssl5.is_in_session_cache = false;
+
+ socket_factory_.AddSSLSocketDataProvider(&ssl);
+ socket_factory_.AddSSLSocketDataProvider(&ssl2);
+ socket_factory_.AddSSLSocketDataProvider(&ssl3);
+ socket_factory_.AddSSLSocketDataProvider(&ssl4);
+ socket_factory_.AddSSLSocketDataProvider(&ssl5);
+
+ enable_ssl_connect_job_waiting_ = true;
+ CreatePool(true, false, false);
+ scoped_refptr<SSLSocketParams> params1 =
+ SSLParams(ProxyServer::SCHEME_DIRECT, false);
+ scoped_refptr<SSLSocketParams> params2 =
+ SSLParams(ProxyServer::SCHEME_DIRECT, false);
+ scoped_refptr<SSLSocketParams> params3 =
+ SSLParams(ProxyServer::SCHEME_DIRECT, false);
+ scoped_refptr<SSLSocketParams> params4 =
+ SSLParams(ProxyServer::SCHEME_DIRECT, false);
+ ClientSocketHandle handle1;
+ ClientSocketHandle handle2;
+ ClientSocketHandle handle3;
+ ClientSocketHandle handle4;
+ TestCompletionCallback callback1;
+ TestCompletionCallback callback2;
+ TestCompletionCallback callback3;
+ TestCompletionCallback callback4;
+ handle1.Init(
+ "b", params1, MEDIUM, callback1.callback(), pool_.get(), BoundNetLog());
+ handle2.Init(
+ "b", params2, MEDIUM, callback2.callback(), pool_.get(), BoundNetLog());
+ handle3.Init(
+ "b", params3, MEDIUM, callback3.callback(), pool_.get(), BoundNetLog());
+ handle4.Init(
+ "b", params4, MEDIUM, callback4.callback(), pool_.get(), BoundNetLog());
+
+ base::RunLoop().RunUntilIdle();
+
+ std::vector<MockSSLClientSocket*> sockets =
+ socket_factory_.ssl_client_sockets();
+
+ std::vector<MockSSLClientSocket*>::const_iterator it = sockets.begin();
+
+ // The first socket should have had Connect called on it.
+ EXPECT_TRUE((*it)->reached_connect());
+ ++it;
+
+ // No other socket should have reached connect yet.
+ for (; it != sockets.end(); ++it)
+ EXPECT_FALSE((*it)->reached_connect());
+
+ // Allow the first socket to resume it's connection process.
+ sockets[0]->RestartPausedConnect();
+
+ base::RunLoop().RunUntilIdle();
+
+ // The second socket should have reached connect.
+ EXPECT_TRUE(sockets[1]->reached_connect());
+
+ // Allow the second socket to continue its connection.
+ sockets[1]->RestartPausedConnect();
+
+ base::RunLoop().RunUntilIdle();
+
+ EXPECT_FALSE(handle1.socket());
+ EXPECT_TRUE(handle2.socket()->IsConnected());
+ EXPECT_TRUE(handle3.socket()->IsConnected());
+ EXPECT_TRUE(handle4.socket()->IsConnected());
+}
+
+// Make sure that no sockets connect before the "leader" socket,
+// given that the leader has a successful connection.
+TEST_P(SSLClientSocketPoolTest, SimultaneousConnectJobsSuccess) {
+ StaticSocketDataProvider data1;
+ StaticSocketDataProvider data2;
+ StaticSocketDataProvider data3;
+ socket_factory_.AddSocketDataProvider(&data1);
+ socket_factory_.AddSocketDataProvider(&data2);
+ socket_factory_.AddSocketDataProvider(&data3);
+
+ SSLSocketDataProvider ssl(ASYNC, OK);
+ ssl.is_in_session_cache = false;
+ ssl.should_pause_on_connect = true;
+ SSLSocketDataProvider ssl2(ASYNC, OK);
+ ssl2.is_in_session_cache = false;
+ SSLSocketDataProvider ssl3(ASYNC, OK);
+ ssl3.is_in_session_cache = false;
+ socket_factory_.AddSSLSocketDataProvider(&ssl);
+ socket_factory_.AddSSLSocketDataProvider(&ssl2);
+ socket_factory_.AddSSLSocketDataProvider(&ssl3);
+
+ enable_ssl_connect_job_waiting_ = true;
+ CreatePool(true, false, false);
+
+ scoped_refptr<SSLSocketParams> params1 =
+ SSLParams(ProxyServer::SCHEME_DIRECT, false);
+ scoped_refptr<SSLSocketParams> params2 =
+ SSLParams(ProxyServer::SCHEME_DIRECT, false);
+ scoped_refptr<SSLSocketParams> params3 =
+ SSLParams(ProxyServer::SCHEME_DIRECT, false);
+ ClientSocketHandle handle1;
+ ClientSocketHandle handle2;
+ ClientSocketHandle handle3;
+ TestCompletionCallback callback1;
+ TestCompletionCallback callback2;
+ TestCompletionCallback callback3;
+
+ handle1.Init(
+ "b", params1, MEDIUM, callback1.callback(), pool_.get(), BoundNetLog());
+ handle2.Init(
+ "b", params2, MEDIUM, callback2.callback(), pool_.get(), BoundNetLog());
+ handle3.Init(
+ "b", params3, MEDIUM, callback3.callback(), pool_.get(), BoundNetLog());
+
+ // Allow the connections to proceed until the first socket has finished
+ // connecting.
+ base::RunLoop().RunUntilIdle();
+
+ std::vector<MockSSLClientSocket*> sockets =
+ socket_factory_.ssl_client_sockets();
+
+ std::vector<MockSSLClientSocket*>::const_iterator it = sockets.begin();
+ // The first socket should have reached connect.
+ EXPECT_TRUE((*it)->reached_connect());
+ ++it;
+ // No other socket should have reached connect yet.
+ for (; it != sockets.end(); ++it)
+ EXPECT_FALSE((*it)->reached_connect());
+
+ sockets[0]->RestartPausedConnect();
+
+ callback1.WaitForResult();
+ callback2.WaitForResult();
+ callback3.WaitForResult();
+
+ EXPECT_TRUE(handle1.socket()->IsConnected());
+ EXPECT_TRUE(handle2.socket()->IsConnected());
+ EXPECT_TRUE(handle3.socket()->IsConnected());
+}
+
TEST_P(SSLClientSocketPoolTest, TCPFail) {
StaticSocketDataProvider data;
data.set_connect_data(MockConnect(SYNCHRONOUS, ERR_CONNECTION_FAILED));
@@ -466,8 +935,7 @@ TEST_P(SSLClientSocketPoolTest, DirectGotSPDY) {
SSLClientSocket* ssl_socket = static_cast<SSLClientSocket*>(handle.socket());
EXPECT_TRUE(ssl_socket->WasNpnNegotiated());
std::string proto;
- std::string server_protos;
- ssl_socket->GetNextProto(&proto, &server_protos);
+ ssl_socket->GetNextProto(&proto);
EXPECT_EQ(GetParam(), SSLClientSocket::NextProtoFromString(proto));
}
@@ -498,8 +966,7 @@ TEST_P(SSLClientSocketPoolTest, DirectGotBonusSPDY) {
SSLClientSocket* ssl_socket = static_cast<SSLClientSocket*>(handle.socket());
EXPECT_TRUE(ssl_socket->WasNpnNegotiated());
std::string proto;
- std::string server_protos;
- ssl_socket->GetNextProto(&proto, &server_protos);
+ ssl_socket->GetNextProto(&proto);
EXPECT_EQ(GetParam(), SSLClientSocket::NextProtoFromString(proto));
}
@@ -798,8 +1265,7 @@ TEST_P(SSLClientSocketPoolTest, NeedProxyAuth) {
EXPECT_FALSE(tunnel_handle->socket()->IsConnected());
}
-// TODO(rch): re-enable this.
-TEST_P(SSLClientSocketPoolTest, DISABLED_IPPooling) {
+TEST_P(SSLClientSocketPoolTest, IPPooling) {
const int kTestPort = 80;
struct TestHosts {
std::string name;
@@ -813,7 +1279,7 @@ TEST_P(SSLClientSocketPoolTest, DISABLED_IPPooling) {
};
host_resolver_.set_synchronous_mode(true);
- for (size_t i = 0; i < ARRAYSIZE_UNSAFE(test_hosts); i++) {
+ for (size_t i = 0; i < arraysize(test_hosts); i++) {
host_resolver_.rules()->AddIPLiteralRule(
test_hosts[i].name, test_hosts[i].iplist, std::string());
@@ -873,7 +1339,7 @@ void SSLClientSocketPoolTest::TestIPPoolingDisabled(
TestCompletionCallback callback;
int rv;
- for (size_t i = 0; i < ARRAYSIZE_UNSAFE(test_hosts); i++) {
+ for (size_t i = 0; i < arraysize(test_hosts); i++) {
host_resolver_.rules()->AddIPLiteralRule(
test_hosts[i].name, test_hosts[i].iplist, std::string());
diff --git a/chromium/net/socket/ssl_client_socket_unittest.cc b/chromium/net/socket/ssl_client_socket_unittest.cc
index 51c67565dad..16e03f7eb8b 100644
--- a/chromium/net/socket/ssl_client_socket_unittest.cc
+++ b/chromium/net/socket/ssl_client_socket_unittest.cc
@@ -15,6 +15,8 @@
#include "net/base/net_log_unittest.h"
#include "net/base/test_completion_callback.h"
#include "net/base/test_data_directory.h"
+#include "net/cert/asn1_util.h"
+#include "net/cert/ct_verifier.h"
#include "net/cert/mock_cert_verifier.h"
#include "net/cert/test_root_certs.h"
#include "net/dns/host_resolver.h"
@@ -23,16 +25,22 @@
#include "net/socket/client_socket_handle.h"
#include "net/socket/socket_test_util.h"
#include "net/socket/tcp_client_socket.h"
-#include "net/ssl/default_server_bound_cert_store.h"
+#include "net/ssl/channel_id_service.h"
+#include "net/ssl/default_channel_id_store.h"
#include "net/ssl/ssl_cert_request_info.h"
#include "net/ssl/ssl_config_service.h"
#include "net/test/cert_test_util.h"
#include "net/test/spawned_test_server/spawned_test_server.h"
+#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "testing/platform_test.h"
//-----------------------------------------------------------------------------
+using testing::_;
+using testing::Return;
+using testing::Truly;
+
namespace net {
namespace {
@@ -49,65 +57,57 @@ class WrappedStreamSocket : public StreamSocket {
public:
explicit WrappedStreamSocket(scoped_ptr<StreamSocket> transport)
: transport_(transport.Pass()) {}
- virtual ~WrappedStreamSocket() {}
+ ~WrappedStreamSocket() override {}
// StreamSocket implementation:
- virtual int Connect(const CompletionCallback& callback) OVERRIDE {
+ int Connect(const CompletionCallback& callback) override {
return transport_->Connect(callback);
}
- virtual void Disconnect() OVERRIDE { transport_->Disconnect(); }
- virtual bool IsConnected() const OVERRIDE {
- return transport_->IsConnected();
- }
- virtual bool IsConnectedAndIdle() const OVERRIDE {
+ void Disconnect() override { transport_->Disconnect(); }
+ bool IsConnected() const override { return transport_->IsConnected(); }
+ bool IsConnectedAndIdle() const override {
return transport_->IsConnectedAndIdle();
}
- virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE {
+ int GetPeerAddress(IPEndPoint* address) const override {
return transport_->GetPeerAddress(address);
}
- virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE {
+ int GetLocalAddress(IPEndPoint* address) const override {
return transport_->GetLocalAddress(address);
}
- virtual const BoundNetLog& NetLog() const OVERRIDE {
- return transport_->NetLog();
- }
- virtual void SetSubresourceSpeculation() OVERRIDE {
+ const BoundNetLog& NetLog() const override { return transport_->NetLog(); }
+ void SetSubresourceSpeculation() override {
transport_->SetSubresourceSpeculation();
}
- virtual void SetOmniboxSpeculation() OVERRIDE {
- transport_->SetOmniboxSpeculation();
- }
- virtual bool WasEverUsed() const OVERRIDE {
- return transport_->WasEverUsed();
- }
- virtual bool UsingTCPFastOpen() const OVERRIDE {
+ void SetOmniboxSpeculation() override { transport_->SetOmniboxSpeculation(); }
+ bool WasEverUsed() const override { return transport_->WasEverUsed(); }
+ bool UsingTCPFastOpen() const override {
return transport_->UsingTCPFastOpen();
}
- virtual bool WasNpnNegotiated() const OVERRIDE {
+ bool WasNpnNegotiated() const override {
return transport_->WasNpnNegotiated();
}
- virtual NextProto GetNegotiatedProtocol() const OVERRIDE {
+ NextProto GetNegotiatedProtocol() const override {
return transport_->GetNegotiatedProtocol();
}
- virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE {
+ bool GetSSLInfo(SSLInfo* ssl_info) override {
return transport_->GetSSLInfo(ssl_info);
}
// Socket implementation:
- virtual int Read(IOBuffer* buf,
- int buf_len,
- const CompletionCallback& callback) OVERRIDE {
+ int Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override {
return transport_->Read(buf, buf_len, callback);
}
- virtual int Write(IOBuffer* buf,
- int buf_len,
- const CompletionCallback& callback) OVERRIDE {
+ int Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override {
return transport_->Write(buf, buf_len, callback);
}
- virtual int SetReceiveBufferSize(int32 size) OVERRIDE {
+ int SetReceiveBufferSize(int32 size) override {
return transport_->SetReceiveBufferSize(size);
}
- virtual int SetSendBufferSize(int32 size) OVERRIDE {
+ int SetSendBufferSize(int32 size) override {
return transport_->SetSendBufferSize(size);
}
@@ -124,12 +124,12 @@ class WrappedStreamSocket : public StreamSocket {
class ReadBufferingStreamSocket : public WrappedStreamSocket {
public:
explicit ReadBufferingStreamSocket(scoped_ptr<StreamSocket> transport);
- virtual ~ReadBufferingStreamSocket() {}
+ ~ReadBufferingStreamSocket() override {}
// Socket implementation:
- virtual int Read(IOBuffer* buf,
- int buf_len,
- const CompletionCallback& callback) OVERRIDE;
+ int Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override;
// Sets the internal buffer to |size|. This must not be greater than
// the largest value supplied to Read() - that is, it does not handle
@@ -254,21 +254,21 @@ void ReadBufferingStreamSocket::OnReadCompleted(int result) {
class SynchronousErrorStreamSocket : public WrappedStreamSocket {
public:
explicit SynchronousErrorStreamSocket(scoped_ptr<StreamSocket> transport);
- virtual ~SynchronousErrorStreamSocket() {}
+ ~SynchronousErrorStreamSocket() override {}
// Socket implementation:
- virtual int Read(IOBuffer* buf,
- int buf_len,
- const CompletionCallback& callback) OVERRIDE;
- virtual int Write(IOBuffer* buf,
- int buf_len,
- const CompletionCallback& callback) OVERRIDE;
+ int Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override;
+ int Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override;
// Sets the next Read() call and all future calls to return |error|.
// If there is already a pending asynchronous read, the configured error
// will not be returned until that asynchronous read has completed and Read()
// is called again.
- void SetNextReadError(Error error) {
+ void SetNextReadError(int error) {
DCHECK_GE(0, error);
have_read_error_ = true;
pending_read_error_ = error;
@@ -278,7 +278,7 @@ class SynchronousErrorStreamSocket : public WrappedStreamSocket {
// If there is already a pending asynchronous write, the configured error
// will not be returned until that asynchronous write has completed and
// Write() is called again.
- void SetNextWriteError(Error error) {
+ void SetNextWriteError(int error) {
DCHECK_GE(0, error);
have_write_error_ = true;
pending_write_error_ = error;
@@ -325,15 +325,15 @@ int SynchronousErrorStreamSocket::Write(IOBuffer* buf,
class FakeBlockingStreamSocket : public WrappedStreamSocket {
public:
explicit FakeBlockingStreamSocket(scoped_ptr<StreamSocket> transport);
- virtual ~FakeBlockingStreamSocket() {}
+ ~FakeBlockingStreamSocket() override {}
// Socket implementation:
- virtual int Read(IOBuffer* buf,
- int buf_len,
- const CompletionCallback& callback) OVERRIDE;
- virtual int Write(IOBuffer* buf,
- int buf_len,
- const CompletionCallback& callback) OVERRIDE;
+ int Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override;
+ int Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override;
// Blocks read results on the socket. Reads will not complete until
// UnblockReadResult() has been called and a result is ready from the
@@ -357,6 +357,9 @@ class FakeBlockingStreamSocket : public WrappedStreamSocket {
// Waits for the blocked Write() call to be scheduled.
void WaitForWrite();
+ // Returns the wrapped stream socket.
+ StreamSocket* transport() { return transport_.get(); }
+
private:
// Handles completion from the underlying transport read.
void OnReadCompleted(int result);
@@ -429,7 +432,7 @@ int FakeBlockingStreamSocket::Write(IOBuffer* buf,
return transport_->Write(buf, len, callback);
// Schedule the write, but do nothing.
- DCHECK(!pending_write_buf_);
+ DCHECK(!pending_write_buf_.get());
DCHECK_EQ(-1, pending_write_len_);
DCHECK(pending_write_callback_.is_null());
DCHECK(!callback.is_null());
@@ -485,11 +488,11 @@ void FakeBlockingStreamSocket::UnblockWrite() {
// Do nothing if UnblockWrite() was called after BlockWrite(),
// without a Write() in between.
- if (!pending_write_buf_)
+ if (!pending_write_buf_.get())
return;
- int rv = transport_->Write(pending_write_buf_, pending_write_len_,
- pending_write_callback_);
+ int rv = transport_->Write(
+ pending_write_buf_.get(), pending_write_len_, pending_write_callback_);
pending_write_buf_ = NULL;
pending_write_len_ = -1;
if (rv == ERR_IO_PENDING) {
@@ -503,12 +506,12 @@ void FakeBlockingStreamSocket::WaitForWrite() {
DCHECK(should_block_write_);
DCHECK(!write_loop_);
- if (pending_write_buf_)
+ if (pending_write_buf_.get())
return;
write_loop_.reset(new base::RunLoop);
write_loop_->Run();
write_loop_.reset();
- DCHECK(pending_write_buf_);
+ DCHECK(pending_write_buf_.get());
}
void FakeBlockingStreamSocket::OnReadCompleted(int result) {
@@ -538,18 +541,18 @@ class CountingStreamSocket : public WrappedStreamSocket {
: WrappedStreamSocket(transport.Pass()),
read_count_(0),
write_count_(0) {}
- virtual ~CountingStreamSocket() {}
+ ~CountingStreamSocket() override {}
// Socket implementation:
- virtual int Read(IOBuffer* buf,
- int buf_len,
- const CompletionCallback& callback) OVERRIDE {
+ int Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override {
read_count_++;
return transport_->Read(buf, buf_len, callback);
}
- virtual int Write(IOBuffer* buf,
- int buf_len,
- const CompletionCallback& callback) OVERRIDE {
+ int Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override {
write_count_++;
return transport_->Write(buf, buf_len, callback);
}
@@ -570,7 +573,7 @@ class DeleteSocketCallback : public TestCompletionCallbackBase {
: socket_(socket),
callback_(base::Bind(&DeleteSocketCallback::OnComplete,
base::Unretained(this))) {}
- virtual ~DeleteSocketCallback() {}
+ ~DeleteSocketCallback() override {}
const CompletionCallback& callback() const { return callback_; }
@@ -591,33 +594,72 @@ class DeleteSocketCallback : public TestCompletionCallbackBase {
DISALLOW_COPY_AND_ASSIGN(DeleteSocketCallback);
};
-// A ServerBoundCertStore that always returns an error when asked for a
-// certificate.
-class FailingServerBoundCertStore : public ServerBoundCertStore {
- virtual int GetServerBoundCert(const std::string& server_identifier,
- base::Time* expiration_time,
- std::string* private_key_result,
- std::string* cert_result,
- const GetCertCallback& callback) OVERRIDE {
+// A ChannelIDStore that always returns an error when asked for a
+// channel id.
+class FailingChannelIDStore : public ChannelIDStore {
+ int GetChannelID(const std::string& server_identifier,
+ base::Time* expiration_time,
+ std::string* private_key_result,
+ std::string* cert_result,
+ const GetChannelIDCallback& callback) override {
return ERR_UNEXPECTED;
}
- virtual void SetServerBoundCert(const std::string& server_identifier,
- base::Time creation_time,
- base::Time expiration_time,
- const std::string& private_key,
- const std::string& cert) OVERRIDE {}
- virtual void DeleteServerBoundCert(const std::string& server_identifier,
- const base::Closure& completion_callback)
- OVERRIDE {}
- virtual void DeleteAllCreatedBetween(base::Time delete_begin,
- base::Time delete_end,
- const base::Closure& completion_callback)
- OVERRIDE {}
- virtual void DeleteAll(const base::Closure& completion_callback) OVERRIDE {}
- virtual void GetAllServerBoundCerts(const GetCertListCallback& callback)
- OVERRIDE {}
- virtual int GetCertCount() OVERRIDE { return 0; }
- virtual void SetForceKeepSessionState() OVERRIDE {}
+ void SetChannelID(const std::string& server_identifier,
+ base::Time creation_time,
+ base::Time expiration_time,
+ const std::string& private_key,
+ const std::string& cert) override {}
+ void DeleteChannelID(const std::string& server_identifier,
+ const base::Closure& completion_callback) override {}
+ void DeleteAllCreatedBetween(
+ base::Time delete_begin,
+ base::Time delete_end,
+ const base::Closure& completion_callback) override {}
+ void DeleteAll(const base::Closure& completion_callback) override {}
+ void GetAllChannelIDs(const GetChannelIDListCallback& callback) override {}
+ int GetChannelIDCount() override { return 0; }
+ void SetForceKeepSessionState() override {}
+};
+
+// A ChannelIDStore that asynchronously returns an error when asked for a
+// channel id.
+class AsyncFailingChannelIDStore : public ChannelIDStore {
+ int GetChannelID(const std::string& server_identifier,
+ base::Time* expiration_time,
+ std::string* private_key_result,
+ std::string* cert_result,
+ const GetChannelIDCallback& callback) override {
+ base::MessageLoop::current()->PostTask(
+ FROM_HERE, base::Bind(callback, ERR_UNEXPECTED,
+ server_identifier, base::Time(), "", ""));
+ return ERR_IO_PENDING;
+ }
+ void SetChannelID(const std::string& server_identifier,
+ base::Time creation_time,
+ base::Time expiration_time,
+ const std::string& private_key,
+ const std::string& cert) override {}
+ void DeleteChannelID(const std::string& server_identifier,
+ const base::Closure& completion_callback) override {}
+ void DeleteAllCreatedBetween(
+ base::Time delete_begin,
+ base::Time delete_end,
+ const base::Closure& completion_callback) override {}
+ void DeleteAll(const base::Closure& completion_callback) override {}
+ void GetAllChannelIDs(const GetChannelIDListCallback& callback) override {}
+ int GetChannelIDCount() override { return 0; }
+ void SetForceKeepSessionState() override {}
+};
+
+// A mock CTVerifier that records every call to Verify but doesn't verify
+// anything.
+class MockCTVerifier : public CTVerifier {
+ public:
+ MOCK_METHOD5(Verify, int(X509Certificate*,
+ const std::string&,
+ const std::string&,
+ ct::CTVerifyResult*,
+ const BoundNetLog&));
};
class SSLClientSocketTest : public PlatformTest {
@@ -625,12 +667,15 @@ class SSLClientSocketTest : public PlatformTest {
SSLClientSocketTest()
: socket_factory_(ClientSocketFactory::GetDefaultFactory()),
cert_verifier_(new MockCertVerifier),
- transport_security_state_(new TransportSecurityState) {
+ transport_security_state_(new TransportSecurityState),
+ ran_handshake_completion_callback_(false) {
cert_verifier_->set_default_result(OK);
context_.cert_verifier = cert_verifier_.get();
context_.transport_security_state = transport_security_state_.get();
}
+ void RecordCompletedHandshake() { ran_handshake_completion_callback_ = true; }
+
protected:
// The address of the spawned test server, after calling StartTestServer().
const AddressList& addr() const { return addr_; }
@@ -638,6 +683,10 @@ class SSLClientSocketTest : public PlatformTest {
// The SpawnedTestServer object, after calling StartTestServer().
const SpawnedTestServer* test_server() const { return test_server_.get(); }
+ void SetCTVerifier(CTVerifier* ct_verifier) {
+ context_.cert_transparency_verifier = ct_verifier;
+ }
+
// Starts the test server with SSL configuration |ssl_options|. Returns true
// on success.
bool StartTestServer(const SpawnedTestServer::SSLOptions& ssl_options) {
@@ -707,6 +756,7 @@ class SSLClientSocketTest : public PlatformTest {
SSLClientSocketContext context_;
scoped_ptr<SSLClientSocket> sock_;
CapturingNetLog log_;
+ bool ran_handshake_completion_callback_;
private:
scoped_ptr<StreamSocket> transport_;
@@ -759,6 +809,11 @@ class SSLClientSocketCertRequestInfoTest : public SSLClientSocketTest {
};
class SSLClientSocketFalseStartTest : public SSLClientSocketTest {
+ public:
+ SSLClientSocketFalseStartTest()
+ : monitor_handshake_callback_(false),
+ fail_handshake_after_false_start_(false) {}
+
protected:
// Creates an SSLClientSocket with |client_config| attached to a
// FakeBlockingStreamSocket, returning both in |*out_raw_transport| and
@@ -780,18 +835,25 @@ class SSLClientSocketFalseStartTest : public SSLClientSocketTest {
scoped_ptr<SSLClientSocket>* out_sock) {
CHECK(test_server());
- scoped_ptr<StreamSocket> real_transport(
- new TCPClientSocket(addr(), NULL, NetLog::Source()));
+ scoped_ptr<StreamSocket> real_transport(scoped_ptr<StreamSocket>(
+ new TCPClientSocket(addr(), NULL, NetLog::Source())));
+ real_transport.reset(
+ new SynchronousErrorStreamSocket(real_transport.Pass()));
+
scoped_ptr<FakeBlockingStreamSocket> transport(
new FakeBlockingStreamSocket(real_transport.Pass()));
int rv = callback->GetResult(transport->Connect(callback->callback()));
EXPECT_EQ(OK, rv);
FakeBlockingStreamSocket* raw_transport = transport.get();
- scoped_ptr<SSLClientSocket> sock =
- CreateSSLClientSocket(transport.PassAs<StreamSocket>(),
- test_server()->host_port_pair(),
- client_config);
+ scoped_ptr<SSLClientSocket> sock = CreateSSLClientSocket(
+ transport.Pass(), test_server()->host_port_pair(), client_config);
+
+ if (monitor_handshake_callback_) {
+ sock->SetHandshakeCompletionCallback(
+ base::Bind(&SSLClientSocketTest::RecordCompletedHandshake,
+ base::Unretained(this)));
+ }
// Connect. Stop before the client processes the first server leg
// (ServerHello, etc.)
@@ -808,6 +870,12 @@ class SSLClientSocketFalseStartTest : public SSLClientSocketTest {
raw_transport->UnblockReadResult();
raw_transport->WaitForWrite();
+ if (fail_handshake_after_false_start_) {
+ SynchronousErrorStreamSocket* error_socket =
+ static_cast<SynchronousErrorStreamSocket*>(
+ raw_transport->transport());
+ error_socket->SetNextReadError(ERR_CONNECTION_RESET);
+ }
// And, finally, release that and block the next server leg
// (ChangeCipherSpec, Finished).
raw_transport->BlockReadResult();
@@ -825,6 +893,7 @@ class SSLClientSocketFalseStartTest : public SSLClientSocketTest {
TestCompletionCallback callback;
FakeBlockingStreamSocket* raw_transport = NULL;
scoped_ptr<SSLClientSocket> sock;
+
ASSERT_NO_FATAL_FAILURE(CreateAndConnectUntilServerFinishedReceived(
client_config, &callback, &raw_transport, &sock));
@@ -859,7 +928,10 @@ class SSLClientSocketFalseStartTest : public SSLClientSocketTest {
// After releasing reads, the connection proceeds.
raw_transport->UnblockReadResult();
rv = callback.GetResult(rv);
- EXPECT_LT(0, rv);
+ if (fail_handshake_after_false_start_)
+ EXPECT_EQ(ERR_CONNECTION_RESET, rv);
+ else
+ EXPECT_LT(0, rv);
} else {
// False Start is not enabled, so the handshake will not complete because
// the server second leg is blocked.
@@ -867,25 +939,39 @@ class SSLClientSocketFalseStartTest : public SSLClientSocketTest {
EXPECT_FALSE(callback.have_result());
}
}
+
+ // Indicates that the socket's handshake completion callback should
+ // be monitored.
+ bool monitor_handshake_callback_;
+ // Indicates that this test's handshake should fail after the client
+ // "finished" message is sent.
+ bool fail_handshake_after_false_start_;
};
class SSLClientSocketChannelIDTest : public SSLClientSocketTest {
protected:
void EnableChannelID() {
- cert_service_.reset(
- new ServerBoundCertService(new DefaultServerBoundCertStore(NULL),
- base::MessageLoopProxy::current()));
- context_.server_bound_cert_service = cert_service_.get();
+ channel_id_service_.reset(
+ new ChannelIDService(new DefaultChannelIDStore(NULL),
+ base::MessageLoopProxy::current()));
+ context_.channel_id_service = channel_id_service_.get();
}
void EnableFailingChannelID() {
- cert_service_.reset(new ServerBoundCertService(
- new FailingServerBoundCertStore(), base::MessageLoopProxy::current()));
- context_.server_bound_cert_service = cert_service_.get();
+ channel_id_service_.reset(new ChannelIDService(
+ new FailingChannelIDStore(), base::MessageLoopProxy::current()));
+ context_.channel_id_service = channel_id_service_.get();
+ }
+
+ void EnableAsyncFailingChannelID() {
+ channel_id_service_.reset(new ChannelIDService(
+ new AsyncFailingChannelIDStore(),
+ base::MessageLoopProxy::current()));
+ context_.channel_id_service = channel_id_service_.get();
}
private:
- scoped_ptr<ServerBoundCertService> cert_service_;
+ scoped_ptr<ChannelIDService> channel_id_service_;
};
//-----------------------------------------------------------------------------
@@ -1207,6 +1293,41 @@ TEST_F(SSLClientSocketTest, Read) {
}
}
+// Tests that SSLClientSocket properly handles when the underlying transport
+// synchronously fails a transport read in during the handshake. The error code
+// should be preserved so SSLv3 fallback logic can condition on it.
+TEST_F(SSLClientSocketTest, Connect_WithSynchronousError) {
+ SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
+ SpawnedTestServer::kLocalhost,
+ base::FilePath());
+ ASSERT_TRUE(test_server.Start());
+
+ AddressList addr;
+ ASSERT_TRUE(test_server.GetAddressList(&addr));
+
+ TestCompletionCallback callback;
+ scoped_ptr<StreamSocket> real_transport(
+ new TCPClientSocket(addr, NULL, NetLog::Source()));
+ scoped_ptr<SynchronousErrorStreamSocket> transport(
+ new SynchronousErrorStreamSocket(real_transport.Pass()));
+ int rv = callback.GetResult(transport->Connect(callback.callback()));
+ EXPECT_EQ(OK, rv);
+
+ // Disable TLS False Start to avoid handshake non-determinism.
+ SSLConfig ssl_config;
+ ssl_config.false_start_enabled = false;
+
+ SynchronousErrorStreamSocket* raw_transport = transport.get();
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), ssl_config));
+
+ raw_transport->SetNextWriteError(ERR_CONNECTION_RESET);
+
+ rv = callback.GetResult(sock->Connect(callback.callback()));
+ EXPECT_EQ(ERR_CONNECTION_RESET, rv);
+ EXPECT_FALSE(sock->IsConnected());
+}
+
// Tests that the SSLClientSocket properly handles when the underlying transport
// synchronously returns an error code - such as if an intermediary terminates
// the socket connection uncleanly.
@@ -1233,10 +1354,8 @@ TEST_F(SSLClientSocketTest, Read_WithSynchronousError) {
ssl_config.false_start_enabled = false;
SynchronousErrorStreamSocket* raw_transport = transport.get();
- scoped_ptr<SSLClientSocket> sock(
- CreateSSLClientSocket(transport.PassAs<StreamSocket>(),
- test_server.host_port_pair(),
- ssl_config));
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), ssl_config));
rv = callback.GetResult(sock->Connect(callback.callback()));
EXPECT_EQ(OK, rv);
@@ -1261,14 +1380,7 @@ TEST_F(SSLClientSocketTest, Read_WithSynchronousError) {
// rv != ERR_IO_PENDING is insufficient, as ERR_IO_PENDING is a legitimate
// result when using a dedicated task runner for NSS.
rv = callback.GetResult(sock->Read(buf.get(), 4096, callback.callback()));
-
-#if !defined(USE_OPENSSL)
- // SSLClientSocketNSS records the error exactly
EXPECT_EQ(ERR_CONNECTION_RESET, rv);
-#else
- // SSLClientSocketOpenSSL treats any errors as a simple EOF.
- EXPECT_EQ(0, rv);
-#endif
}
// Tests that the SSLClientSocket properly handles when the underlying transport
@@ -1293,7 +1405,7 @@ TEST_F(SSLClientSocketTest, Write_WithSynchronousError) {
new SynchronousErrorStreamSocket(real_transport.Pass()));
SynchronousErrorStreamSocket* raw_error_socket = error_socket.get();
scoped_ptr<FakeBlockingStreamSocket> transport(
- new FakeBlockingStreamSocket(error_socket.PassAs<StreamSocket>()));
+ new FakeBlockingStreamSocket(error_socket.Pass()));
FakeBlockingStreamSocket* raw_transport = transport.get();
int rv = callback.GetResult(transport->Connect(callback.callback()));
EXPECT_EQ(OK, rv);
@@ -1302,10 +1414,8 @@ TEST_F(SSLClientSocketTest, Write_WithSynchronousError) {
SSLConfig ssl_config;
ssl_config.false_start_enabled = false;
- scoped_ptr<SSLClientSocket> sock(
- CreateSSLClientSocket(transport.PassAs<StreamSocket>(),
- test_server.host_port_pair(),
- ssl_config));
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), ssl_config));
rv = callback.GetResult(sock->Connect(callback.callback()));
EXPECT_EQ(OK, rv);
@@ -1342,14 +1452,7 @@ TEST_F(SSLClientSocketTest, Write_WithSynchronousError) {
// checking that rv != ERR_IO_PENDING is insufficient, as ERR_IO_PENDING
// is a legitimate result when using a dedicated task runner for NSS.
rv = callback.GetResult(rv);
-
-#if !defined(USE_OPENSSL)
- // SSLClientSocketNSS records the error exactly
EXPECT_EQ(ERR_CONNECTION_RESET, rv);
-#else
- // SSLClientSocketOpenSSL treats any errors as a simple EOF.
- EXPECT_EQ(0, rv);
-#endif
}
// If there is a Write failure at the transport with no follow-up Read, although
@@ -1374,7 +1477,7 @@ TEST_F(SSLClientSocketTest, Write_WithSynchronousErrorNoRead) {
new SynchronousErrorStreamSocket(real_transport.Pass()));
SynchronousErrorStreamSocket* raw_error_socket = error_socket.get();
scoped_ptr<CountingStreamSocket> counting_socket(
- new CountingStreamSocket(error_socket.PassAs<StreamSocket>()));
+ new CountingStreamSocket(error_socket.Pass()));
CountingStreamSocket* raw_counting_socket = counting_socket.get();
int rv = callback.GetResult(counting_socket->Connect(callback.callback()));
ASSERT_EQ(OK, rv);
@@ -1383,10 +1486,8 @@ TEST_F(SSLClientSocketTest, Write_WithSynchronousErrorNoRead) {
SSLConfig ssl_config;
ssl_config.false_start_enabled = false;
- scoped_ptr<SSLClientSocket> sock(
- CreateSSLClientSocket(counting_socket.PassAs<StreamSocket>(),
- test_server.host_port_pair(),
- ssl_config));
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ counting_socket.Pass(), test_server.host_port_pair(), ssl_config));
rv = callback.GetResult(sock->Connect(callback.callback()));
ASSERT_EQ(OK, rv);
@@ -1504,7 +1605,7 @@ TEST_F(SSLClientSocketTest, Read_DeleteWhilePendingFullDuplex) {
new SynchronousErrorStreamSocket(real_transport.Pass()));
SynchronousErrorStreamSocket* raw_error_socket = error_socket.get();
scoped_ptr<FakeBlockingStreamSocket> transport(
- new FakeBlockingStreamSocket(error_socket.PassAs<StreamSocket>()));
+ new FakeBlockingStreamSocket(error_socket.Pass()));
FakeBlockingStreamSocket* raw_transport = transport.get();
int rv = callback.GetResult(transport->Connect(callback.callback()));
@@ -1514,10 +1615,8 @@ TEST_F(SSLClientSocketTest, Read_DeleteWhilePendingFullDuplex) {
SSLConfig ssl_config;
ssl_config.false_start_enabled = false;
- scoped_ptr<SSLClientSocket> sock =
- CreateSSLClientSocket(transport.PassAs<StreamSocket>(),
- test_server.host_port_pair(),
- ssl_config);
+ scoped_ptr<SSLClientSocket> sock = CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), ssl_config);
rv = callback.GetResult(sock->Connect(callback.callback()));
EXPECT_EQ(OK, rv);
@@ -1592,14 +1691,7 @@ TEST_F(SSLClientSocketTest, Read_DeleteWhilePendingFullDuplex) {
raw_transport->UnblockWrite();
rv = read_callback.WaitForResult();
-
-#if !defined(USE_OPENSSL)
- // NSS records the error exactly.
EXPECT_EQ(ERR_CONNECTION_RESET, rv);
-#else
- // OpenSSL treats any errors as a simple EOF.
- EXPECT_EQ(0, rv);
-#endif
// The Write callback should not have been called.
EXPECT_FALSE(callback.have_result());
@@ -1627,7 +1719,7 @@ TEST_F(SSLClientSocketTest, Read_WithWriteError) {
new SynchronousErrorStreamSocket(real_transport.Pass()));
SynchronousErrorStreamSocket* raw_error_socket = error_socket.get();
scoped_ptr<FakeBlockingStreamSocket> transport(
- new FakeBlockingStreamSocket(error_socket.PassAs<StreamSocket>()));
+ new FakeBlockingStreamSocket(error_socket.Pass()));
FakeBlockingStreamSocket* raw_transport = transport.get();
int rv = callback.GetResult(transport->Connect(callback.callback()));
@@ -1637,10 +1729,8 @@ TEST_F(SSLClientSocketTest, Read_WithWriteError) {
SSLConfig ssl_config;
ssl_config.false_start_enabled = false;
- scoped_ptr<SSLClientSocket> sock(
- CreateSSLClientSocket(transport.PassAs<StreamSocket>(),
- test_server.host_port_pair(),
- ssl_config));
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), ssl_config));
rv = callback.GetResult(sock->Connect(callback.callback()));
EXPECT_EQ(OK, rv);
@@ -1694,21 +1784,141 @@ TEST_F(SSLClientSocketTest, Read_WithWriteError) {
}
} while (rv > 0);
-#if !defined(USE_OPENSSL)
- // NSS records the error exactly.
EXPECT_EQ(ERR_CONNECTION_RESET, rv);
-#else
- // OpenSSL treats the reset as a generic protocol error.
- EXPECT_EQ(ERR_SSL_PROTOCOL_ERROR, rv);
-#endif
- // Release the read. Some bytes should go through.
+ // Release the read.
raw_transport->UnblockReadResult();
rv = read_callback.WaitForResult();
- // Per the fix for http://crbug.com/249848, write failures currently break
- // reads. Change this assertion if they're changed to not collide.
+#if defined(USE_OPENSSL)
+ // Should still read bytes despite the write error.
+ EXPECT_LT(0, rv);
+#else
+ // NSS attempts to flush the write buffer in PR_Read on an SSL socket before
+ // pumping the read state machine, unless configured with SSL_ENABLE_FDX, so
+ // the write error stops future reads.
EXPECT_EQ(ERR_CONNECTION_RESET, rv);
+#endif
+}
+
+// Tests that SSLClientSocket fails the handshake if the underlying
+// transport is cleanly closed.
+TEST_F(SSLClientSocketTest, Connect_WithZeroReturn) {
+ SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
+ SpawnedTestServer::kLocalhost,
+ base::FilePath());
+ ASSERT_TRUE(test_server.Start());
+
+ AddressList addr;
+ ASSERT_TRUE(test_server.GetAddressList(&addr));
+
+ TestCompletionCallback callback;
+ scoped_ptr<StreamSocket> real_transport(
+ new TCPClientSocket(addr, NULL, NetLog::Source()));
+ scoped_ptr<SynchronousErrorStreamSocket> transport(
+ new SynchronousErrorStreamSocket(real_transport.Pass()));
+ int rv = callback.GetResult(transport->Connect(callback.callback()));
+ EXPECT_EQ(OK, rv);
+
+ SynchronousErrorStreamSocket* raw_transport = transport.get();
+ scoped_ptr<SSLClientSocket> sock(
+ CreateSSLClientSocket(transport.Pass(),
+ test_server.host_port_pair(),
+ kDefaultSSLConfig));
+
+ raw_transport->SetNextReadError(0);
+
+ rv = callback.GetResult(sock->Connect(callback.callback()));
+ EXPECT_EQ(ERR_CONNECTION_CLOSED, rv);
+ EXPECT_FALSE(sock->IsConnected());
+}
+
+// Tests that SSLClientSocket cleanly returns a Read of size 0 if the
+// underlying socket is cleanly closed.
+// This is a regression test for https://crbug.com/422246
+TEST_F(SSLClientSocketTest, Read_WithZeroReturn) {
+ SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
+ SpawnedTestServer::kLocalhost,
+ base::FilePath());
+ ASSERT_TRUE(test_server.Start());
+
+ AddressList addr;
+ ASSERT_TRUE(test_server.GetAddressList(&addr));
+
+ TestCompletionCallback callback;
+ scoped_ptr<StreamSocket> real_transport(
+ new TCPClientSocket(addr, NULL, NetLog::Source()));
+ scoped_ptr<SynchronousErrorStreamSocket> transport(
+ new SynchronousErrorStreamSocket(real_transport.Pass()));
+ int rv = callback.GetResult(transport->Connect(callback.callback()));
+ EXPECT_EQ(OK, rv);
+
+ // Disable TLS False Start to ensure the handshake has completed.
+ SSLConfig ssl_config;
+ ssl_config.false_start_enabled = false;
+
+ SynchronousErrorStreamSocket* raw_transport = transport.get();
+ scoped_ptr<SSLClientSocket> sock(
+ CreateSSLClientSocket(transport.Pass(),
+ test_server.host_port_pair(),
+ ssl_config));
+
+ rv = callback.GetResult(sock->Connect(callback.callback()));
+ EXPECT_EQ(OK, rv);
+ EXPECT_TRUE(sock->IsConnected());
+
+ raw_transport->SetNextReadError(0);
+ scoped_refptr<IOBuffer> buf(new IOBuffer(4096));
+ rv = callback.GetResult(sock->Read(buf.get(), 4096, callback.callback()));
+ EXPECT_EQ(0, rv);
+}
+
+// Tests that SSLClientSocket cleanly returns a Read of size 0 if the
+// underlying socket is cleanly closed asynchronously.
+// This is a regression test for https://crbug.com/422246
+TEST_F(SSLClientSocketTest, Read_WithAsyncZeroReturn) {
+ SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
+ SpawnedTestServer::kLocalhost,
+ base::FilePath());
+ ASSERT_TRUE(test_server.Start());
+
+ AddressList addr;
+ ASSERT_TRUE(test_server.GetAddressList(&addr));
+
+ TestCompletionCallback callback;
+ scoped_ptr<StreamSocket> real_transport(
+ new TCPClientSocket(addr, NULL, NetLog::Source()));
+ scoped_ptr<SynchronousErrorStreamSocket> error_socket(
+ new SynchronousErrorStreamSocket(real_transport.Pass()));
+ SynchronousErrorStreamSocket* raw_error_socket = error_socket.get();
+ scoped_ptr<FakeBlockingStreamSocket> transport(
+ new FakeBlockingStreamSocket(error_socket.Pass()));
+ FakeBlockingStreamSocket* raw_transport = transport.get();
+ int rv = callback.GetResult(transport->Connect(callback.callback()));
+ EXPECT_EQ(OK, rv);
+
+ // Disable TLS False Start to ensure the handshake has completed.
+ SSLConfig ssl_config;
+ ssl_config.false_start_enabled = false;
+
+ scoped_ptr<SSLClientSocket> sock(
+ CreateSSLClientSocket(transport.Pass(),
+ test_server.host_port_pair(),
+ ssl_config));
+
+ rv = callback.GetResult(sock->Connect(callback.callback()));
+ EXPECT_EQ(OK, rv);
+ EXPECT_TRUE(sock->IsConnected());
+
+ raw_error_socket->SetNextReadError(0);
+ raw_transport->BlockReadResult();
+ scoped_refptr<IOBuffer> buf(new IOBuffer(4096));
+ rv = sock->Read(buf.get(), 4096, callback.callback());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+
+ raw_transport->UnblockReadResult();
+ rv = callback.GetResult(rv);
+ EXPECT_EQ(0, rv);
}
TEST_F(SSLClientSocketTest, Read_SmallChunks) {
@@ -1782,10 +1992,8 @@ TEST_F(SSLClientSocketTest, Read_ManySmallRecords) {
int rv = callback.GetResult(transport->Connect(callback.callback()));
ASSERT_EQ(OK, rv);
- scoped_ptr<SSLClientSocket> sock(
- CreateSSLClientSocket(transport.PassAs<StreamSocket>(),
- test_server.host_port_pair(),
- kDefaultSSLConfig));
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
rv = callback.GetResult(sock->Connect(callback.callback()));
ASSERT_EQ(OK, rv);
@@ -2251,7 +2459,7 @@ TEST_F(SSLClientSocketTest, VerifyReturnChainProperlyOrdered) {
// Load and install the root for the validated chain.
scoped_refptr<X509Certificate> root_cert = ImportCertFromFile(
GetTestCertsDirectory(), "redundant-validated-chain-root.pem");
- ASSERT_NE(static_cast<X509Certificate*>(NULL), root_cert);
+ ASSERT_NE(static_cast<X509Certificate*>(NULL), root_cert.get());
ScopedTestRoot scoped_root(root_cert.get());
// Set up a test server with CERT_CHAIN_WRONG_ROOT.
@@ -2389,45 +2597,45 @@ TEST_F(SSLClientSocketTest, ConnectSignedCertTimestampsEnabledTLSExtension) {
ASSERT_TRUE(test_server.GetAddressList(&addr));
TestCompletionCallback callback;
- CapturingNetLog log;
scoped_ptr<StreamSocket> transport(
- new TCPClientSocket(addr, &log, NetLog::Source()));
- int rv = transport->Connect(callback.callback());
- if (rv == ERR_IO_PENDING)
- rv = callback.WaitForResult();
+ new TCPClientSocket(addr, &log_, NetLog::Source()));
+ int rv = callback.GetResult(transport->Connect(callback.callback()));
EXPECT_EQ(OK, rv);
SSLConfig ssl_config;
ssl_config.signed_cert_timestamps_enabled = true;
- scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
- transport.Pass(), test_server.host_port_pair(), ssl_config));
+ MockCTVerifier ct_verifier;
+ SetCTVerifier(&ct_verifier);
- EXPECT_FALSE(sock->IsConnected());
+ // Check that the SCT list is extracted as expected.
+ EXPECT_CALL(ct_verifier, Verify(_, "", "test", _, _)).WillRepeatedly(
+ Return(ERR_CT_NO_SCTS_VERIFIED_OK));
- rv = sock->Connect(callback.callback());
-
- CapturingNetLog::CapturedEntryList entries;
- log.GetEntries(&entries);
- EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT));
- if (rv == ERR_IO_PENDING)
- rv = callback.WaitForResult();
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), ssl_config));
+ rv = callback.GetResult(sock->Connect(callback.callback()));
EXPECT_EQ(OK, rv);
- EXPECT_TRUE(sock->IsConnected());
- log.GetEntries(&entries);
- EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1));
-#if !defined(USE_OPENSSL)
EXPECT_TRUE(sock->signed_cert_timestamps_received_);
-#else
- // Enabling CT for OpenSSL is currently a noop.
- EXPECT_FALSE(sock->signed_cert_timestamps_received_);
-#endif
+}
- sock->Disconnect();
- EXPECT_FALSE(sock->IsConnected());
+namespace {
+
+bool IsValidOCSPResponse(const base::StringPiece& input) {
+ base::StringPiece ocsp_response = input;
+ base::StringPiece sequence, response_status, response_bytes;
+ return asn1::GetElement(&ocsp_response, asn1::kSEQUENCE, &sequence) &&
+ ocsp_response.empty() &&
+ asn1::GetElement(&sequence, asn1::kENUMERATED, &response_status) &&
+ asn1::GetElement(&sequence,
+ asn1::kContextSpecific | asn1::kConstructed | 0,
+ &response_status) &&
+ sequence.empty();
}
+} // namespace
+
// Test that enabling Signed Certificate Timestamps enables OCSP stapling.
TEST_F(SSLClientSocketTest, ConnectSignedCertTimestampsEnabledOCSP) {
SpawnedTestServer::SSLOptions ssl_options;
@@ -2445,12 +2653,9 @@ TEST_F(SSLClientSocketTest, ConnectSignedCertTimestampsEnabledOCSP) {
ASSERT_TRUE(test_server.GetAddressList(&addr));
TestCompletionCallback callback;
- CapturingNetLog log;
scoped_ptr<StreamSocket> transport(
- new TCPClientSocket(addr, &log, NetLog::Source()));
- int rv = transport->Connect(callback.callback());
- if (rv == ERR_IO_PENDING)
- rv = callback.WaitForResult();
+ new TCPClientSocket(addr, &log_, NetLog::Source()));
+ int rv = callback.GetResult(transport->Connect(callback.callback()));
EXPECT_EQ(OK, rv);
SSLConfig ssl_config;
@@ -2459,32 +2664,24 @@ TEST_F(SSLClientSocketTest, ConnectSignedCertTimestampsEnabledOCSP) {
// is able to process the OCSP status itself.
ssl_config.signed_cert_timestamps_enabled = true;
- scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
- transport.Pass(), test_server.host_port_pair(), ssl_config));
-
- EXPECT_FALSE(sock->IsConnected());
+ MockCTVerifier ct_verifier;
+ SetCTVerifier(&ct_verifier);
- rv = sock->Connect(callback.callback());
+ // Check that the OCSP response is extracted and well-formed. It should be the
+ // DER encoding of an OCSPResponse (RFC 2560), so check that it consists of a
+ // SEQUENCE of an ENUMERATED type and an element tagged with [0] EXPLICIT. In
+ // particular, it should not include the overall two-byte length prefix from
+ // TLS.
+ EXPECT_CALL(ct_verifier,
+ Verify(_, Truly(IsValidOCSPResponse), "", _, _)).WillRepeatedly(
+ Return(ERR_CT_NO_SCTS_VERIFIED_OK));
- CapturingNetLog::CapturedEntryList entries;
- log.GetEntries(&entries);
- EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT));
- if (rv == ERR_IO_PENDING)
- rv = callback.WaitForResult();
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), ssl_config));
+ rv = callback.GetResult(sock->Connect(callback.callback()));
EXPECT_EQ(OK, rv);
- EXPECT_TRUE(sock->IsConnected());
- log.GetEntries(&entries);
- EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1));
-#if !defined(USE_OPENSSL)
EXPECT_TRUE(sock->stapled_ocsp_response_received_);
-#else
- // OCSP stapling isn't currently supported in the OpenSSL socket.
- EXPECT_FALSE(sock->stapled_ocsp_response_received_);
-#endif
-
- sock->Disconnect();
- EXPECT_FALSE(sock->IsConnected());
}
TEST_F(SSLClientSocketTest, ConnectSignedCertTimestampsDisabled) {
@@ -2500,12 +2697,9 @@ TEST_F(SSLClientSocketTest, ConnectSignedCertTimestampsDisabled) {
ASSERT_TRUE(test_server.GetAddressList(&addr));
TestCompletionCallback callback;
- CapturingNetLog log;
scoped_ptr<StreamSocket> transport(
- new TCPClientSocket(addr, &log, NetLog::Source()));
- int rv = transport->Connect(callback.callback());
- if (rv == ERR_IO_PENDING)
- rv = callback.WaitForResult();
+ new TCPClientSocket(addr, &log_, NetLog::Source()));
+ int rv = callback.GetResult(transport->Connect(callback.callback()));
EXPECT_EQ(OK, rv);
SSLConfig ssl_config;
@@ -2513,25 +2707,10 @@ TEST_F(SSLClientSocketTest, ConnectSignedCertTimestampsDisabled) {
scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
transport.Pass(), test_server.host_port_pair(), ssl_config));
-
- EXPECT_FALSE(sock->IsConnected());
-
- rv = sock->Connect(callback.callback());
-
- CapturingNetLog::CapturedEntryList entries;
- log.GetEntries(&entries);
- EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT));
- if (rv == ERR_IO_PENDING)
- rv = callback.WaitForResult();
+ rv = callback.GetResult(sock->Connect(callback.callback()));
EXPECT_EQ(OK, rv);
- EXPECT_TRUE(sock->IsConnected());
- log.GetEntries(&entries);
- EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1));
EXPECT_FALSE(sock->signed_cert_timestamps_received_);
-
- sock->Disconnect();
- EXPECT_FALSE(sock->IsConnected());
}
// Tests that IsConnectedAndIdle and WasEverUsed behave as expected.
@@ -2590,6 +2769,148 @@ TEST_F(SSLClientSocketTest, ReuseStates) {
// attempt to read one byte extra.
}
+#if defined(USE_OPENSSL)
+
+TEST_F(SSLClientSocketTest, HandshakeCallbackIsRun_WithFailure) {
+ SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
+ SpawnedTestServer::kLocalhost,
+ base::FilePath());
+ ASSERT_TRUE(test_server.Start());
+
+ AddressList addr;
+ ASSERT_TRUE(test_server.GetAddressList(&addr));
+
+ TestCompletionCallback callback;
+ scoped_ptr<StreamSocket> real_transport(
+ new TCPClientSocket(addr, NULL, NetLog::Source()));
+ scoped_ptr<SynchronousErrorStreamSocket> transport(
+ new SynchronousErrorStreamSocket(real_transport.Pass()));
+ int rv = callback.GetResult(transport->Connect(callback.callback()));
+ EXPECT_EQ(OK, rv);
+
+ // Disable TLS False Start to avoid handshake non-determinism.
+ SSLConfig ssl_config;
+ ssl_config.false_start_enabled = false;
+
+ SynchronousErrorStreamSocket* raw_transport = transport.get();
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), ssl_config));
+
+ sock->SetHandshakeCompletionCallback(base::Bind(
+ &SSLClientSocketTest::RecordCompletedHandshake, base::Unretained(this)));
+
+ raw_transport->SetNextWriteError(ERR_CONNECTION_RESET);
+
+ rv = callback.GetResult(sock->Connect(callback.callback()));
+ EXPECT_EQ(ERR_CONNECTION_RESET, rv);
+ EXPECT_FALSE(sock->IsConnected());
+
+ EXPECT_TRUE(ran_handshake_completion_callback_);
+}
+
+// Tests that the completion callback is run when an SSL connection
+// completes successfully.
+TEST_F(SSLClientSocketTest, HandshakeCallbackIsRun_WithSuccess) {
+ SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
+ SpawnedTestServer::kLocalhost,
+ base::FilePath());
+ ASSERT_TRUE(test_server.Start());
+
+ AddressList addr;
+ ASSERT_TRUE(test_server.GetAddressList(&addr));
+
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, NULL, NetLog::Source()));
+
+ TestCompletionCallback callback;
+ int rv = transport->Connect(callback.callback());
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+ EXPECT_EQ(OK, rv);
+
+ SSLConfig ssl_config = kDefaultSSLConfig;
+ ssl_config.false_start_enabled = false;
+
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), ssl_config));
+
+ sock->SetHandshakeCompletionCallback(base::Bind(
+ &SSLClientSocketTest::RecordCompletedHandshake, base::Unretained(this)));
+
+ rv = callback.GetResult(sock->Connect(callback.callback()));
+
+ EXPECT_EQ(OK, rv);
+ EXPECT_TRUE(sock->IsConnected());
+ EXPECT_TRUE(ran_handshake_completion_callback_);
+}
+
+// Tests that the completion callback is run with a server that doesn't cache
+// sessions.
+TEST_F(SSLClientSocketTest, HandshakeCallbackIsRun_WithDisabledSessionCache) {
+ SpawnedTestServer::SSLOptions ssl_options;
+ ssl_options.disable_session_cache = true;
+ SpawnedTestServer test_server(
+ SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath());
+ ASSERT_TRUE(test_server.Start());
+
+ AddressList addr;
+ ASSERT_TRUE(test_server.GetAddressList(&addr));
+
+ scoped_ptr<StreamSocket> transport(
+ new TCPClientSocket(addr, NULL, NetLog::Source()));
+
+ TestCompletionCallback callback;
+ int rv = transport->Connect(callback.callback());
+ if (rv == ERR_IO_PENDING)
+ rv = callback.WaitForResult();
+ EXPECT_EQ(OK, rv);
+
+ SSLConfig ssl_config = kDefaultSSLConfig;
+ ssl_config.false_start_enabled = false;
+
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
+ transport.Pass(), test_server.host_port_pair(), ssl_config));
+
+ sock->SetHandshakeCompletionCallback(base::Bind(
+ &SSLClientSocketTest::RecordCompletedHandshake, base::Unretained(this)));
+
+ rv = callback.GetResult(sock->Connect(callback.callback()));
+
+ EXPECT_EQ(OK, rv);
+ EXPECT_TRUE(sock->IsConnected());
+ EXPECT_TRUE(ran_handshake_completion_callback_);
+}
+
+TEST_F(SSLClientSocketFalseStartTest,
+ HandshakeCallbackIsRun_WithFalseStartFailure) {
+ // False Start requires NPN and a forward-secret cipher suite.
+ SpawnedTestServer::SSLOptions server_options;
+ server_options.key_exchanges =
+ SpawnedTestServer::SSLOptions::KEY_EXCHANGE_DHE_RSA;
+ server_options.enable_npn = true;
+ SSLConfig client_config;
+ client_config.next_protos.push_back("http/1.1");
+ monitor_handshake_callback_ = true;
+ fail_handshake_after_false_start_ = true;
+ ASSERT_NO_FATAL_FAILURE(TestFalseStart(server_options, client_config, true));
+ ASSERT_TRUE(ran_handshake_completion_callback_);
+}
+
+TEST_F(SSLClientSocketFalseStartTest,
+ HandshakeCallbackIsRun_WithFalseStartSuccess) {
+ // False Start requires NPN and a forward-secret cipher suite.
+ SpawnedTestServer::SSLOptions server_options;
+ server_options.key_exchanges =
+ SpawnedTestServer::SSLOptions::KEY_EXCHANGE_DHE_RSA;
+ server_options.enable_npn = true;
+ SSLConfig client_config;
+ client_config.next_protos.push_back("http/1.1");
+ monitor_handshake_callback_ = true;
+ ASSERT_NO_FATAL_FAILURE(TestFalseStart(server_options, client_config, true));
+ ASSERT_TRUE(ran_handshake_completion_callback_);
+}
+#endif // defined(USE_OPENSSL)
+
TEST_F(SSLClientSocketFalseStartTest, FalseStartEnabled) {
// False Start requires NPN and a forward-secret cipher suite.
SpawnedTestServer::SSLOptions server_options;
@@ -2718,8 +3039,8 @@ TEST_F(SSLClientSocketChannelIDTest, SendChannelID) {
EXPECT_FALSE(sock_->IsConnected());
}
-// Connect to a server using channel id but without sending a key. It should
-// fail.
+// Connect to a server using Channel ID but failing to look up the Channel
+// ID. It should fail.
TEST_F(SSLClientSocketChannelIDTest, FailingChannelID) {
SpawnedTestServer::SSLOptions ssl_options;
@@ -2740,4 +3061,22 @@ TEST_F(SSLClientSocketChannelIDTest, FailingChannelID) {
EXPECT_FALSE(sock_->IsConnected());
}
+// Connect to a server using Channel ID but asynchronously failing to look up
+// the Channel ID. It should fail.
+TEST_F(SSLClientSocketChannelIDTest, FailingChannelIDAsync) {
+ SpawnedTestServer::SSLOptions ssl_options;
+
+ ASSERT_TRUE(ConnectToTestServer(ssl_options));
+
+ EnableAsyncFailingChannelID();
+ SSLConfig ssl_config = kDefaultSSLConfig;
+ ssl_config.channel_id_enabled = true;
+
+ int rv;
+ ASSERT_TRUE(CreateAndConnectSSLClientSocket(ssl_config, &rv));
+
+ EXPECT_EQ(ERR_UNEXPECTED, rv);
+ EXPECT_FALSE(sock_->IsConnected());
+}
+
} // namespace net
diff --git a/chromium/net/socket/ssl_error_params.cc b/chromium/net/socket/ssl_error_params.cc
deleted file mode 100644
index 37561f0de48..00000000000
--- a/chromium/net/socket/ssl_error_params.cc
+++ /dev/null
@@ -1,31 +0,0 @@
-// Copyright (c) 2012 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#include "net/socket/ssl_error_params.h"
-
-#include "base/bind.h"
-#include "base/values.h"
-
-namespace net {
-
-namespace {
-
-base::Value* NetLogSSLErrorCallback(int net_error,
- int ssl_lib_error,
- NetLog::LogLevel /* log_level */) {
- base::DictionaryValue* dict = new base::DictionaryValue();
- dict->SetInteger("net_error", net_error);
- if (ssl_lib_error)
- dict->SetInteger("ssl_lib_error", ssl_lib_error);
- return dict;
-}
-
-} // namespace
-
-NetLog::ParametersCallback CreateNetLogSSLErrorCallback(int net_error,
- int ssl_lib_error) {
- return base::Bind(&NetLogSSLErrorCallback, net_error, ssl_lib_error);
-}
-
-} // namespace net
diff --git a/chromium/net/socket/ssl_error_params.h b/chromium/net/socket/ssl_error_params.h
deleted file mode 100644
index 07a1c4d99d9..00000000000
--- a/chromium/net/socket/ssl_error_params.h
+++ /dev/null
@@ -1,18 +0,0 @@
-// Copyright (c) 2012 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#ifndef NET_SOCKET_SSL_ERROR_PARAMS_H_
-#define NET_SOCKET_SSL_ERROR_PARAMS_H_
-
-#include "net/base/net_log.h"
-
-namespace net {
-
-// Creates NetLog callback for when we receive an SSL error.
-NetLog::ParametersCallback CreateNetLogSSLErrorCallback(int net_error,
- int ssl_lib_error);
-
-} // namespace net
-
-#endif // NET_SOCKET_SSL_ERROR_PARAMS_H_
diff --git a/chromium/net/socket/ssl_server_socket.h b/chromium/net/socket/ssl_server_socket.h
index 8b607bf80cf..88f7f941439 100644
--- a/chromium/net/socket/ssl_server_socket.h
+++ b/chromium/net/socket/ssl_server_socket.h
@@ -23,7 +23,7 @@ class X509Certificate;
class SSLServerSocket : public SSLSocket {
public:
- virtual ~SSLServerSocket() {}
+ ~SSLServerSocket() override {}
// Perform the SSL server handshake, and notify the supplied callback
// if the process completes asynchronously. If Disconnect is called before
diff --git a/chromium/net/socket/ssl_server_socket_nss.cc b/chromium/net/socket/ssl_server_socket_nss.cc
index cfdc5c545ef..7fa5835b430 100644
--- a/chromium/net/socket/ssl_server_socket_nss.cc
+++ b/chromium/net/socket/ssl_server_socket_nss.cc
@@ -38,7 +38,6 @@
#include "net/base/net_errors.h"
#include "net/base/net_log.h"
#include "net/socket/nss_ssl_util.h"
-#include "net/socket/ssl_error_params.h"
// SSL plaintext fragments are shorter than 16KB. Although the record layer
// overhead is allowed to be 2K + 5 bytes, in practice the overhead is much
@@ -60,7 +59,7 @@ class NSSSSLServerInitSingleton {
NSSSSLServerInitSingleton() {
EnsureNSSSSLInit();
- SSL_ConfigServerSessionIDCache(1024, 5, 5, NULL);
+ SSL_ConfigServerSessionIDCache(64, 28800, 28800, NULL);
g_nss_server_sockets_init = true;
}
@@ -85,7 +84,7 @@ scoped_ptr<SSLServerSocket> CreateSSLServerSocket(
crypto::RSAPrivateKey* key,
const SSLConfig& ssl_config) {
DCHECK(g_nss_server_sockets_init) << "EnableSSLServerSockets() has not been"
- << "called yet!";
+ << " called yet!";
return scoped_ptr<SSLServerSocket>(
new SSLServerSocketNSS(socket.Pass(), cert, key, ssl_config));
@@ -446,7 +445,7 @@ int SSLServerSocketNSS::InitializeSSLOptions() {
}
SECKEYPrivateKeyStr* private_key = NULL;
- PK11SlotInfo* slot = crypto::GetPrivateNSSKeySlot();
+ PK11SlotInfo* slot = PK11_GetInternalSlot();
if (!slot) {
CERT_DestroyCertificate(cert);
return ERR_UNEXPECTED;
diff --git a/chromium/net/socket/ssl_server_socket_nss.h b/chromium/net/socket/ssl_server_socket_nss.h
index bc5b65d5368..d40b096577c 100644
--- a/chromium/net/socket/ssl_server_socket_nss.h
+++ b/chromium/net/socket/ssl_server_socket_nss.h
@@ -28,42 +28,44 @@ class SSLServerSocketNSS : public SSLServerSocket {
scoped_refptr<X509Certificate> certificate,
crypto::RSAPrivateKey* key,
const SSLConfig& ssl_config);
- virtual ~SSLServerSocketNSS();
+ ~SSLServerSocketNSS() override;
// SSLServerSocket interface.
- virtual int Handshake(const CompletionCallback& callback) OVERRIDE;
+ int Handshake(const CompletionCallback& callback) override;
// SSLSocket interface.
- virtual int ExportKeyingMaterial(const base::StringPiece& label,
- bool has_context,
- const base::StringPiece& context,
- unsigned char* out,
- unsigned int outlen) OVERRIDE;
- virtual int GetTLSUniqueChannelBinding(std::string* out) OVERRIDE;
+ int ExportKeyingMaterial(const base::StringPiece& label,
+ bool has_context,
+ const base::StringPiece& context,
+ unsigned char* out,
+ unsigned int outlen) override;
+ int GetTLSUniqueChannelBinding(std::string* out) override;
// Socket interface (via StreamSocket).
- virtual int Read(IOBuffer* buf, int buf_len,
- const CompletionCallback& callback) OVERRIDE;
- virtual int Write(IOBuffer* buf, int buf_len,
- const CompletionCallback& callback) OVERRIDE;
- virtual int SetReceiveBufferSize(int32 size) OVERRIDE;
- virtual int SetSendBufferSize(int32 size) OVERRIDE;
+ int Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override;
+ int Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override;
+ int SetReceiveBufferSize(int32 size) override;
+ int SetSendBufferSize(int32 size) override;
// StreamSocket implementation.
- virtual int Connect(const CompletionCallback& callback) OVERRIDE;
- virtual void Disconnect() OVERRIDE;
- virtual bool IsConnected() const OVERRIDE;
- virtual bool IsConnectedAndIdle() const OVERRIDE;
- virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE;
- virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE;
- virtual const BoundNetLog& NetLog() const OVERRIDE;
- virtual void SetSubresourceSpeculation() OVERRIDE;
- virtual void SetOmniboxSpeculation() OVERRIDE;
- virtual bool WasEverUsed() const OVERRIDE;
- virtual bool UsingTCPFastOpen() const OVERRIDE;
- virtual bool WasNpnNegotiated() const OVERRIDE;
- virtual NextProto GetNegotiatedProtocol() const OVERRIDE;
- virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE;
+ int Connect(const CompletionCallback& callback) override;
+ void Disconnect() override;
+ bool IsConnected() const override;
+ bool IsConnectedAndIdle() const override;
+ int GetPeerAddress(IPEndPoint* address) const override;
+ int GetLocalAddress(IPEndPoint* address) const override;
+ const BoundNetLog& NetLog() const override;
+ void SetSubresourceSpeculation() override;
+ void SetOmniboxSpeculation() override;
+ bool WasEverUsed() const override;
+ bool UsingTCPFastOpen() const override;
+ bool WasNpnNegotiated() const override;
+ NextProto GetNegotiatedProtocol() const override;
+ bool GetSSLInfo(SSLInfo* ssl_info) override;
private:
enum State {
diff --git a/chromium/net/socket/ssl_server_socket_openssl.cc b/chromium/net/socket/ssl_server_socket_openssl.cc
index f6bd0cd8f81..29d1ffa0508 100644
--- a/chromium/net/socket/ssl_server_socket_openssl.cc
+++ b/chromium/net/socket/ssl_server_socket_openssl.cc
@@ -11,9 +11,9 @@
#include "base/logging.h"
#include "crypto/openssl_util.h"
#include "crypto/rsa_private_key.h"
+#include "crypto/scoped_openssl_types.h"
#include "net/base/net_errors.h"
-#include "net/socket/openssl_ssl_util.h"
-#include "net/socket/ssl_error_params.h"
+#include "net/ssl/openssl_ssl_util.h"
#define GotoState(s) next_handshake_state_ = s
@@ -304,7 +304,7 @@ int SSLServerSocketOpenSSL::BufferSend() {
if (!send_buffer_.get()) {
// Get a fresh send buffer out of the send BIO.
- size_t max_read = BIO_ctrl_pending(transport_bio_);
+ size_t max_read = BIO_pending(transport_bio_);
if (!max_read)
return 0; // Nothing pending in the OpenSSL write BIO.
send_buffer_ = new DrainableIOBuffer(new IOBuffer(max_read), max_read);
@@ -456,10 +456,13 @@ int SSLServerSocketOpenSSL::DoPayloadRead() {
if (rv >= 0)
return rv;
int ssl_error = SSL_get_error(ssl_, rv);
- int net_error = MapOpenSSLError(ssl_error, err_tracer);
+ OpenSSLErrorInfo error_info;
+ int net_error = MapOpenSSLErrorWithDetails(ssl_error, err_tracer,
+ &error_info);
if (net_error != ERR_IO_PENDING) {
- net_log_.AddEvent(NetLog::TYPE_SSL_READ_ERROR,
- CreateNetLogSSLErrorCallback(net_error, ssl_error));
+ net_log_.AddEvent(
+ NetLog::TYPE_SSL_READ_ERROR,
+ CreateNetLogOpenSSLErrorCallback(net_error, ssl_error, error_info));
}
return net_error;
}
@@ -471,10 +474,13 @@ int SSLServerSocketOpenSSL::DoPayloadWrite() {
if (rv >= 0)
return rv;
int ssl_error = SSL_get_error(ssl_, rv);
- int net_error = MapOpenSSLError(ssl_error, err_tracer);
+ OpenSSLErrorInfo error_info;
+ int net_error = MapOpenSSLErrorWithDetails(ssl_error, err_tracer,
+ &error_info);
if (net_error != ERR_IO_PENDING) {
- net_log_.AddEvent(NetLog::TYPE_SSL_WRITE_ERROR,
- CreateNetLogSSLErrorCallback(net_error, ssl_error));
+ net_log_.AddEvent(
+ NetLog::TYPE_SSL_WRITE_ERROR,
+ CreateNetLogOpenSSLErrorCallback(net_error, ssl_error, error_info));
}
return net_error;
}
@@ -553,7 +559,8 @@ int SSLServerSocketOpenSSL::DoHandshake() {
completed_handshake_ = true;
} else {
int ssl_error = SSL_get_error(ssl_, rv);
- net_error = MapOpenSSLError(ssl_error, err_tracer);
+ OpenSSLErrorInfo error_info;
+ net_error = MapOpenSSLErrorWithDetails(ssl_error, err_tracer, &error_info);
// If not done, stay in this state
if (net_error == ERR_IO_PENDING) {
@@ -562,8 +569,9 @@ int SSLServerSocketOpenSSL::DoHandshake() {
LOG(ERROR) << "handshake failed; returned " << rv
<< ", SSL error code " << ssl_error
<< ", net_error " << net_error;
- net_log_.AddEvent(NetLog::TYPE_SSL_HANDSHAKE_ERROR,
- CreateNetLogSSLErrorCallback(net_error, ssl_error));
+ net_log_.AddEvent(
+ NetLog::TYPE_SSL_HANDSHAKE_ERROR,
+ CreateNetLogOpenSSLErrorCallback(net_error, ssl_error, error_info));
}
}
return net_error;
@@ -598,7 +606,7 @@ int SSLServerSocketOpenSSL::Init() {
crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE);
- crypto::ScopedOpenSSL<SSL_CTX, SSL_CTX_free> ssl_ctx(
+ crypto::ScopedOpenSSL<SSL_CTX, SSL_CTX_free>::Type ssl_ctx(
// It support SSLv2, SSLv3, and TLSv1.
SSL_CTX_new(SSLv23_server_method()));
ssl_ = SSL_new(ssl_ctx.get());
@@ -630,8 +638,8 @@ int SSLServerSocketOpenSSL::Init() {
const unsigned char* der_string_array =
reinterpret_cast<const unsigned char*>(der_string.data());
- crypto::ScopedOpenSSL<X509, X509_free>
- x509(d2i_X509(NULL, &der_string_array, der_string.length()));
+ crypto::ScopedOpenSSL<X509, X509_free>::Type x509(
+ d2i_X509(NULL, &der_string_array, der_string.length()));
if (!x509.get())
return ERR_UNEXPECTED;
diff --git a/chromium/net/socket/ssl_server_socket_openssl.h b/chromium/net/socket/ssl_server_socket_openssl.h
index e1c8aad7aea..c58bd569352 100644
--- a/chromium/net/socket/ssl_server_socket_openssl.h
+++ b/chromium/net/socket/ssl_server_socket_openssl.h
@@ -30,42 +30,44 @@ class SSLServerSocketOpenSSL : public SSLServerSocket {
scoped_refptr<X509Certificate> certificate,
crypto::RSAPrivateKey* key,
const SSLConfig& ssl_config);
- virtual ~SSLServerSocketOpenSSL();
+ ~SSLServerSocketOpenSSL() override;
// SSLServerSocket interface.
- virtual int Handshake(const CompletionCallback& callback) OVERRIDE;
+ int Handshake(const CompletionCallback& callback) override;
// SSLSocket interface.
- virtual int ExportKeyingMaterial(const base::StringPiece& label,
- bool has_context,
- const base::StringPiece& context,
- unsigned char* out,
- unsigned int outlen) OVERRIDE;
- virtual int GetTLSUniqueChannelBinding(std::string* out) OVERRIDE;
+ int ExportKeyingMaterial(const base::StringPiece& label,
+ bool has_context,
+ const base::StringPiece& context,
+ unsigned char* out,
+ unsigned int outlen) override;
+ int GetTLSUniqueChannelBinding(std::string* out) override;
// Socket interface (via StreamSocket).
- virtual int Read(IOBuffer* buf, int buf_len,
- const CompletionCallback& callback) OVERRIDE;
- virtual int Write(IOBuffer* buf, int buf_len,
- const CompletionCallback& callback) OVERRIDE;
- virtual int SetReceiveBufferSize(int32 size) OVERRIDE;
- virtual int SetSendBufferSize(int32 size) OVERRIDE;
+ int Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override;
+ int Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override;
+ int SetReceiveBufferSize(int32 size) override;
+ int SetSendBufferSize(int32 size) override;
// StreamSocket implementation.
- virtual int Connect(const CompletionCallback& callback) OVERRIDE;
- virtual void Disconnect() OVERRIDE;
- virtual bool IsConnected() const OVERRIDE;
- virtual bool IsConnectedAndIdle() const OVERRIDE;
- virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE;
- virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE;
- virtual const BoundNetLog& NetLog() const OVERRIDE;
- virtual void SetSubresourceSpeculation() OVERRIDE;
- virtual void SetOmniboxSpeculation() OVERRIDE;
- virtual bool WasEverUsed() const OVERRIDE;
- virtual bool UsingTCPFastOpen() const OVERRIDE;
- virtual bool WasNpnNegotiated() const OVERRIDE;
- virtual NextProto GetNegotiatedProtocol() const OVERRIDE;
- virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE;
+ int Connect(const CompletionCallback& callback) override;
+ void Disconnect() override;
+ bool IsConnected() const override;
+ bool IsConnectedAndIdle() const override;
+ int GetPeerAddress(IPEndPoint* address) const override;
+ int GetLocalAddress(IPEndPoint* address) const override;
+ const BoundNetLog& NetLog() const override;
+ void SetSubresourceSpeculation() override;
+ void SetOmniboxSpeculation() override;
+ bool WasEverUsed() const override;
+ bool UsingTCPFastOpen() const override;
+ bool WasNpnNegotiated() const override;
+ NextProto GetNegotiatedProtocol() const override;
+ bool GetSSLInfo(SSLInfo* ssl_info) override;
private:
enum State {
diff --git a/chromium/net/socket/ssl_server_socket_unittest.cc b/chromium/net/socket/ssl_server_socket_unittest.cc
index 9da10f6bbd0..5fabdd2e90e 100644
--- a/chromium/net/socket/ssl_server_socket_unittest.cc
+++ b/chromium/net/socket/ssl_server_socket_unittest.cc
@@ -20,10 +20,9 @@
#include <queue>
#include "base/compiler_specific.h"
-#include "base/file_util.h"
#include "base/files/file_path.h"
+#include "base/files/file_util.h"
#include "base/message_loop/message_loop.h"
-#include "base/path_service.h"
#include "crypto/nss_util.h"
#include "crypto/rsa_private_key.h"
#include "net/base/address_list.h"
@@ -62,29 +61,35 @@ class FakeDataChannel {
}
int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback) {
+ DCHECK(read_callback_.is_null());
+ DCHECK(!read_buf_.get());
if (closed_)
return 0;
if (data_.empty()) {
read_callback_ = callback;
read_buf_ = buf;
read_buf_len_ = buf_len;
- return net::ERR_IO_PENDING;
+ return ERR_IO_PENDING;
}
return PropogateData(buf, buf_len);
}
int Write(IOBuffer* buf, int buf_len, const CompletionCallback& callback) {
+ DCHECK(write_callback_.is_null());
if (closed_) {
if (write_called_after_close_)
- return net::ERR_CONNECTION_RESET;
+ return ERR_CONNECTION_RESET;
write_called_after_close_ = true;
write_callback_ = callback;
base::MessageLoop::current()->PostTask(
FROM_HERE, base::Bind(&FakeDataChannel::DoWriteCallback,
weak_factory_.GetWeakPtr()));
- return net::ERR_IO_PENDING;
+ return ERR_IO_PENDING;
}
- data_.push(new net::DrainableIOBuffer(buf, buf_len));
+ // This function returns synchronously, so make a copy of the buffer.
+ data_.push(new DrainableIOBuffer(
+ new StringIOBuffer(std::string(buf->data(), buf_len)),
+ buf_len));
base::MessageLoop::current()->PostTask(
FROM_HERE, base::Bind(&FakeDataChannel::DoReadCallback,
weak_factory_.GetWeakPtr()));
@@ -118,11 +123,11 @@ class FakeDataChannel {
CompletionCallback callback = write_callback_;
write_callback_.Reset();
- callback.Run(net::ERR_CONNECTION_RESET);
+ callback.Run(ERR_CONNECTION_RESET);
}
- int PropogateData(scoped_refptr<net::IOBuffer> read_buf, int read_buf_len) {
- scoped_refptr<net::DrainableIOBuffer> buf = data_.front();
+ int PropogateData(scoped_refptr<IOBuffer> read_buf, int read_buf_len) {
+ scoped_refptr<DrainableIOBuffer> buf = data_.front();
int copied = std::min(buf->BytesRemaining(), read_buf_len);
memcpy(read_buf->data(), buf->data(), copied);
buf->DidConsume(copied);
@@ -133,12 +138,12 @@ class FakeDataChannel {
}
CompletionCallback read_callback_;
- scoped_refptr<net::IOBuffer> read_buf_;
+ scoped_refptr<IOBuffer> read_buf_;
int read_buf_len_;
CompletionCallback write_callback_;
- std::queue<scoped_refptr<net::DrainableIOBuffer> > data_;
+ std::queue<scoped_refptr<DrainableIOBuffer> > data_;
// True if Close() has been called.
bool closed_;
@@ -161,90 +166,68 @@ class FakeSocket : public StreamSocket {
outgoing_(outgoing_channel) {
}
- virtual ~FakeSocket() {
- }
+ ~FakeSocket() override {}
- virtual int Read(IOBuffer* buf, int buf_len,
- const CompletionCallback& callback) OVERRIDE {
+ int Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override {
// Read random number of bytes.
buf_len = rand() % buf_len + 1;
return incoming_->Read(buf, buf_len, callback);
}
- virtual int Write(IOBuffer* buf, int buf_len,
- const CompletionCallback& callback) OVERRIDE {
+ int Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override {
// Write random number of bytes.
buf_len = rand() % buf_len + 1;
return outgoing_->Write(buf, buf_len, callback);
}
- virtual int SetReceiveBufferSize(int32 size) OVERRIDE {
- return net::OK;
- }
+ int SetReceiveBufferSize(int32 size) override { return OK; }
- virtual int SetSendBufferSize(int32 size) OVERRIDE {
- return net::OK;
- }
+ int SetSendBufferSize(int32 size) override { return OK; }
- virtual int Connect(const CompletionCallback& callback) OVERRIDE {
- return net::OK;
- }
+ int Connect(const CompletionCallback& callback) override { return OK; }
- virtual void Disconnect() OVERRIDE {
+ void Disconnect() override {
incoming_->Close();
outgoing_->Close();
}
- virtual bool IsConnected() const OVERRIDE {
- return true;
- }
+ bool IsConnected() const override { return true; }
- virtual bool IsConnectedAndIdle() const OVERRIDE {
- return true;
- }
+ bool IsConnectedAndIdle() const override { return true; }
- virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE {
- net::IPAddressNumber ip_address(net::kIPv4AddressSize);
- *address = net::IPEndPoint(ip_address, 0 /*port*/);
- return net::OK;
+ int GetPeerAddress(IPEndPoint* address) const override {
+ IPAddressNumber ip_address(kIPv4AddressSize);
+ *address = IPEndPoint(ip_address, 0 /*port*/);
+ return OK;
}
- virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE {
- net::IPAddressNumber ip_address(4);
- *address = net::IPEndPoint(ip_address, 0);
- return net::OK;
+ int GetLocalAddress(IPEndPoint* address) const override {
+ IPAddressNumber ip_address(4);
+ *address = IPEndPoint(ip_address, 0);
+ return OK;
}
- virtual const BoundNetLog& NetLog() const OVERRIDE {
- return net_log_;
- }
+ const BoundNetLog& NetLog() const override { return net_log_; }
- virtual void SetSubresourceSpeculation() OVERRIDE {}
- virtual void SetOmniboxSpeculation() OVERRIDE {}
+ void SetSubresourceSpeculation() override {}
+ void SetOmniboxSpeculation() override {}
- virtual bool WasEverUsed() const OVERRIDE {
- return true;
- }
-
- virtual bool UsingTCPFastOpen() const OVERRIDE {
- return false;
- }
+ bool WasEverUsed() const override { return true; }
+ bool UsingTCPFastOpen() const override { return false; }
- virtual bool WasNpnNegotiated() const OVERRIDE {
- return false;
- }
+ bool WasNpnNegotiated() const override { return false; }
- virtual NextProto GetNegotiatedProtocol() const OVERRIDE {
- return kProtoUnknown;
- }
+ NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
- virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE {
- return false;
- }
+ bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
private:
- net::BoundNetLog net_log_;
+ BoundNetLog net_log_;
FakeDataChannel* incoming_;
FakeDataChannel* outgoing_;
@@ -264,8 +247,8 @@ TEST(FakeSocketTest, DataTransfer) {
const char kTestData[] = "testing123";
const int kTestDataSize = strlen(kTestData);
const int kReadBufSize = 1024;
- scoped_refptr<net::IOBuffer> write_buf = new net::StringIOBuffer(kTestData);
- scoped_refptr<net::IOBuffer> read_buf = new net::IOBuffer(kReadBufSize);
+ scoped_refptr<IOBuffer> write_buf = new StringIOBuffer(kTestData);
+ scoped_refptr<IOBuffer> read_buf = new IOBuffer(kReadBufSize);
// Write then read.
int written =
@@ -280,7 +263,7 @@ TEST(FakeSocketTest, DataTransfer) {
// Read then write.
TestCompletionCallback callback;
- EXPECT_EQ(net::ERR_IO_PENDING,
+ EXPECT_EQ(ERR_IO_PENDING,
server.Read(read_buf.get(), kReadBufSize, callback.callback()));
written = client.Write(write_buf.get(), kTestDataSize, CompletionCallback());
@@ -296,10 +279,10 @@ TEST(FakeSocketTest, DataTransfer) {
class SSLServerSocketTest : public PlatformTest {
public:
SSLServerSocketTest()
- : socket_factory_(net::ClientSocketFactory::GetDefaultFactory()),
+ : socket_factory_(ClientSocketFactory::GetDefaultFactory()),
cert_verifier_(new MockCertVerifier()),
transport_security_state_(new TransportSecurityState) {
- cert_verifier_->set_default_result(net::CERT_STATUS_AUTHORITY_INVALID);
+ cert_verifier_->set_default_result(CERT_STATUS_AUTHORITY_INVALID);
}
protected:
@@ -316,7 +299,7 @@ class SSLServerSocketTest : public PlatformTest {
std::string cert_der;
ASSERT_TRUE(base::ReadFileToString(cert_path, &cert_der));
- scoped_refptr<net::X509Certificate> cert =
+ scoped_refptr<X509Certificate> cert =
X509Certificate::CreateFromBytes(cert_der.data(), cert_der.size());
base::FilePath key_path = certs_dir.AppendASCII("unittest.key.bin");
@@ -330,35 +313,35 @@ class SSLServerSocketTest : public PlatformTest {
scoped_ptr<crypto::RSAPrivateKey> private_key(
crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector));
- net::SSLConfig ssl_config;
+ SSLConfig ssl_config;
ssl_config.false_start_enabled = false;
ssl_config.channel_id_enabled = false;
// Certificate provided by the host doesn't need authority.
- net::SSLConfig::CertAndStatus cert_and_status;
+ SSLConfig::CertAndStatus cert_and_status;
cert_and_status.cert_status = CERT_STATUS_AUTHORITY_INVALID;
cert_and_status.der_cert = cert_der;
ssl_config.allowed_bad_certs.push_back(cert_and_status);
- net::HostPortPair host_and_pair("unittest", 0);
- net::SSLClientSocketContext context;
+ HostPortPair host_and_pair("unittest", 0);
+ SSLClientSocketContext context;
context.cert_verifier = cert_verifier_.get();
context.transport_security_state = transport_security_state_.get();
client_socket_ =
socket_factory_->CreateSSLClientSocket(
client_connection.Pass(), host_and_pair, ssl_config, context);
- server_socket_ = net::CreateSSLServerSocket(
+ server_socket_ = CreateSSLServerSocket(
server_socket.Pass(),
- cert.get(), private_key.get(), net::SSLConfig());
+ cert.get(), private_key.get(), SSLConfig());
}
FakeDataChannel channel_1_;
FakeDataChannel channel_2_;
- scoped_ptr<net::SSLClientSocket> client_socket_;
- scoped_ptr<net::SSLServerSocket> server_socket_;
- net::ClientSocketFactory* socket_factory_;
- scoped_ptr<net::MockCertVerifier> cert_verifier_;
- scoped_ptr<net::TransportSecurityState> transport_security_state_;
+ scoped_ptr<SSLClientSocket> client_socket_;
+ scoped_ptr<SSLServerSocket> server_socket_;
+ ClientSocketFactory* socket_factory_;
+ scoped_ptr<MockCertVerifier> cert_verifier_;
+ scoped_ptr<TransportSecurityState> transport_security_state_;
};
// This test only executes creation of client and server sockets. This is to
@@ -378,16 +361,16 @@ TEST_F(SSLServerSocketTest, Handshake) {
TestCompletionCallback handshake_callback;
int server_ret = server_socket_->Handshake(handshake_callback.callback());
- EXPECT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING);
+ EXPECT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING);
int client_ret = client_socket_->Connect(connect_callback.callback());
- EXPECT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING);
+ EXPECT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING);
- if (client_ret == net::ERR_IO_PENDING) {
- EXPECT_EQ(net::OK, connect_callback.WaitForResult());
+ if (client_ret == ERR_IO_PENDING) {
+ EXPECT_EQ(OK, connect_callback.WaitForResult());
}
- if (server_ret == net::ERR_IO_PENDING) {
- EXPECT_EQ(net::OK, handshake_callback.WaitForResult());
+ if (server_ret == ERR_IO_PENDING) {
+ EXPECT_EQ(OK, handshake_callback.WaitForResult());
}
// Make sure the cert status is expected.
@@ -404,32 +387,31 @@ TEST_F(SSLServerSocketTest, DataTransfer) {
// Establish connection.
int client_ret = client_socket_->Connect(connect_callback.callback());
- ASSERT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING);
+ ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING);
int server_ret = server_socket_->Handshake(handshake_callback.callback());
- ASSERT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING);
+ ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING);
client_ret = connect_callback.GetResult(client_ret);
- ASSERT_EQ(net::OK, client_ret);
+ ASSERT_EQ(OK, client_ret);
server_ret = handshake_callback.GetResult(server_ret);
- ASSERT_EQ(net::OK, server_ret);
+ ASSERT_EQ(OK, server_ret);
const int kReadBufSize = 1024;
- scoped_refptr<net::StringIOBuffer> write_buf =
- new net::StringIOBuffer("testing123");
- scoped_refptr<net::DrainableIOBuffer> read_buf =
- new net::DrainableIOBuffer(new net::IOBuffer(kReadBufSize),
- kReadBufSize);
+ scoped_refptr<StringIOBuffer> write_buf =
+ new StringIOBuffer("testing123");
+ scoped_refptr<DrainableIOBuffer> read_buf =
+ new DrainableIOBuffer(new IOBuffer(kReadBufSize), kReadBufSize);
// Write then read.
TestCompletionCallback write_callback;
TestCompletionCallback read_callback;
server_ret = server_socket_->Write(
write_buf.get(), write_buf->size(), write_callback.callback());
- EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING);
+ EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING);
client_ret = client_socket_->Read(
read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
- EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING);
+ EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
server_ret = write_callback.GetResult(server_ret);
EXPECT_GT(server_ret, 0);
@@ -440,7 +422,7 @@ TEST_F(SSLServerSocketTest, DataTransfer) {
while (read_buf->BytesConsumed() < write_buf->size()) {
client_ret = client_socket_->Read(
read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
- EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING);
+ EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
client_ret = read_callback.GetResult(client_ret);
ASSERT_GT(client_ret, 0);
read_buf->DidConsume(client_ret);
@@ -450,13 +432,13 @@ TEST_F(SSLServerSocketTest, DataTransfer) {
EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size()));
// Read then write.
- write_buf = new net::StringIOBuffer("hello123");
+ write_buf = new StringIOBuffer("hello123");
server_ret = server_socket_->Read(
read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
- EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING);
+ EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING);
client_ret = client_socket_->Write(
write_buf.get(), write_buf->size(), write_callback.callback());
- EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING);
+ EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
server_ret = read_callback.GetResult(server_ret);
ASSERT_GT(server_ret, 0);
@@ -467,7 +449,7 @@ TEST_F(SSLServerSocketTest, DataTransfer) {
while (read_buf->BytesConsumed() < write_buf->size()) {
server_ret = server_socket_->Read(
read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
- EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING);
+ EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING);
server_ret = read_callback.GetResult(server_ret);
ASSERT_GT(server_ret, 0);
read_buf->DidConsume(server_ret);
@@ -477,17 +459,11 @@ TEST_F(SSLServerSocketTest, DataTransfer) {
EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size()));
}
-// Flaky on Android: http://crbug.com/381147
-#if defined(OS_ANDROID)
-#define MAYBE_ClientWriteAfterServerClose DISABLED_ClientWriteAfterServerClose
-#else
-#define MAYBE_ClientWriteAfterServerClose ClientWriteAfterServerClose
-#endif
// A regression test for bug 127822 (http://crbug.com/127822).
// If the server closes the connection after the handshake is finished,
// the client's Write() call should not cause an infinite loop.
// NOTE: this is a test for SSLClientSocket rather than SSLServerSocket.
-TEST_F(SSLServerSocketTest, MAYBE_ClientWriteAfterServerClose) {
+TEST_F(SSLServerSocketTest, ClientWriteAfterServerClose) {
Initialize();
TestCompletionCallback connect_callback;
@@ -495,18 +471,17 @@ TEST_F(SSLServerSocketTest, MAYBE_ClientWriteAfterServerClose) {
// Establish connection.
int client_ret = client_socket_->Connect(connect_callback.callback());
- ASSERT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING);
+ ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING);
int server_ret = server_socket_->Handshake(handshake_callback.callback());
- ASSERT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING);
+ ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING);
client_ret = connect_callback.GetResult(client_ret);
- ASSERT_EQ(net::OK, client_ret);
+ ASSERT_EQ(OK, client_ret);
server_ret = handshake_callback.GetResult(server_ret);
- ASSERT_EQ(net::OK, server_ret);
+ ASSERT_EQ(OK, server_ret);
- scoped_refptr<net::StringIOBuffer> write_buf =
- new net::StringIOBuffer("testing123");
+ scoped_refptr<StringIOBuffer> write_buf = new StringIOBuffer("testing123");
// The server closes the connection. The server needs to write some
// data first so that the client's Read() calls from the transport
@@ -516,7 +491,7 @@ TEST_F(SSLServerSocketTest, MAYBE_ClientWriteAfterServerClose) {
server_ret = server_socket_->Write(
write_buf.get(), write_buf->size(), write_callback.callback());
- EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING);
+ EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING);
server_ret = write_callback.GetResult(server_ret);
EXPECT_GT(server_ret, 0);
@@ -526,7 +501,7 @@ TEST_F(SSLServerSocketTest, MAYBE_ClientWriteAfterServerClose) {
// The client writes some data. This should not cause an infinite loop.
client_ret = client_socket_->Write(
write_buf.get(), write_buf->size(), write_callback.callback());
- EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING);
+ EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
client_ret = write_callback.GetResult(client_ret);
EXPECT_GT(client_ret, 0);
@@ -547,16 +522,16 @@ TEST_F(SSLServerSocketTest, ExportKeyingMaterial) {
TestCompletionCallback handshake_callback;
int client_ret = client_socket_->Connect(connect_callback.callback());
- ASSERT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING);
+ ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING);
int server_ret = server_socket_->Handshake(handshake_callback.callback());
- ASSERT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING);
+ ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING);
- if (client_ret == net::ERR_IO_PENDING) {
- ASSERT_EQ(net::OK, connect_callback.WaitForResult());
+ if (client_ret == ERR_IO_PENDING) {
+ ASSERT_EQ(OK, connect_callback.WaitForResult());
}
- if (server_ret == net::ERR_IO_PENDING) {
- ASSERT_EQ(net::OK, handshake_callback.WaitForResult());
+ if (server_ret == ERR_IO_PENDING) {
+ ASSERT_EQ(OK, handshake_callback.WaitForResult());
}
const int kKeyingMaterialSize = 32;
@@ -566,13 +541,13 @@ TEST_F(SSLServerSocketTest, ExportKeyingMaterial) {
int rv = server_socket_->ExportKeyingMaterial(kKeyingLabel,
false, kKeyingContext,
server_out, sizeof(server_out));
- ASSERT_EQ(net::OK, rv);
+ ASSERT_EQ(OK, rv);
unsigned char client_out[kKeyingMaterialSize];
rv = client_socket_->ExportKeyingMaterial(kKeyingLabel,
false, kKeyingContext,
client_out, sizeof(client_out));
- ASSERT_EQ(net::OK, rv);
+ ASSERT_EQ(OK, rv);
EXPECT_EQ(0, memcmp(server_out, client_out, sizeof(server_out)));
const char* kKeyingLabelBad = "EXPERIMENTAL-server-socket-test-bad";
@@ -580,7 +555,7 @@ TEST_F(SSLServerSocketTest, ExportKeyingMaterial) {
rv = client_socket_->ExportKeyingMaterial(kKeyingLabelBad,
false, kKeyingContext,
client_bad, sizeof(client_bad));
- ASSERT_EQ(rv, net::OK);
+ ASSERT_EQ(rv, OK);
EXPECT_NE(0, memcmp(server_out, client_bad, sizeof(server_out)));
}
diff --git a/chromium/net/socket/ssl_session_cache_openssl.cc b/chromium/net/socket/ssl_session_cache_openssl.cc
index d16bb8d6325..92ae44b9ac5 100644
--- a/chromium/net/socket/ssl_session_cache_openssl.cc
+++ b/chromium/net/socket/ssl_session_cache_openssl.cc
@@ -87,8 +87,9 @@ struct SessionId {
// this one is just simple enough to do the job.
size_t ComputeHash(const unsigned char* id, unsigned id_len) {
size_t result = 0;
- for (unsigned n = 0; n < id_len; ++n)
- result += 131 * id[n];
+ for (unsigned n = 0; n < id_len; ++n) {
+ result = (result * 131) + id[n];
+ }
return result;
}
};
@@ -236,10 +237,25 @@ class SSLSessionCacheOpenSSLImpl {
return SSL_set_session(ssl, session) == 1;
}
+ // Return true iff a cached session was associated with the given |cache_key|.
+ bool SSLSessionIsInCache(const std::string& cache_key) const {
+ base::AutoLock locked(lock_);
+ KeyIndex::const_iterator it = key_index_.find(cache_key);
+ if (it == key_index_.end())
+ return false;
+
+ SSL_SESSION* session = *it->second;
+ DCHECK(session);
+
+ void* session_is_good =
+ SSL_SESSION_get_ex_data(session, GetSSLSessionExIndex());
+
+ return session_is_good != NULL;
+ }
+
void MarkSSLSessionAsGood(SSL* ssl) {
SSL_SESSION* session = SSL_get_session(ssl);
- if (!session)
- return;
+ CHECK(session);
// Mark the session as good, allowing it to be used for future connections.
SSL_SESSION_set_ex_data(
@@ -342,7 +358,8 @@ class SSLSessionCacheOpenSSLImpl {
// to indicate that it took ownership of the session, i.e. that the caller
// should not decrement its reference count after completion.
static int NewSessionCallbackStatic(SSL* ssl, SSL_SESSION* session) {
- GetCache(ssl->ctx)->OnSessionAdded(ssl, session);
+ SSLSessionCacheOpenSSLImpl* cache = GetCache(ssl->ctx);
+ cache->OnSessionAdded(ssl, session);
return 1;
}
@@ -469,7 +486,7 @@ class SSLSessionCacheOpenSSLImpl {
// method to get the index which can later be used with SSL_CTX_get_ex_data()
// or SSL_CTX_set_ex_data().
- base::Lock lock_; // Protects access to containers below.
+ mutable base::Lock lock_; // Protects access to containers below.
MRUSessionList ordering_;
KeyIndex key_index_;
@@ -499,6 +516,11 @@ bool SSLSessionCacheOpenSSL::SetSSLSessionWithKey(
return impl_->SetSSLSessionWithKey(ssl, cache_key);
}
+bool SSLSessionCacheOpenSSL::SSLSessionIsInCache(
+ const std::string& cache_key) const {
+ return impl_->SSLSessionIsInCache(cache_key);
+}
+
void SSLSessionCacheOpenSSL::MarkSSLSessionAsGood(SSL* ssl) {
return impl_->MarkSSLSessionAsGood(ssl);
}
diff --git a/chromium/net/socket/ssl_session_cache_openssl.h b/chromium/net/socket/ssl_session_cache_openssl.h
index bbd9659641d..abf5eab78cb 100644
--- a/chromium/net/socket/ssl_session_cache_openssl.h
+++ b/chromium/net/socket/ssl_session_cache_openssl.h
@@ -113,6 +113,9 @@ class NET_EXPORT SSLSessionCacheOpenSSL {
// Return true iff a cached session was associated with the |ssl| connection.
bool SetSSLSessionWithKey(SSL* ssl, const std::string& cache_key);
+ // Return true iff a cached session was associated with the given |cache_key|.
+ bool SSLSessionIsInCache(const std::string& cache_key) const;
+
// Indicates that the SSL session associated with |ssl| is "good" - that is,
// that all associated cryptographic parameters that were negotiated,
// including the peer's certificate, were successfully validated. Because
diff --git a/chromium/net/socket/ssl_session_cache_openssl_unittest.cc b/chromium/net/socket/ssl_session_cache_openssl_unittest.cc
index 22c4fbaeb9c..78bac63ccb2 100644
--- a/chromium/net/socket/ssl_session_cache_openssl_unittest.cc
+++ b/chromium/net/socket/ssl_session_cache_openssl_unittest.cc
@@ -10,6 +10,7 @@
#include "base/logging.h"
#include "base/strings/stringprintf.h"
#include "crypto/openssl_util.h"
+#include "crypto/scoped_openssl_types.h"
#include "testing/gtest/include/gtest/gtest.h"
@@ -19,18 +20,19 @@
// |s| is the target SSL connection handle.
// |session| is non-0 to ask for the creation of a new session. If 0,
// this will set an empty session with no ID instead.
-extern "C" int ssl_get_new_session(SSL* s, int session);
+extern "C" OPENSSL_EXPORT int ssl_get_new_session(SSL* s, int session);
// This is an internal OpenSSL function which is used internally to add
// a new session to the cache. It is normally triggered by a succesful
// connection. However, this unit test does not use the network at all.
-extern "C" void ssl_update_cache(SSL* s, int mode);
+extern "C" OPENSSL_EXPORT void ssl_update_cache(SSL* s, int mode);
namespace net {
namespace {
-typedef crypto::ScopedOpenSSL<SSL, SSL_free> ScopedSSL;
+typedef crypto::ScopedOpenSSL<SSL, SSL_free>::Type ScopedSSL;
+typedef crypto::ScopedOpenSSL<SSL_CTX, SSL_CTX_free>::Type ScopedSSL_CTX;
// Helper class used to associate arbitrary std::string keys with SSL objects.
class SSLKeyHelper {
@@ -73,8 +75,8 @@ class SSLKeyHelper {
// Called when an SSL object is copied through SSL_dup(). This needs to copy
// the value as well.
static int KeyDup(CRYPTO_EX_DATA* to,
- CRYPTO_EX_DATA* from,
- void* from_fd,
+ const CRYPTO_EX_DATA* from,
+ void** from_fd,
int idx,
long argl,
void* argp) {
@@ -142,7 +144,7 @@ class SSLSessionCacheOpenSSLTest : public testing::Test {
static const SSLSessionCacheOpenSSL::Config kDefaultConfig;
protected:
- crypto::ScopedOpenSSL<SSL_CTX, SSL_CTX_free> ctx_;
+ ScopedSSL_CTX ctx_;
// |cache_| must be destroyed before |ctx_| and thus appears after it.
SSLSessionCacheOpenSSL cache_;
};
diff --git a/chromium/net/socket/ssl_socket.h b/chromium/net/socket/ssl_socket.h
index 68d1e4a2bfe..0dc817becdd 100644
--- a/chromium/net/socket/ssl_socket.h
+++ b/chromium/net/socket/ssl_socket.h
@@ -15,7 +15,7 @@ namespace net {
// and server SSL sockets.
class NET_EXPORT SSLSocket : public StreamSocket {
public:
- virtual ~SSLSocket() {}
+ ~SSLSocket() override {}
// Exports data derived from the SSL master-secret (see RFC 5705).
// If |has_context| is false, uses the no-context construction from the
diff --git a/chromium/net/socket/stream_listen_socket.cc b/chromium/net/socket/stream_listen_socket.cc
index fd164a556d5..abb5fbc6b52 100644
--- a/chromium/net/socket/stream_listen_socket.cc
+++ b/chromium/net/socket/stream_listen_socket.cc
@@ -21,6 +21,7 @@
#include "base/memory/ref_counted.h"
#include "base/memory/scoped_ptr.h"
#include "base/posix/eintr_wrapper.h"
+#include "base/profiler/scoped_tracker.h"
#include "base/sys_byteorder.h"
#include "base/threading/platform_thread.h"
#include "build/build_config.h"
@@ -246,6 +247,10 @@ void StreamListenSocket::UnwatchSocket() {
#if defined(OS_WIN)
// MessageLoop watcher callback.
void StreamListenSocket::OnObjectSignaled(HANDLE object) {
+ // TODO(vadimt): Remove ScopedTracker below once crbug.com/418183 is fixed.
+ tracked_objects::ScopedTracker tracking_profile(
+ FROM_HERE_WITH_EXPLICIT_FUNCTION("StreamListenSocket_OnObjectSignaled"));
+
WSANETWORKEVENTS ev;
if (kSocketError == WSAEnumNetworkEvents(socket_, socket_event_, &ev)) {
// TODO
diff --git a/chromium/net/socket/stream_listen_socket.h b/chromium/net/socket/stream_listen_socket.h
index 813d96a22be..f8f9419484d 100644
--- a/chromium/net/socket/stream_listen_socket.h
+++ b/chromium/net/socket/stream_listen_socket.h
@@ -47,7 +47,7 @@ class NET_EXPORT StreamListenSocket
#endif
public:
- virtual ~StreamListenSocket();
+ ~StreamListenSocket() override;
// TODO(erikkay): this delegate should really be split into two parts
// to split up the listener from the connected socket. Perhaps this class
@@ -116,8 +116,8 @@ class NET_EXPORT StreamListenSocket
HANDLE socket_event_;
#elif defined(OS_POSIX)
// Called by MessagePumpLibevent when the socket is ready to do I/O.
- virtual void OnFileCanReadWithoutBlocking(int fd) OVERRIDE;
- virtual void OnFileCanWriteWithoutBlocking(int fd) OVERRIDE;
+ void OnFileCanReadWithoutBlocking(int fd) override;
+ void OnFileCanWriteWithoutBlocking(int fd) override;
WaitState wait_state_;
// The socket's libevent wrapper.
base::MessageLoopForIO::FileDescriptorWatcher watcher_;
diff --git a/chromium/net/socket/stream_socket.h b/chromium/net/socket/stream_socket.h
index 72088103a3d..b41fed8b51b 100644
--- a/chromium/net/socket/stream_socket.h
+++ b/chromium/net/socket/stream_socket.h
@@ -17,7 +17,7 @@ class SSLInfo;
class NET_EXPORT_PRIVATE StreamSocket : public Socket {
public:
- virtual ~StreamSocket() {}
+ ~StreamSocket() override {}
// Called to establish a connection. Returns OK if the connection could be
// established synchronously. Otherwise, ERR_IO_PENDING is returned and the
@@ -75,10 +75,15 @@ class NET_EXPORT_PRIVATE StreamSocket : public Socket {
// Write() methods had been called, not the underlying transport's.
virtual bool WasEverUsed() const = 0;
+ // TODO(jri): Clean up -- remove this method.
// Returns true if the underlying transport socket is using TCP FastOpen.
// TCP FastOpen is an experiment with sending data in the TCP SYN packet.
virtual bool UsingTCPFastOpen() const = 0;
+ // TODO(jri): Clean up -- rename to a more general EnableAutoConnectOnWrite.
+ // Enables use of TCP FastOpen for the underlying transport socket.
+ virtual void EnableTCPFastOpenIfSupported() {}
+
// Returns true if NPN was negotiated during the connection of this socket.
virtual bool WasNpnNegotiated() const = 0;
diff --git a/chromium/net/socket/tcp_client_socket.cc b/chromium/net/socket/tcp_client_socket.cc
index 53d8d1fe259..dcf124bfd6b 100644
--- a/chromium/net/socket/tcp_client_socket.cc
+++ b/chromium/net/socket/tcp_client_socket.cc
@@ -6,6 +6,7 @@
#include "base/callback_helpers.h"
#include "base/logging.h"
+#include "base/profiler/scoped_tracker.h"
#include "net/base/io_buffer.h"
#include "net/base/ip_endpoint.h"
#include "net/base/net_errors.h"
@@ -225,6 +226,10 @@ bool TCPClientSocket::UsingTCPFastOpen() const {
return socket_->UsingTCPFastOpen();
}
+void TCPClientSocket::EnableTCPFastOpenIfSupported() {
+ socket_->EnableTCPFastOpenIfSupported();
+}
+
bool TCPClientSocket::WasNpnNegotiated() const {
return false;
}
@@ -302,6 +307,10 @@ void TCPClientSocket::DidCompleteReadWrite(const CompletionCallback& callback,
if (result > 0)
use_history_.set_was_used_to_convey_data();
+ // TODO(vadimt): Remove ScopedTracker below once crbug.com/418183 is fixed.
+ tracked_objects::ScopedTracker tracking_profile(
+ FROM_HERE_WITH_EXPLICIT_FUNCTION(
+ "TCPClientSocket::DidCompleteReadWrite"));
callback.Run(result);
}
diff --git a/chromium/net/socket/tcp_client_socket.h b/chromium/net/socket/tcp_client_socket.h
index 970da2a026f..0deec2a0c9f 100644
--- a/chromium/net/socket/tcp_client_socket.h
+++ b/chromium/net/socket/tcp_client_socket.h
@@ -32,36 +32,39 @@ class NET_EXPORT TCPClientSocket : public StreamSocket {
TCPClientSocket(scoped_ptr<TCPSocket> connected_socket,
const IPEndPoint& peer_address);
- virtual ~TCPClientSocket();
+ ~TCPClientSocket() override;
// Binds the socket to a local IP address and port.
int Bind(const IPEndPoint& address);
// StreamSocket implementation.
- virtual int Connect(const CompletionCallback& callback) OVERRIDE;
- virtual void Disconnect() OVERRIDE;
- virtual bool IsConnected() const OVERRIDE;
- virtual bool IsConnectedAndIdle() const OVERRIDE;
- virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE;
- virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE;
- virtual const BoundNetLog& NetLog() const OVERRIDE;
- virtual void SetSubresourceSpeculation() OVERRIDE;
- virtual void SetOmniboxSpeculation() OVERRIDE;
- virtual bool WasEverUsed() const OVERRIDE;
- virtual bool UsingTCPFastOpen() const OVERRIDE;
- virtual bool WasNpnNegotiated() const OVERRIDE;
- virtual NextProto GetNegotiatedProtocol() const OVERRIDE;
- virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE;
+ int Connect(const CompletionCallback& callback) override;
+ void Disconnect() override;
+ bool IsConnected() const override;
+ bool IsConnectedAndIdle() const override;
+ int GetPeerAddress(IPEndPoint* address) const override;
+ int GetLocalAddress(IPEndPoint* address) const override;
+ const BoundNetLog& NetLog() const override;
+ void SetSubresourceSpeculation() override;
+ void SetOmniboxSpeculation() override;
+ bool WasEverUsed() const override;
+ bool UsingTCPFastOpen() const override;
+ void EnableTCPFastOpenIfSupported() override;
+ bool WasNpnNegotiated() const override;
+ NextProto GetNegotiatedProtocol() const override;
+ bool GetSSLInfo(SSLInfo* ssl_info) override;
// Socket implementation.
// Multiple outstanding requests are not supported.
// Full duplex mode (reading and writing at the same time) is supported.
- virtual int Read(IOBuffer* buf, int buf_len,
- const CompletionCallback& callback) OVERRIDE;
- virtual int Write(IOBuffer* buf, int buf_len,
- const CompletionCallback& callback) OVERRIDE;
- virtual int SetReceiveBufferSize(int32 size) OVERRIDE;
- virtual int SetSendBufferSize(int32 size) OVERRIDE;
+ int Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override;
+ int Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override;
+ int SetReceiveBufferSize(int32 size) override;
+ int SetSendBufferSize(int32 size) override;
virtual bool SetKeepAlive(bool enable, int delay);
virtual bool SetNoDelay(bool no_delay);
diff --git a/chromium/net/socket/tcp_listen_socket.cc b/chromium/net/socket/tcp_listen_socket.cc
index 223abee2cba..585c41292de 100644
--- a/chromium/net/socket/tcp_listen_socket.cc
+++ b/chromium/net/socket/tcp_listen_socket.cc
@@ -107,7 +107,7 @@ void TCPListenSocket::Accept() {
#if defined(OS_POSIX)
sock->WatchSocket(WAITING_READ);
#endif
- socket_delegate_->DidAccept(this, sock.PassAs<StreamListenSocket>());
+ socket_delegate_->DidAccept(this, sock.Pass());
}
TCPListenSocketFactory::TCPListenSocketFactory(const string& ip, int port)
@@ -119,8 +119,7 @@ TCPListenSocketFactory::~TCPListenSocketFactory() {}
scoped_ptr<StreamListenSocket> TCPListenSocketFactory::CreateAndListen(
StreamListenSocket::Delegate* delegate) const {
- return TCPListenSocket::CreateAndListen(ip_, port_, delegate)
- .PassAs<StreamListenSocket>();
+ return TCPListenSocket::CreateAndListen(ip_, port_, delegate);
}
} // namespace net
diff --git a/chromium/net/socket/tcp_listen_socket.h b/chromium/net/socket/tcp_listen_socket.h
index 54a91de59bb..1702e50e8ed 100644
--- a/chromium/net/socket/tcp_listen_socket.h
+++ b/chromium/net/socket/tcp_listen_socket.h
@@ -17,7 +17,7 @@ namespace net {
// Implements a TCP socket.
class NET_EXPORT TCPListenSocket : public StreamListenSocket {
public:
- virtual ~TCPListenSocket();
+ ~TCPListenSocket() override;
// Listen on port for the specified IP address. Use 127.0.0.1 to only
// accept local connections.
static scoped_ptr<TCPListenSocket> CreateAndListen(
@@ -34,7 +34,7 @@ class NET_EXPORT TCPListenSocket : public StreamListenSocket {
TCPListenSocket(SocketDescriptor s, StreamListenSocket::Delegate* del);
// Implements StreamListenSocket::Accept.
- virtual void Accept() OVERRIDE;
+ void Accept() override;
private:
DISALLOW_COPY_AND_ASSIGN(TCPListenSocket);
@@ -44,11 +44,11 @@ class NET_EXPORT TCPListenSocket : public StreamListenSocket {
class NET_EXPORT TCPListenSocketFactory : public StreamListenSocketFactory {
public:
TCPListenSocketFactory(const std::string& ip, int port);
- virtual ~TCPListenSocketFactory();
+ ~TCPListenSocketFactory() override;
// StreamListenSocketFactory overrides.
- virtual scoped_ptr<StreamListenSocket> CreateAndListen(
- StreamListenSocket::Delegate* delegate) const OVERRIDE;
+ scoped_ptr<StreamListenSocket> CreateAndListen(
+ StreamListenSocket::Delegate* delegate) const override;
private:
const std::string ip_;
diff --git a/chromium/net/socket/tcp_listen_socket_unittest.cc b/chromium/net/socket/tcp_listen_socket_unittest.cc
index 41c41f81fe7..b58642643ca 100644
--- a/chromium/net/socket/tcp_listen_socket_unittest.cc
+++ b/chromium/net/socket/tcp_listen_socket_unittest.cc
@@ -275,13 +275,13 @@ class TCPListenSocketTest : public PlatformTest {
tester_ = NULL;
}
- virtual void SetUp() {
+ void SetUp() override {
PlatformTest::SetUp();
tester_ = new TCPListenSocketTester();
tester_->SetUp();
}
- virtual void TearDown() {
+ void TearDown() override {
PlatformTest::TearDown();
tester_->TearDown();
tester_ = NULL;
diff --git a/chromium/net/socket/tcp_listen_socket_unittest.h b/chromium/net/socket/tcp_listen_socket_unittest.h
index 1bc31a8d1ce..984442afdc0 100644
--- a/chromium/net/socket/tcp_listen_socket_unittest.h
+++ b/chromium/net/socket/tcp_listen_socket_unittest.h
@@ -90,11 +90,12 @@ class TCPListenSocketTester :
virtual bool Send(SocketDescriptor sock, const std::string& str);
// StreamListenSocket::Delegate:
- virtual void DidAccept(StreamListenSocket* server,
- scoped_ptr<StreamListenSocket> connection) OVERRIDE;
- virtual void DidRead(StreamListenSocket* connection, const char* data,
- int len) OVERRIDE;
- virtual void DidClose(StreamListenSocket* sock) OVERRIDE;
+ void DidAccept(StreamListenSocket* server,
+ scoped_ptr<StreamListenSocket> connection) override;
+ void DidRead(StreamListenSocket* connection,
+ const char* data,
+ int len) override;
+ void DidClose(StreamListenSocket* sock) override;
scoped_ptr<base::Thread> thread_;
base::MessageLoopForIO* loop_;
@@ -111,7 +112,7 @@ class TCPListenSocketTester :
private:
friend class base::RefCountedThreadSafe<TCPListenSocketTester>;
- virtual ~TCPListenSocketTester();
+ ~TCPListenSocketTester() override;
virtual scoped_ptr<TCPListenSocket> DoListen();
diff --git a/chromium/net/socket/tcp_server_socket.h b/chromium/net/socket/tcp_server_socket.h
index faff9ad826a..a3919e6845a 100644
--- a/chromium/net/socket/tcp_server_socket.h
+++ b/chromium/net/socket/tcp_server_socket.h
@@ -19,13 +19,13 @@ namespace net {
class NET_EXPORT_PRIVATE TCPServerSocket : public ServerSocket {
public:
TCPServerSocket(NetLog* net_log, const NetLog::Source& source);
- virtual ~TCPServerSocket();
+ ~TCPServerSocket() override;
// net::ServerSocket implementation.
- virtual int Listen(const IPEndPoint& address, int backlog) OVERRIDE;
- virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE;
- virtual int Accept(scoped_ptr<StreamSocket>* socket,
- const CompletionCallback& callback) OVERRIDE;
+ int Listen(const IPEndPoint& address, int backlog) override;
+ int GetLocalAddress(IPEndPoint* address) const override;
+ int Accept(scoped_ptr<StreamSocket>* socket,
+ const CompletionCallback& callback) override;
private:
// Converts |accepted_socket_| and stores the result in
diff --git a/chromium/net/socket/tcp_server_socket_unittest.cc b/chromium/net/socket/tcp_server_socket_unittest.cc
index fd81e550d08..01bae9ff188 100644
--- a/chromium/net/socket/tcp_server_socket_unittest.cc
+++ b/chromium/net/socket/tcp_server_socket_unittest.cc
@@ -215,8 +215,8 @@ TEST_F(TCPServerSocketTest, AcceptIO) {
size_t bytes_written = 0;
while (bytes_written < message.size()) {
- scoped_refptr<net::IOBufferWithSize> write_buffer(
- new net::IOBufferWithSize(message.size() - bytes_written));
+ scoped_refptr<IOBufferWithSize> write_buffer(
+ new IOBufferWithSize(message.size() - bytes_written));
memmove(write_buffer->data(), message.data(), message.size());
TestCompletionCallback write_callback;
@@ -230,8 +230,8 @@ TEST_F(TCPServerSocketTest, AcceptIO) {
size_t bytes_read = 0;
while (bytes_read < message.size()) {
- scoped_refptr<net::IOBufferWithSize> read_buffer(
- new net::IOBufferWithSize(message.size() - bytes_read));
+ scoped_refptr<IOBufferWithSize> read_buffer(
+ new IOBufferWithSize(message.size() - bytes_read));
TestCompletionCallback read_callback;
int read_result = connecting_socket.Read(
read_buffer.get(), read_buffer->size(), read_callback.callback());
diff --git a/chromium/net/socket/tcp_socket.cc b/chromium/net/socket/tcp_socket.cc
deleted file mode 100644
index 63703627134..00000000000
--- a/chromium/net/socket/tcp_socket.cc
+++ /dev/null
@@ -1,82 +0,0 @@
-// Copyright 2013 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#include "net/socket/tcp_socket.h"
-
-#include "base/file_util.h"
-#include "base/files/file_path.h"
-#include "base/memory/ref_counted.h"
-#include "base/threading/worker_pool.h"
-
-namespace net {
-
-namespace {
-
-bool g_tcp_fastopen_enabled = false;
-
-#if defined(OS_LINUX) || defined(OS_ANDROID)
-
-typedef base::RefCountedData<bool> SharedBoolean;
-
-// Checks to see if the system supports TCP FastOpen. Notably, it requires
-// kernel support. Additionally, this checks system configuration to ensure that
-// it's enabled.
-void SystemSupportsTCPFastOpen(scoped_refptr<SharedBoolean> supported) {
- supported->data = false;
- static const base::FilePath::CharType kTCPFastOpenProcFilePath[] =
- "/proc/sys/net/ipv4/tcp_fastopen";
- std::string system_enabled_tcp_fastopen;
- if (!base::ReadFileToString(base::FilePath(kTCPFastOpenProcFilePath),
- &system_enabled_tcp_fastopen)) {
- return;
- }
-
- // As per http://lxr.linux.no/linux+v3.7.7/include/net/tcp.h#L225
- // TFO_CLIENT_ENABLE is the LSB
- if (system_enabled_tcp_fastopen.empty() ||
- (system_enabled_tcp_fastopen[0] & 0x1) == 0) {
- return;
- }
-
- supported->data = true;
-}
-
-void EnableCallback(scoped_refptr<SharedBoolean> supported) {
- g_tcp_fastopen_enabled = supported->data;
-}
-
-// This is asynchronous because it needs to do file IO, and it isn't allowed to
-// do that on the IO thread.
-void EnableFastOpenIfSupported() {
- scoped_refptr<SharedBoolean> supported = new SharedBoolean;
- base::WorkerPool::PostTaskAndReply(
- FROM_HERE,
- base::Bind(SystemSupportsTCPFastOpen, supported),
- base::Bind(EnableCallback, supported),
- false);
-}
-
-#else
-
-void EnableFastOpenIfSupported() {
- g_tcp_fastopen_enabled = false;
-}
-
-#endif
-
-} // namespace
-
-void SetTCPFastOpenEnabled(bool value) {
- if (value) {
- EnableFastOpenIfSupported();
- } else {
- g_tcp_fastopen_enabled = false;
- }
-}
-
-bool IsTCPFastOpenEnabled() {
- return g_tcp_fastopen_enabled;
-}
-
-} // namespace net
diff --git a/chromium/net/socket/tcp_socket.h b/chromium/net/socket/tcp_socket.h
index 8b36fade758..04fd7d2ba6f 100644
--- a/chromium/net/socket/tcp_socket.h
+++ b/chromium/net/socket/tcp_socket.h
@@ -16,13 +16,6 @@
namespace net {
-// Enable/disable experimental TCP FastOpen option.
-// Not thread safe. Must be called during initialization/startup only.
-NET_EXPORT void SetTCPFastOpenEnabled(bool value);
-
-// Check if the TCP FastOpen option is enabled.
-bool IsTCPFastOpenEnabled();
-
// TCPSocket provides a platform-independent interface for TCP sockets.
//
// It is recommended to use TCPClientSocket/TCPServerSocket instead of this
@@ -35,6 +28,17 @@ typedef TCPSocketWin TCPSocket;
typedef TCPSocketLibevent TCPSocket;
#endif
+// Check if TCP FastOpen is supported by the OS.
+bool IsTCPFastOpenSupported();
+
+// Check if TCP FastOpen is enabled by the user.
+bool IsTCPFastOpenUserEnabled();
+
+// Checks if TCP FastOpen is supported by the kernel. Also enables TFO for all
+// connections if indicated by user.
+// Not thread safe. Must be called during initialization/startup only.
+NET_EXPORT void CheckSupportAndMaybeEnableTCPFastOpen(bool user_enabled);
+
} // namespace net
#endif // NET_SOCKET_TCP_SOCKET_H_
diff --git a/chromium/net/socket/tcp_socket_libevent.cc b/chromium/net/socket/tcp_socket_libevent.cc
index 72ae5809e10..cc2376590f5 100644
--- a/chromium/net/socket/tcp_socket_libevent.cc
+++ b/chromium/net/socket/tcp_socket_libevent.cc
@@ -5,18 +5,16 @@
#include "net/socket/tcp_socket.h"
#include <errno.h>
-#include <fcntl.h>
-#include <netdb.h>
-#include <netinet/in.h>
#include <netinet/tcp.h>
#include <sys/socket.h>
-#include "base/callback_helpers.h"
+#include "base/bind.h"
#include "base/logging.h"
#include "base/metrics/histogram.h"
#include "base/metrics/stats_counters.h"
#include "base/posix/eintr_wrapper.h"
-#include "build/build_config.h"
+#include "base/task_runner_util.h"
+#include "base/threading/worker_pool.h"
#include "net/base/address_list.h"
#include "net/base/connection_type_histograms.h"
#include "net/base/io_buffer.h"
@@ -24,6 +22,7 @@
#include "net/base/net_errors.h"
#include "net/base/net_util.h"
#include "net/base/network_change_notifier.h"
+#include "net/socket/socket_libevent.h"
#include "net/socket/socket_net_log_params.h"
// If we don't have a definition for TCPI_OPT_SYN_DATA, create one.
@@ -35,6 +34,14 @@ namespace net {
namespace {
+// True if OS supports TCP FastOpen.
+bool g_tcp_fastopen_supported = false;
+// True if TCP FastOpen is user-enabled for all connections.
+// TODO(jri): Change global variable to param in HttpNetworkSession::Params.
+bool g_tcp_fastopen_user_enabled = false;
+// True if TCP FastOpen connect-with-write has failed at least once.
+bool g_tcp_fastopen_has_failed = false;
+
// SetTCPNoDelay turns on/off buffering in the kernel. By default, TCP sockets
// will wait up to 200ms for more data to complete a packet before transmitting.
// After calling this function, the kernel will not wait. See TCP_NODELAY in
@@ -72,90 +79,61 @@ bool SetTCPKeepAlive(int fd, bool enable, int delay) {
return true;
}
-int MapAcceptError(int os_error) {
- switch (os_error) {
- // If the client aborts the connection before the server calls accept,
- // POSIX specifies accept should fail with ECONNABORTED. The server can
- // ignore the error and just call accept again, so we map the error to
- // ERR_IO_PENDING. See UNIX Network Programming, Vol. 1, 3rd Ed., Sec.
- // 5.11, "Connection Abort before accept Returns".
- case ECONNABORTED:
- return ERR_IO_PENDING;
- default:
- return MapSystemError(os_error);
+#if defined(OS_LINUX) || defined(OS_ANDROID)
+// Checks if the kernel supports TCP FastOpen.
+bool SystemSupportsTCPFastOpen() {
+ const base::FilePath::CharType kTCPFastOpenProcFilePath[] =
+ "/proc/sys/net/ipv4/tcp_fastopen";
+ std::string system_supports_tcp_fastopen;
+ if (!base::ReadFileToString(base::FilePath(kTCPFastOpenProcFilePath),
+ &system_supports_tcp_fastopen)) {
+ return false;
}
+ // The read from /proc should return '1' if TCP FastOpen is enabled in the OS.
+ if (system_supports_tcp_fastopen.empty() ||
+ (system_supports_tcp_fastopen[0] != '1')) {
+ return false;
+ }
+ return true;
}
-int MapConnectError(int os_error) {
- switch (os_error) {
- case EACCES:
- return ERR_NETWORK_ACCESS_DENIED;
- case ETIMEDOUT:
- return ERR_CONNECTION_TIMED_OUT;
- default: {
- int net_error = MapSystemError(os_error);
- if (net_error == ERR_FAILED)
- return ERR_CONNECTION_FAILED; // More specific than ERR_FAILED.
-
- // Give a more specific error when the user is offline.
- if (net_error == ERR_ADDRESS_UNREACHABLE &&
- NetworkChangeNotifier::IsOffline()) {
- return ERR_INTERNET_DISCONNECTED;
- }
- return net_error;
- }
- }
+void RegisterTCPFastOpenIntentAndSupport(bool user_enabled,
+ bool system_supported) {
+ g_tcp_fastopen_supported = system_supported;
+ g_tcp_fastopen_user_enabled = user_enabled;
}
+#endif
} // namespace
//-----------------------------------------------------------------------------
-TCPSocketLibevent::Watcher::Watcher(
- const base::Closure& read_ready_callback,
- const base::Closure& write_ready_callback)
- : read_ready_callback_(read_ready_callback),
- write_ready_callback_(write_ready_callback) {
+bool IsTCPFastOpenSupported() {
+ return g_tcp_fastopen_supported;
}
-TCPSocketLibevent::Watcher::~Watcher() {
+bool IsTCPFastOpenUserEnabled() {
+ return g_tcp_fastopen_user_enabled;
}
-void TCPSocketLibevent::Watcher::OnFileCanReadWithoutBlocking(int /* fd */) {
- if (!read_ready_callback_.is_null())
- read_ready_callback_.Run();
- else
- NOTREACHED();
-}
-
-void TCPSocketLibevent::Watcher::OnFileCanWriteWithoutBlocking(int /* fd */) {
- if (!write_ready_callback_.is_null())
- write_ready_callback_.Run();
- else
- NOTREACHED();
+// This is asynchronous because it needs to do file IO, and it isn't allowed to
+// do that on the IO thread.
+void CheckSupportAndMaybeEnableTCPFastOpen(bool user_enabled) {
+#if defined(OS_LINUX) || defined(OS_ANDROID)
+ base::PostTaskAndReplyWithResult(
+ base::WorkerPool::GetTaskRunner(/*task_is_slow=*/false).get(),
+ FROM_HERE,
+ base::Bind(SystemSupportsTCPFastOpen),
+ base::Bind(RegisterTCPFastOpenIntentAndSupport, user_enabled));
+#endif
}
TCPSocketLibevent::TCPSocketLibevent(NetLog* net_log,
const NetLog::Source& source)
- : socket_(kInvalidSocket),
- accept_watcher_(base::Bind(&TCPSocketLibevent::DidCompleteAccept,
- base::Unretained(this)),
- base::Closure()),
- accept_socket_(NULL),
- accept_address_(NULL),
- read_watcher_(base::Bind(&TCPSocketLibevent::DidCompleteRead,
- base::Unretained(this)),
- base::Closure()),
- write_watcher_(base::Closure(),
- base::Bind(&TCPSocketLibevent::DidCompleteConnectOrWrite,
- base::Unretained(this))),
- read_buf_len_(0),
- write_buf_len_(0),
- use_tcp_fastopen_(IsTCPFastOpenEnabled()),
+ : use_tcp_fastopen_(false),
+ tcp_fastopen_write_attempted_(false),
tcp_fastopen_connected_(false),
- fast_open_status_(FAST_OPEN_STATUS_UNKNOWN),
- waiting_connect_(false),
- connect_os_error_(0),
+ tcp_fastopen_status_(TCP_FASTOPEN_STATUS_UNKNOWN),
logging_multiple_connect_attempts_(false),
net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) {
net_log_.BeginEvent(NetLog::TYPE_SOCKET_ALIVE,
@@ -164,274 +142,173 @@ TCPSocketLibevent::TCPSocketLibevent(NetLog* net_log,
TCPSocketLibevent::~TCPSocketLibevent() {
net_log_.EndEvent(NetLog::TYPE_SOCKET_ALIVE);
- if (tcp_fastopen_connected_) {
- UMA_HISTOGRAM_ENUMERATION("Net.TcpFastOpenSocketConnection",
- fast_open_status_, FAST_OPEN_MAX_VALUE);
- }
Close();
}
int TCPSocketLibevent::Open(AddressFamily family) {
- DCHECK(CalledOnValidThread());
- DCHECK_EQ(socket_, kInvalidSocket);
-
- socket_ = CreatePlatformSocket(ConvertAddressFamily(family), SOCK_STREAM,
- IPPROTO_TCP);
- if (socket_ < 0) {
- PLOG(ERROR) << "CreatePlatformSocket() returned an error";
- return MapSystemError(errno);
- }
-
- if (SetNonBlocking(socket_)) {
- int result = MapSystemError(errno);
- Close();
- return result;
- }
-
- return OK;
+ DCHECK(!socket_);
+ socket_.reset(new SocketLibevent);
+ int rv = socket_->Open(ConvertAddressFamily(family));
+ if (rv != OK)
+ socket_.reset();
+ return rv;
}
-int TCPSocketLibevent::AdoptConnectedSocket(int socket,
+int TCPSocketLibevent::AdoptConnectedSocket(int socket_fd,
const IPEndPoint& peer_address) {
- DCHECK(CalledOnValidThread());
- DCHECK_EQ(socket_, kInvalidSocket);
+ DCHECK(!socket_);
- socket_ = socket;
-
- if (SetNonBlocking(socket_)) {
- int result = MapSystemError(errno);
- Close();
- return result;
+ SockaddrStorage storage;
+ if (!peer_address.ToSockAddr(storage.addr, &storage.addr_len) &&
+ // For backward compatibility, allows the empty address.
+ !(peer_address == IPEndPoint())) {
+ return ERR_ADDRESS_INVALID;
}
- peer_address_.reset(new IPEndPoint(peer_address));
-
- return OK;
+ socket_.reset(new SocketLibevent);
+ int rv = socket_->AdoptConnectedSocket(socket_fd, storage);
+ if (rv != OK)
+ socket_.reset();
+ return rv;
}
int TCPSocketLibevent::Bind(const IPEndPoint& address) {
- DCHECK(CalledOnValidThread());
- DCHECK_NE(socket_, kInvalidSocket);
+ DCHECK(socket_);
SockaddrStorage storage;
if (!address.ToSockAddr(storage.addr, &storage.addr_len))
return ERR_ADDRESS_INVALID;
- int result = bind(socket_, storage.addr, storage.addr_len);
- if (result < 0) {
- PLOG(ERROR) << "bind() returned an error";
- return MapSystemError(errno);
- }
-
- return OK;
+ return socket_->Bind(storage);
}
int TCPSocketLibevent::Listen(int backlog) {
- DCHECK(CalledOnValidThread());
- DCHECK_GT(backlog, 0);
- DCHECK_NE(socket_, kInvalidSocket);
-
- int result = listen(socket_, backlog);
- if (result < 0) {
- PLOG(ERROR) << "listen() returned an error";
- return MapSystemError(errno);
- }
-
- return OK;
+ DCHECK(socket_);
+ return socket_->Listen(backlog);
}
-int TCPSocketLibevent::Accept(scoped_ptr<TCPSocketLibevent>* socket,
+int TCPSocketLibevent::Accept(scoped_ptr<TCPSocketLibevent>* tcp_socket,
IPEndPoint* address,
const CompletionCallback& callback) {
- DCHECK(CalledOnValidThread());
- DCHECK(socket);
- DCHECK(address);
+ DCHECK(tcp_socket);
DCHECK(!callback.is_null());
- DCHECK(accept_callback_.is_null());
+ DCHECK(socket_);
+ DCHECK(!accept_socket_);
net_log_.BeginEvent(NetLog::TYPE_TCP_ACCEPT);
- int result = AcceptInternal(socket, address);
-
- if (result == ERR_IO_PENDING) {
- if (!base::MessageLoopForIO::current()->WatchFileDescriptor(
- socket_, true, base::MessageLoopForIO::WATCH_READ,
- &accept_socket_watcher_, &accept_watcher_)) {
- PLOG(ERROR) << "WatchFileDescriptor failed on read";
- return MapSystemError(errno);
- }
-
- accept_socket_ = socket;
- accept_address_ = address;
- accept_callback_ = callback;
- }
-
- return result;
+ int rv = socket_->Accept(
+ &accept_socket_,
+ base::Bind(&TCPSocketLibevent::AcceptCompleted,
+ base::Unretained(this), tcp_socket, address, callback));
+ if (rv != ERR_IO_PENDING)
+ rv = HandleAcceptCompleted(tcp_socket, address, rv);
+ return rv;
}
int TCPSocketLibevent::Connect(const IPEndPoint& address,
const CompletionCallback& callback) {
- DCHECK(CalledOnValidThread());
- DCHECK_NE(socket_, kInvalidSocket);
- DCHECK(!waiting_connect_);
-
- // |peer_address_| will be non-NULL if Connect() has been called. Unless
- // Close() is called to reset the internal state, a second call to Connect()
- // is not allowed.
- // Please note that we don't allow a second Connect() even if the previous
- // Connect() has failed. Connecting the same |socket_| again after a
- // connection attempt failed results in unspecified behavior according to
- // POSIX.
- DCHECK(!peer_address_);
+ DCHECK(socket_);
if (!logging_multiple_connect_attempts_)
LogConnectBegin(AddressList(address));
- peer_address_.reset(new IPEndPoint(address));
+ net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT,
+ CreateNetLogIPEndPointCallback(&address));
- int rv = DoConnect();
- if (rv == ERR_IO_PENDING) {
- // Synchronous operation not supported.
- DCHECK(!callback.is_null());
- write_callback_ = callback;
- waiting_connect_ = true;
- } else {
- DoConnectComplete(rv);
+ SockaddrStorage storage;
+ if (!address.ToSockAddr(storage.addr, &storage.addr_len))
+ return ERR_ADDRESS_INVALID;
+
+ if (use_tcp_fastopen_) {
+ // With TCP FastOpen, we pretend that the socket is connected.
+ DCHECK(!tcp_fastopen_write_attempted_);
+ socket_->SetPeerAddress(storage);
+ return OK;
}
+ int rv = socket_->Connect(storage,
+ base::Bind(&TCPSocketLibevent::ConnectCompleted,
+ base::Unretained(this), callback));
+ if (rv != ERR_IO_PENDING)
+ rv = HandleConnectCompleted(rv);
return rv;
}
bool TCPSocketLibevent::IsConnected() const {
- DCHECK(CalledOnValidThread());
-
- if (socket_ == kInvalidSocket || waiting_connect_)
+ if (!socket_)
return false;
- if (use_tcp_fastopen_ && !tcp_fastopen_connected_ && peer_address_) {
+ if (use_tcp_fastopen_ && !tcp_fastopen_write_attempted_ &&
+ socket_->HasPeerAddress()) {
// With TCP FastOpen, we pretend that the socket is connected.
// This allows GetPeerAddress() to return peer_address_.
return true;
}
- // Check if connection is alive.
- char c;
- int rv = HANDLE_EINTR(recv(socket_, &c, 1, MSG_PEEK));
- if (rv == 0)
- return false;
- if (rv == -1 && errno != EAGAIN && errno != EWOULDBLOCK)
- return false;
-
- return true;
+ return socket_->IsConnected();
}
bool TCPSocketLibevent::IsConnectedAndIdle() const {
- DCHECK(CalledOnValidThread());
-
- if (socket_ == kInvalidSocket || waiting_connect_)
- return false;
-
// TODO(wtc): should we also handle the TCP FastOpen case here,
// as we do in IsConnected()?
-
- // Check if connection is alive and we haven't received any data
- // unexpectedly.
- char c;
- int rv = HANDLE_EINTR(recv(socket_, &c, 1, MSG_PEEK));
- if (rv >= 0)
- return false;
- if (errno != EAGAIN && errno != EWOULDBLOCK)
- return false;
-
- return true;
+ return socket_ && socket_->IsConnectedAndIdle();
}
int TCPSocketLibevent::Read(IOBuffer* buf,
int buf_len,
const CompletionCallback& callback) {
- DCHECK(CalledOnValidThread());
- DCHECK_NE(kInvalidSocket, socket_);
- DCHECK(!waiting_connect_);
- DCHECK(read_callback_.is_null());
- // Synchronous operation not supported
+ DCHECK(socket_);
DCHECK(!callback.is_null());
- DCHECK_GT(buf_len, 0);
-
- int nread = HANDLE_EINTR(read(socket_, buf->data(), buf_len));
- if (nread >= 0) {
- base::StatsCounter read_bytes("tcp.read_bytes");
- read_bytes.Add(nread);
- net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED, nread,
- buf->data());
- RecordFastOpenStatus();
- return nread;
- }
- if (errno != EAGAIN && errno != EWOULDBLOCK) {
- int net_error = MapSystemError(errno);
- net_log_.AddEvent(NetLog::TYPE_SOCKET_READ_ERROR,
- CreateNetLogSocketErrorCallback(net_error, errno));
- return net_error;
- }
- if (!base::MessageLoopForIO::current()->WatchFileDescriptor(
- socket_, true, base::MessageLoopForIO::WATCH_READ,
- &read_socket_watcher_, &read_watcher_)) {
- DVLOG(1) << "WatchFileDescriptor failed on read, errno " << errno;
- return MapSystemError(errno);
- }
-
- read_buf_ = buf;
- read_buf_len_ = buf_len;
- read_callback_ = callback;
- return ERR_IO_PENDING;
+ int rv = socket_->Read(
+ buf, buf_len,
+ base::Bind(&TCPSocketLibevent::ReadCompleted,
+ // Grab a reference to |buf| so that ReadCompleted() can still
+ // use it when Read() completes, as otherwise, this transfers
+ // ownership of buf to socket.
+ base::Unretained(this), make_scoped_refptr(buf), callback));
+ if (rv != ERR_IO_PENDING)
+ rv = HandleReadCompleted(buf, rv);
+ return rv;
}
int TCPSocketLibevent::Write(IOBuffer* buf,
int buf_len,
const CompletionCallback& callback) {
- DCHECK(CalledOnValidThread());
- DCHECK_NE(kInvalidSocket, socket_);
- DCHECK(!waiting_connect_);
- DCHECK(write_callback_.is_null());
- // Synchronous operation not supported
+ DCHECK(socket_);
DCHECK(!callback.is_null());
- DCHECK_GT(buf_len, 0);
-
- int nwrite = InternalWrite(buf, buf_len);
- if (nwrite >= 0) {
- base::StatsCounter write_bytes("tcp.write_bytes");
- write_bytes.Add(nwrite);
- net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_SENT, nwrite,
- buf->data());
- return nwrite;
- }
- if (errno != EAGAIN && errno != EWOULDBLOCK) {
- int net_error = MapSystemError(errno);
- net_log_.AddEvent(NetLog::TYPE_SOCKET_WRITE_ERROR,
- CreateNetLogSocketErrorCallback(net_error, errno));
- return net_error;
- }
- if (!base::MessageLoopForIO::current()->WatchFileDescriptor(
- socket_, true, base::MessageLoopForIO::WATCH_WRITE,
- &write_socket_watcher_, &write_watcher_)) {
- DVLOG(1) << "WatchFileDescriptor failed on write, errno " << errno;
- return MapSystemError(errno);
+ CompletionCallback write_callback =
+ base::Bind(&TCPSocketLibevent::WriteCompleted,
+ // Grab a reference to |buf| so that WriteCompleted() can still
+ // use it when Write() completes, as otherwise, this transfers
+ // ownership of buf to socket.
+ base::Unretained(this), make_scoped_refptr(buf), callback);
+ int rv;
+
+ if (use_tcp_fastopen_ && !tcp_fastopen_write_attempted_) {
+ rv = TcpFastOpenWrite(buf, buf_len, write_callback);
+ } else {
+ rv = socket_->Write(buf, buf_len, write_callback);
}
- write_buf_ = buf;
- write_buf_len_ = buf_len;
- write_callback_ = callback;
- return ERR_IO_PENDING;
+ if (rv != ERR_IO_PENDING)
+ rv = HandleWriteCompleted(buf, rv);
+ return rv;
}
int TCPSocketLibevent::GetLocalAddress(IPEndPoint* address) const {
- DCHECK(CalledOnValidThread());
DCHECK(address);
+ if (!socket_)
+ return ERR_SOCKET_NOT_CONNECTED;
+
SockaddrStorage storage;
- if (getsockname(socket_, storage.addr, &storage.addr_len) < 0)
- return MapSystemError(errno);
+ int rv = socket_->GetLocalAddress(&storage);
+ if (rv != OK)
+ return rv;
+
if (!address->FromSockAddr(storage.addr, storage.addr_len))
return ERR_ADDRESS_INVALID;
@@ -439,25 +316,34 @@ int TCPSocketLibevent::GetLocalAddress(IPEndPoint* address) const {
}
int TCPSocketLibevent::GetPeerAddress(IPEndPoint* address) const {
- DCHECK(CalledOnValidThread());
DCHECK(address);
+
if (!IsConnected())
return ERR_SOCKET_NOT_CONNECTED;
- *address = *peer_address_;
+
+ SockaddrStorage storage;
+ int rv = socket_->GetPeerAddress(&storage);
+ if (rv != OK)
+ return rv;
+
+ if (!address->FromSockAddr(storage.addr, storage.addr_len))
+ return ERR_ADDRESS_INVALID;
+
return OK;
}
int TCPSocketLibevent::SetDefaultOptionsForServer() {
- DCHECK(CalledOnValidThread());
+ DCHECK(socket_);
return SetAddressReuse(true);
}
void TCPSocketLibevent::SetDefaultOptionsForClient() {
- DCHECK(CalledOnValidThread());
+ DCHECK(socket_);
// This mirrors the behaviour on Windows. See the comment in
// tcp_socket_win.cc after searching for "NODELAY".
- SetTCPNoDelay(socket_, true); // If SetTCPNoDelay fails, we don't care.
+ // If SetTCPNoDelay fails, we don't care.
+ SetTCPNoDelay(socket_->socket_fd(), true);
// TCP keep alive wakes up the radio, which is expensive on mobile. Do not
// enable it there. It's useful to prevent TCP middleboxes from timing out
@@ -473,12 +359,12 @@ void TCPSocketLibevent::SetDefaultOptionsForClient() {
#if !defined(OS_ANDROID) && !defined(OS_IOS)
const int kTCPKeepAliveSeconds = 45;
- SetTCPKeepAlive(socket_, true, kTCPKeepAliveSeconds);
+ SetTCPKeepAlive(socket_->socket_fd(), true, kTCPKeepAliveSeconds);
#endif
}
int TCPSocketLibevent::SetAddressReuse(bool allow) {
- DCHECK(CalledOnValidThread());
+ DCHECK(socket_);
// SO_REUSEADDR is useful for server sockets to bind to a recently unbound
// port. When a socket is closed, the end point changes its state to TIME_WAIT
@@ -494,82 +380,74 @@ int TCPSocketLibevent::SetAddressReuse(bool allow) {
//
// SO_REUSEPORT is provided in MacOS X and iOS.
int boolean_value = allow ? 1 : 0;
- int rv = setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR, &boolean_value,
- sizeof(boolean_value));
+ int rv = setsockopt(socket_->socket_fd(), SOL_SOCKET, SO_REUSEADDR,
+ &boolean_value, sizeof(boolean_value));
if (rv < 0)
return MapSystemError(errno);
return OK;
}
int TCPSocketLibevent::SetReceiveBufferSize(int32 size) {
- DCHECK(CalledOnValidThread());
- int rv = setsockopt(socket_, SOL_SOCKET, SO_RCVBUF,
+ DCHECK(socket_);
+ int rv = setsockopt(socket_->socket_fd(), SOL_SOCKET, SO_RCVBUF,
reinterpret_cast<const char*>(&size), sizeof(size));
return (rv == 0) ? OK : MapSystemError(errno);
}
int TCPSocketLibevent::SetSendBufferSize(int32 size) {
- DCHECK(CalledOnValidThread());
- int rv = setsockopt(socket_, SOL_SOCKET, SO_SNDBUF,
+ DCHECK(socket_);
+ int rv = setsockopt(socket_->socket_fd(), SOL_SOCKET, SO_SNDBUF,
reinterpret_cast<const char*>(&size), sizeof(size));
return (rv == 0) ? OK : MapSystemError(errno);
}
bool TCPSocketLibevent::SetKeepAlive(bool enable, int delay) {
- DCHECK(CalledOnValidThread());
- return SetTCPKeepAlive(socket_, enable, delay);
+ DCHECK(socket_);
+ return SetTCPKeepAlive(socket_->socket_fd(), enable, delay);
}
bool TCPSocketLibevent::SetNoDelay(bool no_delay) {
- DCHECK(CalledOnValidThread());
- return SetTCPNoDelay(socket_, no_delay);
+ DCHECK(socket_);
+ return SetTCPNoDelay(socket_->socket_fd(), no_delay);
}
void TCPSocketLibevent::Close() {
- DCHECK(CalledOnValidThread());
-
- bool ok = accept_socket_watcher_.StopWatchingFileDescriptor();
- DCHECK(ok);
- ok = read_socket_watcher_.StopWatchingFileDescriptor();
- DCHECK(ok);
- ok = write_socket_watcher_.StopWatchingFileDescriptor();
- DCHECK(ok);
-
- if (socket_ != kInvalidSocket) {
- if (IGNORE_EINTR(close(socket_)) < 0)
- PLOG(ERROR) << "close";
- socket_ = kInvalidSocket;
- }
-
- if (!accept_callback_.is_null()) {
- accept_socket_ = NULL;
- accept_address_ = NULL;
- accept_callback_.Reset();
- }
-
- if (!read_callback_.is_null()) {
- read_buf_ = NULL;
- read_buf_len_ = 0;
- read_callback_.Reset();
- }
+ socket_.reset();
- if (!write_callback_.is_null()) {
- write_buf_ = NULL;
- write_buf_len_ = 0;
- write_callback_.Reset();
+ // Record and reset TCP FastOpen state.
+ if (tcp_fastopen_write_attempted_ ||
+ tcp_fastopen_status_ == TCP_FASTOPEN_PREVIOUSLY_FAILED) {
+ UMA_HISTOGRAM_ENUMERATION("Net.TcpFastOpenSocketConnection",
+ tcp_fastopen_status_, TCP_FASTOPEN_MAX_VALUE);
}
-
+ use_tcp_fastopen_ = false;
tcp_fastopen_connected_ = false;
- fast_open_status_ = FAST_OPEN_STATUS_UNKNOWN;
- waiting_connect_ = false;
- peer_address_.reset();
- connect_os_error_ = 0;
+ tcp_fastopen_write_attempted_ = false;
+ tcp_fastopen_status_ = TCP_FASTOPEN_STATUS_UNKNOWN;
}
bool TCPSocketLibevent::UsingTCPFastOpen() const {
return use_tcp_fastopen_;
}
+void TCPSocketLibevent::EnableTCPFastOpenIfSupported() {
+ if (!IsTCPFastOpenSupported())
+ return;
+
+ // Do not enable TCP FastOpen if it had previously failed.
+ // This check conservatively avoids middleboxes that may blackhole
+ // TCP FastOpen SYN+Data packets; on such a failure, subsequent sockets
+ // should not use TCP FastOpen.
+ if(!g_tcp_fastopen_has_failed)
+ use_tcp_fastopen_ = true;
+ else
+ tcp_fastopen_status_ = TCP_FASTOPEN_PREVIOUSLY_FAILED;
+}
+
+bool TCPSocketLibevent::IsValid() const {
+ return socket_ != NULL && socket_->socket_fd() != kInvalidSocket;
+}
+
void TCPSocketLibevent::StartLoggingMultipleConnectAttempts(
const AddressList& addresses) {
if (!logging_multiple_connect_attempts_) {
@@ -589,98 +467,76 @@ void TCPSocketLibevent::EndLoggingMultipleConnectAttempts(int net_error) {
}
}
-int TCPSocketLibevent::AcceptInternal(scoped_ptr<TCPSocketLibevent>* socket,
- IPEndPoint* address) {
- SockaddrStorage storage;
- int new_socket = HANDLE_EINTR(accept(socket_,
- storage.addr,
- &storage.addr_len));
- if (new_socket < 0) {
- int net_error = MapAcceptError(errno);
- if (net_error != ERR_IO_PENDING)
- net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, net_error);
- return net_error;
- }
-
- IPEndPoint ip_end_point;
- if (!ip_end_point.FromSockAddr(storage.addr, storage.addr_len)) {
- NOTREACHED();
- if (IGNORE_EINTR(close(new_socket)) < 0)
- PLOG(ERROR) << "close";
- int net_error = ERR_ADDRESS_INVALID;
- net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, net_error);
- return net_error;
- }
- scoped_ptr<TCPSocketLibevent> tcp_socket(new TCPSocketLibevent(
- net_log_.net_log(), net_log_.source()));
- int adopt_result = tcp_socket->AdoptConnectedSocket(new_socket, ip_end_point);
- if (adopt_result != OK) {
- net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, adopt_result);
- return adopt_result;
- }
- *socket = tcp_socket.Pass();
- *address = ip_end_point;
- net_log_.EndEvent(NetLog::TYPE_TCP_ACCEPT,
- CreateNetLogIPEndPointCallback(&ip_end_point));
- return OK;
+void TCPSocketLibevent::AcceptCompleted(
+ scoped_ptr<TCPSocketLibevent>* tcp_socket,
+ IPEndPoint* address,
+ const CompletionCallback& callback,
+ int rv) {
+ DCHECK_NE(ERR_IO_PENDING, rv);
+ callback.Run(HandleAcceptCompleted(tcp_socket, address, rv));
}
-int TCPSocketLibevent::DoConnect() {
- DCHECK_EQ(0, connect_os_error_);
+int TCPSocketLibevent::HandleAcceptCompleted(
+ scoped_ptr<TCPSocketLibevent>* tcp_socket,
+ IPEndPoint* address,
+ int rv) {
+ if (rv == OK)
+ rv = BuildTcpSocketLibevent(tcp_socket, address);
- net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT,
- CreateNetLogIPEndPointCallback(peer_address_.get()));
-
- // Connect the socket.
- if (!use_tcp_fastopen_) {
- SockaddrStorage storage;
- if (!peer_address_->ToSockAddr(storage.addr, &storage.addr_len))
- return ERR_ADDRESS_INVALID;
-
- if (!HANDLE_EINTR(connect(socket_, storage.addr, storage.addr_len))) {
- // Connected without waiting!
- return OK;
- }
+ if (rv == OK) {
+ net_log_.EndEvent(NetLog::TYPE_TCP_ACCEPT,
+ CreateNetLogIPEndPointCallback(address));
} else {
- // With TCP FastOpen, we pretend that the socket is connected.
- DCHECK(!tcp_fastopen_connected_);
- return OK;
+ net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, rv);
}
- // Check if the connect() failed synchronously.
- connect_os_error_ = errno;
- if (connect_os_error_ != EINPROGRESS)
- return MapConnectError(connect_os_error_);
-
- // Otherwise the connect() is going to complete asynchronously, so watch
- // for its completion.
- if (!base::MessageLoopForIO::current()->WatchFileDescriptor(
- socket_, true, base::MessageLoopForIO::WATCH_WRITE,
- &write_socket_watcher_, &write_watcher_)) {
- connect_os_error_ = errno;
- DVLOG(1) << "WatchFileDescriptor failed: " << connect_os_error_;
- return MapSystemError(connect_os_error_);
+ return rv;
+}
+
+int TCPSocketLibevent::BuildTcpSocketLibevent(
+ scoped_ptr<TCPSocketLibevent>* tcp_socket,
+ IPEndPoint* address) {
+ DCHECK(accept_socket_);
+
+ SockaddrStorage storage;
+ if (accept_socket_->GetPeerAddress(&storage) != OK ||
+ !address->FromSockAddr(storage.addr, storage.addr_len)) {
+ accept_socket_.reset();
+ return ERR_ADDRESS_INVALID;
}
- return ERR_IO_PENDING;
+ tcp_socket->reset(new TCPSocketLibevent(net_log_.net_log(),
+ net_log_.source()));
+ (*tcp_socket)->socket_.reset(accept_socket_.release());
+ return OK;
+}
+
+void TCPSocketLibevent::ConnectCompleted(const CompletionCallback& callback,
+ int rv) const {
+ DCHECK_NE(ERR_IO_PENDING, rv);
+ callback.Run(HandleConnectCompleted(rv));
}
-void TCPSocketLibevent::DoConnectComplete(int result) {
+int TCPSocketLibevent::HandleConnectCompleted(int rv) const {
// Log the end of this attempt (and any OS error it threw).
- int os_error = connect_os_error_;
- connect_os_error_ = 0;
- if (result != OK) {
+ if (rv != OK) {
net_log_.EndEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT,
- NetLog::IntegerCallback("os_error", os_error));
+ NetLog::IntegerCallback("os_error", errno));
} else {
net_log_.EndEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT);
}
+ // Give a more specific error when the user is offline.
+ if (rv == ERR_ADDRESS_UNREACHABLE && NetworkChangeNotifier::IsOffline())
+ rv = ERR_INTERNET_DISCONNECTED;
+
if (!logging_multiple_connect_attempts_)
- LogConnectEnd(result);
+ LogConnectEnd(rv);
+
+ return rv;
}
-void TCPSocketLibevent::LogConnectBegin(const AddressList& addresses) {
+void TCPSocketLibevent::LogConnectBegin(const AddressList& addresses) const {
base::StatsCounter connects("tcp.connect");
connects.Increment();
@@ -688,19 +544,18 @@ void TCPSocketLibevent::LogConnectBegin(const AddressList& addresses) {
addresses.CreateNetLogCallback());
}
-void TCPSocketLibevent::LogConnectEnd(int net_error) {
- if (net_error == OK)
- UpdateConnectionTypeHistograms(CONNECTION_ANY);
-
+void TCPSocketLibevent::LogConnectEnd(int net_error) const {
if (net_error != OK) {
net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_CONNECT, net_error);
return;
}
+ UpdateConnectionTypeHistograms(CONNECTION_ANY);
+
SockaddrStorage storage;
- int rv = getsockname(socket_, storage.addr, &storage.addr_len);
- if (rv != 0) {
- PLOG(ERROR) << "getsockname() [rv: " << rv << "] error: ";
+ int rv = socket_->GetLocalAddress(&storage);
+ if (rv != OK) {
+ PLOG(ERROR) << "GetLocalAddress() [rv: " << rv << "] error: ";
NOTREACHED();
net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_CONNECT, rv);
return;
@@ -711,191 +566,179 @@ void TCPSocketLibevent::LogConnectEnd(int net_error) {
storage.addr_len));
}
-void TCPSocketLibevent::DidCompleteRead() {
- RecordFastOpenStatus();
- if (read_callback_.is_null())
- return;
-
- int bytes_transferred;
- bytes_transferred = HANDLE_EINTR(read(socket_, read_buf_->data(),
- read_buf_len_));
-
- int result;
- if (bytes_transferred >= 0) {
- result = bytes_transferred;
- base::StatsCounter read_bytes("tcp.read_bytes");
- read_bytes.Add(bytes_transferred);
- net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED, result,
- read_buf_->data());
- } else {
- result = MapSystemError(errno);
- if (result != ERR_IO_PENDING) {
- net_log_.AddEvent(NetLog::TYPE_SOCKET_READ_ERROR,
- CreateNetLogSocketErrorCallback(result, errno));
- }
+void TCPSocketLibevent::ReadCompleted(const scoped_refptr<IOBuffer>& buf,
+ const CompletionCallback& callback,
+ int rv) {
+ DCHECK_NE(ERR_IO_PENDING, rv);
+ callback.Run(HandleReadCompleted(buf.get(), rv));
+}
+
+int TCPSocketLibevent::HandleReadCompleted(IOBuffer* buf, int rv) {
+ if (tcp_fastopen_write_attempted_ && !tcp_fastopen_connected_) {
+ // A TCP FastOpen connect-with-write was attempted. This read was a
+ // subsequent read, which either succeeded or failed. If the read
+ // succeeded, the socket is considered connected via TCP FastOpen.
+ // If the read failed, TCP FastOpen is (conservatively) turned off for all
+ // subsequent connections. TCP FastOpen status is recorded in both cases.
+ // TODO (jri): This currently results in conservative behavior, where TCP
+ // FastOpen is turned off on _any_ error. Implement optimizations,
+ // such as turning off TCP FastOpen on more specific errors, and
+ // re-attempting TCP FastOpen after a certain amount of time has passed.
+ if (rv >= 0)
+ tcp_fastopen_connected_ = true;
+ else
+ g_tcp_fastopen_has_failed = true;
+ UpdateTCPFastOpenStatusAfterRead();
+ }
+
+ if (rv < 0) {
+ net_log_.AddEvent(NetLog::TYPE_SOCKET_READ_ERROR,
+ CreateNetLogSocketErrorCallback(rv, errno));
+ return rv;
}
- if (result != ERR_IO_PENDING) {
- read_buf_ = NULL;
- read_buf_len_ = 0;
- bool ok = read_socket_watcher_.StopWatchingFileDescriptor();
- DCHECK(ok);
- base::ResetAndReturn(&read_callback_).Run(result);
- }
+ base::StatsCounter read_bytes("tcp.read_bytes");
+ read_bytes.Add(rv);
+ net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED, rv,
+ buf->data());
+ return rv;
}
-void TCPSocketLibevent::DidCompleteWrite() {
- if (write_callback_.is_null())
- return;
-
- int bytes_transferred;
- bytes_transferred = HANDLE_EINTR(write(socket_, write_buf_->data(),
- write_buf_len_));
-
- int result;
- if (bytes_transferred >= 0) {
- result = bytes_transferred;
- base::StatsCounter write_bytes("tcp.write_bytes");
- write_bytes.Add(bytes_transferred);
- net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_SENT, result,
- write_buf_->data());
- } else {
- result = MapSystemError(errno);
- if (result != ERR_IO_PENDING) {
- net_log_.AddEvent(NetLog::TYPE_SOCKET_WRITE_ERROR,
- CreateNetLogSocketErrorCallback(result, errno));
+void TCPSocketLibevent::WriteCompleted(const scoped_refptr<IOBuffer>& buf,
+ const CompletionCallback& callback,
+ int rv) {
+ DCHECK_NE(ERR_IO_PENDING, rv);
+ callback.Run(HandleWriteCompleted(buf.get(), rv));
+}
+
+int TCPSocketLibevent::HandleWriteCompleted(IOBuffer* buf, int rv) {
+ if (rv < 0) {
+ if (tcp_fastopen_write_attempted_ && !tcp_fastopen_connected_) {
+ // TCP FastOpen connect-with-write was attempted, and the write failed
+ // for unknown reasons. Record status and (conservatively) turn off
+ // TCP FastOpen for all subsequent connections.
+ // TODO (jri): This currently results in conservative behavior, where TCP
+ // FastOpen is turned off on _any_ error. Implement optimizations,
+ // such as turning off TCP FastOpen on more specific errors, and
+ // re-attempting TCP FastOpen after a certain amount of time has passed.
+ tcp_fastopen_status_ = TCP_FASTOPEN_ERROR;
+ g_tcp_fastopen_has_failed = true;
}
+ net_log_.AddEvent(NetLog::TYPE_SOCKET_WRITE_ERROR,
+ CreateNetLogSocketErrorCallback(rv, errno));
+ return rv;
}
- if (result != ERR_IO_PENDING) {
- write_buf_ = NULL;
- write_buf_len_ = 0;
- write_socket_watcher_.StopWatchingFileDescriptor();
- base::ResetAndReturn(&write_callback_).Run(result);
- }
+ base::StatsCounter write_bytes("tcp.write_bytes");
+ write_bytes.Add(rv);
+ net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_SENT, rv,
+ buf->data());
+ return rv;
}
-void TCPSocketLibevent::DidCompleteConnect() {
- DCHECK(waiting_connect_);
-
- // Get the error that connect() completed with.
- int os_error = 0;
- socklen_t len = sizeof(os_error);
- if (getsockopt(socket_, SOL_SOCKET, SO_ERROR, &os_error, &len) < 0)
- os_error = errno;
+int TCPSocketLibevent::TcpFastOpenWrite(
+ IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) {
+ SockaddrStorage storage;
+ int rv = socket_->GetPeerAddress(&storage);
+ if (rv != OK)
+ return rv;
- int result = MapConnectError(os_error);
- connect_os_error_ = os_error;
- if (result != ERR_IO_PENDING) {
- DoConnectComplete(result);
- waiting_connect_ = false;
- write_socket_watcher_.StopWatchingFileDescriptor();
- base::ResetAndReturn(&write_callback_).Run(result);
+ int flags = 0x20000000; // Magic flag to enable TCP_FASTOPEN.
+#if defined(OS_LINUX) || defined(OS_ANDROID)
+ // sendto() will fail with EPIPE when the system doesn't implement TCP
+ // FastOpen, and with EOPNOTSUPP when the system implements TCP FastOpen
+ // but it is disabled. Theoretically these shouldn't happen
+ // since the caller should check for system support on startup, but
+ // users may dynamically disable TCP FastOpen via sysctl.
+ flags |= MSG_NOSIGNAL;
+#endif // defined(OS_LINUX) || defined(OS_ANDROID)
+ rv = HANDLE_EINTR(sendto(socket_->socket_fd(),
+ buf->data(),
+ buf_len,
+ flags,
+ storage.addr,
+ storage.addr_len));
+ tcp_fastopen_write_attempted_ = true;
+
+ if (rv >= 0) {
+ tcp_fastopen_status_ = TCP_FASTOPEN_FAST_CONNECT_RETURN;
+ return rv;
+ }
+
+ DCHECK_NE(EPIPE, errno);
+
+ // If errno == EINPROGRESS, that means the kernel didn't have a cookie
+ // and would block. The kernel is internally doing a connect() though.
+ // Remap EINPROGRESS to EAGAIN so we treat this the same as our other
+ // asynchronous cases. Note that the user buffer has not been copied to
+ // kernel space.
+ if (errno == EINPROGRESS) {
+ rv = ERR_IO_PENDING;
+ } else {
+ rv = MapSystemError(errno);
}
-}
-
-void TCPSocketLibevent::DidCompleteConnectOrWrite() {
- if (waiting_connect_)
- DidCompleteConnect();
- else
- DidCompleteWrite();
-}
-void TCPSocketLibevent::DidCompleteAccept() {
- DCHECK(CalledOnValidThread());
-
- int result = AcceptInternal(accept_socket_, accept_address_);
- if (result != ERR_IO_PENDING) {
- accept_socket_ = NULL;
- accept_address_ = NULL;
- bool ok = accept_socket_watcher_.StopWatchingFileDescriptor();
- DCHECK(ok);
- CompletionCallback callback = accept_callback_;
- accept_callback_.Reset();
- callback.Run(result);
+ if (rv != ERR_IO_PENDING) {
+ // TCP FastOpen connect-with-write was attempted, and the write failed
+ // since TCP FastOpen was not implemented or disabled in the OS.
+ // Record status and turn off TCP FastOpen for all subsequent connections.
+ // TODO (jri): This is almost certainly too conservative, since it blanket
+ // turns off TCP FastOpen on any write error. Two things need to be done
+ // here: (i) record a histogram of write errors; in particular, record
+ // occurrences of EOPNOTSUPP and EPIPE, and (ii) afterwards, consider
+ // turning off TCP FastOpen on more specific errors.
+ tcp_fastopen_status_ = TCP_FASTOPEN_ERROR;
+ g_tcp_fastopen_has_failed = true;
+ return rv;
}
+
+ tcp_fastopen_status_ = TCP_FASTOPEN_SLOW_CONNECT_RETURN;
+ return socket_->WaitForWrite(buf, buf_len, callback);
}
-int TCPSocketLibevent::InternalWrite(IOBuffer* buf, int buf_len) {
- int nwrite;
- if (use_tcp_fastopen_ && !tcp_fastopen_connected_) {
- SockaddrStorage storage;
- if (!peer_address_->ToSockAddr(storage.addr, &storage.addr_len)) {
- // Set errno to EADDRNOTAVAIL so that MapSystemError will map it to
- // ERR_ADDRESS_INVALID later.
- errno = EADDRNOTAVAIL;
- return -1;
- }
+void TCPSocketLibevent::UpdateTCPFastOpenStatusAfterRead() {
+ DCHECK(tcp_fastopen_status_ == TCP_FASTOPEN_FAST_CONNECT_RETURN ||
+ tcp_fastopen_status_ == TCP_FASTOPEN_SLOW_CONNECT_RETURN);
- int flags = 0x20000000; // Magic flag to enable TCP_FASTOPEN.
-#if defined(OS_LINUX)
- // sendto() will fail with EPIPE when the system doesn't support TCP Fast
- // Open. Theoretically that shouldn't happen since the caller should check
- // for system support on startup, but users may dynamically disable TCP Fast
- // Open via sysctl.
- flags |= MSG_NOSIGNAL;
-#endif // defined(OS_LINUX)
- nwrite = HANDLE_EINTR(sendto(socket_,
- buf->data(),
- buf_len,
- flags,
- storage.addr,
- storage.addr_len));
- tcp_fastopen_connected_ = true;
-
- if (nwrite < 0) {
- DCHECK_NE(EPIPE, errno);
-
- // If errno == EINPROGRESS, that means the kernel didn't have a cookie
- // and would block. The kernel is internally doing a connect() though.
- // Remap EINPROGRESS to EAGAIN so we treat this the same as our other
- // asynchronous cases. Note that the user buffer has not been copied to
- // kernel space.
- if (errno == EINPROGRESS) {
- errno = EAGAIN;
- fast_open_status_ = FAST_OPEN_SLOW_CONNECT_RETURN;
- } else {
- fast_open_status_ = FAST_OPEN_ERROR;
- }
- } else {
- fast_open_status_ = FAST_OPEN_FAST_CONNECT_RETURN;
- }
- } else {
- nwrite = HANDLE_EINTR(write(socket_, buf->data(), buf_len));
+ if (tcp_fastopen_write_attempted_ && !tcp_fastopen_connected_) {
+ // TCP FastOpen connect-with-write was attempted, and failed.
+ tcp_fastopen_status_ =
+ (tcp_fastopen_status_ == TCP_FASTOPEN_FAST_CONNECT_RETURN ?
+ TCP_FASTOPEN_FAST_CONNECT_READ_FAILED :
+ TCP_FASTOPEN_SLOW_CONNECT_READ_FAILED);
+ return;
}
- return nwrite;
-}
-void TCPSocketLibevent::RecordFastOpenStatus() {
- if (use_tcp_fastopen_ &&
- (fast_open_status_ == FAST_OPEN_FAST_CONNECT_RETURN ||
- fast_open_status_ == FAST_OPEN_SLOW_CONNECT_RETURN)) {
- DCHECK_NE(FAST_OPEN_STATUS_UNKNOWN, fast_open_status_);
- bool getsockopt_success(false);
- bool server_acked_data(false);
+ bool getsockopt_success = false;
+ bool server_acked_data = false;
#if defined(TCP_INFO)
- // Probe to see the if the socket used TCP Fast Open.
- tcp_info info;
- socklen_t info_len = sizeof(tcp_info);
- getsockopt_success =
- getsockopt(socket_, IPPROTO_TCP, TCP_INFO, &info, &info_len) == 0 &&
- info_len == sizeof(tcp_info);
- server_acked_data = getsockopt_success &&
- (info.tcpi_options & TCPI_OPT_SYN_DATA);
+ // Probe to see the if the socket used TCP FastOpen.
+ tcp_info info;
+ socklen_t info_len = sizeof(tcp_info);
+ getsockopt_success = getsockopt(socket_->socket_fd(), IPPROTO_TCP, TCP_INFO,
+ &info, &info_len) == 0 &&
+ info_len == sizeof(tcp_info);
+ server_acked_data = getsockopt_success &&
+ (info.tcpi_options & TCPI_OPT_SYN_DATA);
#endif
- if (getsockopt_success) {
- if (fast_open_status_ == FAST_OPEN_FAST_CONNECT_RETURN) {
- fast_open_status_ = (server_acked_data ? FAST_OPEN_SYN_DATA_ACK :
- FAST_OPEN_SYN_DATA_NACK);
- } else {
- fast_open_status_ = (server_acked_data ? FAST_OPEN_NO_SYN_DATA_ACK :
- FAST_OPEN_NO_SYN_DATA_NACK);
- }
+
+ if (getsockopt_success) {
+ if (tcp_fastopen_status_ == TCP_FASTOPEN_FAST_CONNECT_RETURN) {
+ tcp_fastopen_status_ = (server_acked_data ?
+ TCP_FASTOPEN_SYN_DATA_ACK :
+ TCP_FASTOPEN_SYN_DATA_NACK);
} else {
- fast_open_status_ = (fast_open_status_ == FAST_OPEN_FAST_CONNECT_RETURN ?
- FAST_OPEN_SYN_DATA_FAILED :
- FAST_OPEN_NO_SYN_DATA_FAILED);
+ tcp_fastopen_status_ = (server_acked_data ?
+ TCP_FASTOPEN_NO_SYN_DATA_ACK :
+ TCP_FASTOPEN_NO_SYN_DATA_NACK);
}
+ } else {
+ tcp_fastopen_status_ =
+ (tcp_fastopen_status_ == TCP_FASTOPEN_FAST_CONNECT_RETURN ?
+ TCP_FASTOPEN_SYN_DATA_GETSOCKOPT_FAILED :
+ TCP_FASTOPEN_NO_SYN_DATA_GETSOCKOPT_FAILED);
}
}
diff --git a/chromium/net/socket/tcp_socket_libevent.h b/chromium/net/socket/tcp_socket_libevent.h
index 9ef235be85d..0958b6d25d0 100644
--- a/chromium/net/socket/tcp_socket_libevent.h
+++ b/chromium/net/socket/tcp_socket_libevent.h
@@ -8,30 +8,27 @@
#include "base/basictypes.h"
#include "base/callback.h"
#include "base/compiler_specific.h"
-#include "base/memory/ref_counted.h"
#include "base/memory/scoped_ptr.h"
-#include "base/message_loop/message_loop.h"
-#include "base/threading/non_thread_safe.h"
#include "net/base/address_family.h"
#include "net/base/completion_callback.h"
#include "net/base/net_export.h"
#include "net/base/net_log.h"
-#include "net/socket/socket_descriptor.h"
namespace net {
class AddressList;
class IOBuffer;
class IPEndPoint;
+class SocketLibevent;
-class NET_EXPORT TCPSocketLibevent : public base::NonThreadSafe {
+class NET_EXPORT TCPSocketLibevent {
public:
TCPSocketLibevent(NetLog* net_log, const NetLog::Source& source);
virtual ~TCPSocketLibevent();
int Open(AddressFamily family);
- // Takes ownership of |socket|.
- int AdoptConnectedSocket(int socket, const IPEndPoint& peer_address);
+ // Takes ownership of |socket_fd|.
+ int AdoptConnectedSocket(int socket_fd, const IPEndPoint& peer_address);
int Bind(const IPEndPoint& address);
@@ -68,8 +65,11 @@ class NET_EXPORT TCPSocketLibevent : public base::NonThreadSafe {
void Close();
+ // Setter/Getter methods for TCP FastOpen socket option.
bool UsingTCPFastOpen() const;
- bool IsValid() const { return socket_ != kInvalidSocket; }
+ void EnableTCPFastOpenIfSupported();
+
+ bool IsValid() const;
// Marks the start/end of a series of connect attempts for logging purpose.
//
@@ -87,141 +87,120 @@ class NET_EXPORT TCPSocketLibevent : public base::NonThreadSafe {
const BoundNetLog& net_log() const { return net_log_; }
private:
- // States that a fast open socket attempt can result in.
- enum FastOpenStatus {
- FAST_OPEN_STATUS_UNKNOWN,
+ // States that using a socket with TCP FastOpen can lead to.
+ enum TCPFastOpenStatus {
+ TCP_FASTOPEN_STATUS_UNKNOWN,
- // The initial fast open connect attempted returned synchronously,
+ // The initial FastOpen connect attempted returned synchronously,
// indicating that we had and sent a cookie along with the initial data.
- FAST_OPEN_FAST_CONNECT_RETURN,
+ TCP_FASTOPEN_FAST_CONNECT_RETURN,
- // The initial fast open connect attempted returned asynchronously,
+ // The initial FastOpen connect attempted returned asynchronously,
// indicating that we did not have a cookie for the server.
- FAST_OPEN_SLOW_CONNECT_RETURN,
+ TCP_FASTOPEN_SLOW_CONNECT_RETURN,
// Some other error occurred on connection, so we couldn't tell if
- // fast open would have worked.
- FAST_OPEN_ERROR,
+ // FastOpen would have worked.
+ TCP_FASTOPEN_ERROR,
- // An attempt to do a fast open succeeded immediately
- // (FAST_OPEN_FAST_CONNECT_RETURN) and we later confirmed that the server
+ // An attempt to do a FastOpen succeeded immediately
+ // (TCP_FASTOPEN_FAST_CONNECT_RETURN) and we later confirmed that the server
// had acked the data we sent.
- FAST_OPEN_SYN_DATA_ACK,
+ TCP_FASTOPEN_SYN_DATA_ACK,
- // An attempt to do a fast open succeeded immediately
- // (FAST_OPEN_FAST_CONNECT_RETURN) and we later confirmed that the server
+ // An attempt to do a FastOpen succeeded immediately
+ // (TCP_FASTOPEN_FAST_CONNECT_RETURN) and we later confirmed that the server
// had nacked the data we sent.
- FAST_OPEN_SYN_DATA_NACK,
+ TCP_FASTOPEN_SYN_DATA_NACK,
- // An attempt to do a fast open succeeded immediately
- // (FAST_OPEN_FAST_CONNECT_RETURN) and our probe to determine if the
- // socket was using fast open failed.
- FAST_OPEN_SYN_DATA_FAILED,
+ // An attempt to do a FastOpen succeeded immediately
+ // (TCP_FASTOPEN_FAST_CONNECT_RETURN) and our probe to determine if the
+ // socket was using FastOpen failed.
+ TCP_FASTOPEN_SYN_DATA_GETSOCKOPT_FAILED,
- // An attempt to do a fast open failed (FAST_OPEN_SLOW_CONNECT_RETURN)
+ // An attempt to do a FastOpen failed (TCP_FASTOPEN_SLOW_CONNECT_RETURN)
// and we later confirmed that the server had acked initial data. This
// should never happen (we didn't send data, so it shouldn't have
// been acked).
- FAST_OPEN_NO_SYN_DATA_ACK,
+ TCP_FASTOPEN_NO_SYN_DATA_ACK,
- // An attempt to do a fast open failed (FAST_OPEN_SLOW_CONNECT_RETURN)
+ // An attempt to do a FastOpen failed (TCP_FASTOPEN_SLOW_CONNECT_RETURN)
// and we later discovered that the server had nacked initial data. This
- // is the expected case results for FAST_OPEN_SLOW_CONNECT_RETURN.
- FAST_OPEN_NO_SYN_DATA_NACK,
+ // is the expected case results for TCP_FASTOPEN_SLOW_CONNECT_RETURN.
+ TCP_FASTOPEN_NO_SYN_DATA_NACK,
- // An attempt to do a fast open failed (FAST_OPEN_SLOW_CONNECT_RETURN)
+ // An attempt to do a FastOpen failed (TCP_FASTOPEN_SLOW_CONNECT_RETURN)
// and our later probe for ack/nack state failed.
- FAST_OPEN_NO_SYN_DATA_FAILED,
-
- FAST_OPEN_MAX_VALUE
- };
-
- // Watcher simply forwards notifications to Closure objects set via the
- // constructor.
- class Watcher: public base::MessageLoopForIO::Watcher {
- public:
- Watcher(const base::Closure& read_ready_callback,
- const base::Closure& write_ready_callback);
- virtual ~Watcher();
-
- // base::MessageLoopForIO::Watcher methods.
- virtual void OnFileCanReadWithoutBlocking(int fd) OVERRIDE;
- virtual void OnFileCanWriteWithoutBlocking(int fd) OVERRIDE;
-
- private:
- base::Closure read_ready_callback_;
- base::Closure write_ready_callback_;
-
- DISALLOW_COPY_AND_ASSIGN(Watcher);
+ TCP_FASTOPEN_NO_SYN_DATA_GETSOCKOPT_FAILED,
+
+ // The initial FastOpen connect+write succeeded immediately
+ // (TCP_FASTOPEN_FAST_CONNECT_RETURN) and a subsequent attempt to read from
+ // the connection failed.
+ TCP_FASTOPEN_FAST_CONNECT_READ_FAILED,
+
+ // The initial FastOpen connect+write failed
+ // (TCP_FASTOPEN_SLOW_CONNECT_RETURN)
+ // and a subsequent attempt to read from the connection failed.
+ TCP_FASTOPEN_SLOW_CONNECT_READ_FAILED,
+
+ // We didn't try FastOpen because it had failed in the past
+ // (g_tcp_fastopen_has_failed was true.)
+ // NOTE: This status is currently registered before a connect/write call
+ // is attempted, and may capture some cases where the status is registered
+ // but no connect is subsequently attempted.
+ // TODO(jri): The expectation is that such cases are not the common case
+ // with TCP FastOpen for SSL sockets however. Change code to be more
+ // accurate when TCP FastOpen is used for more than just SSL sockets.
+ TCP_FASTOPEN_PREVIOUSLY_FAILED,
+
+ TCP_FASTOPEN_MAX_VALUE
};
- int AcceptInternal(scoped_ptr<TCPSocketLibevent>* socket,
- IPEndPoint* address);
-
- int DoConnect();
- void DoConnectComplete(int result);
-
- void LogConnectBegin(const AddressList& addresses);
- void LogConnectEnd(int net_error);
-
- void DidCompleteRead();
- void DidCompleteWrite();
- void DidCompleteConnect();
- void DidCompleteConnectOrWrite();
- void DidCompleteAccept();
-
- // Internal function to write to a socket. Returns an OS error.
- int InternalWrite(IOBuffer* buf, int buf_len);
-
- // Called when the socket is known to be in a connected state.
- void RecordFastOpenStatus();
-
- int socket_;
-
- base::MessageLoopForIO::FileDescriptorWatcher accept_socket_watcher_;
- Watcher accept_watcher_;
-
- scoped_ptr<TCPSocketLibevent>* accept_socket_;
- IPEndPoint* accept_address_;
- CompletionCallback accept_callback_;
-
- // The socket's libevent wrappers for reads and writes.
- base::MessageLoopForIO::FileDescriptorWatcher read_socket_watcher_;
- base::MessageLoopForIO::FileDescriptorWatcher write_socket_watcher_;
-
- // The corresponding watchers for reads and writes.
- Watcher read_watcher_;
- Watcher write_watcher_;
-
- // The buffer used for reads.
- scoped_refptr<IOBuffer> read_buf_;
- int read_buf_len_;
-
- // The buffer used for writes.
- scoped_refptr<IOBuffer> write_buf_;
- int write_buf_len_;
-
- // External callback; called when read is complete.
- CompletionCallback read_callback_;
-
- // External callback; called when write or connect is complete.
- CompletionCallback write_callback_;
+ void AcceptCompleted(scoped_ptr<TCPSocketLibevent>* tcp_socket,
+ IPEndPoint* address,
+ const CompletionCallback& callback,
+ int rv);
+ int HandleAcceptCompleted(scoped_ptr<TCPSocketLibevent>* tcp_socket,
+ IPEndPoint* address,
+ int rv);
+ int BuildTcpSocketLibevent(scoped_ptr<TCPSocketLibevent>* tcp_socket,
+ IPEndPoint* address);
+
+ void ConnectCompleted(const CompletionCallback& callback, int rv) const;
+ int HandleConnectCompleted(int rv) const;
+ void LogConnectBegin(const AddressList& addresses) const;
+ void LogConnectEnd(int net_error) const;
+
+ void ReadCompleted(const scoped_refptr<IOBuffer>& buf,
+ const CompletionCallback& callback,
+ int rv);
+ int HandleReadCompleted(IOBuffer* buf, int rv);
+
+ void WriteCompleted(const scoped_refptr<IOBuffer>& buf,
+ const CompletionCallback& callback,
+ int rv);
+ int HandleWriteCompleted(IOBuffer* buf, int rv);
+ int TcpFastOpenWrite(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback);
+
+ // Called after the first read completes on a TCP FastOpen socket.
+ void UpdateTCPFastOpenStatusAfterRead();
+
+ scoped_ptr<SocketLibevent> socket_;
+ scoped_ptr<SocketLibevent> accept_socket_;
// Enables experimental TCP FastOpen option.
- const bool use_tcp_fastopen_;
+ bool use_tcp_fastopen_;
+
+ // True when TCP FastOpen is in use and we have attempted the
+ // connect with write.
+ bool tcp_fastopen_write_attempted_;
// True when TCP FastOpen is in use and we have done the connect.
bool tcp_fastopen_connected_;
- FastOpenStatus fast_open_status_;
-
- // A connect operation is pending. In this case, |write_callback_| needs to be
- // called when connect is complete.
- bool waiting_connect_;
-
- scoped_ptr<IPEndPoint> peer_address_;
- // The OS error that a connect attempt last completed with.
- int connect_os_error_;
+ TCPFastOpenStatus tcp_fastopen_status_;
bool logging_multiple_connect_attempts_;
diff --git a/chromium/net/socket/tcp_socket_win.cc b/chromium/net/socket/tcp_socket_win.cc
index 88db36fd41c..d5565ad669c 100644
--- a/chromium/net/socket/tcp_socket_win.cc
+++ b/chromium/net/socket/tcp_socket_win.cc
@@ -2,6 +2,7 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
+#include "net/socket/tcp_socket.h"
#include "net/socket/tcp_socket_win.h"
#include <mstcpip.h>
@@ -9,6 +10,7 @@
#include "base/callback_helpers.h"
#include "base/logging.h"
#include "base/metrics/stats_counters.h"
+#include "base/profiler/scoped_tracker.h"
#include "base/win/windows_version.h"
#include "net/base/address_list.h"
#include "net/base/connection_type_histograms.h"
@@ -123,6 +125,12 @@ int MapConnectError(int os_error) {
//-----------------------------------------------------------------------------
+// Nothing to do for Windows since it doesn't support TCP FastOpen.
+// TODO(jri): Remove these along with the corresponding global variables.
+bool IsTCPFastOpenSupported() { return false; }
+bool IsTCPFastOpenUserEnabled() { return false; }
+void CheckSupportAndMaybeEnableTCPFastOpen(bool user_enabled) {}
+
// This class encapsulates all the state that has to be preserved as long as
// there is a network IO operation in progress. If the owner TCPSocketWin is
// destroyed while an operation is in progress, the Core is detached and it
@@ -238,6 +246,11 @@ void TCPSocketWin::Core::WatchForWrite() {
}
void TCPSocketWin::Core::ReadDelegate::OnObjectSignaled(HANDLE object) {
+ // TODO(vadimt): Remove ScopedTracker below once crbug.com/418183 is fixed.
+ tracked_objects::ScopedTracker tracking_profile(
+ FROM_HERE_WITH_EXPLICIT_FUNCTION(
+ "TCPSocketWin_Core_ReadDelegate_OnObjectSignaled"));
+
DCHECK_EQ(object, core_->read_overlapped_.hEvent);
if (core_->socket_) {
if (core_->socket_->waiting_connect_)
@@ -251,6 +264,11 @@ void TCPSocketWin::Core::ReadDelegate::OnObjectSignaled(HANDLE object) {
void TCPSocketWin::Core::WriteDelegate::OnObjectSignaled(
HANDLE object) {
+ // TODO(vadimt): Remove ScopedTracker below once crbug.com/418183 is fixed.
+ tracked_objects::ScopedTracker tracking_profile(
+ FROM_HERE_WITH_EXPLICIT_FUNCTION(
+ "TCPSocketWin_Core_WriteDelegate_OnObjectSignaled"));
+
DCHECK_EQ(object, core_->write_overlapped_.hEvent);
if (core_->socket_)
core_->socket_->DidCompleteWrite();
@@ -485,7 +503,7 @@ int TCPSocketWin::Read(IOBuffer* buf,
DCHECK(CalledOnValidThread());
DCHECK_NE(socket_, INVALID_SOCKET);
DCHECK(!waiting_read_);
- DCHECK(read_callback_.is_null());
+ CHECK(read_callback_.is_null());
DCHECK(!core_->read_iobuffer_);
return DoRead(buf, buf_len, callback);
@@ -497,7 +515,7 @@ int TCPSocketWin::Write(IOBuffer* buf,
DCHECK(CalledOnValidThread());
DCHECK_NE(socket_, INVALID_SOCKET);
DCHECK(!waiting_write_);
- DCHECK(write_callback_.is_null());
+ CHECK(write_callback_.is_null());
DCHECK_GT(buf_len, 0);
DCHECK(!core_->write_iobuffer_);
@@ -694,11 +712,6 @@ void TCPSocketWin::Close() {
connect_os_error_ = 0;
}
-bool TCPSocketWin::UsingTCPFastOpen() const {
- // Not supported on windows.
- return false;
-}
-
void TCPSocketWin::StartLoggingMultipleConnectAttempts(
const AddressList& addresses) {
if (!logging_multiple_connect_attempts_) {
@@ -753,6 +766,10 @@ int TCPSocketWin::AcceptInternal(scoped_ptr<TCPSocketWin>* socket,
}
void TCPSocketWin::OnObjectSignaled(HANDLE object) {
+ // TODO(vadimt): Remove ScopedTracker below once crbug.com/418183 is fixed.
+ tracked_objects::ScopedTracker tracking_profile(
+ FROM_HERE_WITH_EXPLICIT_FUNCTION("TCPSocketWin_OnObjectSignaled"));
+
WSANETWORKEVENTS ev;
if (WSAEnumNetworkEvents(socket_, accept_event_, &ev) == SOCKET_ERROR) {
PLOG(ERROR) << "WSAEnumNetworkEvents()";
@@ -1017,8 +1034,10 @@ void TCPSocketWin::DidSignalRead() {
core_->read_buffer_length_ = 0;
DCHECK_NE(rv, ERR_IO_PENDING);
+ // TODO(vadimt): Remove ScopedTracker below once crbug.com/418183 is fixed.
+ tracked_objects::ScopedTracker tracking_profile(
+ FROM_HERE_WITH_EXPLICIT_FUNCTION("TCPSocketWin::DidSignalRead"));
base::ResetAndReturn(&read_callback_).Run(rv);
}
} // namespace net
-
diff --git a/chromium/net/socket/tcp_socket_win.h b/chromium/net/socket/tcp_socket_win.h
index 2ddd538f211..80174adea9c 100644
--- a/chromium/net/socket/tcp_socket_win.h
+++ b/chromium/net/socket/tcp_socket_win.h
@@ -75,7 +75,11 @@ class NET_EXPORT TCPSocketWin : NON_EXPORTED_BASE(public base::NonThreadSafe),
void Close();
- bool UsingTCPFastOpen() const;
+ // Setter/Getter methods for TCP FastOpen socket option.
+ // NOOPs since TCP FastOpen is not implemented in Windows.
+ bool UsingTCPFastOpen() const { return false; }
+ void EnableTCPFastOpenIfSupported() {}
+
bool IsValid() const { return socket_ != INVALID_SOCKET; }
// Marks the start/end of a series of connect attempts for logging purpose.
@@ -97,7 +101,7 @@ class NET_EXPORT TCPSocketWin : NON_EXPORTED_BASE(public base::NonThreadSafe),
class Core;
// base::ObjectWatcher::Delegate implementation.
- virtual void OnObjectSignaled(HANDLE object) OVERRIDE;
+ virtual void OnObjectSignaled(HANDLE object) override;
int AcceptInternal(scoped_ptr<TCPSocketWin>* socket,
IPEndPoint* address);
@@ -152,4 +156,3 @@ class NET_EXPORT TCPSocketWin : NON_EXPORTED_BASE(public base::NonThreadSafe),
} // namespace net
#endif // NET_SOCKET_TCP_SOCKET_WIN_H_
-
diff --git a/chromium/net/socket/transport_client_socket_pool.cc b/chromium/net/socket/transport_client_socket_pool.cc
index dc481fa4b6d..06202f13d81 100644
--- a/chromium/net/socket/transport_client_socket_pool.cc
+++ b/chromium/net/socket/transport_client_socket_pool.cc
@@ -31,7 +31,7 @@ namespace net {
// TODO(willchan): Base this off RTT instead of statically setting it. Note we
// choose a timeout that is different from the backup connect job timer so they
// don't synchronize.
-const int TransportConnectJob::kIPv6FallbackTimerInMs = 300;
+const int TransportConnectJobHelper::kIPv6FallbackTimerInMs = 300;
namespace {
@@ -60,12 +60,21 @@ TransportSocketParams::TransportSocketParams(
const HostPortPair& host_port_pair,
bool disable_resolver_cache,
bool ignore_limits,
- const OnHostResolutionCallback& host_resolution_callback)
+ const OnHostResolutionCallback& host_resolution_callback,
+ CombineConnectAndWritePolicy combine_connect_and_write_if_supported)
: destination_(host_port_pair),
ignore_limits_(ignore_limits),
- host_resolution_callback_(host_resolution_callback) {
+ host_resolution_callback_(host_resolution_callback),
+ combine_connect_and_write_(combine_connect_and_write_if_supported) {
if (disable_resolver_cache)
destination_.set_allow_cached_response(false);
+ // combine_connect_and_write currently translates to TCP FastOpen.
+ // Enable TCP FastOpen if user wants it.
+ if (combine_connect_and_write_ == COMBINE_CONNECT_AND_WRITE_DEFAULT) {
+ IsTCPFastOpenUserEnabled() ? combine_connect_and_write_ =
+ COMBINE_CONNECT_AND_WRITE_DESIRED :
+ COMBINE_CONNECT_AND_WRITE_PROHIBITED;
+ }
}
TransportSocketParams::~TransportSocketParams() {}
@@ -81,6 +90,107 @@ TransportSocketParams::~TransportSocketParams() {}
// See comment #12 at http://crbug.com/23364 for specifics.
static const int kTransportConnectJobTimeoutInSeconds = 240; // 4 minutes.
+TransportConnectJobHelper::TransportConnectJobHelper(
+ const scoped_refptr<TransportSocketParams>& params,
+ ClientSocketFactory* client_socket_factory,
+ HostResolver* host_resolver,
+ LoadTimingInfo::ConnectTiming* connect_timing)
+ : params_(params),
+ client_socket_factory_(client_socket_factory),
+ resolver_(host_resolver),
+ next_state_(STATE_NONE),
+ connect_timing_(connect_timing) {}
+
+TransportConnectJobHelper::~TransportConnectJobHelper() {}
+
+int TransportConnectJobHelper::DoResolveHost(RequestPriority priority,
+ const BoundNetLog& net_log) {
+ next_state_ = STATE_RESOLVE_HOST_COMPLETE;
+ connect_timing_->dns_start = base::TimeTicks::Now();
+
+ return resolver_.Resolve(
+ params_->destination(), priority, &addresses_, on_io_complete_, net_log);
+}
+
+int TransportConnectJobHelper::DoResolveHostComplete(
+ int result,
+ const BoundNetLog& net_log) {
+ connect_timing_->dns_end = base::TimeTicks::Now();
+ // Overwrite connection start time, since for connections that do not go
+ // through proxies, |connect_start| should not include dns lookup time.
+ connect_timing_->connect_start = connect_timing_->dns_end;
+
+ if (result == OK) {
+ // Invoke callback, and abort if it fails.
+ if (!params_->host_resolution_callback().is_null())
+ result = params_->host_resolution_callback().Run(addresses_, net_log);
+
+ if (result == OK)
+ next_state_ = STATE_TRANSPORT_CONNECT;
+ }
+ return result;
+}
+
+base::TimeDelta TransportConnectJobHelper::HistogramDuration(
+ ConnectionLatencyHistogram race_result) {
+ DCHECK(!connect_timing_->connect_start.is_null());
+ DCHECK(!connect_timing_->dns_start.is_null());
+ base::TimeTicks now = base::TimeTicks::Now();
+ base::TimeDelta total_duration = now - connect_timing_->dns_start;
+ UMA_HISTOGRAM_CUSTOM_TIMES("Net.DNS_Resolution_And_TCP_Connection_Latency2",
+ total_duration,
+ base::TimeDelta::FromMilliseconds(1),
+ base::TimeDelta::FromMinutes(10),
+ 100);
+
+ base::TimeDelta connect_duration = now - connect_timing_->connect_start;
+ UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency",
+ connect_duration,
+ base::TimeDelta::FromMilliseconds(1),
+ base::TimeDelta::FromMinutes(10),
+ 100);
+
+ switch (race_result) {
+ case CONNECTION_LATENCY_IPV4_WINS_RACE:
+ UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency_IPv4_Wins_Race",
+ connect_duration,
+ base::TimeDelta::FromMilliseconds(1),
+ base::TimeDelta::FromMinutes(10),
+ 100);
+ break;
+
+ case CONNECTION_LATENCY_IPV4_NO_RACE:
+ UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency_IPv4_No_Race",
+ connect_duration,
+ base::TimeDelta::FromMilliseconds(1),
+ base::TimeDelta::FromMinutes(10),
+ 100);
+ break;
+
+ case CONNECTION_LATENCY_IPV6_RACEABLE:
+ UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency_IPv6_Raceable",
+ connect_duration,
+ base::TimeDelta::FromMilliseconds(1),
+ base::TimeDelta::FromMinutes(10),
+ 100);
+ break;
+
+ case CONNECTION_LATENCY_IPV6_SOLO:
+ UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency_IPv6_Solo",
+ connect_duration,
+ base::TimeDelta::FromMilliseconds(1),
+ base::TimeDelta::FromMinutes(10),
+ 100);
+ break;
+
+ default:
+ NOTREACHED();
+ break;
+ }
+
+ return connect_duration;
+}
+
TransportConnectJob::TransportConnectJob(
const std::string& group_name,
RequestPriority priority,
@@ -92,11 +202,9 @@ TransportConnectJob::TransportConnectJob(
NetLog* net_log)
: ConnectJob(group_name, timeout_duration, priority, delegate,
BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)),
- params_(params),
- client_socket_factory_(client_socket_factory),
- resolver_(host_resolver),
- next_state_(STATE_NONE),
+ helper_(params, client_socket_factory, host_resolver, &connect_timing_),
interval_between_connects_(CONNECT_INTERVAL_GT_20MS) {
+ helper_.SetOnIOComplete(this);
}
TransportConnectJob::~TransportConnectJob() {
@@ -105,14 +213,14 @@ TransportConnectJob::~TransportConnectJob() {
}
LoadState TransportConnectJob::GetLoadState() const {
- switch (next_state_) {
- case STATE_RESOLVE_HOST:
- case STATE_RESOLVE_HOST_COMPLETE:
+ switch (helper_.next_state()) {
+ case TransportConnectJobHelper::STATE_RESOLVE_HOST:
+ case TransportConnectJobHelper::STATE_RESOLVE_HOST_COMPLETE:
return LOAD_STATE_RESOLVING_HOST;
- case STATE_TRANSPORT_CONNECT:
- case STATE_TRANSPORT_CONNECT_COMPLETE:
+ case TransportConnectJobHelper::STATE_TRANSPORT_CONNECT:
+ case TransportConnectJobHelper::STATE_TRANSPORT_CONNECT_COMPLETE:
return LOAD_STATE_CONNECTING;
- case STATE_NONE:
+ case TransportConnectJobHelper::STATE_NONE:
return LOAD_STATE_IDLE;
}
NOTREACHED();
@@ -129,71 +237,12 @@ void TransportConnectJob::MakeAddressListStartWithIPv4(AddressList* list) {
}
}
-void TransportConnectJob::OnIOComplete(int result) {
- int rv = DoLoop(result);
- if (rv != ERR_IO_PENDING)
- NotifyDelegateOfCompletion(rv); // Deletes |this|
-}
-
-int TransportConnectJob::DoLoop(int result) {
- DCHECK_NE(next_state_, STATE_NONE);
-
- int rv = result;
- do {
- State state = next_state_;
- next_state_ = STATE_NONE;
- switch (state) {
- case STATE_RESOLVE_HOST:
- DCHECK_EQ(OK, rv);
- rv = DoResolveHost();
- break;
- case STATE_RESOLVE_HOST_COMPLETE:
- rv = DoResolveHostComplete(rv);
- break;
- case STATE_TRANSPORT_CONNECT:
- DCHECK_EQ(OK, rv);
- rv = DoTransportConnect();
- break;
- case STATE_TRANSPORT_CONNECT_COMPLETE:
- rv = DoTransportConnectComplete(rv);
- break;
- default:
- NOTREACHED();
- rv = ERR_FAILED;
- break;
- }
- } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
-
- return rv;
-}
-
int TransportConnectJob::DoResolveHost() {
- next_state_ = STATE_RESOLVE_HOST_COMPLETE;
- connect_timing_.dns_start = base::TimeTicks::Now();
-
- return resolver_.Resolve(
- params_->destination(),
- priority(),
- &addresses_,
- base::Bind(&TransportConnectJob::OnIOComplete, base::Unretained(this)),
- net_log());
+ return helper_.DoResolveHost(priority(), net_log());
}
int TransportConnectJob::DoResolveHostComplete(int result) {
- connect_timing_.dns_end = base::TimeTicks::Now();
- // Overwrite connection start time, since for connections that do not go
- // through proxies, |connect_start| should not include dns lookup time.
- connect_timing_.connect_start = connect_timing_.dns_end;
-
- if (result == OK) {
- // Invoke callback, and abort if it fails.
- if (!params_->host_resolution_callback().is_null())
- result = params_->host_resolution_callback().Run(addresses_, net_log());
-
- if (result == OK)
- next_state_ = STATE_TRANSPORT_CONNECT;
- }
- return result;
+ return helper_.DoResolveHostComplete(result, net_log());
}
int TransportConnectJob::DoTransportConnect() {
@@ -216,42 +265,57 @@ int TransportConnectJob::DoTransportConnect() {
interval_between_connects_ = CONNECT_INTERVAL_GT_20MS;
}
- next_state_ = STATE_TRANSPORT_CONNECT_COMPLETE;
- transport_socket_ = client_socket_factory_->CreateTransportClientSocket(
- addresses_, net_log().net_log(), net_log().source());
- int rv = transport_socket_->Connect(
- base::Bind(&TransportConnectJob::OnIOComplete, base::Unretained(this)));
- if (rv == ERR_IO_PENDING &&
- addresses_.front().GetFamily() == ADDRESS_FAMILY_IPV6 &&
- !AddressListOnlyContainsIPv6(addresses_)) {
- fallback_timer_.Start(FROM_HERE,
- base::TimeDelta::FromMilliseconds(kIPv6FallbackTimerInMs),
- this, &TransportConnectJob::DoIPv6FallbackTransportConnect);
+ helper_.set_next_state(
+ TransportConnectJobHelper::STATE_TRANSPORT_CONNECT_COMPLETE);
+ transport_socket_ =
+ helper_.client_socket_factory()->CreateTransportClientSocket(
+ helper_.addresses(), net_log().net_log(), net_log().source());
+
+ // If the list contains IPv6 and IPv4 addresses, the first address will
+ // be IPv6, and the IPv4 addresses will be tried as fallback addresses,
+ // per "Happy Eyeballs" (RFC 6555).
+ bool try_ipv6_connect_with_ipv4_fallback =
+ helper_.addresses().front().GetFamily() == ADDRESS_FAMILY_IPV6 &&
+ !AddressListOnlyContainsIPv6(helper_.addresses());
+
+ // Enable TCP FastOpen if indicated by transport socket params.
+ // Note: We currently do not turn on TCP FastOpen for destinations where
+ // we try a TCP connect over IPv6 with fallback to IPv4.
+ if (!try_ipv6_connect_with_ipv4_fallback &&
+ helper_.params()->combine_connect_and_write() ==
+ TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DESIRED) {
+ transport_socket_->EnableTCPFastOpenIfSupported();
+ }
+
+ int rv = transport_socket_->Connect(helper_.on_io_complete());
+ if (rv == ERR_IO_PENDING && try_ipv6_connect_with_ipv4_fallback) {
+ fallback_timer_.Start(
+ FROM_HERE,
+ base::TimeDelta::FromMilliseconds(
+ TransportConnectJobHelper::kIPv6FallbackTimerInMs),
+ this,
+ &TransportConnectJob::DoIPv6FallbackTransportConnect);
}
return rv;
}
int TransportConnectJob::DoTransportConnectComplete(int result) {
if (result == OK) {
- bool is_ipv4 = addresses_.front().GetFamily() == ADDRESS_FAMILY_IPV4;
- DCHECK(!connect_timing_.connect_start.is_null());
- DCHECK(!connect_timing_.dns_start.is_null());
- base::TimeTicks now = base::TimeTicks::Now();
- base::TimeDelta total_duration = now - connect_timing_.dns_start;
- UMA_HISTOGRAM_CUSTOM_TIMES(
- "Net.DNS_Resolution_And_TCP_Connection_Latency2",
- total_duration,
- base::TimeDelta::FromMilliseconds(1),
- base::TimeDelta::FromMinutes(10),
- 100);
-
- base::TimeDelta connect_duration = now - connect_timing_.connect_start;
- UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency",
- connect_duration,
- base::TimeDelta::FromMilliseconds(1),
- base::TimeDelta::FromMinutes(10),
- 100);
-
+ bool is_ipv4 =
+ helper_.addresses().front().GetFamily() == ADDRESS_FAMILY_IPV4;
+ TransportConnectJobHelper::ConnectionLatencyHistogram race_result =
+ TransportConnectJobHelper::CONNECTION_LATENCY_UNKNOWN;
+ if (is_ipv4) {
+ race_result = TransportConnectJobHelper::CONNECTION_LATENCY_IPV4_NO_RACE;
+ } else {
+ if (AddressListOnlyContainsIPv6(helper_.addresses())) {
+ race_result = TransportConnectJobHelper::CONNECTION_LATENCY_IPV6_SOLO;
+ } else {
+ race_result =
+ TransportConnectJobHelper::CONNECTION_LATENCY_IPV6_RACEABLE;
+ }
+ }
+ base::TimeDelta connect_duration = helper_.HistogramDuration(race_result);
switch (interval_between_connects_) {
case CONNECT_INTERVAL_LE_10MS:
UMA_HISTOGRAM_CUSTOM_TIMES(
@@ -282,27 +346,6 @@ int TransportConnectJob::DoTransportConnectComplete(int result) {
break;
}
- if (is_ipv4) {
- UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency_IPv4_No_Race",
- connect_duration,
- base::TimeDelta::FromMilliseconds(1),
- base::TimeDelta::FromMinutes(10),
- 100);
- } else {
- if (AddressListOnlyContainsIPv6(addresses_)) {
- UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency_IPv6_Solo",
- connect_duration,
- base::TimeDelta::FromMilliseconds(1),
- base::TimeDelta::FromMinutes(10),
- 100);
- } else {
- UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency_IPv6_Raceable",
- connect_duration,
- base::TimeDelta::FromMilliseconds(1),
- base::TimeDelta::FromMinutes(10),
- 100);
- }
- }
SetSocket(transport_socket_.Pass());
fallback_timer_.Stop();
} else {
@@ -317,7 +360,8 @@ int TransportConnectJob::DoTransportConnectComplete(int result) {
void TransportConnectJob::DoIPv6FallbackTransportConnect() {
// The timer should only fire while we're waiting for the main connect to
// succeed.
- if (next_state_ != STATE_TRANSPORT_CONNECT_COMPLETE) {
+ if (helper_.next_state() !=
+ TransportConnectJobHelper::STATE_TRANSPORT_CONNECT_COMPLETE) {
NOTREACHED();
return;
}
@@ -325,10 +369,10 @@ void TransportConnectJob::DoIPv6FallbackTransportConnect() {
DCHECK(!fallback_transport_socket_.get());
DCHECK(!fallback_addresses_.get());
- fallback_addresses_.reset(new AddressList(addresses_));
+ fallback_addresses_.reset(new AddressList(helper_.addresses()));
MakeAddressListStartWithIPv4(fallback_addresses_.get());
fallback_transport_socket_ =
- client_socket_factory_->CreateTransportClientSocket(
+ helper_.client_socket_factory()->CreateTransportClientSocket(
*fallback_addresses_, net_log().net_log(), net_log().source());
fallback_connect_start_time_ = base::TimeTicks::Now();
int rv = fallback_transport_socket_->Connect(
@@ -341,7 +385,8 @@ void TransportConnectJob::DoIPv6FallbackTransportConnect() {
void TransportConnectJob::DoIPv6FallbackTransportConnectComplete(int result) {
// This should only happen when we're waiting for the main connect to succeed.
- if (next_state_ != STATE_TRANSPORT_CONNECT_COMPLETE) {
+ if (helper_.next_state() !=
+ TransportConnectJobHelper::STATE_TRANSPORT_CONNECT_COMPLETE) {
NOTREACHED();
return;
}
@@ -352,30 +397,11 @@ void TransportConnectJob::DoIPv6FallbackTransportConnectComplete(int result) {
if (result == OK) {
DCHECK(!fallback_connect_start_time_.is_null());
- DCHECK(!connect_timing_.dns_start.is_null());
- base::TimeTicks now = base::TimeTicks::Now();
- base::TimeDelta total_duration = now - connect_timing_.dns_start;
- UMA_HISTOGRAM_CUSTOM_TIMES(
- "Net.DNS_Resolution_And_TCP_Connection_Latency2",
- total_duration,
- base::TimeDelta::FromMilliseconds(1),
- base::TimeDelta::FromMinutes(10),
- 100);
-
- base::TimeDelta connect_duration = now - fallback_connect_start_time_;
- UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency",
- connect_duration,
- base::TimeDelta::FromMilliseconds(1),
- base::TimeDelta::FromMinutes(10),
- 100);
-
- UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency_IPv4_Wins_Race",
- connect_duration,
- base::TimeDelta::FromMilliseconds(1),
- base::TimeDelta::FromMinutes(10),
- 100);
+ connect_timing_.connect_start = fallback_connect_start_time_;
+ helper_.HistogramDuration(
+ TransportConnectJobHelper::CONNECTION_LATENCY_IPV4_WINS_RACE);
SetSocket(fallback_transport_socket_.Pass());
- next_state_ = STATE_NONE;
+ helper_.set_next_state(TransportConnectJobHelper::STATE_NONE);
transport_socket_.reset();
} else {
// Be a bit paranoid and kill off the fallback members to prevent reuse.
@@ -386,12 +412,11 @@ void TransportConnectJob::DoIPv6FallbackTransportConnectComplete(int result) {
}
int TransportConnectJob::ConnectInternal() {
- next_state_ = STATE_RESOLVE_HOST;
- return DoLoop(OK);
+ return helper_.DoConnectInternal(this);
}
scoped_ptr<ConnectJob>
- TransportClientSocketPool::TransportConnectJobFactory::NewConnectJob(
+TransportClientSocketPool::TransportConnectJobFactory::NewConnectJob(
const std::string& group_name,
const PoolBase::Request& request,
ConnectJob::Delegate* delegate) const {
@@ -439,6 +464,15 @@ int TransportClientSocketPool::RequestSocket(
const scoped_refptr<TransportSocketParams>* casted_params =
static_cast<const scoped_refptr<TransportSocketParams>*>(params);
+ NetLogTcpClientSocketPoolRequestedSocket(net_log, casted_params);
+
+ return base_.RequestSocket(group_name, *casted_params, priority, handle,
+ callback, net_log);
+}
+
+void TransportClientSocketPool::NetLogTcpClientSocketPoolRequestedSocket(
+ const BoundNetLog& net_log,
+ const scoped_refptr<TransportSocketParams>* casted_params) {
if (net_log.IsLogging()) {
// TODO(eroman): Split out the host and port parameters.
net_log.AddEvent(
@@ -446,9 +480,6 @@ int TransportClientSocketPool::RequestSocket(
CreateNetLogHostPortPairCallback(
&casted_params->get()->destination().host_port_pair()));
}
-
- return base_.RequestSocket(group_name, *casted_params, priority, handle,
- callback, net_log);
}
void TransportClientSocketPool::RequestSockets(
diff --git a/chromium/net/socket/transport_client_socket_pool.h b/chromium/net/socket/transport_client_socket_pool.h
index 1c22bf29ec3..15cef5c02ac 100644
--- a/chromium/net/socket/transport_client_socket_pool.h
+++ b/chromium/net/socket/transport_client_socket_pool.h
@@ -29,14 +29,30 @@ OnHostResolutionCallback;
class NET_EXPORT_PRIVATE TransportSocketParams
: public base::RefCounted<TransportSocketParams> {
public:
+ // CombineConnectAndWrite currently translates to using TCP FastOpen.
+ // TCP FastOpen should not be used if the first write to the socket may
+ // be non-idempotent, as the underlying socket could retransmit the data
+ // on failure of the first transmission.
+ // NOTE: Currently, COMBINE_CONNECT_AND_WRITE_DESIRED is used if the data in
+ // the write is known to be idempotent, and COMBINE_CONNECT_AND_WRITE_DEFAULT
+ // is used as a default for other cases (including non-idempotent writes).
+ enum CombineConnectAndWritePolicy {
+ COMBINE_CONNECT_AND_WRITE_DEFAULT, // Default policy, implemented in
+ // TransportSocketParams constructor.
+ COMBINE_CONNECT_AND_WRITE_DESIRED, // Combine if supported by socket.
+ COMBINE_CONNECT_AND_WRITE_PROHIBITED // Do not combine.
+ };
+
// |host_resolution_callback| will be invoked after the the hostname is
// resolved. If |host_resolution_callback| does not return OK, then the
- // connection will be aborted with that value.
+ // connection will be aborted with that value. |combine_connect_and_write|
+ // defines the policy for use of TCP FastOpen on this socket.
TransportSocketParams(
const HostPortPair& host_port_pair,
bool disable_resolver_cache,
bool ignore_limits,
- const OnHostResolutionCallback& host_resolution_callback);
+ const OnHostResolutionCallback& host_resolution_callback,
+ CombineConnectAndWritePolicy combine_connect_and_write);
const HostResolver::RequestInfo& destination() const { return destination_; }
bool ignore_limits() const { return ignore_limits_; }
@@ -44,6 +60,10 @@ class NET_EXPORT_PRIVATE TransportSocketParams
return host_resolution_callback_;
}
+ CombineConnectAndWritePolicy combine_connect_and_write() const {
+ return combine_connect_and_write_;
+ }
+
private:
friend class base::RefCounted<TransportSocketParams>;
~TransportSocketParams();
@@ -51,10 +71,81 @@ class NET_EXPORT_PRIVATE TransportSocketParams
HostResolver::RequestInfo destination_;
bool ignore_limits_;
const OnHostResolutionCallback host_resolution_callback_;
+ CombineConnectAndWritePolicy combine_connect_and_write_;
DISALLOW_COPY_AND_ASSIGN(TransportSocketParams);
};
+// Common data and logic shared between TransportConnectJob and
+// WebSocketTransportConnectJob.
+class NET_EXPORT_PRIVATE TransportConnectJobHelper {
+ public:
+ enum State {
+ STATE_RESOLVE_HOST,
+ STATE_RESOLVE_HOST_COMPLETE,
+ STATE_TRANSPORT_CONNECT,
+ STATE_TRANSPORT_CONNECT_COMPLETE,
+ STATE_NONE,
+ };
+
+ // For recording the connection time in the appropriate bucket.
+ enum ConnectionLatencyHistogram {
+ CONNECTION_LATENCY_UNKNOWN,
+ CONNECTION_LATENCY_IPV4_WINS_RACE,
+ CONNECTION_LATENCY_IPV4_NO_RACE,
+ CONNECTION_LATENCY_IPV6_RACEABLE,
+ CONNECTION_LATENCY_IPV6_SOLO,
+ };
+
+ TransportConnectJobHelper(const scoped_refptr<TransportSocketParams>& params,
+ ClientSocketFactory* client_socket_factory,
+ HostResolver* host_resolver,
+ LoadTimingInfo::ConnectTiming* connect_timing);
+ ~TransportConnectJobHelper();
+
+ ClientSocketFactory* client_socket_factory() {
+ return client_socket_factory_;
+ }
+
+ const AddressList& addresses() const { return addresses_; }
+ State next_state() const { return next_state_; }
+ void set_next_state(State next_state) { next_state_ = next_state; }
+ CompletionCallback on_io_complete() const { return on_io_complete_; }
+ const TransportSocketParams* params() { return params_.get(); }
+
+ int DoResolveHost(RequestPriority priority, const BoundNetLog& net_log);
+ int DoResolveHostComplete(int result, const BoundNetLog& net_log);
+
+ template <class T>
+ int DoConnectInternal(T* job);
+
+ template <class T>
+ void SetOnIOComplete(T* job);
+
+ template <class T>
+ void OnIOComplete(T* job, int result);
+
+ // Record the histograms Net.DNS_Resolution_And_TCP_Connection_Latency2 and
+ // Net.TCP_Connection_Latency and return the connect duration.
+ base::TimeDelta HistogramDuration(ConnectionLatencyHistogram race_result);
+
+ static const int kIPv6FallbackTimerInMs;
+
+ private:
+ template <class T>
+ int DoLoop(T* job, int result);
+
+ scoped_refptr<TransportSocketParams> params_;
+ ClientSocketFactory* const client_socket_factory_;
+ SingleRequestHostResolver resolver_;
+ AddressList addresses_;
+ State next_state_;
+ CompletionCallback on_io_complete_;
+ LoadTimingInfo::ConnectTiming* connect_timing_;
+
+ DISALLOW_COPY_AND_ASSIGN(TransportConnectJobHelper);
+};
+
// TransportConnectJob handles the host resolution necessary for socket creation
// and the transport (likely TCP) connect. TransportConnectJob also has fallback
// logic for IPv6 connect() timeouts (which may happen due to networks / routers
@@ -73,36 +164,23 @@ class NET_EXPORT_PRIVATE TransportConnectJob : public ConnectJob {
HostResolver* host_resolver,
Delegate* delegate,
NetLog* net_log);
- virtual ~TransportConnectJob();
+ ~TransportConnectJob() override;
// ConnectJob methods.
- virtual LoadState GetLoadState() const OVERRIDE;
+ LoadState GetLoadState() const override;
// Rolls |addrlist| forward until the first IPv4 address, if any.
// WARNING: this method should only be used to implement the prefer-IPv4 hack.
static void MakeAddressListStartWithIPv4(AddressList* addrlist);
- static const int kIPv6FallbackTimerInMs;
-
private:
- enum State {
- STATE_RESOLVE_HOST,
- STATE_RESOLVE_HOST_COMPLETE,
- STATE_TRANSPORT_CONNECT,
- STATE_TRANSPORT_CONNECT_COMPLETE,
- STATE_NONE,
- };
-
enum ConnectInterval {
CONNECT_INTERVAL_LE_10MS,
CONNECT_INTERVAL_LE_20MS,
CONNECT_INTERVAL_GT_20MS,
};
- void OnIOComplete(int result);
-
- // Runs the state transition loop.
- int DoLoop(int result);
+ friend class TransportConnectJobHelper;
int DoResolveHost();
int DoResolveHostComplete(int result);
@@ -116,13 +194,9 @@ class NET_EXPORT_PRIVATE TransportConnectJob : public ConnectJob {
// Begins the host resolution and the TCP connect. Returns OK on success
// and ERR_IO_PENDING if it cannot immediately service the request.
// Otherwise, it returns a net error code.
- virtual int ConnectInternal() OVERRIDE;
+ int ConnectInternal() override;
- scoped_refptr<TransportSocketParams> params_;
- ClientSocketFactory* const client_socket_factory_;
- SingleRequestHostResolver resolver_;
- AddressList addresses_;
- State next_state_;
+ TransportConnectJobHelper helper_;
scoped_ptr<StreamSocket> transport_socket_;
@@ -149,43 +223,47 @@ class NET_EXPORT_PRIVATE TransportClientSocketPool : public ClientSocketPool {
ClientSocketFactory* client_socket_factory,
NetLog* net_log);
- virtual ~TransportClientSocketPool();
+ ~TransportClientSocketPool() override;
// ClientSocketPool implementation.
- virtual int RequestSocket(const std::string& group_name,
- const void* resolve_info,
- RequestPriority priority,
- ClientSocketHandle* handle,
- const CompletionCallback& callback,
- const BoundNetLog& net_log) OVERRIDE;
- virtual void RequestSockets(const std::string& group_name,
- const void* params,
- int num_sockets,
- const BoundNetLog& net_log) OVERRIDE;
- virtual void CancelRequest(const std::string& group_name,
- ClientSocketHandle* handle) OVERRIDE;
- virtual void ReleaseSocket(const std::string& group_name,
- scoped_ptr<StreamSocket> socket,
- int id) OVERRIDE;
- virtual void FlushWithError(int error) OVERRIDE;
- virtual void CloseIdleSockets() OVERRIDE;
- virtual int IdleSocketCount() const OVERRIDE;
- virtual int IdleSocketCountInGroup(
- const std::string& group_name) const OVERRIDE;
- virtual LoadState GetLoadState(
- const std::string& group_name,
- const ClientSocketHandle* handle) const OVERRIDE;
- virtual base::DictionaryValue* GetInfoAsValue(
+ int RequestSocket(const std::string& group_name,
+ const void* resolve_info,
+ RequestPriority priority,
+ ClientSocketHandle* handle,
+ const CompletionCallback& callback,
+ const BoundNetLog& net_log) override;
+ void RequestSockets(const std::string& group_name,
+ const void* params,
+ int num_sockets,
+ const BoundNetLog& net_log) override;
+ void CancelRequest(const std::string& group_name,
+ ClientSocketHandle* handle) override;
+ void ReleaseSocket(const std::string& group_name,
+ scoped_ptr<StreamSocket> socket,
+ int id) override;
+ void FlushWithError(int error) override;
+ void CloseIdleSockets() override;
+ int IdleSocketCount() const override;
+ int IdleSocketCountInGroup(const std::string& group_name) const override;
+ LoadState GetLoadState(const std::string& group_name,
+ const ClientSocketHandle* handle) const override;
+ base::DictionaryValue* GetInfoAsValue(
const std::string& name,
const std::string& type,
- bool include_nested_pools) const OVERRIDE;
- virtual base::TimeDelta ConnectionTimeout() const OVERRIDE;
- virtual ClientSocketPoolHistograms* histograms() const OVERRIDE;
+ bool include_nested_pools) const override;
+ base::TimeDelta ConnectionTimeout() const override;
+ ClientSocketPoolHistograms* histograms() const override;
// HigherLayeredPool implementation.
- virtual bool IsStalled() const OVERRIDE;
- virtual void AddHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE;
- virtual void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE;
+ bool IsStalled() const override;
+ void AddHigherLayeredPool(HigherLayeredPool* higher_pool) override;
+ void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool) override;
+
+ protected:
+ // Methods shared with WebSocketTransportClientSocketPool
+ void NetLogTcpClientSocketPoolRequestedSocket(
+ const BoundNetLog& net_log,
+ const scoped_refptr<TransportSocketParams>* casted_params);
private:
typedef ClientSocketPoolBase<TransportSocketParams> PoolBase;
@@ -200,16 +278,16 @@ class NET_EXPORT_PRIVATE TransportClientSocketPool : public ClientSocketPool {
host_resolver_(host_resolver),
net_log_(net_log) {}
- virtual ~TransportConnectJobFactory() {}
+ ~TransportConnectJobFactory() override {}
// ClientSocketPoolBase::ConnectJobFactory methods.
- virtual scoped_ptr<ConnectJob> NewConnectJob(
+ scoped_ptr<ConnectJob> NewConnectJob(
const std::string& group_name,
const PoolBase::Request& request,
- ConnectJob::Delegate* delegate) const OVERRIDE;
+ ConnectJob::Delegate* delegate) const override;
- virtual base::TimeDelta ConnectionTimeout() const OVERRIDE;
+ base::TimeDelta ConnectionTimeout() const override;
private:
ClientSocketFactory* const client_socket_factory_;
@@ -224,6 +302,61 @@ class NET_EXPORT_PRIVATE TransportClientSocketPool : public ClientSocketPool {
DISALLOW_COPY_AND_ASSIGN(TransportClientSocketPool);
};
+template <class T>
+int TransportConnectJobHelper::DoConnectInternal(T* job) {
+ next_state_ = STATE_RESOLVE_HOST;
+ return this->DoLoop(job, OK);
+}
+
+template <class T>
+void TransportConnectJobHelper::SetOnIOComplete(T* job) {
+ // These usages of base::Unretained() are safe because IO callbacks are
+ // guaranteed not to be called after the object is destroyed.
+ on_io_complete_ = base::Bind(&TransportConnectJobHelper::OnIOComplete<T>,
+ base::Unretained(this),
+ base::Unretained(job));
+}
+
+template <class T>
+void TransportConnectJobHelper::OnIOComplete(T* job, int result) {
+ result = this->DoLoop(job, result);
+ if (result != ERR_IO_PENDING)
+ job->NotifyDelegateOfCompletion(result); // Deletes |job| and |this|
+}
+
+template <class T>
+int TransportConnectJobHelper::DoLoop(T* job, int result) {
+ DCHECK_NE(next_state_, STATE_NONE);
+
+ int rv = result;
+ do {
+ State state = next_state_;
+ next_state_ = STATE_NONE;
+ switch (state) {
+ case STATE_RESOLVE_HOST:
+ DCHECK_EQ(OK, rv);
+ rv = job->DoResolveHost();
+ break;
+ case STATE_RESOLVE_HOST_COMPLETE:
+ rv = job->DoResolveHostComplete(rv);
+ break;
+ case STATE_TRANSPORT_CONNECT:
+ DCHECK_EQ(OK, rv);
+ rv = job->DoTransportConnect();
+ break;
+ case STATE_TRANSPORT_CONNECT_COMPLETE:
+ rv = job->DoTransportConnectComplete(rv);
+ break;
+ default:
+ NOTREACHED();
+ rv = ERR_FAILED;
+ break;
+ }
+ } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
+
+ return rv;
+}
+
} // namespace net
#endif // NET_SOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_
diff --git a/chromium/net/socket/transport_client_socket_pool_test_util.cc b/chromium/net/socket/transport_client_socket_pool_test_util.cc
new file mode 100644
index 00000000000..82ed8e6a78e
--- /dev/null
+++ b/chromium/net/socket/transport_client_socket_pool_test_util.cc
@@ -0,0 +1,424 @@
+// Copyright 2014 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/transport_client_socket_pool_test_util.h"
+
+#include <string>
+
+#include "base/logging.h"
+#include "base/memory/weak_ptr.h"
+#include "base/run_loop.h"
+#include "net/base/ip_endpoint.h"
+#include "net/base/load_timing_info.h"
+#include "net/base/load_timing_info_test_util.h"
+#include "net/base/net_util.h"
+#include "net/socket/client_socket_handle.h"
+#include "net/socket/ssl_client_socket.h"
+#include "net/udp/datagram_client_socket.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace net {
+
+namespace {
+
+IPAddressNumber ParseIP(const std::string& ip) {
+ IPAddressNumber number;
+ CHECK(ParseIPLiteralToNumber(ip, &number));
+ return number;
+}
+
+// A StreamSocket which connects synchronously and successfully.
+class MockConnectClientSocket : public StreamSocket {
+ public:
+ MockConnectClientSocket(const AddressList& addrlist, net::NetLog* net_log)
+ : connected_(false),
+ addrlist_(addrlist),
+ net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)),
+ use_tcp_fastopen_(false) {}
+
+ // StreamSocket implementation.
+ int Connect(const CompletionCallback& callback) override {
+ connected_ = true;
+ return OK;
+ }
+ void Disconnect() override { connected_ = false; }
+ bool IsConnected() const override { return connected_; }
+ bool IsConnectedAndIdle() const override { return connected_; }
+
+ int GetPeerAddress(IPEndPoint* address) const override {
+ *address = addrlist_.front();
+ return OK;
+ }
+ int GetLocalAddress(IPEndPoint* address) const override {
+ if (!connected_)
+ return ERR_SOCKET_NOT_CONNECTED;
+ if (addrlist_.front().GetFamily() == ADDRESS_FAMILY_IPV4)
+ SetIPv4Address(address);
+ else
+ SetIPv6Address(address);
+ return OK;
+ }
+ const BoundNetLog& NetLog() const override { return net_log_; }
+
+ void SetSubresourceSpeculation() override {}
+ void SetOmniboxSpeculation() override {}
+ bool WasEverUsed() const override { return false; }
+ void EnableTCPFastOpenIfSupported() override { use_tcp_fastopen_ = true; }
+ bool UsingTCPFastOpen() const override { return use_tcp_fastopen_; }
+ bool WasNpnNegotiated() const override { return false; }
+ NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
+ bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
+
+ // Socket implementation.
+ int Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override {
+ return ERR_FAILED;
+ }
+ int Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override {
+ return ERR_FAILED;
+ }
+ int SetReceiveBufferSize(int32 size) override { return OK; }
+ int SetSendBufferSize(int32 size) override { return OK; }
+
+ private:
+ bool connected_;
+ const AddressList addrlist_;
+ BoundNetLog net_log_;
+ bool use_tcp_fastopen_;
+
+ DISALLOW_COPY_AND_ASSIGN(MockConnectClientSocket);
+};
+
+class MockFailingClientSocket : public StreamSocket {
+ public:
+ MockFailingClientSocket(const AddressList& addrlist, net::NetLog* net_log)
+ : addrlist_(addrlist),
+ net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)),
+ use_tcp_fastopen_(false) {}
+
+ // StreamSocket implementation.
+ int Connect(const CompletionCallback& callback) override {
+ return ERR_CONNECTION_FAILED;
+ }
+
+ void Disconnect() override {}
+
+ bool IsConnected() const override { return false; }
+ bool IsConnectedAndIdle() const override { return false; }
+ int GetPeerAddress(IPEndPoint* address) const override {
+ return ERR_UNEXPECTED;
+ }
+ int GetLocalAddress(IPEndPoint* address) const override {
+ return ERR_UNEXPECTED;
+ }
+ const BoundNetLog& NetLog() const override { return net_log_; }
+
+ void SetSubresourceSpeculation() override {}
+ void SetOmniboxSpeculation() override {}
+ bool WasEverUsed() const override { return false; }
+ void EnableTCPFastOpenIfSupported() override { use_tcp_fastopen_ = true; }
+ bool UsingTCPFastOpen() const override { return use_tcp_fastopen_; }
+ bool WasNpnNegotiated() const override { return false; }
+ NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
+ bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
+
+ // Socket implementation.
+ int Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override {
+ return ERR_FAILED;
+ }
+
+ int Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override {
+ return ERR_FAILED;
+ }
+ int SetReceiveBufferSize(int32 size) override { return OK; }
+ int SetSendBufferSize(int32 size) override { return OK; }
+
+ private:
+ const AddressList addrlist_;
+ BoundNetLog net_log_;
+ bool use_tcp_fastopen_;
+
+ DISALLOW_COPY_AND_ASSIGN(MockFailingClientSocket);
+};
+
+class MockTriggerableClientSocket : public StreamSocket {
+ public:
+ // |should_connect| indicates whether the socket should successfully complete
+ // or fail.
+ MockTriggerableClientSocket(const AddressList& addrlist,
+ bool should_connect,
+ net::NetLog* net_log)
+ : should_connect_(should_connect),
+ is_connected_(false),
+ addrlist_(addrlist),
+ net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)),
+ use_tcp_fastopen_(false),
+ weak_factory_(this) {}
+
+ // Call this method to get a closure which will trigger the connect callback
+ // when called. The closure can be called even after the socket is deleted; it
+ // will safely do nothing.
+ base::Closure GetConnectCallback() {
+ return base::Bind(&MockTriggerableClientSocket::DoCallback,
+ weak_factory_.GetWeakPtr());
+ }
+
+ static scoped_ptr<StreamSocket> MakeMockPendingClientSocket(
+ const AddressList& addrlist,
+ bool should_connect,
+ net::NetLog* net_log) {
+ scoped_ptr<MockTriggerableClientSocket> socket(
+ new MockTriggerableClientSocket(addrlist, should_connect, net_log));
+ base::MessageLoop::current()->PostTask(FROM_HERE,
+ socket->GetConnectCallback());
+ return socket.Pass();
+ }
+
+ static scoped_ptr<StreamSocket> MakeMockDelayedClientSocket(
+ const AddressList& addrlist,
+ bool should_connect,
+ const base::TimeDelta& delay,
+ net::NetLog* net_log) {
+ scoped_ptr<MockTriggerableClientSocket> socket(
+ new MockTriggerableClientSocket(addrlist, should_connect, net_log));
+ base::MessageLoop::current()->PostDelayedTask(
+ FROM_HERE, socket->GetConnectCallback(), delay);
+ return socket.Pass();
+ }
+
+ static scoped_ptr<StreamSocket> MakeMockStalledClientSocket(
+ const AddressList& addrlist,
+ net::NetLog* net_log) {
+ scoped_ptr<MockTriggerableClientSocket> socket(
+ new MockTriggerableClientSocket(addrlist, true, net_log));
+ return socket.Pass();
+ }
+
+ // StreamSocket implementation.
+ int Connect(const CompletionCallback& callback) override {
+ DCHECK(callback_.is_null());
+ callback_ = callback;
+ return ERR_IO_PENDING;
+ }
+
+ void Disconnect() override {}
+
+ bool IsConnected() const override { return is_connected_; }
+ bool IsConnectedAndIdle() const override { return is_connected_; }
+ int GetPeerAddress(IPEndPoint* address) const override {
+ *address = addrlist_.front();
+ return OK;
+ }
+ int GetLocalAddress(IPEndPoint* address) const override {
+ if (!is_connected_)
+ return ERR_SOCKET_NOT_CONNECTED;
+ if (addrlist_.front().GetFamily() == ADDRESS_FAMILY_IPV4)
+ SetIPv4Address(address);
+ else
+ SetIPv6Address(address);
+ return OK;
+ }
+ const BoundNetLog& NetLog() const override { return net_log_; }
+
+ void SetSubresourceSpeculation() override {}
+ void SetOmniboxSpeculation() override {}
+ bool WasEverUsed() const override { return false; }
+ void EnableTCPFastOpenIfSupported() override { use_tcp_fastopen_ = true; }
+ bool UsingTCPFastOpen() const override { return use_tcp_fastopen_; }
+ bool WasNpnNegotiated() const override { return false; }
+ NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
+ bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
+
+ // Socket implementation.
+ int Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override {
+ return ERR_FAILED;
+ }
+
+ int Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override {
+ return ERR_FAILED;
+ }
+ int SetReceiveBufferSize(int32 size) override { return OK; }
+ int SetSendBufferSize(int32 size) override { return OK; }
+
+ private:
+ void DoCallback() {
+ is_connected_ = should_connect_;
+ callback_.Run(is_connected_ ? OK : ERR_CONNECTION_FAILED);
+ }
+
+ bool should_connect_;
+ bool is_connected_;
+ const AddressList addrlist_;
+ BoundNetLog net_log_;
+ CompletionCallback callback_;
+ bool use_tcp_fastopen_;
+
+ base::WeakPtrFactory<MockTriggerableClientSocket> weak_factory_;
+
+ DISALLOW_COPY_AND_ASSIGN(MockTriggerableClientSocket);
+};
+
+} // namespace
+
+void TestLoadTimingInfoConnectedReused(const ClientSocketHandle& handle) {
+ LoadTimingInfo load_timing_info;
+ // Only pass true in as |is_reused|, as in general, HttpStream types should
+ // have stricter concepts of reuse than socket pools.
+ EXPECT_TRUE(handle.GetLoadTimingInfo(true, &load_timing_info));
+
+ EXPECT_TRUE(load_timing_info.socket_reused);
+ EXPECT_NE(NetLog::Source::kInvalidId, load_timing_info.socket_log_id);
+
+ ExpectConnectTimingHasNoTimes(load_timing_info.connect_timing);
+ ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info);
+}
+
+void TestLoadTimingInfoConnectedNotReused(const ClientSocketHandle& handle) {
+ EXPECT_FALSE(handle.is_reused());
+
+ LoadTimingInfo load_timing_info;
+ EXPECT_TRUE(handle.GetLoadTimingInfo(false, &load_timing_info));
+
+ EXPECT_FALSE(load_timing_info.socket_reused);
+ EXPECT_NE(NetLog::Source::kInvalidId, load_timing_info.socket_log_id);
+
+ ExpectConnectTimingHasTimes(load_timing_info.connect_timing,
+ CONNECT_TIMING_HAS_DNS_TIMES);
+ ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info);
+
+ TestLoadTimingInfoConnectedReused(handle);
+}
+
+void SetIPv4Address(IPEndPoint* address) {
+ *address = IPEndPoint(ParseIP("1.1.1.1"), 80);
+}
+
+void SetIPv6Address(IPEndPoint* address) {
+ *address = IPEndPoint(ParseIP("1:abcd::3:4:ff"), 80);
+}
+
+MockTransportClientSocketFactory::MockTransportClientSocketFactory(
+ NetLog* net_log)
+ : net_log_(net_log),
+ allocation_count_(0),
+ client_socket_type_(MOCK_CLIENT_SOCKET),
+ client_socket_types_(NULL),
+ client_socket_index_(0),
+ client_socket_index_max_(0),
+ delay_(base::TimeDelta::FromMilliseconds(
+ ClientSocketPool::kMaxConnectRetryIntervalMs)) {}
+
+MockTransportClientSocketFactory::~MockTransportClientSocketFactory() {}
+
+scoped_ptr<DatagramClientSocket>
+MockTransportClientSocketFactory::CreateDatagramClientSocket(
+ DatagramSocket::BindType bind_type,
+ const RandIntCallback& rand_int_cb,
+ NetLog* net_log,
+ const NetLog::Source& source) {
+ NOTREACHED();
+ return scoped_ptr<DatagramClientSocket>();
+}
+
+scoped_ptr<StreamSocket>
+MockTransportClientSocketFactory::CreateTransportClientSocket(
+ const AddressList& addresses,
+ NetLog* /* net_log */,
+ const NetLog::Source& /* source */) {
+ allocation_count_++;
+
+ ClientSocketType type = client_socket_type_;
+ if (client_socket_types_ && client_socket_index_ < client_socket_index_max_) {
+ type = client_socket_types_[client_socket_index_++];
+ }
+
+ switch (type) {
+ case MOCK_CLIENT_SOCKET:
+ return scoped_ptr<StreamSocket>(
+ new MockConnectClientSocket(addresses, net_log_));
+ case MOCK_FAILING_CLIENT_SOCKET:
+ return scoped_ptr<StreamSocket>(
+ new MockFailingClientSocket(addresses, net_log_));
+ case MOCK_PENDING_CLIENT_SOCKET:
+ return MockTriggerableClientSocket::MakeMockPendingClientSocket(
+ addresses, true, net_log_);
+ case MOCK_PENDING_FAILING_CLIENT_SOCKET:
+ return MockTriggerableClientSocket::MakeMockPendingClientSocket(
+ addresses, false, net_log_);
+ case MOCK_DELAYED_CLIENT_SOCKET:
+ return MockTriggerableClientSocket::MakeMockDelayedClientSocket(
+ addresses, true, delay_, net_log_);
+ case MOCK_DELAYED_FAILING_CLIENT_SOCKET:
+ return MockTriggerableClientSocket::MakeMockDelayedClientSocket(
+ addresses, false, delay_, net_log_);
+ case MOCK_STALLED_CLIENT_SOCKET:
+ return MockTriggerableClientSocket::MakeMockStalledClientSocket(addresses,
+ net_log_);
+ case MOCK_TRIGGERABLE_CLIENT_SOCKET: {
+ scoped_ptr<MockTriggerableClientSocket> rv(
+ new MockTriggerableClientSocket(addresses, true, net_log_));
+ triggerable_sockets_.push(rv->GetConnectCallback());
+ // run_loop_quit_closure_ behaves like a condition variable. It will
+ // wake up WaitForTriggerableSocketCreation() if it is sleeping. We
+ // don't need to worry about atomicity because this code is
+ // single-threaded.
+ if (!run_loop_quit_closure_.is_null())
+ run_loop_quit_closure_.Run();
+ return rv.Pass();
+ }
+ default:
+ NOTREACHED();
+ return scoped_ptr<StreamSocket>(
+ new MockConnectClientSocket(addresses, net_log_));
+ }
+}
+
+scoped_ptr<SSLClientSocket>
+MockTransportClientSocketFactory::CreateSSLClientSocket(
+ scoped_ptr<ClientSocketHandle> transport_socket,
+ const HostPortPair& host_and_port,
+ const SSLConfig& ssl_config,
+ const SSLClientSocketContext& context) {
+ NOTIMPLEMENTED();
+ return scoped_ptr<SSLClientSocket>();
+}
+
+void MockTransportClientSocketFactory::ClearSSLSessionCache() {
+ NOTIMPLEMENTED();
+}
+
+void MockTransportClientSocketFactory::set_client_socket_types(
+ ClientSocketType* type_list,
+ int num_types) {
+ DCHECK_GT(num_types, 0);
+ client_socket_types_ = type_list;
+ client_socket_index_ = 0;
+ client_socket_index_max_ = num_types;
+}
+
+base::Closure
+MockTransportClientSocketFactory::WaitForTriggerableSocketCreation() {
+ while (triggerable_sockets_.empty()) {
+ base::RunLoop run_loop;
+ run_loop_quit_closure_ = run_loop.QuitClosure();
+ run_loop.Run();
+ run_loop_quit_closure_.Reset();
+ }
+ base::Closure trigger = triggerable_sockets_.front();
+ triggerable_sockets_.pop();
+ return trigger;
+}
+
+} // namespace net
diff --git a/chromium/net/socket/transport_client_socket_pool_test_util.h b/chromium/net/socket/transport_client_socket_pool_test_util.h
new file mode 100644
index 00000000000..b375353f06f
--- /dev/null
+++ b/chromium/net/socket/transport_client_socket_pool_test_util.h
@@ -0,0 +1,127 @@
+// Copyright 2014 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+// Test methods and classes common to transport_client_socket_pool_unittest.cc
+// and websocket_transport_client_socket_pool_unittest.cc. If you find you need
+// to use these for another purpose, consider moving them to socket_test_util.h.
+
+#ifndef NET_SOCKET_TRANSPORT_CLIENT_SOCKET_POOL_TEST_UTIL_H_
+#define NET_SOCKET_TRANSPORT_CLIENT_SOCKET_POOL_TEST_UTIL_H_
+
+#include <queue>
+
+#include "base/callback.h"
+#include "base/compiler_specific.h"
+#include "base/macros.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/time/time.h"
+#include "net/base/address_list.h"
+#include "net/base/net_log.h"
+#include "net/socket/client_socket_factory.h"
+#include "net/socket/client_socket_handle.h"
+#include "net/socket/stream_socket.h"
+
+namespace net {
+
+class ClientSocketHandle;
+class IPEndPoint;
+
+// Make sure |handle| sets load times correctly when it has been assigned a
+// reused socket. Uses gtest expectations.
+void TestLoadTimingInfoConnectedReused(const ClientSocketHandle& handle);
+
+// Make sure |handle| sets load times correctly when it has been assigned a
+// fresh socket. Also runs TestLoadTimingInfoConnectedReused, since the owner
+// of a connection where |is_reused| is false may consider the connection
+// reused. Uses gtest expectations.
+void TestLoadTimingInfoConnectedNotReused(const ClientSocketHandle& handle);
+
+// Set |address| to 1.1.1.1:80
+void SetIPv4Address(IPEndPoint* address);
+
+// Set |address| to [1:abcd::3:4:ff]:80
+void SetIPv6Address(IPEndPoint* address);
+
+// A ClientSocketFactory that produces sockets with the specified connection
+// behaviours.
+class MockTransportClientSocketFactory : public ClientSocketFactory {
+ public:
+ enum ClientSocketType {
+ // Connects successfully, synchronously.
+ MOCK_CLIENT_SOCKET,
+ // Fails to connect, synchronously.
+ MOCK_FAILING_CLIENT_SOCKET,
+ // Connects successfully, asynchronously.
+ MOCK_PENDING_CLIENT_SOCKET,
+ // Fails to connect, asynchronously.
+ MOCK_PENDING_FAILING_CLIENT_SOCKET,
+ // A delayed socket will pause before connecting through the message loop.
+ MOCK_DELAYED_CLIENT_SOCKET,
+ // A delayed socket that fails.
+ MOCK_DELAYED_FAILING_CLIENT_SOCKET,
+ // A stalled socket that never connects at all.
+ MOCK_STALLED_CLIENT_SOCKET,
+ // A socket that can be triggered to connect explicitly, asynchronously.
+ MOCK_TRIGGERABLE_CLIENT_SOCKET,
+ };
+
+ explicit MockTransportClientSocketFactory(NetLog* net_log);
+ ~MockTransportClientSocketFactory() override;
+
+ scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket(
+ DatagramSocket::BindType bind_type,
+ const RandIntCallback& rand_int_cb,
+ NetLog* net_log,
+ const NetLog::Source& source) override;
+
+ scoped_ptr<StreamSocket> CreateTransportClientSocket(
+ const AddressList& addresses,
+ NetLog* /* net_log */,
+ const NetLog::Source& /* source */) override;
+
+ scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
+ scoped_ptr<ClientSocketHandle> transport_socket,
+ const HostPortPair& host_and_port,
+ const SSLConfig& ssl_config,
+ const SSLClientSocketContext& context) override;
+
+ void ClearSSLSessionCache() override;
+
+ int allocation_count() const { return allocation_count_; }
+
+ // Set the default ClientSocketType.
+ void set_default_client_socket_type(ClientSocketType type) {
+ client_socket_type_ = type;
+ }
+
+ // Set a list of ClientSocketTypes to be used.
+ void set_client_socket_types(ClientSocketType* type_list, int num_types);
+
+ void set_delay(base::TimeDelta delay) { delay_ = delay; }
+
+ // If one or more MOCK_TRIGGERABLE_CLIENT_SOCKETs has already been created,
+ // then returns a Closure that can be called to cause the first
+ // not-yet-connected one to connect. If no MOCK_TRIGGERABLE_CLIENT_SOCKETs
+ // have been created yet, wait for one to be created before returning the
+ // Closure. This method should be called the same number of times as
+ // MOCK_TRIGGERABLE_CLIENT_SOCKETs are created in the test.
+ base::Closure WaitForTriggerableSocketCreation();
+
+ private:
+ NetLog* net_log_;
+ int allocation_count_;
+ ClientSocketType client_socket_type_;
+ ClientSocketType* client_socket_types_;
+ int client_socket_index_;
+ int client_socket_index_max_;
+ base::TimeDelta delay_;
+ std::queue<base::Closure> triggerable_sockets_;
+ base::Closure run_loop_quit_closure_;
+
+ DISALLOW_COPY_AND_ASSIGN(MockTransportClientSocketFactory);
+};
+
+} // namespace net
+
+#endif // NET_SOCKET_TRANSPORT_CLIENT_SOCKET_POOL_TEST_UTIL_H_
diff --git a/chromium/net/socket/transport_client_socket_pool_unittest.cc b/chromium/net/socket/transport_client_socket_pool_unittest.cc
index 425bb8cc421..c0687ef5a43 100644
--- a/chromium/net/socket/transport_client_socket_pool_unittest.cc
+++ b/chromium/net/socket/transport_client_socket_pool_unittest.cc
@@ -7,8 +7,6 @@
#include "base/bind.h"
#include "base/bind_helpers.h"
#include "base/callback.h"
-#include "base/compiler_specific.h"
-#include "base/logging.h"
#include "base/message_loop/message_loop.h"
#include "base/threading/platform_thread.h"
#include "net/base/capturing_net_log.h"
@@ -19,12 +17,11 @@
#include "net/base/net_util.h"
#include "net/base/test_completion_callback.h"
#include "net/dns/mock_host_resolver.h"
-#include "net/socket/client_socket_factory.h"
#include "net/socket/client_socket_handle.h"
#include "net/socket/client_socket_pool_histograms.h"
#include "net/socket/socket_test_util.h"
-#include "net/socket/ssl_client_socket.h"
#include "net/socket/stream_socket.h"
+#include "net/socket/transport_client_socket_pool_test_util.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace net {
@@ -35,411 +32,7 @@ namespace {
const int kMaxSockets = 32;
const int kMaxSocketsPerGroup = 6;
-const net::RequestPriority kDefaultPriority = LOW;
-
-// Make sure |handle| sets load times correctly when it has been assigned a
-// reused socket.
-void TestLoadTimingInfoConnectedReused(const ClientSocketHandle& handle) {
- LoadTimingInfo load_timing_info;
- // Only pass true in as |is_reused|, as in general, HttpStream types should
- // have stricter concepts of reuse than socket pools.
- EXPECT_TRUE(handle.GetLoadTimingInfo(true, &load_timing_info));
-
- EXPECT_TRUE(load_timing_info.socket_reused);
- EXPECT_NE(NetLog::Source::kInvalidId, load_timing_info.socket_log_id);
-
- ExpectConnectTimingHasNoTimes(load_timing_info.connect_timing);
- ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info);
-}
-
-// Make sure |handle| sets load times correctly when it has been assigned a
-// fresh socket. Also runs TestLoadTimingInfoConnectedReused, since the owner
-// of a connection where |is_reused| is false may consider the connection
-// reused.
-void TestLoadTimingInfoConnectedNotReused(const ClientSocketHandle& handle) {
- EXPECT_FALSE(handle.is_reused());
-
- LoadTimingInfo load_timing_info;
- EXPECT_TRUE(handle.GetLoadTimingInfo(false, &load_timing_info));
-
- EXPECT_FALSE(load_timing_info.socket_reused);
- EXPECT_NE(NetLog::Source::kInvalidId, load_timing_info.socket_log_id);
-
- ExpectConnectTimingHasTimes(load_timing_info.connect_timing,
- CONNECT_TIMING_HAS_DNS_TIMES);
- ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info);
-
- TestLoadTimingInfoConnectedReused(handle);
-}
-
-void SetIPv4Address(IPEndPoint* address) {
- IPAddressNumber number;
- CHECK(ParseIPLiteralToNumber("1.1.1.1", &number));
- *address = IPEndPoint(number, 80);
-}
-
-void SetIPv6Address(IPEndPoint* address) {
- IPAddressNumber number;
- CHECK(ParseIPLiteralToNumber("1:abcd::3:4:ff", &number));
- *address = IPEndPoint(number, 80);
-}
-
-class MockClientSocket : public StreamSocket {
- public:
- MockClientSocket(const AddressList& addrlist, net::NetLog* net_log)
- : connected_(false),
- addrlist_(addrlist),
- net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) {
- }
-
- // StreamSocket implementation.
- virtual int Connect(const CompletionCallback& callback) OVERRIDE {
- connected_ = true;
- return OK;
- }
- virtual void Disconnect() OVERRIDE {
- connected_ = false;
- }
- virtual bool IsConnected() const OVERRIDE {
- return connected_;
- }
- virtual bool IsConnectedAndIdle() const OVERRIDE {
- return connected_;
- }
- virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE {
- return ERR_UNEXPECTED;
- }
- virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE {
- if (!connected_)
- return ERR_SOCKET_NOT_CONNECTED;
- if (addrlist_.front().GetFamily() == ADDRESS_FAMILY_IPV4)
- SetIPv4Address(address);
- else
- SetIPv6Address(address);
- return OK;
- }
- virtual const BoundNetLog& NetLog() const OVERRIDE {
- return net_log_;
- }
-
- virtual void SetSubresourceSpeculation() OVERRIDE {}
- virtual void SetOmniboxSpeculation() OVERRIDE {}
- virtual bool WasEverUsed() const OVERRIDE { return false; }
- virtual bool UsingTCPFastOpen() const OVERRIDE { return false; }
- virtual bool WasNpnNegotiated() const OVERRIDE {
- return false;
- }
- virtual NextProto GetNegotiatedProtocol() const OVERRIDE {
- return kProtoUnknown;
- }
- virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE {
- return false;
- }
-
- // Socket implementation.
- virtual int Read(IOBuffer* buf, int buf_len,
- const CompletionCallback& callback) OVERRIDE {
- return ERR_FAILED;
- }
- virtual int Write(IOBuffer* buf, int buf_len,
- const CompletionCallback& callback) OVERRIDE {
- return ERR_FAILED;
- }
- virtual int SetReceiveBufferSize(int32 size) OVERRIDE { return OK; }
- virtual int SetSendBufferSize(int32 size) OVERRIDE { return OK; }
-
- private:
- bool connected_;
- const AddressList addrlist_;
- BoundNetLog net_log_;
-
- DISALLOW_COPY_AND_ASSIGN(MockClientSocket);
-};
-
-class MockFailingClientSocket : public StreamSocket {
- public:
- MockFailingClientSocket(const AddressList& addrlist, net::NetLog* net_log)
- : addrlist_(addrlist),
- net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) {
- }
-
- // StreamSocket implementation.
- virtual int Connect(const CompletionCallback& callback) OVERRIDE {
- return ERR_CONNECTION_FAILED;
- }
-
- virtual void Disconnect() OVERRIDE {}
-
- virtual bool IsConnected() const OVERRIDE {
- return false;
- }
- virtual bool IsConnectedAndIdle() const OVERRIDE {
- return false;
- }
- virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE {
- return ERR_UNEXPECTED;
- }
- virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE {
- return ERR_UNEXPECTED;
- }
- virtual const BoundNetLog& NetLog() const OVERRIDE {
- return net_log_;
- }
-
- virtual void SetSubresourceSpeculation() OVERRIDE {}
- virtual void SetOmniboxSpeculation() OVERRIDE {}
- virtual bool WasEverUsed() const OVERRIDE { return false; }
- virtual bool UsingTCPFastOpen() const OVERRIDE { return false; }
- virtual bool WasNpnNegotiated() const OVERRIDE {
- return false;
- }
- virtual NextProto GetNegotiatedProtocol() const OVERRIDE {
- return kProtoUnknown;
- }
- virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE {
- return false;
- }
-
- // Socket implementation.
- virtual int Read(IOBuffer* buf, int buf_len,
- const CompletionCallback& callback) OVERRIDE {
- return ERR_FAILED;
- }
-
- virtual int Write(IOBuffer* buf, int buf_len,
- const CompletionCallback& callback) OVERRIDE {
- return ERR_FAILED;
- }
- virtual int SetReceiveBufferSize(int32 size) OVERRIDE { return OK; }
- virtual int SetSendBufferSize(int32 size) OVERRIDE { return OK; }
-
- private:
- const AddressList addrlist_;
- BoundNetLog net_log_;
-
- DISALLOW_COPY_AND_ASSIGN(MockFailingClientSocket);
-};
-
-class MockPendingClientSocket : public StreamSocket {
- public:
- // |should_connect| indicates whether the socket should successfully complete
- // or fail.
- // |should_stall| indicates that this socket should never connect.
- // |delay_ms| is the delay, in milliseconds, before simulating a connect.
- MockPendingClientSocket(
- const AddressList& addrlist,
- bool should_connect,
- bool should_stall,
- base::TimeDelta delay,
- net::NetLog* net_log)
- : should_connect_(should_connect),
- should_stall_(should_stall),
- delay_(delay),
- is_connected_(false),
- addrlist_(addrlist),
- net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)),
- weak_factory_(this) {
- }
-
- // StreamSocket implementation.
- virtual int Connect(const CompletionCallback& callback) OVERRIDE {
- base::MessageLoop::current()->PostDelayedTask(
- FROM_HERE,
- base::Bind(&MockPendingClientSocket::DoCallback,
- weak_factory_.GetWeakPtr(), callback),
- delay_);
- return ERR_IO_PENDING;
- }
-
- virtual void Disconnect() OVERRIDE {}
-
- virtual bool IsConnected() const OVERRIDE {
- return is_connected_;
- }
- virtual bool IsConnectedAndIdle() const OVERRIDE {
- return is_connected_;
- }
- virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE {
- return ERR_UNEXPECTED;
- }
- virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE {
- if (!is_connected_)
- return ERR_SOCKET_NOT_CONNECTED;
- if (addrlist_.front().GetFamily() == ADDRESS_FAMILY_IPV4)
- SetIPv4Address(address);
- else
- SetIPv6Address(address);
- return OK;
- }
- virtual const BoundNetLog& NetLog() const OVERRIDE {
- return net_log_;
- }
-
- virtual void SetSubresourceSpeculation() OVERRIDE {}
- virtual void SetOmniboxSpeculation() OVERRIDE {}
- virtual bool WasEverUsed() const OVERRIDE { return false; }
- virtual bool UsingTCPFastOpen() const OVERRIDE { return false; }
- virtual bool WasNpnNegotiated() const OVERRIDE {
- return false;
- }
- virtual NextProto GetNegotiatedProtocol() const OVERRIDE {
- return kProtoUnknown;
- }
- virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE {
- return false;
- }
-
- // Socket implementation.
- virtual int Read(IOBuffer* buf, int buf_len,
- const CompletionCallback& callback) OVERRIDE {
- return ERR_FAILED;
- }
-
- virtual int Write(IOBuffer* buf, int buf_len,
- const CompletionCallback& callback) OVERRIDE {
- return ERR_FAILED;
- }
- virtual int SetReceiveBufferSize(int32 size) OVERRIDE { return OK; }
- virtual int SetSendBufferSize(int32 size) OVERRIDE { return OK; }
-
- private:
- void DoCallback(const CompletionCallback& callback) {
- if (should_stall_)
- return;
-
- if (should_connect_) {
- is_connected_ = true;
- callback.Run(OK);
- } else {
- is_connected_ = false;
- callback.Run(ERR_CONNECTION_FAILED);
- }
- }
-
- bool should_connect_;
- bool should_stall_;
- base::TimeDelta delay_;
- bool is_connected_;
- const AddressList addrlist_;
- BoundNetLog net_log_;
-
- base::WeakPtrFactory<MockPendingClientSocket> weak_factory_;
-
- DISALLOW_COPY_AND_ASSIGN(MockPendingClientSocket);
-};
-
-class MockClientSocketFactory : public ClientSocketFactory {
- public:
- enum ClientSocketType {
- MOCK_CLIENT_SOCKET,
- MOCK_FAILING_CLIENT_SOCKET,
- MOCK_PENDING_CLIENT_SOCKET,
- MOCK_PENDING_FAILING_CLIENT_SOCKET,
- // A delayed socket will pause before connecting through the message loop.
- MOCK_DELAYED_CLIENT_SOCKET,
- // A stalled socket that never connects at all.
- MOCK_STALLED_CLIENT_SOCKET,
- };
-
- explicit MockClientSocketFactory(NetLog* net_log)
- : net_log_(net_log), allocation_count_(0),
- client_socket_type_(MOCK_CLIENT_SOCKET), client_socket_types_(NULL),
- client_socket_index_(0), client_socket_index_max_(0),
- delay_(base::TimeDelta::FromMilliseconds(
- ClientSocketPool::kMaxConnectRetryIntervalMs)) {}
-
- virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket(
- DatagramSocket::BindType bind_type,
- const RandIntCallback& rand_int_cb,
- NetLog* net_log,
- const NetLog::Source& source) OVERRIDE {
- NOTREACHED();
- return scoped_ptr<DatagramClientSocket>();
- }
-
- virtual scoped_ptr<StreamSocket> CreateTransportClientSocket(
- const AddressList& addresses,
- NetLog* /* net_log */,
- const NetLog::Source& /* source */) OVERRIDE {
- allocation_count_++;
-
- ClientSocketType type = client_socket_type_;
- if (client_socket_types_ &&
- client_socket_index_ < client_socket_index_max_) {
- type = client_socket_types_[client_socket_index_++];
- }
-
- switch (type) {
- case MOCK_CLIENT_SOCKET:
- return scoped_ptr<StreamSocket>(
- new MockClientSocket(addresses, net_log_));
- case MOCK_FAILING_CLIENT_SOCKET:
- return scoped_ptr<StreamSocket>(
- new MockFailingClientSocket(addresses, net_log_));
- case MOCK_PENDING_CLIENT_SOCKET:
- return scoped_ptr<StreamSocket>(
- new MockPendingClientSocket(
- addresses, true, false, base::TimeDelta(), net_log_));
- case MOCK_PENDING_FAILING_CLIENT_SOCKET:
- return scoped_ptr<StreamSocket>(
- new MockPendingClientSocket(
- addresses, false, false, base::TimeDelta(), net_log_));
- case MOCK_DELAYED_CLIENT_SOCKET:
- return scoped_ptr<StreamSocket>(
- new MockPendingClientSocket(
- addresses, true, false, delay_, net_log_));
- case MOCK_STALLED_CLIENT_SOCKET:
- return scoped_ptr<StreamSocket>(
- new MockPendingClientSocket(
- addresses, true, true, base::TimeDelta(), net_log_));
- default:
- NOTREACHED();
- return scoped_ptr<StreamSocket>(
- new MockClientSocket(addresses, net_log_));
- }
- }
-
- virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
- scoped_ptr<ClientSocketHandle> transport_socket,
- const HostPortPair& host_and_port,
- const SSLConfig& ssl_config,
- const SSLClientSocketContext& context) OVERRIDE {
- NOTIMPLEMENTED();
- return scoped_ptr<SSLClientSocket>();
- }
-
- virtual void ClearSSLSessionCache() OVERRIDE {
- NOTIMPLEMENTED();
- }
-
- int allocation_count() const { return allocation_count_; }
-
- // Set the default ClientSocketType.
- void set_client_socket_type(ClientSocketType type) {
- client_socket_type_ = type;
- }
-
- // Set a list of ClientSocketTypes to be used.
- void set_client_socket_types(ClientSocketType* type_list, int num_types) {
- DCHECK_GT(num_types, 0);
- client_socket_types_ = type_list;
- client_socket_index_ = 0;
- client_socket_index_max_ = num_types;
- }
-
- void set_delay(base::TimeDelta delay) { delay_ = delay; }
-
- private:
- NetLog* net_log_;
- int allocation_count_;
- ClientSocketType client_socket_type_;
- ClientSocketType* client_socket_types_;
- int client_socket_index_;
- int client_socket_index_max_;
- base::TimeDelta delay_;
-
- DISALLOW_COPY_AND_ASSIGN(MockClientSocketFactory);
-};
+const RequestPriority kDefaultPriority = LOW;
class TransportClientSocketPoolTest : public testing::Test {
protected:
@@ -447,9 +40,12 @@ class TransportClientSocketPoolTest : public testing::Test {
: connect_backup_jobs_enabled_(
ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(true)),
params_(
- new TransportSocketParams(HostPortPair("www.google.com", 80),
- false, false,
- OnHostResolutionCallback())),
+ new TransportSocketParams(
+ HostPortPair("www.google.com", 80),
+ false,
+ false,
+ OnHostResolutionCallback(),
+ TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)),
histograms_(new ClientSocketPoolHistograms("TCPUnitTest")),
host_resolver_(new MockHostResolver),
client_socket_factory_(&net_log_),
@@ -461,15 +57,22 @@ class TransportClientSocketPoolTest : public testing::Test {
NULL) {
}
- virtual ~TransportClientSocketPoolTest() {
+ ~TransportClientSocketPoolTest() override {
internal::ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(
connect_backup_jobs_enabled_);
}
+ scoped_refptr<TransportSocketParams> CreateParamsForTCPFastOpen() {
+ return new TransportSocketParams(HostPortPair("www.google.com", 80),
+ false, false, OnHostResolutionCallback(),
+ TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DESIRED);
+ }
+
int StartRequest(const std::string& group_name, RequestPriority priority) {
scoped_refptr<TransportSocketParams> params(new TransportSocketParams(
HostPortPair("www.google.com", 80), false, false,
- OnHostResolutionCallback()));
+ OnHostResolutionCallback(),
+ TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT));
return test_base_.StartRequestUsingPool(
&pool_, group_name, priority, params);
}
@@ -494,10 +97,11 @@ class TransportClientSocketPoolTest : public testing::Test {
scoped_refptr<TransportSocketParams> params_;
scoped_ptr<ClientSocketPoolHistograms> histograms_;
scoped_ptr<MockHostResolver> host_resolver_;
- MockClientSocketFactory client_socket_factory_;
+ MockTransportClientSocketFactory client_socket_factory_;
TransportClientSocketPool pool_;
ClientSocketPoolTest test_base_;
+ private:
DISALLOW_COPY_AND_ASSIGN(TransportClientSocketPoolTest);
};
@@ -607,8 +211,8 @@ TEST_F(TransportClientSocketPoolTest, InitHostResolutionFailure) {
ClientSocketHandle handle;
HostPortPair host_port_pair("unresolvable.host.name", 80);
scoped_refptr<TransportSocketParams> dest(new TransportSocketParams(
- host_port_pair, false, false,
- OnHostResolutionCallback()));
+ host_port_pair, false, false, OnHostResolutionCallback(),
+ TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT));
EXPECT_EQ(ERR_IO_PENDING,
handle.Init("a", dest, kDefaultPriority, callback.callback(),
&pool_, BoundNetLog()));
@@ -616,8 +220,8 @@ TEST_F(TransportClientSocketPoolTest, InitHostResolutionFailure) {
}
TEST_F(TransportClientSocketPoolTest, InitConnectionFailure) {
- client_socket_factory_.set_client_socket_type(
- MockClientSocketFactory::MOCK_FAILING_CLIENT_SOCKET);
+ client_socket_factory_.set_default_client_socket_type(
+ MockTransportClientSocketFactory::MOCK_FAILING_CLIENT_SOCKET);
TestCompletionCallback callback;
ClientSocketHandle handle;
EXPECT_EQ(ERR_IO_PENDING,
@@ -760,8 +364,8 @@ TEST_F(TransportClientSocketPoolTest, TwoRequestsCancelOne) {
}
TEST_F(TransportClientSocketPoolTest, ConnectCancelConnect) {
- client_socket_factory_.set_client_socket_type(
- MockClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET);
+ client_socket_factory_.set_default_client_socket_type(
+ MockTransportClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET);
ClientSocketHandle handle;
TestCompletionCallback callback;
EXPECT_EQ(ERR_IO_PENDING,
@@ -861,7 +465,7 @@ class RequestSocketCallback : public TestCompletionCallbackBase {
base::Unretained(this))) {
}
- virtual ~RequestSocketCallback() {}
+ ~RequestSocketCallback() override {}
const CompletionCallback& callback() const { return callback_; }
@@ -883,7 +487,8 @@ class RequestSocketCallback : public TestCompletionCallbackBase {
within_callback_ = true;
scoped_refptr<TransportSocketParams> dest(new TransportSocketParams(
HostPortPair("www.google.com", 80), false, false,
- OnHostResolutionCallback()));
+ OnHostResolutionCallback(),
+ TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT));
int rv = handle_->Init("a", dest, LOWEST, callback(), pool_,
BoundNetLog());
EXPECT_EQ(OK, rv);
@@ -903,7 +508,8 @@ TEST_F(TransportClientSocketPoolTest, RequestTwice) {
RequestSocketCallback callback(&handle, &pool_);
scoped_refptr<TransportSocketParams> dest(new TransportSocketParams(
HostPortPair("www.google.com", 80), false, false,
- OnHostResolutionCallback()));
+ OnHostResolutionCallback(),
+ TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT));
int rv = handle.Init("a", dest, LOWEST, callback.callback(), &pool_,
BoundNetLog());
ASSERT_EQ(ERR_IO_PENDING, rv);
@@ -920,8 +526,8 @@ TEST_F(TransportClientSocketPoolTest, RequestTwice) {
// Make sure that pending requests get serviced after active requests get
// cancelled.
TEST_F(TransportClientSocketPoolTest, CancelActiveRequestWithPendingRequests) {
- client_socket_factory_.set_client_socket_type(
- MockClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET);
+ client_socket_factory_.set_default_client_socket_type(
+ MockTransportClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET);
// Queue up all the requests
EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
@@ -950,8 +556,8 @@ TEST_F(TransportClientSocketPoolTest, CancelActiveRequestWithPendingRequests) {
// Make sure that pending requests get serviced after active requests fail.
TEST_F(TransportClientSocketPoolTest, FailingActiveRequestWithPendingRequests) {
- client_socket_factory_.set_client_socket_type(
- MockClientSocketFactory::MOCK_PENDING_FAILING_CLIENT_SOCKET);
+ client_socket_factory_.set_default_client_socket_type(
+ MockTransportClientSocketFactory::MOCK_PENDING_FAILING_CLIENT_SOCKET);
const int kNumRequests = 2 * kMaxSocketsPerGroup + 1;
ASSERT_LE(kNumRequests, kMaxSockets); // Otherwise the test will hang.
@@ -1022,24 +628,24 @@ TEST_F(TransportClientSocketPoolTest, ResetIdleSocketsOnIPAddressChange) {
TEST_F(TransportClientSocketPoolTest, BackupSocketConnect) {
// Case 1 tests the first socket stalling, and the backup connecting.
- MockClientSocketFactory::ClientSocketType case1_types[] = {
+ MockTransportClientSocketFactory::ClientSocketType case1_types[] = {
// The first socket will not connect.
- MockClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET,
+ MockTransportClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET,
// The second socket will connect more quickly.
- MockClientSocketFactory::MOCK_CLIENT_SOCKET
+ MockTransportClientSocketFactory::MOCK_CLIENT_SOCKET
};
// Case 2 tests the first socket being slow, so that we start the
// second connect, but the second connect stalls, and we still
// complete the first.
- MockClientSocketFactory::ClientSocketType case2_types[] = {
+ MockTransportClientSocketFactory::ClientSocketType case2_types[] = {
// The first socket will connect, although delayed.
- MockClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET,
+ MockTransportClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET,
// The second socket will not connect.
- MockClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET
+ MockTransportClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET
};
- MockClientSocketFactory::ClientSocketType* cases[2] = {
+ MockTransportClientSocketFactory::ClientSocketType* cases[2] = {
case1_types,
case2_types
};
@@ -1083,8 +689,8 @@ TEST_F(TransportClientSocketPoolTest, BackupSocketConnect) {
// Test the case where a socket took long enough to start the creation
// of the backup socket, but then we cancelled the request after that.
TEST_F(TransportClientSocketPoolTest, BackupSocketCancel) {
- client_socket_factory_.set_client_socket_type(
- MockClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET);
+ client_socket_factory_.set_default_client_socket_type(
+ MockTransportClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET);
enum { CANCEL_BEFORE_WAIT, CANCEL_AFTER_WAIT };
@@ -1126,11 +732,11 @@ TEST_F(TransportClientSocketPoolTest, BackupSocketCancel) {
// of the backup socket and never completes, and then the backup
// connection fails.
TEST_F(TransportClientSocketPoolTest, BackupSocketFailAfterStall) {
- MockClientSocketFactory::ClientSocketType case_types[] = {
+ MockTransportClientSocketFactory::ClientSocketType case_types[] = {
// The first socket will not connect.
- MockClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET,
+ MockTransportClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET,
// The second socket will fail immediately.
- MockClientSocketFactory::MOCK_FAILING_CLIENT_SOCKET
+ MockTransportClientSocketFactory::MOCK_FAILING_CLIENT_SOCKET
};
client_socket_factory_.set_client_socket_types(case_types, 2);
@@ -1173,11 +779,11 @@ TEST_F(TransportClientSocketPoolTest, BackupSocketFailAfterStall) {
// of the backup socket and eventually completes, but the backup socket
// fails.
TEST_F(TransportClientSocketPoolTest, BackupSocketFailAfterDelay) {
- MockClientSocketFactory::ClientSocketType case_types[] = {
+ MockTransportClientSocketFactory::ClientSocketType case_types[] = {
// The first socket will connect, although delayed.
- MockClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET,
+ MockTransportClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET,
// The second socket will not connect.
- MockClientSocketFactory::MOCK_FAILING_CLIENT_SOCKET
+ MockTransportClientSocketFactory::MOCK_FAILING_CLIENT_SOCKET
};
client_socket_factory_.set_client_socket_types(case_types, 2);
@@ -1228,11 +834,11 @@ TEST_F(TransportClientSocketPoolTest, IPv6FallbackSocketIPv4FinishesFirst) {
&client_socket_factory_,
NULL);
- MockClientSocketFactory::ClientSocketType case_types[] = {
+ MockTransportClientSocketFactory::ClientSocketType case_types[] = {
// This is the IPv6 socket.
- MockClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET,
+ MockTransportClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET,
// This is the IPv4 socket.
- MockClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET
+ MockTransportClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET
};
client_socket_factory_.set_client_socket_types(case_types, 2);
@@ -1271,16 +877,16 @@ TEST_F(TransportClientSocketPoolTest, IPv6FallbackSocketIPv6FinishesFirst) {
&client_socket_factory_,
NULL);
- MockClientSocketFactory::ClientSocketType case_types[] = {
+ MockTransportClientSocketFactory::ClientSocketType case_types[] = {
// This is the IPv6 socket.
- MockClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET,
+ MockTransportClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET,
// This is the IPv4 socket.
- MockClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET
+ MockTransportClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET
};
client_socket_factory_.set_client_socket_types(case_types, 2);
client_socket_factory_.set_delay(base::TimeDelta::FromMilliseconds(
- TransportConnectJob::kIPv6FallbackTimerInMs + 50));
+ TransportConnectJobHelper::kIPv6FallbackTimerInMs + 50));
// Resolve an AddressList with a IPv6 address first and then a IPv4 address.
host_resolver_->rules()
@@ -1313,8 +919,8 @@ TEST_F(TransportClientSocketPoolTest, IPv6NoIPv4AddressesToFallbackTo) {
&client_socket_factory_,
NULL);
- client_socket_factory_.set_client_socket_type(
- MockClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET);
+ client_socket_factory_.set_default_client_socket_type(
+ MockTransportClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET);
// Resolve an AddressList with only IPv6 addresses.
host_resolver_->rules()
@@ -1347,8 +953,8 @@ TEST_F(TransportClientSocketPoolTest, IPv4HasNoFallback) {
&client_socket_factory_,
NULL);
- client_socket_factory_.set_client_socket_type(
- MockClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET);
+ client_socket_factory_.set_default_client_socket_type(
+ MockTransportClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET);
// Resolve an AddressList with only IPv4 addresses.
host_resolver_->rules()->AddIPLiteralRule("*", "1.1.1.1", std::string());
@@ -1370,6 +976,137 @@ TEST_F(TransportClientSocketPoolTest, IPv4HasNoFallback) {
EXPECT_EQ(1, client_socket_factory_.allocation_count());
}
+// Test that if TCP FastOpen is enabled, it is set on the socket
+// when we have only an IPv4 address.
+TEST_F(TransportClientSocketPoolTest, TCPFastOpenOnIPv4WithNoFallback) {
+ // Create a pool without backup jobs.
+ ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(false);
+ TransportClientSocketPool pool(kMaxSockets,
+ kMaxSocketsPerGroup,
+ histograms_.get(),
+ host_resolver_.get(),
+ &client_socket_factory_,
+ NULL);
+ client_socket_factory_.set_default_client_socket_type(
+ MockTransportClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET);
+ // Resolve an AddressList with only IPv4 addresses.
+ host_resolver_->rules()->AddIPLiteralRule("*", "1.1.1.1", std::string());
+
+ TestCompletionCallback callback;
+ ClientSocketHandle handle;
+ // Enable TCP FastOpen in TransportSocketParams.
+ scoped_refptr<TransportSocketParams> params = CreateParamsForTCPFastOpen();
+ handle.Init("a", params, LOW, callback.callback(), &pool, BoundNetLog());
+ EXPECT_EQ(OK, callback.WaitForResult());
+ EXPECT_TRUE(handle.socket()->UsingTCPFastOpen());
+}
+
+// Test that if TCP FastOpen is enabled, it is set on the socket
+// when we have only IPv6 addresses.
+TEST_F(TransportClientSocketPoolTest, TCPFastOpenOnIPv6WithNoFallback) {
+ // Create a pool without backup jobs.
+ ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(false);
+ TransportClientSocketPool pool(kMaxSockets,
+ kMaxSocketsPerGroup,
+ histograms_.get(),
+ host_resolver_.get(),
+ &client_socket_factory_,
+ NULL);
+ client_socket_factory_.set_default_client_socket_type(
+ MockTransportClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET);
+ // Resolve an AddressList with only IPv6 addresses.
+ host_resolver_->rules()
+ ->AddIPLiteralRule("*", "2:abcd::3:4:ff,3:abcd::3:4:ff", std::string());
+
+ TestCompletionCallback callback;
+ ClientSocketHandle handle;
+ // Enable TCP FastOpen in TransportSocketParams.
+ scoped_refptr<TransportSocketParams> params = CreateParamsForTCPFastOpen();
+ handle.Init("a", params, LOW, callback.callback(), &pool, BoundNetLog());
+ EXPECT_EQ(OK, callback.WaitForResult());
+ EXPECT_TRUE(handle.socket()->UsingTCPFastOpen());
+}
+
+// Test that if TCP FastOpen is enabled, it does not do anything when there
+// is a IPv6 address with fallback to an IPv4 address. This test tests the case
+// when the IPv6 connect fails and the IPv4 one succeeds.
+TEST_F(TransportClientSocketPoolTest,
+ NoTCPFastOpenOnIPv6FailureWithIPv4Fallback) {
+ // Create a pool without backup jobs.
+ ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(false);
+ TransportClientSocketPool pool(kMaxSockets,
+ kMaxSocketsPerGroup,
+ histograms_.get(),
+ host_resolver_.get(),
+ &client_socket_factory_,
+ NULL);
+
+ MockTransportClientSocketFactory::ClientSocketType case_types[] = {
+ // This is the IPv6 socket.
+ MockTransportClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET,
+ // This is the IPv4 socket.
+ MockTransportClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET
+ };
+ client_socket_factory_.set_client_socket_types(case_types, 2);
+ // Resolve an AddressList with a IPv6 address first and then a IPv4 address.
+ host_resolver_->rules()
+ ->AddIPLiteralRule("*", "2:abcd::3:4:ff,2.2.2.2", std::string());
+
+ TestCompletionCallback callback;
+ ClientSocketHandle handle;
+ // Enable TCP FastOpen in TransportSocketParams.
+ scoped_refptr<TransportSocketParams> params = CreateParamsForTCPFastOpen();
+ handle.Init("a", params, LOW, callback.callback(), &pool, BoundNetLog());
+ EXPECT_EQ(OK, callback.WaitForResult());
+ // Verify that the socket used is connected to the fallback IPv4 address.
+ IPEndPoint endpoint;
+ handle.socket()->GetLocalAddress(&endpoint);
+ EXPECT_EQ(kIPv4AddressSize, endpoint.address().size());
+ EXPECT_EQ(2, client_socket_factory_.allocation_count());
+ // Verify that TCP FastOpen was not turned on for the socket.
+ EXPECT_FALSE(handle.socket()->UsingTCPFastOpen());
+}
+
+// Test that if TCP FastOpen is enabled, it does not do anything when there
+// is a IPv6 address with fallback to an IPv4 address. This test tests the case
+// when the IPv6 connect succeeds.
+TEST_F(TransportClientSocketPoolTest,
+ NoTCPFastOpenOnIPv6SuccessWithIPv4Fallback) {
+ // Create a pool without backup jobs.
+ ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(false);
+ TransportClientSocketPool pool(kMaxSockets,
+ kMaxSocketsPerGroup,
+ histograms_.get(),
+ host_resolver_.get(),
+ &client_socket_factory_,
+ NULL);
+
+ MockTransportClientSocketFactory::ClientSocketType case_types[] = {
+ // This is the IPv6 socket.
+ MockTransportClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET,
+ // This is the IPv4 socket.
+ MockTransportClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET
+ };
+ client_socket_factory_.set_client_socket_types(case_types, 2);
+ // Resolve an AddressList with a IPv6 address first and then a IPv4 address.
+ host_resolver_->rules()
+ ->AddIPLiteralRule("*", "2:abcd::3:4:ff,2.2.2.2", std::string());
+
+ TestCompletionCallback callback;
+ ClientSocketHandle handle;
+ // Enable TCP FastOpen in TransportSocketParams.
+ scoped_refptr<TransportSocketParams> params = CreateParamsForTCPFastOpen();
+ handle.Init("a", params, LOW, callback.callback(), &pool, BoundNetLog());
+ EXPECT_EQ(OK, callback.WaitForResult());
+ // Verify that the socket used is connected to the IPv6 address.
+ IPEndPoint endpoint;
+ handle.socket()->GetLocalAddress(&endpoint);
+ EXPECT_EQ(kIPv6AddressSize, endpoint.address().size());
+ EXPECT_EQ(1, client_socket_factory_.allocation_count());
+ // Verify that TCP FastOpen was not turned on for the socket.
+ EXPECT_FALSE(handle.socket()->UsingTCPFastOpen());
+}
+
} // namespace
} // namespace net
diff --git a/chromium/net/socket/transport_client_socket_unittest.cc b/chromium/net/socket/transport_client_socket_unittest.cc
index 5548b27b995..d01cbad6dc8 100644
--- a/chromium/net/socket/transport_client_socket_unittest.cc
+++ b/chromium/net/socket/transport_client_socket_unittest.cc
@@ -47,22 +47,22 @@ class TransportClientSocketTest
}
// Implement StreamListenSocket::Delegate methods
- virtual void DidAccept(StreamListenSocket* server,
- scoped_ptr<StreamListenSocket> connection) OVERRIDE {
+ void DidAccept(StreamListenSocket* server,
+ scoped_ptr<StreamListenSocket> connection) override {
connected_sock_.reset(
static_cast<TCPListenSocket*>(connection.release()));
}
- virtual void DidRead(StreamListenSocket*, const char* str, int len) OVERRIDE {
+ void DidRead(StreamListenSocket*, const char* str, int len) override {
// TODO(dkegel): this might not be long enough to tickle some bugs.
connected_sock_->Send(kServerReply, arraysize(kServerReply) - 1,
false /* Don't append line feed */);
if (close_server_socket_on_next_send_)
CloseServerSocket();
}
- virtual void DidClose(StreamListenSocket* sock) OVERRIDE {}
+ void DidClose(StreamListenSocket* sock) override {}
// Testcase hooks
- virtual void SetUp();
+ void SetUp() override;
void CloseServerSocket() {
// delete the connected_sock_, which will close it.
diff --git a/chromium/net/socket/unix_domain_client_socket_posix.cc b/chromium/net/socket/unix_domain_client_socket_posix.cc
new file mode 100644
index 00000000000..5adbca9979e
--- /dev/null
+++ b/chromium/net/socket/unix_domain_client_socket_posix.cc
@@ -0,0 +1,171 @@
+// Copyright 2014 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/unix_domain_client_socket_posix.h"
+
+#include <sys/socket.h>
+#include <sys/un.h>
+
+#include "base/logging.h"
+#include "base/posix/eintr_wrapper.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_util.h"
+#include "net/socket/socket_libevent.h"
+
+namespace net {
+
+UnixDomainClientSocket::UnixDomainClientSocket(const std::string& socket_path,
+ bool use_abstract_namespace)
+ : socket_path_(socket_path),
+ use_abstract_namespace_(use_abstract_namespace) {
+}
+
+UnixDomainClientSocket::UnixDomainClientSocket(
+ scoped_ptr<SocketLibevent> socket)
+ : use_abstract_namespace_(false),
+ socket_(socket.Pass()) {
+}
+
+UnixDomainClientSocket::~UnixDomainClientSocket() {
+ Disconnect();
+}
+
+// static
+bool UnixDomainClientSocket::FillAddress(const std::string& socket_path,
+ bool use_abstract_namespace,
+ SockaddrStorage* address) {
+ struct sockaddr_un* socket_addr =
+ reinterpret_cast<struct sockaddr_un*>(address->addr);
+ size_t path_max = address->addr_len - offsetof(struct sockaddr_un, sun_path);
+ // Non abstract namespace pathname should be null-terminated. Abstract
+ // namespace pathname must start with '\0'. So, the size is always greater
+ // than socket_path size by 1.
+ size_t path_size = socket_path.size() + 1;
+ if (path_size > path_max)
+ return false;
+
+ memset(socket_addr, 0, address->addr_len);
+ socket_addr->sun_family = AF_UNIX;
+ address->addr_len = path_size + offsetof(struct sockaddr_un, sun_path);
+ if (!use_abstract_namespace) {
+ memcpy(socket_addr->sun_path, socket_path.c_str(), socket_path.size());
+ return true;
+ }
+
+#if defined(OS_ANDROID) || defined(OS_LINUX)
+ // Convert the path given into abstract socket name. It must start with
+ // the '\0' character, so we are adding it. |addr_len| must specify the
+ // length of the structure exactly, as potentially the socket name may
+ // have '\0' characters embedded (although we don't support this).
+ // Note that addr.sun_path is already zero initialized.
+ memcpy(socket_addr->sun_path + 1, socket_path.c_str(), socket_path.size());
+ return true;
+#else
+ return false;
+#endif
+}
+
+int UnixDomainClientSocket::Connect(const CompletionCallback& callback) {
+ DCHECK(!socket_);
+
+ if (socket_path_.empty())
+ return ERR_ADDRESS_INVALID;
+
+ SockaddrStorage address;
+ if (!FillAddress(socket_path_, use_abstract_namespace_, &address))
+ return ERR_ADDRESS_INVALID;
+
+ socket_.reset(new SocketLibevent);
+ int rv = socket_->Open(AF_UNIX);
+ DCHECK_NE(ERR_IO_PENDING, rv);
+ if (rv != OK)
+ return rv;
+
+ return socket_->Connect(address, callback);
+}
+
+void UnixDomainClientSocket::Disconnect() {
+ socket_.reset();
+}
+
+bool UnixDomainClientSocket::IsConnected() const {
+ return socket_ && socket_->IsConnected();
+}
+
+bool UnixDomainClientSocket::IsConnectedAndIdle() const {
+ return socket_ && socket_->IsConnectedAndIdle();
+}
+
+int UnixDomainClientSocket::GetPeerAddress(IPEndPoint* address) const {
+ NOTIMPLEMENTED();
+ return ERR_NOT_IMPLEMENTED;
+}
+
+int UnixDomainClientSocket::GetLocalAddress(IPEndPoint* address) const {
+ NOTIMPLEMENTED();
+ return ERR_NOT_IMPLEMENTED;
+}
+
+const BoundNetLog& UnixDomainClientSocket::NetLog() const {
+ return net_log_;
+}
+
+void UnixDomainClientSocket::SetSubresourceSpeculation() {
+}
+
+void UnixDomainClientSocket::SetOmniboxSpeculation() {
+}
+
+bool UnixDomainClientSocket::WasEverUsed() const {
+ return true; // We don't care.
+}
+
+bool UnixDomainClientSocket::UsingTCPFastOpen() const {
+ return false;
+}
+
+bool UnixDomainClientSocket::WasNpnNegotiated() const {
+ return false;
+}
+
+NextProto UnixDomainClientSocket::GetNegotiatedProtocol() const {
+ return kProtoUnknown;
+}
+
+bool UnixDomainClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
+ return false;
+}
+
+int UnixDomainClientSocket::Read(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) {
+ DCHECK(socket_);
+ return socket_->Read(buf, buf_len, callback);
+}
+
+int UnixDomainClientSocket::Write(IOBuffer* buf, int buf_len,
+ const CompletionCallback& callback) {
+ DCHECK(socket_);
+ return socket_->Write(buf, buf_len, callback);
+}
+
+int UnixDomainClientSocket::SetReceiveBufferSize(int32 size) {
+ NOTIMPLEMENTED();
+ return ERR_NOT_IMPLEMENTED;
+}
+
+int UnixDomainClientSocket::SetSendBufferSize(int32 size) {
+ NOTIMPLEMENTED();
+ return ERR_NOT_IMPLEMENTED;
+}
+
+SocketDescriptor UnixDomainClientSocket::ReleaseConnectedSocket() {
+ DCHECK(socket_);
+ DCHECK(socket_->IsConnected());
+
+ SocketDescriptor socket_fd = socket_->ReleaseConnectedSocket();
+ socket_.reset();
+ return socket_fd;
+}
+
+} // namespace net
diff --git a/chromium/net/socket/unix_domain_client_socket_posix.h b/chromium/net/socket/unix_domain_client_socket_posix.h
new file mode 100644
index 00000000000..2a8bdb625c9
--- /dev/null
+++ b/chromium/net/socket/unix_domain_client_socket_posix.h
@@ -0,0 +1,87 @@
+// Copyright 2014 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_UNIX_DOMAIN_CLIENT_SOCKET_POSIX_H_
+#define NET_SOCKET_UNIX_DOMAIN_CLIENT_SOCKET_POSIX_H_
+
+#include <string>
+
+#include "base/basictypes.h"
+#include "base/macros.h"
+#include "base/memory/scoped_ptr.h"
+#include "net/base/completion_callback.h"
+#include "net/base/net_export.h"
+#include "net/base/net_log.h"
+#include "net/socket/socket_descriptor.h"
+#include "net/socket/stream_socket.h"
+
+namespace net {
+
+class SocketLibevent;
+struct SockaddrStorage;
+
+// A client socket that uses unix domain socket as the transport layer.
+class NET_EXPORT UnixDomainClientSocket : public StreamSocket {
+ public:
+ // Builds a client socket with |socket_path|. The caller should call Connect()
+ // to connect to a server socket.
+ UnixDomainClientSocket(const std::string& socket_path,
+ bool use_abstract_namespace);
+ // Builds a client socket with socket libevent which is already connected.
+ // UnixDomainServerSocket uses this after it accepts a connection.
+ explicit UnixDomainClientSocket(scoped_ptr<SocketLibevent> socket);
+
+ ~UnixDomainClientSocket() override;
+
+ // Fills |address| with |socket_path| and its length. For Android or Linux
+ // platform, this supports abstract namespaces.
+ static bool FillAddress(const std::string& socket_path,
+ bool use_abstract_namespace,
+ SockaddrStorage* address);
+
+ // StreamSocket implementation.
+ int Connect(const CompletionCallback& callback) override;
+ void Disconnect() override;
+ bool IsConnected() const override;
+ bool IsConnectedAndIdle() const override;
+ int GetPeerAddress(IPEndPoint* address) const override;
+ int GetLocalAddress(IPEndPoint* address) const override;
+ const BoundNetLog& NetLog() const override;
+ void SetSubresourceSpeculation() override;
+ void SetOmniboxSpeculation() override;
+ bool WasEverUsed() const override;
+ bool UsingTCPFastOpen() const override;
+ bool WasNpnNegotiated() const override;
+ NextProto GetNegotiatedProtocol() const override;
+ bool GetSSLInfo(SSLInfo* ssl_info) override;
+
+ // Socket implementation.
+ int Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override;
+ int Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override;
+ int SetReceiveBufferSize(int32 size) override;
+ int SetSendBufferSize(int32 size) override;
+
+ // Releases ownership of underlying SocketDescriptor to caller.
+ // Internal state is reset so that this object can be used again.
+ // Socket must be connected in order to release it.
+ SocketDescriptor ReleaseConnectedSocket();
+
+ private:
+ const std::string socket_path_;
+ const bool use_abstract_namespace_;
+ scoped_ptr<SocketLibevent> socket_;
+ // This net log is just to comply StreamSocket::NetLog(). It throws away
+ // everything.
+ BoundNetLog net_log_;
+
+ DISALLOW_COPY_AND_ASSIGN(UnixDomainClientSocket);
+};
+
+} // namespace net
+
+#endif // NET_SOCKET_UNIX_DOMAIN_CLIENT_SOCKET_POSIX_H_
diff --git a/chromium/net/socket/unix_domain_client_socket_posix_unittest.cc b/chromium/net/socket/unix_domain_client_socket_posix_unittest.cc
new file mode 100644
index 00000000000..651cd72dd5b
--- /dev/null
+++ b/chromium/net/socket/unix_domain_client_socket_posix_unittest.cc
@@ -0,0 +1,446 @@
+// Copyright 2014 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/unix_domain_client_socket_posix.h"
+
+#include <unistd.h>
+
+#include "base/bind.h"
+#include "base/files/file_path.h"
+#include "base/files/scoped_temp_dir.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/posix/eintr_wrapper.h"
+#include "net/base/io_buffer.h"
+#include "net/base/net_errors.h"
+#include "net/base/test_completion_callback.h"
+#include "net/socket/socket_libevent.h"
+#include "net/socket/unix_domain_server_socket_posix.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace net {
+namespace {
+
+const char kSocketFilename[] = "socket_for_testing";
+
+bool UserCanConnectCallback(
+ bool allow_user, const UnixDomainServerSocket::Credentials& credentials) {
+ // Here peers are running in same process.
+#if defined(OS_LINUX) || defined(OS_ANDROID)
+ EXPECT_EQ(getpid(), credentials.process_id);
+#endif
+ EXPECT_EQ(getuid(), credentials.user_id);
+ EXPECT_EQ(getgid(), credentials.group_id);
+ return allow_user;
+}
+
+UnixDomainServerSocket::AuthCallback CreateAuthCallback(bool allow_user) {
+ return base::Bind(&UserCanConnectCallback, allow_user);
+}
+
+// Connects socket synchronously.
+int ConnectSynchronously(StreamSocket* socket) {
+ TestCompletionCallback connect_callback;
+ int rv = socket->Connect(connect_callback.callback());
+ if (rv == ERR_IO_PENDING)
+ rv = connect_callback.WaitForResult();
+ return rv;
+}
+
+// Reads data from |socket| until it fills |buf| at least up to |min_data_len|.
+// Returns length of data read, or a net error.
+int ReadSynchronously(StreamSocket* socket,
+ IOBuffer* buf,
+ int buf_len,
+ int min_data_len) {
+ DCHECK_LE(min_data_len, buf_len);
+ scoped_refptr<DrainableIOBuffer> read_buf(
+ new DrainableIOBuffer(buf, buf_len));
+ TestCompletionCallback read_callback;
+ // Iterate reading several times (but not infinite) until it reads at least
+ // |min_data_len| bytes into |buf|.
+ for (int retry_count = 10;
+ retry_count > 0 && (read_buf->BytesConsumed() < min_data_len ||
+ // Try at least once when min_data_len == 0.
+ min_data_len == 0);
+ --retry_count) {
+ int rv = socket->Read(
+ read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
+ EXPECT_GE(read_buf->BytesRemaining(), rv);
+ if (rv == ERR_IO_PENDING) {
+ // If |min_data_len| is 0, returns ERR_IO_PENDING to distinguish the case
+ // when some data has been read.
+ if (min_data_len == 0) {
+ // No data has been read because of for-loop condition.
+ DCHECK_EQ(0, read_buf->BytesConsumed());
+ return ERR_IO_PENDING;
+ }
+ rv = read_callback.WaitForResult();
+ }
+ EXPECT_NE(ERR_IO_PENDING, rv);
+ if (rv < 0)
+ return rv;
+ read_buf->DidConsume(rv);
+ }
+ EXPECT_LE(0, read_buf->BytesRemaining());
+ return read_buf->BytesConsumed();
+}
+
+// Writes data to |socket| until it completes writing |buf| up to |buf_len|.
+// Returns length of data written, or a net error.
+int WriteSynchronously(StreamSocket* socket,
+ IOBuffer* buf,
+ int buf_len) {
+ scoped_refptr<DrainableIOBuffer> write_buf(
+ new DrainableIOBuffer(buf, buf_len));
+ TestCompletionCallback write_callback;
+ // Iterate writing several times (but not infinite) until it writes buf fully.
+ for (int retry_count = 10;
+ retry_count > 0 && write_buf->BytesRemaining() > 0;
+ --retry_count) {
+ int rv = socket->Write(write_buf.get(),
+ write_buf->BytesRemaining(),
+ write_callback.callback());
+ EXPECT_GE(write_buf->BytesRemaining(), rv);
+ if (rv == ERR_IO_PENDING)
+ rv = write_callback.WaitForResult();
+ EXPECT_NE(ERR_IO_PENDING, rv);
+ if (rv < 0)
+ return rv;
+ write_buf->DidConsume(rv);
+ }
+ EXPECT_LE(0, write_buf->BytesRemaining());
+ return write_buf->BytesConsumed();
+}
+
+class UnixDomainClientSocketTest : public testing::Test {
+ protected:
+ UnixDomainClientSocketTest() {
+ EXPECT_TRUE(temp_dir_.CreateUniqueTempDir());
+ socket_path_ = temp_dir_.path().Append(kSocketFilename).value();
+ }
+
+ base::ScopedTempDir temp_dir_;
+ std::string socket_path_;
+};
+
+TEST_F(UnixDomainClientSocketTest, Connect) {
+ const bool kUseAbstractNamespace = false;
+
+ UnixDomainServerSocket server_socket(CreateAuthCallback(true),
+ kUseAbstractNamespace);
+ EXPECT_EQ(OK, server_socket.ListenWithAddressAndPort(socket_path_, 0, 1));
+
+ scoped_ptr<StreamSocket> accepted_socket;
+ TestCompletionCallback accept_callback;
+ EXPECT_EQ(ERR_IO_PENDING,
+ server_socket.Accept(&accepted_socket, accept_callback.callback()));
+ EXPECT_FALSE(accepted_socket);
+
+ UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
+ EXPECT_FALSE(client_socket.IsConnected());
+
+ EXPECT_EQ(OK, ConnectSynchronously(&client_socket));
+ EXPECT_TRUE(client_socket.IsConnected());
+ // Server has not yet been notified of the connection.
+ EXPECT_FALSE(accepted_socket);
+
+ EXPECT_EQ(OK, accept_callback.WaitForResult());
+ EXPECT_TRUE(accepted_socket);
+ EXPECT_TRUE(accepted_socket->IsConnected());
+}
+
+TEST_F(UnixDomainClientSocketTest, ConnectWithSocketDescriptor) {
+ const bool kUseAbstractNamespace = false;
+
+ UnixDomainServerSocket server_socket(CreateAuthCallback(true),
+ kUseAbstractNamespace);
+ EXPECT_EQ(OK, server_socket.ListenWithAddressAndPort(socket_path_, 0, 1));
+
+ SocketDescriptor accepted_socket_fd = kInvalidSocket;
+ TestCompletionCallback accept_callback;
+ EXPECT_EQ(ERR_IO_PENDING,
+ server_socket.AcceptSocketDescriptor(&accepted_socket_fd,
+ accept_callback.callback()));
+ EXPECT_EQ(kInvalidSocket, accepted_socket_fd);
+
+ UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
+ EXPECT_FALSE(client_socket.IsConnected());
+
+ EXPECT_EQ(OK, ConnectSynchronously(&client_socket));
+ EXPECT_TRUE(client_socket.IsConnected());
+ // Server has not yet been notified of the connection.
+ EXPECT_EQ(kInvalidSocket, accepted_socket_fd);
+
+ EXPECT_EQ(OK, accept_callback.WaitForResult());
+ EXPECT_NE(kInvalidSocket, accepted_socket_fd);
+
+ SocketDescriptor client_socket_fd = client_socket.ReleaseConnectedSocket();
+ EXPECT_NE(kInvalidSocket, client_socket_fd);
+
+ // Now, re-wrap client_socket_fd in a UnixDomainClientSocket and try a read
+ // to be sure it hasn't gotten accidentally closed.
+ SockaddrStorage addr;
+ ASSERT_TRUE(UnixDomainClientSocket::FillAddress(socket_path_, false, &addr));
+ scoped_ptr<SocketLibevent> adopter(new SocketLibevent);
+ adopter->AdoptConnectedSocket(client_socket_fd, addr);
+ UnixDomainClientSocket rewrapped_socket(adopter.Pass());
+ EXPECT_TRUE(rewrapped_socket.IsConnected());
+
+ // Try to read data.
+ const int kReadDataSize = 10;
+ scoped_refptr<IOBuffer> read_buffer(new IOBuffer(kReadDataSize));
+ TestCompletionCallback read_callback;
+ EXPECT_EQ(ERR_IO_PENDING,
+ rewrapped_socket.Read(
+ read_buffer.get(), kReadDataSize, read_callback.callback()));
+
+ EXPECT_EQ(0, IGNORE_EINTR(close(accepted_socket_fd)));
+}
+
+TEST_F(UnixDomainClientSocketTest, ConnectWithAbstractNamespace) {
+ const bool kUseAbstractNamespace = true;
+
+ UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
+ EXPECT_FALSE(client_socket.IsConnected());
+
+#if defined(OS_ANDROID) || defined(OS_LINUX)
+ UnixDomainServerSocket server_socket(CreateAuthCallback(true),
+ kUseAbstractNamespace);
+ EXPECT_EQ(OK, server_socket.ListenWithAddressAndPort(socket_path_, 0, 1));
+
+ scoped_ptr<StreamSocket> accepted_socket;
+ TestCompletionCallback accept_callback;
+ EXPECT_EQ(ERR_IO_PENDING,
+ server_socket.Accept(&accepted_socket, accept_callback.callback()));
+ EXPECT_FALSE(accepted_socket);
+
+ EXPECT_EQ(OK, ConnectSynchronously(&client_socket));
+ EXPECT_TRUE(client_socket.IsConnected());
+ // Server has not yet beend notified of the connection.
+ EXPECT_FALSE(accepted_socket);
+
+ EXPECT_EQ(OK, accept_callback.WaitForResult());
+ EXPECT_TRUE(accepted_socket);
+ EXPECT_TRUE(accepted_socket->IsConnected());
+#else
+ EXPECT_EQ(ERR_ADDRESS_INVALID, ConnectSynchronously(&client_socket));
+#endif
+}
+
+TEST_F(UnixDomainClientSocketTest, ConnectToNonExistentSocket) {
+ const bool kUseAbstractNamespace = false;
+
+ UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
+ EXPECT_FALSE(client_socket.IsConnected());
+ EXPECT_EQ(ERR_FILE_NOT_FOUND, ConnectSynchronously(&client_socket));
+}
+
+TEST_F(UnixDomainClientSocketTest,
+ ConnectToNonExistentSocketWithAbstractNamespace) {
+ const bool kUseAbstractNamespace = true;
+
+ UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
+ EXPECT_FALSE(client_socket.IsConnected());
+
+ TestCompletionCallback connect_callback;
+#if defined(OS_ANDROID) || defined(OS_LINUX)
+ EXPECT_EQ(ERR_CONNECTION_REFUSED, ConnectSynchronously(&client_socket));
+#else
+ EXPECT_EQ(ERR_ADDRESS_INVALID, ConnectSynchronously(&client_socket));
+#endif
+}
+
+TEST_F(UnixDomainClientSocketTest, DisconnectFromClient) {
+ UnixDomainServerSocket server_socket(CreateAuthCallback(true), false);
+ EXPECT_EQ(OK, server_socket.ListenWithAddressAndPort(socket_path_, 0, 1));
+ scoped_ptr<StreamSocket> accepted_socket;
+ TestCompletionCallback accept_callback;
+ EXPECT_EQ(ERR_IO_PENDING,
+ server_socket.Accept(&accepted_socket, accept_callback.callback()));
+ UnixDomainClientSocket client_socket(socket_path_, false);
+ EXPECT_EQ(OK, ConnectSynchronously(&client_socket));
+
+ EXPECT_EQ(OK, accept_callback.WaitForResult());
+ EXPECT_TRUE(accepted_socket->IsConnected());
+ EXPECT_TRUE(client_socket.IsConnected());
+
+ // Try to read data.
+ const int kReadDataSize = 10;
+ scoped_refptr<IOBuffer> read_buffer(new IOBuffer(kReadDataSize));
+ TestCompletionCallback read_callback;
+ EXPECT_EQ(ERR_IO_PENDING,
+ accepted_socket->Read(
+ read_buffer.get(), kReadDataSize, read_callback.callback()));
+
+ // Disconnect from client side.
+ client_socket.Disconnect();
+ EXPECT_FALSE(client_socket.IsConnected());
+ EXPECT_FALSE(accepted_socket->IsConnected());
+
+ // Connection closed by peer.
+ EXPECT_EQ(0 /* EOF */, read_callback.WaitForResult());
+ // Note that read callback won't be called when the connection is closed
+ // locally before the peer closes it. SocketLibevent just clears callbacks.
+}
+
+TEST_F(UnixDomainClientSocketTest, DisconnectFromServer) {
+ UnixDomainServerSocket server_socket(CreateAuthCallback(true), false);
+ EXPECT_EQ(OK, server_socket.ListenWithAddressAndPort(socket_path_, 0, 1));
+ scoped_ptr<StreamSocket> accepted_socket;
+ TestCompletionCallback accept_callback;
+ EXPECT_EQ(ERR_IO_PENDING,
+ server_socket.Accept(&accepted_socket, accept_callback.callback()));
+ UnixDomainClientSocket client_socket(socket_path_, false);
+ EXPECT_EQ(OK, ConnectSynchronously(&client_socket));
+
+ EXPECT_EQ(OK, accept_callback.WaitForResult());
+ EXPECT_TRUE(accepted_socket->IsConnected());
+ EXPECT_TRUE(client_socket.IsConnected());
+
+ // Try to read data.
+ const int kReadDataSize = 10;
+ scoped_refptr<IOBuffer> read_buffer(new IOBuffer(kReadDataSize));
+ TestCompletionCallback read_callback;
+ EXPECT_EQ(ERR_IO_PENDING,
+ client_socket.Read(
+ read_buffer.get(), kReadDataSize, read_callback.callback()));
+
+ // Disconnect from server side.
+ accepted_socket->Disconnect();
+ EXPECT_FALSE(accepted_socket->IsConnected());
+ EXPECT_FALSE(client_socket.IsConnected());
+
+ // Connection closed by peer.
+ EXPECT_EQ(0 /* EOF */, read_callback.WaitForResult());
+ // Note that read callback won't be called when the connection is closed
+ // locally before the peer closes it. SocketLibevent just clears callbacks.
+}
+
+TEST_F(UnixDomainClientSocketTest, ReadAfterWrite) {
+ UnixDomainServerSocket server_socket(CreateAuthCallback(true), false);
+ EXPECT_EQ(OK, server_socket.ListenWithAddressAndPort(socket_path_, 0, 1));
+ scoped_ptr<StreamSocket> accepted_socket;
+ TestCompletionCallback accept_callback;
+ EXPECT_EQ(ERR_IO_PENDING,
+ server_socket.Accept(&accepted_socket, accept_callback.callback()));
+ UnixDomainClientSocket client_socket(socket_path_, false);
+ EXPECT_EQ(OK, ConnectSynchronously(&client_socket));
+
+ EXPECT_EQ(OK, accept_callback.WaitForResult());
+ EXPECT_TRUE(accepted_socket->IsConnected());
+ EXPECT_TRUE(client_socket.IsConnected());
+
+ // Send data from client to server.
+ const int kWriteDataSize = 10;
+ scoped_refptr<IOBuffer> write_buffer(
+ new StringIOBuffer(std::string(kWriteDataSize, 'd')));
+ EXPECT_EQ(
+ kWriteDataSize,
+ WriteSynchronously(&client_socket, write_buffer.get(), kWriteDataSize));
+
+ // The buffer is bigger than write data size.
+ const int kReadBufferSize = kWriteDataSize * 2;
+ scoped_refptr<IOBuffer> read_buffer(new IOBuffer(kReadBufferSize));
+ EXPECT_EQ(kWriteDataSize,
+ ReadSynchronously(accepted_socket.get(),
+ read_buffer.get(),
+ kReadBufferSize,
+ kWriteDataSize));
+ EXPECT_EQ(std::string(write_buffer->data(), kWriteDataSize),
+ std::string(read_buffer->data(), kWriteDataSize));
+
+ // Send data from server and client.
+ EXPECT_EQ(kWriteDataSize,
+ WriteSynchronously(
+ accepted_socket.get(), write_buffer.get(), kWriteDataSize));
+
+ // Read multiple times.
+ const int kSmallReadBufferSize = kWriteDataSize / 3;
+ EXPECT_EQ(kSmallReadBufferSize,
+ ReadSynchronously(&client_socket,
+ read_buffer.get(),
+ kSmallReadBufferSize,
+ kSmallReadBufferSize));
+ EXPECT_EQ(std::string(write_buffer->data(), kSmallReadBufferSize),
+ std::string(read_buffer->data(), kSmallReadBufferSize));
+
+ EXPECT_EQ(kWriteDataSize - kSmallReadBufferSize,
+ ReadSynchronously(&client_socket,
+ read_buffer.get(),
+ kReadBufferSize,
+ kWriteDataSize - kSmallReadBufferSize));
+ EXPECT_EQ(std::string(write_buffer->data() + kSmallReadBufferSize,
+ kWriteDataSize - kSmallReadBufferSize),
+ std::string(read_buffer->data(),
+ kWriteDataSize - kSmallReadBufferSize));
+
+ // No more data.
+ EXPECT_EQ(
+ ERR_IO_PENDING,
+ ReadSynchronously(&client_socket, read_buffer.get(), kReadBufferSize, 0));
+
+ // Disconnect from server side after read-write.
+ accepted_socket->Disconnect();
+ EXPECT_FALSE(accepted_socket->IsConnected());
+ EXPECT_FALSE(client_socket.IsConnected());
+}
+
+TEST_F(UnixDomainClientSocketTest, ReadBeforeWrite) {
+ UnixDomainServerSocket server_socket(CreateAuthCallback(true), false);
+ EXPECT_EQ(OK, server_socket.ListenWithAddressAndPort(socket_path_, 0, 1));
+ scoped_ptr<StreamSocket> accepted_socket;
+ TestCompletionCallback accept_callback;
+ EXPECT_EQ(ERR_IO_PENDING,
+ server_socket.Accept(&accepted_socket, accept_callback.callback()));
+ UnixDomainClientSocket client_socket(socket_path_, false);
+ EXPECT_EQ(OK, ConnectSynchronously(&client_socket));
+
+ EXPECT_EQ(OK, accept_callback.WaitForResult());
+ EXPECT_TRUE(accepted_socket->IsConnected());
+ EXPECT_TRUE(client_socket.IsConnected());
+
+ // Wait for data from client.
+ const int kWriteDataSize = 10;
+ const int kReadBufferSize = kWriteDataSize * 2;
+ const int kSmallReadBufferSize = kWriteDataSize / 3;
+ // Read smaller than write data size first.
+ scoped_refptr<IOBuffer> read_buffer(new IOBuffer(kReadBufferSize));
+ TestCompletionCallback read_callback;
+ EXPECT_EQ(
+ ERR_IO_PENDING,
+ accepted_socket->Read(
+ read_buffer.get(), kSmallReadBufferSize, read_callback.callback()));
+
+ scoped_refptr<IOBuffer> write_buffer(
+ new StringIOBuffer(std::string(kWriteDataSize, 'd')));
+ EXPECT_EQ(
+ kWriteDataSize,
+ WriteSynchronously(&client_socket, write_buffer.get(), kWriteDataSize));
+
+ // First read completed.
+ int rv = read_callback.WaitForResult();
+ EXPECT_LT(0, rv);
+ EXPECT_LE(rv, kSmallReadBufferSize);
+
+ // Read remaining data.
+ const int kExpectedRemainingDataSize = kWriteDataSize - rv;
+ EXPECT_LE(0, kExpectedRemainingDataSize);
+ EXPECT_EQ(kExpectedRemainingDataSize,
+ ReadSynchronously(accepted_socket.get(),
+ read_buffer.get(),
+ kReadBufferSize,
+ kExpectedRemainingDataSize));
+ // No more data.
+ EXPECT_EQ(ERR_IO_PENDING,
+ ReadSynchronously(
+ accepted_socket.get(), read_buffer.get(), kReadBufferSize, 0));
+
+ // Disconnect from server side after read-write.
+ accepted_socket->Disconnect();
+ EXPECT_FALSE(accepted_socket->IsConnected());
+ EXPECT_FALSE(client_socket.IsConnected());
+}
+
+} // namespace
+} // namespace net
diff --git a/chromium/net/socket/unix_domain_listen_socket_posix.cc b/chromium/net/socket/unix_domain_listen_socket_posix.cc
new file mode 100644
index 00000000000..3e46439c8b5
--- /dev/null
+++ b/chromium/net/socket/unix_domain_listen_socket_posix.cc
@@ -0,0 +1,167 @@
+// Copyright 2014 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/unix_domain_listen_socket_posix.h"
+
+#include <errno.h>
+#include <sys/socket.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <sys/un.h>
+#include <unistd.h>
+
+#include <cstring>
+#include <string>
+
+#include "base/bind.h"
+#include "base/callback.h"
+#include "base/posix/eintr_wrapper.h"
+#include "base/threading/platform_thread.h"
+#include "build/build_config.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_util.h"
+#include "net/socket/socket_descriptor.h"
+#include "net/socket/unix_domain_client_socket_posix.h"
+
+namespace net {
+namespace deprecated {
+
+namespace {
+
+int CreateAndBind(const std::string& socket_path,
+ bool use_abstract_namespace,
+ SocketDescriptor* socket_fd) {
+ DCHECK(socket_fd);
+
+ SockaddrStorage address;
+ if (!UnixDomainClientSocket::FillAddress(socket_path,
+ use_abstract_namespace,
+ &address)) {
+ return ERR_ADDRESS_INVALID;
+ }
+
+ SocketDescriptor fd = CreatePlatformSocket(PF_UNIX, SOCK_STREAM, 0);
+ if (fd == kInvalidSocket)
+ return errno ? MapSystemError(errno) : ERR_UNEXPECTED;
+
+ if (bind(fd, address.addr, address.addr_len) < 0) {
+ int rv = MapSystemError(errno);
+ close(fd);
+ PLOG(ERROR) << "Could not bind unix domain socket to " << socket_path
+ << (use_abstract_namespace ? " (with abstract namespace)" : "");
+ return rv;
+ }
+
+ *socket_fd = fd;
+ return OK;
+}
+
+} // namespace
+
+// static
+scoped_ptr<UnixDomainListenSocket>
+UnixDomainListenSocket::CreateAndListenInternal(
+ const std::string& path,
+ const std::string& fallback_path,
+ StreamListenSocket::Delegate* del,
+ const AuthCallback& auth_callback,
+ bool use_abstract_namespace) {
+ SocketDescriptor socket_fd = kInvalidSocket;
+ int rv = CreateAndBind(path, use_abstract_namespace, &socket_fd);
+ if (rv != OK && !fallback_path.empty())
+ rv = CreateAndBind(fallback_path, use_abstract_namespace, &socket_fd);
+ if (rv != OK)
+ return scoped_ptr<UnixDomainListenSocket>();
+ scoped_ptr<UnixDomainListenSocket> sock(
+ new UnixDomainListenSocket(socket_fd, del, auth_callback));
+ sock->Listen();
+ return sock.Pass();
+}
+
+// static
+scoped_ptr<UnixDomainListenSocket> UnixDomainListenSocket::CreateAndListen(
+ const std::string& path,
+ StreamListenSocket::Delegate* del,
+ const AuthCallback& auth_callback) {
+ return CreateAndListenInternal(path, "", del, auth_callback, false);
+}
+
+#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED)
+// static
+scoped_ptr<UnixDomainListenSocket>
+UnixDomainListenSocket::CreateAndListenWithAbstractNamespace(
+ const std::string& path,
+ const std::string& fallback_path,
+ StreamListenSocket::Delegate* del,
+ const AuthCallback& auth_callback) {
+ return
+ CreateAndListenInternal(path, fallback_path, del, auth_callback, true);
+}
+#endif
+
+UnixDomainListenSocket::UnixDomainListenSocket(
+ SocketDescriptor s,
+ StreamListenSocket::Delegate* del,
+ const AuthCallback& auth_callback)
+ : StreamListenSocket(s, del),
+ auth_callback_(auth_callback) {}
+
+UnixDomainListenSocket::~UnixDomainListenSocket() {}
+
+void UnixDomainListenSocket::Accept() {
+ SocketDescriptor conn = StreamListenSocket::AcceptSocket();
+ if (conn == kInvalidSocket)
+ return;
+ UnixDomainServerSocket::Credentials credentials;
+ if (!UnixDomainServerSocket::GetPeerCredentials(conn, &credentials) ||
+ !auth_callback_.Run(credentials)) {
+ if (IGNORE_EINTR(close(conn)) < 0)
+ LOG(ERROR) << "close() error";
+ return;
+ }
+ scoped_ptr<UnixDomainListenSocket> sock(
+ new UnixDomainListenSocket(conn, socket_delegate_, auth_callback_));
+ // It's up to the delegate to AddRef if it wants to keep it around.
+ sock->WatchSocket(WAITING_READ);
+ socket_delegate_->DidAccept(this, sock.Pass());
+}
+
+UnixDomainListenSocketFactory::UnixDomainListenSocketFactory(
+ const std::string& path,
+ const UnixDomainListenSocket::AuthCallback& auth_callback)
+ : path_(path),
+ auth_callback_(auth_callback) {}
+
+UnixDomainListenSocketFactory::~UnixDomainListenSocketFactory() {}
+
+scoped_ptr<StreamListenSocket> UnixDomainListenSocketFactory::CreateAndListen(
+ StreamListenSocket::Delegate* delegate) const {
+ return UnixDomainListenSocket::CreateAndListen(
+ path_, delegate, auth_callback_).Pass();
+}
+
+#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED)
+
+UnixDomainListenSocketWithAbstractNamespaceFactory::
+UnixDomainListenSocketWithAbstractNamespaceFactory(
+ const std::string& path,
+ const std::string& fallback_path,
+ const UnixDomainListenSocket::AuthCallback& auth_callback)
+ : UnixDomainListenSocketFactory(path, auth_callback),
+ fallback_path_(fallback_path) {}
+
+UnixDomainListenSocketWithAbstractNamespaceFactory::
+~UnixDomainListenSocketWithAbstractNamespaceFactory() {}
+
+scoped_ptr<StreamListenSocket>
+UnixDomainListenSocketWithAbstractNamespaceFactory::CreateAndListen(
+ StreamListenSocket::Delegate* delegate) const {
+ return UnixDomainListenSocket::CreateAndListenWithAbstractNamespace(
+ path_, fallback_path_, delegate, auth_callback_);
+}
+
+#endif
+
+} // namespace deprecated
+} // namespace net
diff --git a/chromium/net/socket/unix_domain_listen_socket_posix.h b/chromium/net/socket/unix_domain_listen_socket_posix.h
new file mode 100644
index 00000000000..82ec342edaa
--- /dev/null
+++ b/chromium/net/socket/unix_domain_listen_socket_posix.h
@@ -0,0 +1,122 @@
+// Copyright 2014 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_UNIX_DOMAIN_LISTEN_SOCKET_POSIX_H_
+#define NET_SOCKET_UNIX_DOMAIN_LISTEN_SOCKET_POSIX_H_
+
+#include <string>
+
+#include "base/basictypes.h"
+#include "base/callback_forward.h"
+#include "base/compiler_specific.h"
+#include "base/macros.h"
+#include "build/build_config.h"
+#include "net/base/net_export.h"
+#include "net/socket/stream_listen_socket.h"
+#include "net/socket/unix_domain_server_socket_posix.h"
+
+#if defined(OS_ANDROID) || defined(OS_LINUX)
+// Feature only supported on Linux currently. This lets the Unix Domain Socket
+// not be backed by the file system.
+#define SOCKET_ABSTRACT_NAMESPACE_SUPPORTED
+#endif
+
+namespace net {
+namespace deprecated {
+
+// Unix Domain Socket Implementation. Supports abstract namespaces on Linux.
+class NET_EXPORT UnixDomainListenSocket : public StreamListenSocket {
+ public:
+ typedef UnixDomainServerSocket::AuthCallback AuthCallback;
+
+ ~UnixDomainListenSocket() override;
+
+ // Note that the returned UnixDomainListenSocket instance does not take
+ // ownership of |del|.
+ static scoped_ptr<UnixDomainListenSocket> CreateAndListen(
+ const std::string& path,
+ StreamListenSocket::Delegate* del,
+ const AuthCallback& auth_callback);
+
+#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED)
+ // Same as above except that the created socket uses the abstract namespace
+ // which is a Linux-only feature. If |fallback_path| is not empty,
+ // make the second attempt with the provided fallback name.
+ static scoped_ptr<UnixDomainListenSocket>
+ CreateAndListenWithAbstractNamespace(
+ const std::string& path,
+ const std::string& fallback_path,
+ StreamListenSocket::Delegate* del,
+ const AuthCallback& auth_callback);
+#endif
+
+ private:
+ UnixDomainListenSocket(SocketDescriptor s,
+ StreamListenSocket::Delegate* del,
+ const AuthCallback& auth_callback);
+
+ static scoped_ptr<UnixDomainListenSocket> CreateAndListenInternal(
+ const std::string& path,
+ const std::string& fallback_path,
+ StreamListenSocket::Delegate* del,
+ const AuthCallback& auth_callback,
+ bool use_abstract_namespace);
+
+ // StreamListenSocket:
+ void Accept() override;
+
+ AuthCallback auth_callback_;
+
+ DISALLOW_COPY_AND_ASSIGN(UnixDomainListenSocket);
+};
+
+// Factory that can be used to instantiate UnixDomainListenSocket.
+class NET_EXPORT UnixDomainListenSocketFactory
+ : public StreamListenSocketFactory {
+ public:
+ // Note that this class does not take ownership of the provided delegate.
+ UnixDomainListenSocketFactory(
+ const std::string& path,
+ const UnixDomainListenSocket::AuthCallback& auth_callback);
+ ~UnixDomainListenSocketFactory() override;
+
+ // StreamListenSocketFactory:
+ scoped_ptr<StreamListenSocket> CreateAndListen(
+ StreamListenSocket::Delegate* delegate) const override;
+
+ protected:
+ const std::string path_;
+ const UnixDomainListenSocket::AuthCallback auth_callback_;
+
+ private:
+ DISALLOW_COPY_AND_ASSIGN(UnixDomainListenSocketFactory);
+};
+
+#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED)
+// Use this factory to instantiate UnixDomainListenSocket using the abstract
+// namespace feature (only supported on Linux).
+class NET_EXPORT UnixDomainListenSocketWithAbstractNamespaceFactory
+ : public UnixDomainListenSocketFactory {
+ public:
+ UnixDomainListenSocketWithAbstractNamespaceFactory(
+ const std::string& path,
+ const std::string& fallback_path,
+ const UnixDomainListenSocket::AuthCallback& auth_callback);
+ ~UnixDomainListenSocketWithAbstractNamespaceFactory() override;
+
+ // UnixDomainListenSocketFactory:
+ scoped_ptr<StreamListenSocket> CreateAndListen(
+ StreamListenSocket::Delegate* delegate) const override;
+
+ private:
+ std::string fallback_path_;
+
+ DISALLOW_COPY_AND_ASSIGN(UnixDomainListenSocketWithAbstractNamespaceFactory);
+};
+#endif
+
+} // namespace deprecated
+} // namespace net
+
+#endif // NET_SOCKET_UNIX_DOMAIN_LISTEN_SOCKET_POSIX_H_
diff --git a/chromium/net/socket/unix_domain_socket_posix_unittest.cc b/chromium/net/socket/unix_domain_listen_socket_posix_unittest.cc
index b1857e62e0e..aaf362310c3 100644
--- a/chromium/net/socket/unix_domain_socket_posix_unittest.cc
+++ b/chromium/net/socket/unix_domain_listen_socket_posix_unittest.cc
@@ -1,7 +1,9 @@
-// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Copyright 2014 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
+#include "net/socket/unix_domain_listen_socket_posix.h"
+
#include <errno.h>
#include <fcntl.h>
#include <poll.h>
@@ -19,8 +21,9 @@
#include "base/bind.h"
#include "base/callback.h"
#include "base/compiler_specific.h"
-#include "base/file_util.h"
#include "base/files/file_path.h"
+#include "base/files/file_util.h"
+#include "base/files/scoped_temp_dir.h"
#include "base/memory/ref_counted.h"
#include "base/memory/scoped_ptr.h"
#include "base/message_loop/message_loop.h"
@@ -30,16 +33,16 @@
#include "base/threading/platform_thread.h"
#include "base/threading/thread.h"
#include "net/socket/socket_descriptor.h"
-#include "net/socket/unix_domain_socket_posix.h"
#include "testing/gtest/include/gtest/gtest.h"
using std::queue;
using std::string;
namespace net {
+namespace deprecated {
namespace {
-const char kSocketFilename[] = "unix_domain_socket_for_testing";
+const char kSocketFilename[] = "socket_for_testing";
const char kInvalidSocketPath[] = "/invalid/path";
const char kMsg[] = "hello";
@@ -52,16 +55,6 @@ enum EventType {
EVENT_READ,
};
-string MakeSocketPath(const string& socket_file_name) {
- base::FilePath temp_dir;
- base::GetTempDir(&temp_dir);
- return temp_dir.Append(socket_file_name).value();
-}
-
-string MakeSocketPath() {
- return MakeSocketPath(kSocketFilename);
-}
-
class EventManager : public base::RefCounted<EventManager> {
public:
EventManager() : condition_(&mutex_) {}
@@ -101,16 +94,16 @@ class TestListenSocketDelegate : public StreamListenSocket::Delegate {
const scoped_refptr<EventManager>& event_manager)
: event_manager_(event_manager) {}
- virtual void DidAccept(StreamListenSocket* server,
- scoped_ptr<StreamListenSocket> connection) OVERRIDE {
+ void DidAccept(StreamListenSocket* server,
+ scoped_ptr<StreamListenSocket> connection) override {
LOG(ERROR) << __PRETTY_FUNCTION__;
connection_ = connection.Pass();
Notify(EVENT_ACCEPT);
}
- virtual void DidRead(StreamListenSocket* connection,
- const char* data,
- int len) OVERRIDE {
+ void DidRead(StreamListenSocket* connection,
+ const char* data,
+ int len) override {
{
base::AutoLock lock(mutex_);
DCHECK(len);
@@ -119,9 +112,7 @@ class TestListenSocketDelegate : public StreamListenSocket::Delegate {
Notify(EVENT_READ);
}
- virtual void DidClose(StreamListenSocket* sock) OVERRIDE {
- Notify(EVENT_CLOSE);
- }
+ void DidClose(StreamListenSocket* sock) override { Notify(EVENT_CLOSE); }
void OnListenCompleted() {
Notify(EVENT_LISTEN);
@@ -145,47 +136,52 @@ class TestListenSocketDelegate : public StreamListenSocket::Delegate {
bool UserCanConnectCallback(
bool allow_user, const scoped_refptr<EventManager>& event_manager,
- uid_t, gid_t) {
+ const UnixDomainServerSocket::Credentials&) {
event_manager->Notify(
allow_user ? EVENT_AUTH_GRANTED : EVENT_AUTH_DENIED);
return allow_user;
}
-class UnixDomainSocketTestHelper : public testing::Test {
+class UnixDomainListenSocketTestHelper : public testing::Test {
public:
void CreateAndListen() {
- socket_ = UnixDomainSocket::CreateAndListen(
+ socket_ = UnixDomainListenSocket::CreateAndListen(
file_path_.value(), socket_delegate_.get(), MakeAuthCallback());
socket_delegate_->OnListenCompleted();
}
protected:
- UnixDomainSocketTestHelper(const string& path, bool allow_user)
- : file_path_(path),
- allow_user_(allow_user) {}
+ UnixDomainListenSocketTestHelper(const string& path_str, bool allow_user)
+ : allow_user_(allow_user) {
+ file_path_ = base::FilePath(path_str);
+ if (!file_path_.IsAbsolute()) {
+ EXPECT_TRUE(temp_dir_.CreateUniqueTempDir());
+ file_path_ = GetTempSocketPath(file_path_.value());
+ }
+ // Beware that if path_str is an absolute path, this class doesn't delete
+ // the file. It must be an invalid path and cannot be created by unittests.
+ }
+
+ base::FilePath GetTempSocketPath(const std::string socket_name) {
+ DCHECK(temp_dir_.IsValid());
+ return temp_dir_.path().Append(socket_name);
+ }
- virtual void SetUp() OVERRIDE {
+ void SetUp() override {
event_manager_ = new EventManager();
socket_delegate_.reset(new TestListenSocketDelegate(event_manager_));
- DeleteSocketFile();
}
- virtual void TearDown() OVERRIDE {
- DeleteSocketFile();
+ void TearDown() override {
socket_.reset();
socket_delegate_.reset();
event_manager_ = NULL;
}
- UnixDomainSocket::AuthCallback MakeAuthCallback() {
+ UnixDomainListenSocket::AuthCallback MakeAuthCallback() {
return base::Bind(&UserCanConnectCallback, allow_user_, event_manager_);
}
- void DeleteSocketFile() {
- ASSERT_FALSE(file_path_.empty());
- base::DeleteFile(file_path_, false /* not recursive */);
- }
-
SocketDescriptor CreateClientSocket() {
const SocketDescriptor sock = CreatePlatformSocket(PF_UNIX, SOCK_STREAM, 0);
if (sock < 0) {
@@ -199,7 +195,8 @@ class UnixDomainSocketTestHelper : public testing::Test {
strncpy(addr.sun_path, file_path_.value().c_str(), sizeof(addr.sun_path));
addr_len = sizeof(sockaddr_un);
if (connect(sock, reinterpret_cast<sockaddr*>(&addr), addr_len) != 0) {
- LOG(ERROR) << "connect() error";
+ LOG(ERROR) << "connect() error: " << strerror(errno)
+ << ": path=" << file_path_.value();
return kInvalidSocket;
}
return sock;
@@ -212,43 +209,48 @@ class UnixDomainSocketTestHelper : public testing::Test {
thread->StartWithOptions(options);
thread->message_loop()->PostTask(
FROM_HERE,
- base::Bind(&UnixDomainSocketTestHelper::CreateAndListen,
+ base::Bind(&UnixDomainListenSocketTestHelper::CreateAndListen,
base::Unretained(this)));
return thread.Pass();
}
- const base::FilePath file_path_;
+ base::ScopedTempDir temp_dir_;
+ base::FilePath file_path_;
const bool allow_user_;
scoped_refptr<EventManager> event_manager_;
scoped_ptr<TestListenSocketDelegate> socket_delegate_;
- scoped_ptr<UnixDomainSocket> socket_;
+ scoped_ptr<UnixDomainListenSocket> socket_;
};
-class UnixDomainSocketTest : public UnixDomainSocketTestHelper {
+class UnixDomainListenSocketTest : public UnixDomainListenSocketTestHelper {
protected:
- UnixDomainSocketTest()
- : UnixDomainSocketTestHelper(MakeSocketPath(), true /* allow user */) {}
+ UnixDomainListenSocketTest()
+ : UnixDomainListenSocketTestHelper(kSocketFilename,
+ true /* allow user */) {}
};
-class UnixDomainSocketTestWithInvalidPath : public UnixDomainSocketTestHelper {
+class UnixDomainListenSocketTestWithInvalidPath
+ : public UnixDomainListenSocketTestHelper {
protected:
- UnixDomainSocketTestWithInvalidPath()
- : UnixDomainSocketTestHelper(kInvalidSocketPath, true) {}
+ UnixDomainListenSocketTestWithInvalidPath()
+ : UnixDomainListenSocketTestHelper(kInvalidSocketPath, true) {}
};
-class UnixDomainSocketTestWithForbiddenUser
- : public UnixDomainSocketTestHelper {
+class UnixDomainListenSocketTestWithForbiddenUser
+ : public UnixDomainListenSocketTestHelper {
protected:
- UnixDomainSocketTestWithForbiddenUser()
- : UnixDomainSocketTestHelper(MakeSocketPath(), false /* forbid user */) {}
+ UnixDomainListenSocketTestWithForbiddenUser()
+ : UnixDomainListenSocketTestHelper(kSocketFilename,
+ false /* forbid user */) {}
};
-TEST_F(UnixDomainSocketTest, CreateAndListen) {
+TEST_F(UnixDomainListenSocketTest, CreateAndListen) {
CreateAndListen();
EXPECT_FALSE(socket_.get() == NULL);
}
-TEST_F(UnixDomainSocketTestWithInvalidPath, CreateAndListenWithInvalidPath) {
+TEST_F(UnixDomainListenSocketTestWithInvalidPath,
+ CreateAndListenWithInvalidPath) {
CreateAndListen();
EXPECT_TRUE(socket_.get() == NULL);
}
@@ -256,35 +258,35 @@ TEST_F(UnixDomainSocketTestWithInvalidPath, CreateAndListenWithInvalidPath) {
#ifdef SOCKET_ABSTRACT_NAMESPACE_SUPPORTED
// Test with an invalid path to make sure that the socket is not backed by a
// file.
-TEST_F(UnixDomainSocketTestWithInvalidPath,
+TEST_F(UnixDomainListenSocketTestWithInvalidPath,
CreateAndListenWithAbstractNamespace) {
- socket_ = UnixDomainSocket::CreateAndListenWithAbstractNamespace(
+ socket_ = UnixDomainListenSocket::CreateAndListenWithAbstractNamespace(
file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback());
EXPECT_FALSE(socket_.get() == NULL);
}
-TEST_F(UnixDomainSocketTest, TestFallbackName) {
- scoped_ptr<UnixDomainSocket> existing_socket =
- UnixDomainSocket::CreateAndListenWithAbstractNamespace(
+TEST_F(UnixDomainListenSocketTest, TestFallbackName) {
+ scoped_ptr<UnixDomainListenSocket> existing_socket =
+ UnixDomainListenSocket::CreateAndListenWithAbstractNamespace(
file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback());
EXPECT_FALSE(existing_socket.get() == NULL);
// First, try to bind socket with the same name with no fallback name.
socket_ =
- UnixDomainSocket::CreateAndListenWithAbstractNamespace(
+ UnixDomainListenSocket::CreateAndListenWithAbstractNamespace(
file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback());
EXPECT_TRUE(socket_.get() == NULL);
// Now with a fallback name.
- const char kFallbackSocketName[] = "unix_domain_socket_for_testing_2";
- socket_ = UnixDomainSocket::CreateAndListenWithAbstractNamespace(
+ const char kFallbackSocketName[] = "socket_for_testing_2";
+ socket_ = UnixDomainListenSocket::CreateAndListenWithAbstractNamespace(
file_path_.value(),
- MakeSocketPath(kFallbackSocketName),
+ GetTempSocketPath(kFallbackSocketName).value(),
socket_delegate_.get(),
MakeAuthCallback());
EXPECT_FALSE(socket_.get() == NULL);
}
#endif
-TEST_F(UnixDomainSocketTest, TestWithClient) {
+TEST_F(UnixDomainListenSocketTest, TestWithClient) {
const scoped_ptr<base::Thread> server_thread = CreateAndRunServerThread();
EventType event = event_manager_->WaitForEvent();
ASSERT_EQ(EVENT_LISTEN, event);
@@ -311,7 +313,7 @@ TEST_F(UnixDomainSocketTest, TestWithClient) {
ASSERT_EQ(EVENT_CLOSE, event);
}
-TEST_F(UnixDomainSocketTestWithForbiddenUser, TestWithForbiddenUser) {
+TEST_F(UnixDomainListenSocketTestWithForbiddenUser, TestWithForbiddenUser) {
const scoped_ptr<base::Thread> server_thread = CreateAndRunServerThread();
EventType event = event_manager_->WaitForEvent();
ASSERT_EQ(EVENT_LISTEN, event);
@@ -335,4 +337,5 @@ TEST_F(UnixDomainSocketTestWithForbiddenUser, TestWithForbiddenUser) {
}
} // namespace
+} // namespace deprecated
} // namespace net
diff --git a/chromium/net/socket/unix_domain_server_socket_posix.cc b/chromium/net/socket/unix_domain_server_socket_posix.cc
new file mode 100644
index 00000000000..4d6328310ff
--- /dev/null
+++ b/chromium/net/socket/unix_domain_server_socket_posix.cc
@@ -0,0 +1,186 @@
+// Copyright 2014 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/unix_domain_server_socket_posix.h"
+
+#include <errno.h>
+#include <sys/socket.h>
+#include <sys/un.h>
+#include <unistd.h>
+
+#include "base/logging.h"
+#include "net/base/net_errors.h"
+#include "net/socket/socket_libevent.h"
+#include "net/socket/unix_domain_client_socket_posix.h"
+
+namespace net {
+
+namespace {
+
+// Intended for use as SetterCallbacks in Accept() helper methods.
+void SetStreamSocket(scoped_ptr<StreamSocket>* socket,
+ scoped_ptr<SocketLibevent> accepted_socket) {
+ socket->reset(new UnixDomainClientSocket(accepted_socket.Pass()));
+}
+
+void SetSocketDescriptor(SocketDescriptor* socket,
+ scoped_ptr<SocketLibevent> accepted_socket) {
+ *socket = accepted_socket->ReleaseConnectedSocket();
+}
+
+} // anonymous namespace
+
+UnixDomainServerSocket::UnixDomainServerSocket(
+ const AuthCallback& auth_callback,
+ bool use_abstract_namespace)
+ : auth_callback_(auth_callback),
+ use_abstract_namespace_(use_abstract_namespace) {
+ DCHECK(!auth_callback_.is_null());
+}
+
+UnixDomainServerSocket::~UnixDomainServerSocket() {
+}
+
+// static
+bool UnixDomainServerSocket::GetPeerCredentials(SocketDescriptor socket,
+ Credentials* credentials) {
+#if defined(OS_LINUX) || defined(OS_ANDROID)
+ struct ucred user_cred;
+ socklen_t len = sizeof(user_cred);
+ if (getsockopt(socket, SOL_SOCKET, SO_PEERCRED, &user_cred, &len) < 0)
+ return false;
+ credentials->process_id = user_cred.pid;
+ credentials->user_id = user_cred.uid;
+ credentials->group_id = user_cred.gid;
+ return true;
+#else
+ return getpeereid(
+ socket, &credentials->user_id, &credentials->group_id) == 0;
+#endif
+}
+
+int UnixDomainServerSocket::Listen(const IPEndPoint& address, int backlog) {
+ NOTIMPLEMENTED();
+ return ERR_NOT_IMPLEMENTED;
+}
+
+int UnixDomainServerSocket::ListenWithAddressAndPort(
+ const std::string& unix_domain_path,
+ int port_unused,
+ int backlog) {
+ DCHECK(!listen_socket_);
+
+ SockaddrStorage address;
+ if (!UnixDomainClientSocket::FillAddress(unix_domain_path,
+ use_abstract_namespace_,
+ &address)) {
+ return ERR_ADDRESS_INVALID;
+ }
+
+ scoped_ptr<SocketLibevent> socket(new SocketLibevent);
+ int rv = socket->Open(AF_UNIX);
+ DCHECK_NE(ERR_IO_PENDING, rv);
+ if (rv != OK)
+ return rv;
+
+ rv = socket->Bind(address);
+ DCHECK_NE(ERR_IO_PENDING, rv);
+ if (rv != OK) {
+ PLOG(ERROR)
+ << "Could not bind unix domain socket to " << unix_domain_path
+ << (use_abstract_namespace_ ? " (with abstract namespace)" : "");
+ return rv;
+ }
+
+ rv = socket->Listen(backlog);
+ DCHECK_NE(ERR_IO_PENDING, rv);
+ if (rv != OK)
+ return rv;
+
+ listen_socket_.swap(socket);
+ return rv;
+}
+
+int UnixDomainServerSocket::GetLocalAddress(IPEndPoint* address) const {
+ NOTIMPLEMENTED();
+ return ERR_NOT_IMPLEMENTED;
+}
+
+int UnixDomainServerSocket::Accept(scoped_ptr<StreamSocket>* socket,
+ const CompletionCallback& callback) {
+ DCHECK(socket);
+
+ SetterCallback setter_callback = base::Bind(&SetStreamSocket, socket);
+ return DoAccept(setter_callback, callback);
+}
+
+int UnixDomainServerSocket::AcceptSocketDescriptor(
+ SocketDescriptor* socket,
+ const CompletionCallback& callback) {
+ DCHECK(socket);
+
+ SetterCallback setter_callback = base::Bind(&SetSocketDescriptor, socket);
+ return DoAccept(setter_callback, callback);
+}
+
+int UnixDomainServerSocket::DoAccept(const SetterCallback& setter_callback,
+ const CompletionCallback& callback) {
+ DCHECK(!setter_callback.is_null());
+ DCHECK(!callback.is_null());
+ DCHECK(listen_socket_);
+ DCHECK(!accept_socket_);
+
+ while (true) {
+ int rv = listen_socket_->Accept(
+ &accept_socket_,
+ base::Bind(&UnixDomainServerSocket::AcceptCompleted,
+ base::Unretained(this),
+ setter_callback,
+ callback));
+ if (rv != OK)
+ return rv;
+ if (AuthenticateAndGetStreamSocket(setter_callback))
+ return OK;
+ // Accept another socket because authentication error should be transparent
+ // to the caller.
+ }
+}
+
+void UnixDomainServerSocket::AcceptCompleted(
+ const SetterCallback& setter_callback,
+ const CompletionCallback& callback,
+ int rv) {
+ if (rv != OK) {
+ callback.Run(rv);
+ return;
+ }
+
+ if (AuthenticateAndGetStreamSocket(setter_callback)) {
+ callback.Run(OK);
+ return;
+ }
+
+ // Accept another socket because authentication error should be transparent
+ // to the caller.
+ rv = DoAccept(setter_callback, callback);
+ if (rv != ERR_IO_PENDING)
+ callback.Run(rv);
+}
+
+bool UnixDomainServerSocket::AuthenticateAndGetStreamSocket(
+ const SetterCallback& setter_callback) {
+ DCHECK(accept_socket_);
+
+ Credentials credentials;
+ if (!GetPeerCredentials(accept_socket_->socket_fd(), &credentials) ||
+ !auth_callback_.Run(credentials)) {
+ accept_socket_.reset();
+ return false;
+ }
+
+ setter_callback.Run(accept_socket_.Pass());
+ return true;
+}
+
+} // namespace net
diff --git a/chromium/net/socket/unix_domain_server_socket_posix.h b/chromium/net/socket/unix_domain_server_socket_posix.h
new file mode 100644
index 00000000000..0a26eb3d375
--- /dev/null
+++ b/chromium/net/socket/unix_domain_server_socket_posix.h
@@ -0,0 +1,91 @@
+// Copyright 2014 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_UNIX_DOMAIN_SERVER_SOCKET_POSIX_H_
+#define NET_SOCKET_UNIX_DOMAIN_SERVER_SOCKET_POSIX_H_
+
+#include <sys/types.h>
+
+#include <string>
+
+#include "base/basictypes.h"
+#include "base/callback.h"
+#include "base/macros.h"
+#include "base/memory/scoped_ptr.h"
+#include "net/base/net_export.h"
+#include "net/socket/server_socket.h"
+#include "net/socket/socket_descriptor.h"
+
+namespace net {
+
+class SocketLibevent;
+
+// Unix Domain Server Socket Implementation. Supports abstract namespaces on
+// Linux and Android.
+class NET_EXPORT UnixDomainServerSocket : public ServerSocket {
+ public:
+ // Credentials of a peer process connected to the socket.
+ struct NET_EXPORT Credentials {
+#if defined(OS_LINUX) || defined(OS_ANDROID)
+ // Linux/Android API provides more information about the connected peer
+ // than Windows/OS X. It's useful for permission-based authorization on
+ // Android.
+ pid_t process_id;
+#endif
+ uid_t user_id;
+ gid_t group_id;
+ };
+
+ // Callback that returns whether the already connected client, identified by
+ // its credentials, is allowed to keep the connection open. Note that
+ // the socket is closed immediately in case the callback returns false.
+ typedef base::Callback<bool (const Credentials&)> AuthCallback;
+
+ UnixDomainServerSocket(const AuthCallback& auth_callack,
+ bool use_abstract_namespace);
+ ~UnixDomainServerSocket() override;
+
+ // Gets credentials of peer to check permissions.
+ static bool GetPeerCredentials(SocketDescriptor socket_fd,
+ Credentials* credentials);
+
+ // ServerSocket implementation.
+ int Listen(const IPEndPoint& address, int backlog) override;
+ int ListenWithAddressAndPort(const std::string& unix_domain_path,
+ int port_unused,
+ int backlog) override;
+ int GetLocalAddress(IPEndPoint* address) const override;
+ int Accept(scoped_ptr<StreamSocket>* socket,
+ const CompletionCallback& callback) override;
+
+ // Accepts an incoming connection on |listen_socket_|, but passes back
+ // a raw SocketDescriptor instead of a StreamSocket.
+ int AcceptSocketDescriptor(SocketDescriptor* socket_descriptor,
+ const CompletionCallback& callback);
+
+ private:
+ // A callback to wrap the setting of the out-parameter to Accept().
+ // This allows the internal machinery of that call to be implemented in
+ // a manner that's agnostic to the caller's desired output.
+ typedef base::Callback<void(scoped_ptr<SocketLibevent>)> SetterCallback;
+
+ int DoAccept(const SetterCallback& setter_callback,
+ const CompletionCallback& callback);
+ void AcceptCompleted(const SetterCallback& setter_callback,
+ const CompletionCallback& callback,
+ int rv);
+ bool AuthenticateAndGetStreamSocket(const SetterCallback& setter_callback);
+
+ scoped_ptr<SocketLibevent> listen_socket_;
+ const AuthCallback auth_callback_;
+ const bool use_abstract_namespace_;
+
+ scoped_ptr<SocketLibevent> accept_socket_;
+
+ DISALLOW_COPY_AND_ASSIGN(UnixDomainServerSocket);
+};
+
+} // namespace net
+
+#endif // NET_SOCKET_UNIX_DOMAIN_SOCKET_POSIX_H_
diff --git a/chromium/net/socket/unix_domain_server_socket_posix_unittest.cc b/chromium/net/socket/unix_domain_server_socket_posix_unittest.cc
new file mode 100644
index 00000000000..bdf1efa29c4
--- /dev/null
+++ b/chromium/net/socket/unix_domain_server_socket_posix_unittest.cc
@@ -0,0 +1,125 @@
+// Copyright 2014 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/unix_domain_server_socket_posix.h"
+
+#include <vector>
+
+#include "base/bind.h"
+#include "base/files/file_path.h"
+#include "base/files/scoped_temp_dir.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/run_loop.h"
+#include "base/stl_util.h"
+#include "net/base/io_buffer.h"
+#include "net/base/net_errors.h"
+#include "net/base/test_completion_callback.h"
+#include "net/socket/unix_domain_client_socket_posix.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace net {
+namespace {
+
+const char kSocketFilename[] = "socket_for_testing";
+const char kInvalidSocketPath[] = "/invalid/path";
+
+bool UserCanConnectCallback(bool allow_user,
+ const UnixDomainServerSocket::Credentials& credentials) {
+ // Here peers are running in same process.
+#if defined(OS_LINUX) || defined(OS_ANDROID)
+ EXPECT_EQ(getpid(), credentials.process_id);
+#endif
+ EXPECT_EQ(getuid(), credentials.user_id);
+ EXPECT_EQ(getgid(), credentials.group_id);
+ return allow_user;
+}
+
+UnixDomainServerSocket::AuthCallback CreateAuthCallback(bool allow_user) {
+ return base::Bind(&UserCanConnectCallback, allow_user);
+}
+
+class UnixDomainServerSocketTest : public testing::Test {
+ protected:
+ UnixDomainServerSocketTest() {
+ EXPECT_TRUE(temp_dir_.CreateUniqueTempDir());
+ socket_path_ = temp_dir_.path().Append(kSocketFilename).value();
+ }
+
+ base::ScopedTempDir temp_dir_;
+ std::string socket_path_;
+};
+
+TEST_F(UnixDomainServerSocketTest, ListenWithInvalidPath) {
+ const bool kUseAbstractNamespace = false;
+ UnixDomainServerSocket server_socket(CreateAuthCallback(true),
+ kUseAbstractNamespace);
+ EXPECT_EQ(ERR_FILE_NOT_FOUND,
+ server_socket.ListenWithAddressAndPort(kInvalidSocketPath, 0, 1));
+}
+
+TEST_F(UnixDomainServerSocketTest, ListenWithInvalidPathWithAbstractNamespace) {
+ const bool kUseAbstractNamespace = true;
+ UnixDomainServerSocket server_socket(CreateAuthCallback(true),
+ kUseAbstractNamespace);
+#if defined(OS_ANDROID) || defined(OS_LINUX)
+ EXPECT_EQ(OK,
+ server_socket.ListenWithAddressAndPort(kInvalidSocketPath, 0, 1));
+#else
+ EXPECT_EQ(ERR_ADDRESS_INVALID,
+ server_socket.ListenWithAddressAndPort(kInvalidSocketPath, 0, 1));
+#endif
+}
+
+TEST_F(UnixDomainServerSocketTest, ListenAgainAfterFailureWithInvalidPath) {
+ const bool kUseAbstractNamespace = false;
+ UnixDomainServerSocket server_socket(CreateAuthCallback(true),
+ kUseAbstractNamespace);
+ EXPECT_EQ(ERR_FILE_NOT_FOUND,
+ server_socket.ListenWithAddressAndPort(kInvalidSocketPath, 0, 1));
+ EXPECT_EQ(OK, server_socket.ListenWithAddressAndPort(socket_path_, 0, 1));
+}
+
+TEST_F(UnixDomainServerSocketTest, AcceptWithForbiddenUser) {
+ const bool kUseAbstractNamespace = false;
+
+ UnixDomainServerSocket server_socket(CreateAuthCallback(false),
+ kUseAbstractNamespace);
+ EXPECT_EQ(OK, server_socket.ListenWithAddressAndPort(socket_path_, 0, 1));
+
+ scoped_ptr<StreamSocket> accepted_socket;
+ TestCompletionCallback accept_callback;
+ EXPECT_EQ(ERR_IO_PENDING,
+ server_socket.Accept(&accepted_socket, accept_callback.callback()));
+ EXPECT_FALSE(accepted_socket);
+
+ UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
+ EXPECT_FALSE(client_socket.IsConnected());
+
+ // Connect() will return OK before the server rejects the connection.
+ TestCompletionCallback connect_callback;
+ int rv = connect_callback.GetResult(
+ client_socket.Connect(connect_callback.callback()));
+ ASSERT_EQ(OK, rv);
+
+ // Try to read from the socket.
+ const int read_buffer_size = 10;
+ scoped_refptr<IOBuffer> read_buffer(new IOBuffer(read_buffer_size));
+ TestCompletionCallback read_callback;
+ rv = read_callback.GetResult(client_socket.Read(
+ read_buffer.get(), read_buffer_size, read_callback.callback()));
+
+ // The server should have disconnected gracefully, without sending any data.
+ ASSERT_EQ(0, rv);
+ EXPECT_FALSE(client_socket.IsConnected());
+
+ // The server socket should not have called |accept_callback| or modified
+ // |accepted_socket|.
+ EXPECT_FALSE(accept_callback.have_result());
+ EXPECT_FALSE(accepted_socket);
+}
+
+// Normal cases including read/write are tested by UnixDomainClientSocketTest.
+
+} // namespace
+} // namespace net
diff --git a/chromium/net/socket/unix_domain_socket_posix.cc b/chromium/net/socket/unix_domain_socket_posix.cc
deleted file mode 100644
index 3141f7166b2..00000000000
--- a/chromium/net/socket/unix_domain_socket_posix.cc
+++ /dev/null
@@ -1,196 +0,0 @@
-// Copyright (c) 2012 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#include "net/socket/unix_domain_socket_posix.h"
-
-#include <cstring>
-#include <string>
-
-#include <errno.h>
-#include <sys/socket.h>
-#include <sys/stat.h>
-#include <sys/types.h>
-#include <sys/un.h>
-#include <unistd.h>
-
-#include "base/bind.h"
-#include "base/callback.h"
-#include "base/posix/eintr_wrapper.h"
-#include "base/threading/platform_thread.h"
-#include "build/build_config.h"
-#include "net/base/net_errors.h"
-#include "net/base/net_util.h"
-#include "net/socket/socket_descriptor.h"
-
-namespace net {
-
-namespace {
-
-bool NoAuthenticationCallback(uid_t, gid_t) {
- return true;
-}
-
-bool GetPeerIds(int socket, uid_t* user_id, gid_t* group_id) {
-#if defined(OS_LINUX) || defined(OS_ANDROID)
- struct ucred user_cred;
- socklen_t len = sizeof(user_cred);
- if (getsockopt(socket, SOL_SOCKET, SO_PEERCRED, &user_cred, &len) == -1)
- return false;
- *user_id = user_cred.uid;
- *group_id = user_cred.gid;
-#else
- if (getpeereid(socket, user_id, group_id) == -1)
- return false;
-#endif
- return true;
-}
-
-} // namespace
-
-// static
-UnixDomainSocket::AuthCallback UnixDomainSocket::NoAuthentication() {
- return base::Bind(NoAuthenticationCallback);
-}
-
-// static
-scoped_ptr<UnixDomainSocket> UnixDomainSocket::CreateAndListenInternal(
- const std::string& path,
- const std::string& fallback_path,
- StreamListenSocket::Delegate* del,
- const AuthCallback& auth_callback,
- bool use_abstract_namespace) {
- SocketDescriptor s = CreateAndBind(path, use_abstract_namespace);
- if (s == kInvalidSocket && !fallback_path.empty())
- s = CreateAndBind(fallback_path, use_abstract_namespace);
- if (s == kInvalidSocket)
- return scoped_ptr<UnixDomainSocket>();
- scoped_ptr<UnixDomainSocket> sock(
- new UnixDomainSocket(s, del, auth_callback));
- sock->Listen();
- return sock.Pass();
-}
-
-// static
-scoped_ptr<UnixDomainSocket> UnixDomainSocket::CreateAndListen(
- const std::string& path,
- StreamListenSocket::Delegate* del,
- const AuthCallback& auth_callback) {
- return CreateAndListenInternal(path, "", del, auth_callback, false);
-}
-
-#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED)
-// static
-scoped_ptr<UnixDomainSocket>
-UnixDomainSocket::CreateAndListenWithAbstractNamespace(
- const std::string& path,
- const std::string& fallback_path,
- StreamListenSocket::Delegate* del,
- const AuthCallback& auth_callback) {
- return
- CreateAndListenInternal(path, fallback_path, del, auth_callback, true);
-}
-#endif
-
-UnixDomainSocket::UnixDomainSocket(
- SocketDescriptor s,
- StreamListenSocket::Delegate* del,
- const AuthCallback& auth_callback)
- : StreamListenSocket(s, del),
- auth_callback_(auth_callback) {}
-
-UnixDomainSocket::~UnixDomainSocket() {}
-
-// static
-SocketDescriptor UnixDomainSocket::CreateAndBind(const std::string& path,
- bool use_abstract_namespace) {
- sockaddr_un addr;
- static const size_t kPathMax = sizeof(addr.sun_path);
- if (use_abstract_namespace + path.size() + 1 /* '\0' */ > kPathMax)
- return kInvalidSocket;
- const SocketDescriptor s = CreatePlatformSocket(PF_UNIX, SOCK_STREAM, 0);
- if (s == kInvalidSocket)
- return kInvalidSocket;
- memset(&addr, 0, sizeof(addr));
- addr.sun_family = AF_UNIX;
- socklen_t addr_len;
- if (use_abstract_namespace) {
- // Convert the path given into abstract socket name. It must start with
- // the '\0' character, so we are adding it. |addr_len| must specify the
- // length of the structure exactly, as potentially the socket name may
- // have '\0' characters embedded (although we don't support this).
- // Note that addr.sun_path is already zero initialized.
- memcpy(addr.sun_path + 1, path.c_str(), path.size());
- addr_len = path.size() + offsetof(struct sockaddr_un, sun_path) + 1;
- } else {
- memcpy(addr.sun_path, path.c_str(), path.size());
- addr_len = sizeof(sockaddr_un);
- }
- if (bind(s, reinterpret_cast<sockaddr*>(&addr), addr_len)) {
- LOG(ERROR) << "Could not bind unix domain socket to " << path;
- if (use_abstract_namespace)
- LOG(ERROR) << " (with abstract namespace enabled)";
- if (IGNORE_EINTR(close(s)) < 0)
- LOG(ERROR) << "close() error";
- return kInvalidSocket;
- }
- return s;
-}
-
-void UnixDomainSocket::Accept() {
- SocketDescriptor conn = StreamListenSocket::AcceptSocket();
- if (conn == kInvalidSocket)
- return;
- uid_t user_id;
- gid_t group_id;
- if (!GetPeerIds(conn, &user_id, &group_id) ||
- !auth_callback_.Run(user_id, group_id)) {
- if (IGNORE_EINTR(close(conn)) < 0)
- LOG(ERROR) << "close() error";
- return;
- }
- scoped_ptr<UnixDomainSocket> sock(
- new UnixDomainSocket(conn, socket_delegate_, auth_callback_));
- // It's up to the delegate to AddRef if it wants to keep it around.
- sock->WatchSocket(WAITING_READ);
- socket_delegate_->DidAccept(this, sock.PassAs<StreamListenSocket>());
-}
-
-UnixDomainSocketFactory::UnixDomainSocketFactory(
- const std::string& path,
- const UnixDomainSocket::AuthCallback& auth_callback)
- : path_(path),
- auth_callback_(auth_callback) {}
-
-UnixDomainSocketFactory::~UnixDomainSocketFactory() {}
-
-scoped_ptr<StreamListenSocket> UnixDomainSocketFactory::CreateAndListen(
- StreamListenSocket::Delegate* delegate) const {
- return UnixDomainSocket::CreateAndListen(
- path_, delegate, auth_callback_).PassAs<StreamListenSocket>();
-}
-
-#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED)
-
-UnixDomainSocketWithAbstractNamespaceFactory::
-UnixDomainSocketWithAbstractNamespaceFactory(
- const std::string& path,
- const std::string& fallback_path,
- const UnixDomainSocket::AuthCallback& auth_callback)
- : UnixDomainSocketFactory(path, auth_callback),
- fallback_path_(fallback_path) {}
-
-UnixDomainSocketWithAbstractNamespaceFactory::
-~UnixDomainSocketWithAbstractNamespaceFactory() {}
-
-scoped_ptr<StreamListenSocket>
-UnixDomainSocketWithAbstractNamespaceFactory::CreateAndListen(
- StreamListenSocket::Delegate* delegate) const {
- return UnixDomainSocket::CreateAndListenWithAbstractNamespace(
- path_, fallback_path_, delegate, auth_callback_)
- .PassAs<StreamListenSocket>();
-}
-
-#endif
-
-} // namespace net
diff --git a/chromium/net/socket/unix_domain_socket_posix.h b/chromium/net/socket/unix_domain_socket_posix.h
deleted file mode 100644
index 98d0c11a648..00000000000
--- a/chromium/net/socket/unix_domain_socket_posix.h
+++ /dev/null
@@ -1,126 +0,0 @@
-// Copyright (c) 2012 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#ifndef NET_SOCKET_UNIX_DOMAIN_SOCKET_POSIX_H_
-#define NET_SOCKET_UNIX_DOMAIN_SOCKET_POSIX_H_
-
-#include <string>
-
-#include "base/basictypes.h"
-#include "base/callback_forward.h"
-#include "base/compiler_specific.h"
-#include "build/build_config.h"
-#include "net/base/net_export.h"
-#include "net/socket/stream_listen_socket.h"
-
-#if defined(OS_ANDROID) || defined(OS_LINUX)
-// Feature only supported on Linux currently. This lets the Unix Domain Socket
-// not be backed by the file system.
-#define SOCKET_ABSTRACT_NAMESPACE_SUPPORTED
-#endif
-
-namespace net {
-
-// Unix Domain Socket Implementation. Supports abstract namespaces on Linux.
-class NET_EXPORT UnixDomainSocket : public StreamListenSocket {
- public:
- virtual ~UnixDomainSocket();
-
- // Callback that returns whether the already connected client, identified by
- // its process |user_id| and |group_id|, is allowed to keep the connection
- // open. Note that the socket is closed immediately in case the callback
- // returns false.
- typedef base::Callback<bool (uid_t user_id, gid_t group_id)> AuthCallback;
-
- // Returns an authentication callback that always grants access for
- // convenience in case you don't want to use authentication.
- static AuthCallback NoAuthentication();
-
- // Note that the returned UnixDomainSocket instance does not take ownership of
- // |del|.
- static scoped_ptr<UnixDomainSocket> CreateAndListen(
- const std::string& path,
- StreamListenSocket::Delegate* del,
- const AuthCallback& auth_callback);
-
-#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED)
- // Same as above except that the created socket uses the abstract namespace
- // which is a Linux-only feature. If |fallback_path| is not empty,
- // make the second attempt with the provided fallback name.
- static scoped_ptr<UnixDomainSocket> CreateAndListenWithAbstractNamespace(
- const std::string& path,
- const std::string& fallback_path,
- StreamListenSocket::Delegate* del,
- const AuthCallback& auth_callback);
-#endif
-
- private:
- UnixDomainSocket(SocketDescriptor s,
- StreamListenSocket::Delegate* del,
- const AuthCallback& auth_callback);
-
- static scoped_ptr<UnixDomainSocket> CreateAndListenInternal(
- const std::string& path,
- const std::string& fallback_path,
- StreamListenSocket::Delegate* del,
- const AuthCallback& auth_callback,
- bool use_abstract_namespace);
-
- static SocketDescriptor CreateAndBind(const std::string& path,
- bool use_abstract_namespace);
-
- // StreamListenSocket:
- virtual void Accept() OVERRIDE;
-
- AuthCallback auth_callback_;
-
- DISALLOW_COPY_AND_ASSIGN(UnixDomainSocket);
-};
-
-// Factory that can be used to instantiate UnixDomainSocket.
-class NET_EXPORT UnixDomainSocketFactory : public StreamListenSocketFactory {
- public:
- // Note that this class does not take ownership of the provided delegate.
- UnixDomainSocketFactory(const std::string& path,
- const UnixDomainSocket::AuthCallback& auth_callback);
- virtual ~UnixDomainSocketFactory();
-
- // StreamListenSocketFactory:
- virtual scoped_ptr<StreamListenSocket> CreateAndListen(
- StreamListenSocket::Delegate* delegate) const OVERRIDE;
-
- protected:
- const std::string path_;
- const UnixDomainSocket::AuthCallback auth_callback_;
-
- private:
- DISALLOW_COPY_AND_ASSIGN(UnixDomainSocketFactory);
-};
-
-#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED)
-// Use this factory to instantiate UnixDomainSocket using the abstract
-// namespace feature (only supported on Linux).
-class NET_EXPORT UnixDomainSocketWithAbstractNamespaceFactory
- : public UnixDomainSocketFactory {
- public:
- UnixDomainSocketWithAbstractNamespaceFactory(
- const std::string& path,
- const std::string& fallback_path,
- const UnixDomainSocket::AuthCallback& auth_callback);
- virtual ~UnixDomainSocketWithAbstractNamespaceFactory();
-
- // UnixDomainSocketFactory:
- virtual scoped_ptr<StreamListenSocket> CreateAndListen(
- StreamListenSocket::Delegate* delegate) const OVERRIDE;
-
- private:
- std::string fallback_path_;
-
- DISALLOW_COPY_AND_ASSIGN(UnixDomainSocketWithAbstractNamespaceFactory);
-};
-#endif
-
-} // namespace net
-
-#endif // NET_SOCKET_UNIX_DOMAIN_SOCKET_POSIX_H_
diff --git a/chromium/net/socket/websocket_endpoint_lock_manager.cc b/chromium/net/socket/websocket_endpoint_lock_manager.cc
new file mode 100644
index 00000000000..e578bb2435b
--- /dev/null
+++ b/chromium/net/socket/websocket_endpoint_lock_manager.cc
@@ -0,0 +1,132 @@
+// Copyright 2014 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/websocket_endpoint_lock_manager.h"
+
+#include <utility>
+
+#include "base/logging.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_log.h"
+
+namespace net {
+
+WebSocketEndpointLockManager::Waiter::~Waiter() {
+ if (next()) {
+ DCHECK(previous());
+ RemoveFromList();
+ }
+}
+
+WebSocketEndpointLockManager* WebSocketEndpointLockManager::GetInstance() {
+ return Singleton<WebSocketEndpointLockManager>::get();
+}
+
+int WebSocketEndpointLockManager::LockEndpoint(const IPEndPoint& endpoint,
+ Waiter* waiter) {
+ LockInfoMap::value_type insert_value(endpoint, LockInfo());
+ std::pair<LockInfoMap::iterator, bool> rv =
+ lock_info_map_.insert(insert_value);
+ LockInfo& lock_info_in_map = rv.first->second;
+ if (rv.second) {
+ DVLOG(3) << "Locking endpoint " << endpoint.ToString();
+ lock_info_in_map.queue.reset(new LockInfo::WaiterQueue);
+ return OK;
+ }
+ DVLOG(3) << "Waiting for endpoint " << endpoint.ToString();
+ lock_info_in_map.queue->Append(waiter);
+ return ERR_IO_PENDING;
+}
+
+void WebSocketEndpointLockManager::RememberSocket(StreamSocket* socket,
+ const IPEndPoint& endpoint) {
+ LockInfoMap::iterator lock_info_it = lock_info_map_.find(endpoint);
+ CHECK(lock_info_it != lock_info_map_.end());
+ bool inserted =
+ socket_lock_info_map_.insert(SocketLockInfoMap::value_type(
+ socket, lock_info_it)).second;
+ DCHECK(inserted);
+ DCHECK(!lock_info_it->second.socket);
+ lock_info_it->second.socket = socket;
+ DVLOG(3) << "Remembered (StreamSocket*)" << socket << " for "
+ << endpoint.ToString() << " (" << socket_lock_info_map_.size()
+ << " socket(s) remembered)";
+}
+
+void WebSocketEndpointLockManager::UnlockSocket(StreamSocket* socket) {
+ SocketLockInfoMap::iterator socket_it = socket_lock_info_map_.find(socket);
+ if (socket_it == socket_lock_info_map_.end())
+ return;
+
+ LockInfoMap::iterator lock_info_it = socket_it->second;
+
+ DVLOG(3) << "Unlocking (StreamSocket*)" << socket << " for "
+ << lock_info_it->first.ToString() << " ("
+ << socket_lock_info_map_.size() << " socket(s) left)";
+ socket_lock_info_map_.erase(socket_it);
+ DCHECK(socket == lock_info_it->second.socket);
+ lock_info_it->second.socket = NULL;
+ UnlockEndpointByIterator(lock_info_it);
+}
+
+void WebSocketEndpointLockManager::UnlockEndpoint(const IPEndPoint& endpoint) {
+ LockInfoMap::iterator lock_info_it = lock_info_map_.find(endpoint);
+ if (lock_info_it == lock_info_map_.end())
+ return;
+
+ UnlockEndpointByIterator(lock_info_it);
+}
+
+bool WebSocketEndpointLockManager::IsEmpty() const {
+ return lock_info_map_.empty() && socket_lock_info_map_.empty();
+}
+
+WebSocketEndpointLockManager::LockInfo::LockInfo() : socket(NULL) {}
+WebSocketEndpointLockManager::LockInfo::~LockInfo() {
+ DCHECK(!socket);
+}
+
+WebSocketEndpointLockManager::LockInfo::LockInfo(const LockInfo& rhs)
+ : socket(rhs.socket) {
+ DCHECK(!rhs.queue);
+}
+
+WebSocketEndpointLockManager::WebSocketEndpointLockManager() {}
+
+WebSocketEndpointLockManager::~WebSocketEndpointLockManager() {
+ DCHECK(lock_info_map_.empty());
+ DCHECK(socket_lock_info_map_.empty());
+}
+
+void WebSocketEndpointLockManager::UnlockEndpointByIterator(
+ LockInfoMap::iterator lock_info_it) {
+ if (lock_info_it->second.socket)
+ EraseSocket(lock_info_it);
+ LockInfo::WaiterQueue* queue = lock_info_it->second.queue.get();
+ DCHECK(queue);
+ if (queue->empty()) {
+ DVLOG(3) << "Unlocking endpoint " << lock_info_it->first.ToString();
+ lock_info_map_.erase(lock_info_it);
+ return;
+ }
+
+ DVLOG(3) << "Unlocking endpoint " << lock_info_it->first.ToString()
+ << " and activating next waiter";
+ Waiter* next_job = queue->head()->value();
+ next_job->RemoveFromList();
+ // This must be last to minimise the excitement caused by re-entrancy.
+ next_job->GotEndpointLock();
+}
+
+void WebSocketEndpointLockManager::EraseSocket(
+ LockInfoMap::iterator lock_info_it) {
+ DVLOG(3) << "Removing (StreamSocket*)" << lock_info_it->second.socket
+ << " for " << lock_info_it->first.ToString() << " ("
+ << socket_lock_info_map_.size() << " socket(s) left)";
+ size_t erased = socket_lock_info_map_.erase(lock_info_it->second.socket);
+ DCHECK_EQ(1U, erased);
+ lock_info_it->second.socket = NULL;
+}
+
+} // namespace net
diff --git a/chromium/net/socket/websocket_endpoint_lock_manager.h b/chromium/net/socket/websocket_endpoint_lock_manager.h
new file mode 100644
index 00000000000..d5cad508d6a
--- /dev/null
+++ b/chromium/net/socket/websocket_endpoint_lock_manager.h
@@ -0,0 +1,121 @@
+// Copyright 2014 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_WEBSOCKET_ENDPOINT_LOCK_MANAGER_H_
+#define NET_SOCKET_WEBSOCKET_ENDPOINT_LOCK_MANAGER_H_
+
+#include <map>
+
+#include "base/containers/linked_list.h"
+#include "base/logging.h"
+#include "base/macros.h"
+#include "base/memory/singleton.h"
+#include "net/base/ip_endpoint.h"
+#include "net/base/net_export.h"
+#include "net/socket/websocket_transport_client_socket_pool.h"
+
+namespace net {
+
+class StreamSocket;
+
+class NET_EXPORT_PRIVATE WebSocketEndpointLockManager {
+ public:
+ class NET_EXPORT_PRIVATE Waiter : public base::LinkNode<Waiter> {
+ public:
+ // If the node is in a list, removes it.
+ virtual ~Waiter();
+
+ virtual void GotEndpointLock() = 0;
+ };
+
+ static WebSocketEndpointLockManager* GetInstance();
+
+ // Returns OK if lock was acquired immediately, ERR_IO_PENDING if not. If the
+ // lock was not acquired, then |waiter->GotEndpointLock()| will be called when
+ // it is. A Waiter automatically removes itself from the list of waiters when
+ // its destructor is called.
+ int LockEndpoint(const IPEndPoint& endpoint, Waiter* waiter);
+
+ // Records the IPEndPoint associated with a particular socket. This is
+ // necessary because TCPClientSocket refuses to return the PeerAddress after
+ // the connection is disconnected. The association will be forgotten when
+ // UnlockSocket() or UnlockEndpoint() is called. The |socket| pointer must not
+ // be deleted between the call to RememberSocket() and the call to
+ // UnlockSocket().
+ void RememberSocket(StreamSocket* socket, const IPEndPoint& endpoint);
+
+ // Releases the lock on the endpoint that was associated with |socket| by
+ // RememberSocket(). If appropriate, triggers the next socket connection.
+ // Should be called exactly once for each |socket| that was passed to
+ // RememberSocket(). Does nothing if UnlockEndpoint() has been called since
+ // the call to RememberSocket().
+ void UnlockSocket(StreamSocket* socket);
+
+ // Releases the lock on |endpoint|. Does nothing if |endpoint| is not locked.
+ // Removes any socket association that was recorded with RememberSocket(). If
+ // appropriate, calls |waiter->GotEndpointLock()|.
+ void UnlockEndpoint(const IPEndPoint& endpoint);
+
+ // Checks that |lock_info_map_| and |socket_lock_info_map_| are empty. For
+ // tests.
+ bool IsEmpty() const;
+
+ private:
+ struct LockInfo {
+ typedef base::LinkedList<Waiter> WaiterQueue;
+
+ LockInfo();
+ ~LockInfo();
+
+ // This object must be copyable to be placed in the map, but it cannot be
+ // copied after |queue| has been assigned to.
+ LockInfo(const LockInfo& rhs);
+
+ // Not used.
+ LockInfo& operator=(const LockInfo& rhs);
+
+ // Must be NULL to copy this object into the map. Must be set to non-NULL
+ // after the object is inserted into the map then point to the same list
+ // until this object is deleted.
+ scoped_ptr<WaiterQueue> queue;
+
+ // This pointer is only used to identify the last instance of StreamSocket
+ // that was passed to RememberSocket() for this endpoint. It should only be
+ // compared with other pointers. It is never dereferenced and not owned. It
+ // is non-NULL if RememberSocket() has been called for this endpoint since
+ // the last call to UnlockSocket() or UnlockEndpoint().
+ StreamSocket* socket;
+ };
+
+ // SocketLockInfoMap requires std::map iterator semantics for LockInfoMap
+ // (ie. that the iterator will remain valid as long as the entry is not
+ // deleted).
+ typedef std::map<IPEndPoint, LockInfo> LockInfoMap;
+ typedef std::map<StreamSocket*, LockInfoMap::iterator> SocketLockInfoMap;
+
+ WebSocketEndpointLockManager();
+ ~WebSocketEndpointLockManager();
+
+ void UnlockEndpointByIterator(LockInfoMap::iterator lock_info_it);
+ void EraseSocket(LockInfoMap::iterator lock_info_it);
+
+ // If an entry is present in the map for a particular endpoint, then that
+ // endpoint is locked. If LockInfo.queue is non-empty, then one or more
+ // Waiters are waiting for the lock.
+ LockInfoMap lock_info_map_;
+
+ // Store sockets remembered by RememberSocket() and not yet unlocked by
+ // UnlockSocket() or UnlockEndpoint(). Every entry in this map always
+ // references a live entry in lock_info_map_, and the LockInfo::socket member
+ // is non-NULL if and only if there is an entry in this map for the socket.
+ SocketLockInfoMap socket_lock_info_map_;
+
+ friend struct DefaultSingletonTraits<WebSocketEndpointLockManager>;
+
+ DISALLOW_COPY_AND_ASSIGN(WebSocketEndpointLockManager);
+};
+
+} // namespace net
+
+#endif // NET_SOCKET_WEBSOCKET_ENDPOINT_LOCK_MANAGER_H_
diff --git a/chromium/net/socket/websocket_endpoint_lock_manager_unittest.cc b/chromium/net/socket/websocket_endpoint_lock_manager_unittest.cc
new file mode 100644
index 00000000000..1626aa90201
--- /dev/null
+++ b/chromium/net/socket/websocket_endpoint_lock_manager_unittest.cc
@@ -0,0 +1,224 @@
+// Copyright 2014 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/websocket_endpoint_lock_manager.h"
+
+#include "net/base/net_errors.h"
+#include "net/socket/next_proto.h"
+#include "net/socket/socket_test_util.h"
+#include "net/socket/stream_socket.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace net {
+
+namespace {
+
+// A StreamSocket implementation with no functionality at all.
+// TODO(ricea): If you need to use this in another file, please move it to
+// socket_test_util.h.
+class FakeStreamSocket : public StreamSocket {
+ public:
+ FakeStreamSocket() {}
+
+ // StreamSocket implementation
+ int Connect(const CompletionCallback& callback) override {
+ return ERR_FAILED;
+ }
+
+ void Disconnect() override { return; }
+
+ bool IsConnected() const override { return false; }
+
+ bool IsConnectedAndIdle() const override { return false; }
+
+ int GetPeerAddress(IPEndPoint* address) const override { return ERR_FAILED; }
+
+ int GetLocalAddress(IPEndPoint* address) const override { return ERR_FAILED; }
+
+ const BoundNetLog& NetLog() const override { return bound_net_log_; }
+
+ void SetSubresourceSpeculation() override { return; }
+ void SetOmniboxSpeculation() override { return; }
+
+ bool WasEverUsed() const override { return false; }
+
+ bool UsingTCPFastOpen() const override { return false; }
+
+ bool WasNpnNegotiated() const override { return false; }
+
+ NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
+
+ bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
+
+ // Socket implementation
+ int Read(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override {
+ return ERR_FAILED;
+ }
+
+ int Write(IOBuffer* buf,
+ int buf_len,
+ const CompletionCallback& callback) override {
+ return ERR_FAILED;
+ }
+
+ int SetReceiveBufferSize(int32 size) override { return ERR_FAILED; }
+
+ int SetSendBufferSize(int32 size) override { return ERR_FAILED; }
+
+ private:
+ BoundNetLog bound_net_log_;
+
+ DISALLOW_COPY_AND_ASSIGN(FakeStreamSocket);
+};
+
+class FakeWaiter : public WebSocketEndpointLockManager::Waiter {
+ public:
+ FakeWaiter() : called_(false) {}
+
+ void GotEndpointLock() override {
+ CHECK(!called_);
+ called_ = true;
+ }
+
+ bool called() const { return called_; }
+
+ private:
+ bool called_;
+};
+
+class WebSocketEndpointLockManagerTest : public ::testing::Test {
+ protected:
+ WebSocketEndpointLockManagerTest()
+ : instance_(WebSocketEndpointLockManager::GetInstance()) {}
+ ~WebSocketEndpointLockManagerTest() override {
+ // If this check fails then subsequent tests may fail.
+ CHECK(instance_->IsEmpty());
+ }
+
+ WebSocketEndpointLockManager* instance() const { return instance_; }
+
+ IPEndPoint DummyEndpoint() {
+ IPAddressNumber ip_address_number;
+ CHECK(ParseIPLiteralToNumber("127.0.0.1", &ip_address_number));
+ return IPEndPoint(ip_address_number, 80);
+ }
+
+ void UnlockDummyEndpoint(int times) {
+ for (int i = 0; i < times; ++i) {
+ instance()->UnlockEndpoint(DummyEndpoint());
+ }
+ }
+
+ WebSocketEndpointLockManager* const instance_;
+};
+
+TEST_F(WebSocketEndpointLockManagerTest, GetInstanceWorks) {
+ // All the work is done by the test framework.
+}
+
+TEST_F(WebSocketEndpointLockManagerTest, LockEndpointReturnsOkOnce) {
+ FakeWaiter waiters[2];
+ EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &waiters[0]));
+ EXPECT_EQ(ERR_IO_PENDING,
+ instance()->LockEndpoint(DummyEndpoint(), &waiters[1]));
+
+ UnlockDummyEndpoint(2);
+}
+
+TEST_F(WebSocketEndpointLockManagerTest, GotEndpointLockNotCalledOnOk) {
+ FakeWaiter waiter;
+ EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &waiter));
+ EXPECT_FALSE(waiter.called());
+
+ UnlockDummyEndpoint(1);
+}
+
+TEST_F(WebSocketEndpointLockManagerTest, GotEndpointLockNotCalledImmediately) {
+ FakeWaiter waiters[2];
+ EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &waiters[0]));
+ EXPECT_EQ(ERR_IO_PENDING,
+ instance()->LockEndpoint(DummyEndpoint(), &waiters[1]));
+ EXPECT_FALSE(waiters[1].called());
+
+ UnlockDummyEndpoint(2);
+}
+
+TEST_F(WebSocketEndpointLockManagerTest, GotEndpointLockCalledWhenUnlocked) {
+ FakeWaiter waiters[2];
+ EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &waiters[0]));
+ EXPECT_EQ(ERR_IO_PENDING,
+ instance()->LockEndpoint(DummyEndpoint(), &waiters[1]));
+ instance()->UnlockEndpoint(DummyEndpoint());
+ EXPECT_TRUE(waiters[1].called());
+
+ UnlockDummyEndpoint(1);
+}
+
+TEST_F(WebSocketEndpointLockManagerTest,
+ EndpointUnlockedIfWaiterAlreadyDeleted) {
+ FakeWaiter first_lock_holder;
+ EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &first_lock_holder));
+
+ {
+ FakeWaiter short_lived_waiter;
+ EXPECT_EQ(ERR_IO_PENDING,
+ instance()->LockEndpoint(DummyEndpoint(), &short_lived_waiter));
+ }
+
+ instance()->UnlockEndpoint(DummyEndpoint());
+
+ FakeWaiter second_lock_holder;
+ EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &second_lock_holder));
+
+ UnlockDummyEndpoint(1);
+}
+
+TEST_F(WebSocketEndpointLockManagerTest, RememberSocketWorks) {
+ FakeWaiter waiters[2];
+ FakeStreamSocket dummy_socket;
+ EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &waiters[0]));
+ EXPECT_EQ(ERR_IO_PENDING,
+ instance()->LockEndpoint(DummyEndpoint(), &waiters[1]));
+
+ instance()->RememberSocket(&dummy_socket, DummyEndpoint());
+ instance()->UnlockSocket(&dummy_socket);
+ EXPECT_TRUE(waiters[1].called());
+
+ UnlockDummyEndpoint(1);
+}
+
+// UnlockEndpoint() should cause any sockets remembered for this endpoint
+// to be forgotten.
+TEST_F(WebSocketEndpointLockManagerTest, SocketAssociationForgottenOnUnlock) {
+ FakeWaiter waiter;
+ FakeStreamSocket dummy_socket;
+
+ EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &waiter));
+ instance()->RememberSocket(&dummy_socket, DummyEndpoint());
+ instance()->UnlockEndpoint(DummyEndpoint());
+ EXPECT_TRUE(instance()->IsEmpty());
+}
+
+// When ownership of the endpoint is passed to a new waiter, the new waiter can
+// call RememberSocket() again.
+TEST_F(WebSocketEndpointLockManagerTest, NextWaiterCanCallRememberSocketAgain) {
+ FakeWaiter waiters[2];
+ FakeStreamSocket dummy_sockets[2];
+ EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &waiters[0]));
+ EXPECT_EQ(ERR_IO_PENDING,
+ instance()->LockEndpoint(DummyEndpoint(), &waiters[1]));
+
+ instance()->RememberSocket(&dummy_sockets[0], DummyEndpoint());
+ instance()->UnlockEndpoint(DummyEndpoint());
+ EXPECT_TRUE(waiters[1].called());
+ instance()->RememberSocket(&dummy_sockets[1], DummyEndpoint());
+
+ UnlockDummyEndpoint(1);
+}
+
+} // namespace
+
+} // namespace net
diff --git a/chromium/net/socket/websocket_transport_client_socket_pool.cc b/chromium/net/socket/websocket_transport_client_socket_pool.cc
new file mode 100644
index 00000000000..15ec028cb18
--- /dev/null
+++ b/chromium/net/socket/websocket_transport_client_socket_pool.cc
@@ -0,0 +1,651 @@
+// Copyright 2014 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/websocket_transport_client_socket_pool.h"
+
+#include <algorithm>
+
+#include "base/compiler_specific.h"
+#include "base/logging.h"
+#include "base/numerics/safe_conversions.h"
+#include "base/strings/string_util.h"
+#include "base/time/time.h"
+#include "base/values.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_log.h"
+#include "net/socket/client_socket_handle.h"
+#include "net/socket/client_socket_pool_base.h"
+#include "net/socket/websocket_endpoint_lock_manager.h"
+#include "net/socket/websocket_transport_connect_sub_job.h"
+
+namespace net {
+
+namespace {
+
+using base::TimeDelta;
+
+// TODO(ricea): For now, we implement a global timeout for compatability with
+// TransportConnectJob. Since WebSocketTransportConnectJob controls the address
+// selection process more tightly, it could do something smarter here.
+const int kTransportConnectJobTimeoutInSeconds = 240; // 4 minutes.
+
+} // namespace
+
+WebSocketTransportConnectJob::WebSocketTransportConnectJob(
+ const std::string& group_name,
+ RequestPriority priority,
+ const scoped_refptr<TransportSocketParams>& params,
+ TimeDelta timeout_duration,
+ const CompletionCallback& callback,
+ ClientSocketFactory* client_socket_factory,
+ HostResolver* host_resolver,
+ ClientSocketHandle* handle,
+ Delegate* delegate,
+ NetLog* pool_net_log,
+ const BoundNetLog& request_net_log)
+ : ConnectJob(group_name,
+ timeout_duration,
+ priority,
+ delegate,
+ BoundNetLog::Make(pool_net_log, NetLog::SOURCE_CONNECT_JOB)),
+ helper_(params, client_socket_factory, host_resolver, &connect_timing_),
+ race_result_(TransportConnectJobHelper::CONNECTION_LATENCY_UNKNOWN),
+ handle_(handle),
+ callback_(callback),
+ request_net_log_(request_net_log),
+ had_ipv4_(false),
+ had_ipv6_(false) {
+ helper_.SetOnIOComplete(this);
+}
+
+WebSocketTransportConnectJob::~WebSocketTransportConnectJob() {}
+
+LoadState WebSocketTransportConnectJob::GetLoadState() const {
+ LoadState load_state = LOAD_STATE_RESOLVING_HOST;
+ if (ipv6_job_)
+ load_state = ipv6_job_->GetLoadState();
+ // This method should return LOAD_STATE_CONNECTING in preference to
+ // LOAD_STATE_WAITING_FOR_AVAILABLE_SOCKET when possible because "waiting for
+ // available socket" implies that nothing is happening.
+ if (ipv4_job_ && load_state != LOAD_STATE_CONNECTING)
+ load_state = ipv4_job_->GetLoadState();
+ return load_state;
+}
+
+int WebSocketTransportConnectJob::DoResolveHost() {
+ return helper_.DoResolveHost(priority(), net_log());
+}
+
+int WebSocketTransportConnectJob::DoResolveHostComplete(int result) {
+ return helper_.DoResolveHostComplete(result, net_log());
+}
+
+int WebSocketTransportConnectJob::DoTransportConnect() {
+ AddressList ipv4_addresses;
+ AddressList ipv6_addresses;
+ int result = ERR_UNEXPECTED;
+ helper_.set_next_state(
+ TransportConnectJobHelper::STATE_TRANSPORT_CONNECT_COMPLETE);
+
+ for (AddressList::const_iterator it = helper_.addresses().begin();
+ it != helper_.addresses().end();
+ ++it) {
+ switch (it->GetFamily()) {
+ case ADDRESS_FAMILY_IPV4:
+ ipv4_addresses.push_back(*it);
+ break;
+
+ case ADDRESS_FAMILY_IPV6:
+ ipv6_addresses.push_back(*it);
+ break;
+
+ default:
+ DVLOG(1) << "Unexpected ADDRESS_FAMILY: " << it->GetFamily();
+ break;
+ }
+ }
+
+ if (!ipv4_addresses.empty()) {
+ had_ipv4_ = true;
+ ipv4_job_.reset(new WebSocketTransportConnectSubJob(
+ ipv4_addresses, this, SUB_JOB_IPV4));
+ }
+
+ if (!ipv6_addresses.empty()) {
+ had_ipv6_ = true;
+ ipv6_job_.reset(new WebSocketTransportConnectSubJob(
+ ipv6_addresses, this, SUB_JOB_IPV6));
+ result = ipv6_job_->Start();
+ switch (result) {
+ case OK:
+ SetSocket(ipv6_job_->PassSocket());
+ race_result_ =
+ had_ipv4_
+ ? TransportConnectJobHelper::CONNECTION_LATENCY_IPV6_RACEABLE
+ : TransportConnectJobHelper::CONNECTION_LATENCY_IPV6_SOLO;
+ return result;
+
+ case ERR_IO_PENDING:
+ if (ipv4_job_) {
+ // This use of base::Unretained is safe because |fallback_timer_| is
+ // owned by this object.
+ fallback_timer_.Start(
+ FROM_HERE,
+ TimeDelta::FromMilliseconds(
+ TransportConnectJobHelper::kIPv6FallbackTimerInMs),
+ base::Bind(&WebSocketTransportConnectJob::StartIPv4JobAsync,
+ base::Unretained(this)));
+ }
+ return result;
+
+ default:
+ ipv6_job_.reset();
+ }
+ }
+
+ DCHECK(!ipv6_job_);
+ if (ipv4_job_) {
+ result = ipv4_job_->Start();
+ if (result == OK) {
+ SetSocket(ipv4_job_->PassSocket());
+ race_result_ =
+ had_ipv6_
+ ? TransportConnectJobHelper::CONNECTION_LATENCY_IPV4_WINS_RACE
+ : TransportConnectJobHelper::CONNECTION_LATENCY_IPV4_NO_RACE;
+ }
+ }
+
+ return result;
+}
+
+int WebSocketTransportConnectJob::DoTransportConnectComplete(int result) {
+ if (result == OK)
+ helper_.HistogramDuration(race_result_);
+ return result;
+}
+
+void WebSocketTransportConnectJob::OnSubJobComplete(
+ int result,
+ WebSocketTransportConnectSubJob* job) {
+ if (result == OK) {
+ switch (job->type()) {
+ case SUB_JOB_IPV4:
+ race_result_ =
+ had_ipv6_
+ ? TransportConnectJobHelper::CONNECTION_LATENCY_IPV4_WINS_RACE
+ : TransportConnectJobHelper::CONNECTION_LATENCY_IPV4_NO_RACE;
+ break;
+
+ case SUB_JOB_IPV6:
+ race_result_ =
+ had_ipv4_
+ ? TransportConnectJobHelper::CONNECTION_LATENCY_IPV6_RACEABLE
+ : TransportConnectJobHelper::CONNECTION_LATENCY_IPV6_SOLO;
+ break;
+ }
+ SetSocket(job->PassSocket());
+
+ // Make sure all connections are cancelled even if this object fails to be
+ // deleted.
+ ipv4_job_.reset();
+ ipv6_job_.reset();
+ } else {
+ switch (job->type()) {
+ case SUB_JOB_IPV4:
+ ipv4_job_.reset();
+ break;
+
+ case SUB_JOB_IPV6:
+ ipv6_job_.reset();
+ if (ipv4_job_ && !ipv4_job_->started()) {
+ fallback_timer_.Stop();
+ result = ipv4_job_->Start();
+ if (result != ERR_IO_PENDING) {
+ OnSubJobComplete(result, ipv4_job_.get());
+ return;
+ }
+ }
+ break;
+ }
+ if (ipv4_job_ || ipv6_job_)
+ return;
+ }
+ helper_.OnIOComplete(this, result);
+}
+
+void WebSocketTransportConnectJob::StartIPv4JobAsync() {
+ DCHECK(ipv4_job_);
+ int result = ipv4_job_->Start();
+ if (result != ERR_IO_PENDING)
+ OnSubJobComplete(result, ipv4_job_.get());
+}
+
+int WebSocketTransportConnectJob::ConnectInternal() {
+ return helper_.DoConnectInternal(this);
+}
+
+WebSocketTransportClientSocketPool::WebSocketTransportClientSocketPool(
+ int max_sockets,
+ int max_sockets_per_group,
+ ClientSocketPoolHistograms* histograms,
+ HostResolver* host_resolver,
+ ClientSocketFactory* client_socket_factory,
+ NetLog* net_log)
+ : TransportClientSocketPool(max_sockets,
+ max_sockets_per_group,
+ histograms,
+ host_resolver,
+ client_socket_factory,
+ net_log),
+ connect_job_delegate_(this),
+ histograms_(histograms),
+ pool_net_log_(net_log),
+ client_socket_factory_(client_socket_factory),
+ host_resolver_(host_resolver),
+ max_sockets_(max_sockets),
+ handed_out_socket_count_(0),
+ flushing_(false),
+ weak_factory_(this) {}
+
+WebSocketTransportClientSocketPool::~WebSocketTransportClientSocketPool() {
+ // Clean up any pending connect jobs.
+ FlushWithError(ERR_ABORTED);
+ DCHECK(pending_connects_.empty());
+ DCHECK_EQ(0, handed_out_socket_count_);
+ DCHECK(stalled_request_queue_.empty());
+ DCHECK(stalled_request_map_.empty());
+}
+
+// static
+void WebSocketTransportClientSocketPool::UnlockEndpoint(
+ ClientSocketHandle* handle) {
+ DCHECK(handle->is_initialized());
+ DCHECK(handle->socket());
+ IPEndPoint address;
+ if (handle->socket()->GetPeerAddress(&address) == OK)
+ WebSocketEndpointLockManager::GetInstance()->UnlockEndpoint(address);
+}
+
+int WebSocketTransportClientSocketPool::RequestSocket(
+ const std::string& group_name,
+ const void* params,
+ RequestPriority priority,
+ ClientSocketHandle* handle,
+ const CompletionCallback& callback,
+ const BoundNetLog& request_net_log) {
+ DCHECK(params);
+ const scoped_refptr<TransportSocketParams>& casted_params =
+ *static_cast<const scoped_refptr<TransportSocketParams>*>(params);
+
+ NetLogTcpClientSocketPoolRequestedSocket(request_net_log, &casted_params);
+
+ CHECK(!callback.is_null());
+ CHECK(handle);
+
+ request_net_log.BeginEvent(NetLog::TYPE_SOCKET_POOL);
+
+ if (ReachedMaxSocketsLimit() && !casted_params->ignore_limits()) {
+ request_net_log.AddEvent(NetLog::TYPE_SOCKET_POOL_STALLED_MAX_SOCKETS);
+ // TODO(ricea): Use emplace_back when C++11 becomes allowed.
+ StalledRequest request(
+ casted_params, priority, handle, callback, request_net_log);
+ stalled_request_queue_.push_back(request);
+ StalledRequestQueue::iterator iterator = stalled_request_queue_.end();
+ --iterator;
+ DCHECK_EQ(handle, iterator->handle);
+ // Because StalledRequestQueue is a std::list, its iterators are guaranteed
+ // to remain valid as long as the elements are not removed. As long as
+ // stalled_request_queue_ and stalled_request_map_ are updated in sync, it
+ // is safe to dereference an iterator in stalled_request_map_ to find the
+ // corresponding list element.
+ stalled_request_map_.insert(
+ StalledRequestMap::value_type(handle, iterator));
+ return ERR_IO_PENDING;
+ }
+
+ scoped_ptr<WebSocketTransportConnectJob> connect_job(
+ new WebSocketTransportConnectJob(group_name,
+ priority,
+ casted_params,
+ ConnectionTimeout(),
+ callback,
+ client_socket_factory_,
+ host_resolver_,
+ handle,
+ &connect_job_delegate_,
+ pool_net_log_,
+ request_net_log));
+
+ int rv = connect_job->Connect();
+ // Regardless of the outcome of |connect_job|, it will always be bound to
+ // |handle|, since this pool uses early-binding. So the binding is logged
+ // here, without waiting for the result.
+ request_net_log.AddEvent(
+ NetLog::TYPE_SOCKET_POOL_BOUND_TO_CONNECT_JOB,
+ connect_job->net_log().source().ToEventParametersCallback());
+ if (rv == OK) {
+ HandOutSocket(connect_job->PassSocket(),
+ connect_job->connect_timing(),
+ handle,
+ request_net_log);
+ request_net_log.EndEvent(NetLog::TYPE_SOCKET_POOL);
+ } else if (rv == ERR_IO_PENDING) {
+ // TODO(ricea): Implement backup job timer?
+ AddJob(handle, connect_job.Pass());
+ } else {
+ scoped_ptr<StreamSocket> error_socket;
+ connect_job->GetAdditionalErrorState(handle);
+ error_socket = connect_job->PassSocket();
+ if (error_socket) {
+ HandOutSocket(error_socket.Pass(),
+ connect_job->connect_timing(),
+ handle,
+ request_net_log);
+ }
+ }
+
+ if (rv != ERR_IO_PENDING) {
+ request_net_log.EndEventWithNetErrorCode(NetLog::TYPE_SOCKET_POOL, rv);
+ }
+
+ return rv;
+}
+
+void WebSocketTransportClientSocketPool::RequestSockets(
+ const std::string& group_name,
+ const void* params,
+ int num_sockets,
+ const BoundNetLog& net_log) {
+ NOTIMPLEMENTED();
+}
+
+void WebSocketTransportClientSocketPool::CancelRequest(
+ const std::string& group_name,
+ ClientSocketHandle* handle) {
+ DCHECK(!handle->is_initialized());
+ if (DeleteStalledRequest(handle))
+ return;
+ scoped_ptr<StreamSocket> socket = handle->PassSocket();
+ if (socket)
+ ReleaseSocket(handle->group_name(), socket.Pass(), handle->id());
+ if (!DeleteJob(handle))
+ pending_callbacks_.erase(handle);
+ if (!ReachedMaxSocketsLimit() && !stalled_request_queue_.empty())
+ ActivateStalledRequest();
+}
+
+void WebSocketTransportClientSocketPool::ReleaseSocket(
+ const std::string& group_name,
+ scoped_ptr<StreamSocket> socket,
+ int id) {
+ WebSocketEndpointLockManager::GetInstance()->UnlockSocket(socket.get());
+ CHECK_GT(handed_out_socket_count_, 0);
+ --handed_out_socket_count_;
+ if (!ReachedMaxSocketsLimit() && !stalled_request_queue_.empty())
+ ActivateStalledRequest();
+}
+
+void WebSocketTransportClientSocketPool::FlushWithError(int error) {
+ // Sockets which are in LOAD_STATE_CONNECTING are in danger of unlocking
+ // sockets waiting for the endpoint lock. If they connected synchronously,
+ // then OnConnectJobComplete(). The |flushing_| flag tells this object to
+ // ignore spurious calls to OnConnectJobComplete(). It is safe to ignore those
+ // calls because this method will delete the jobs and call their callbacks
+ // anyway.
+ flushing_ = true;
+ for (PendingConnectsMap::iterator it = pending_connects_.begin();
+ it != pending_connects_.end();
+ ++it) {
+ InvokeUserCallbackLater(
+ it->second->handle(), it->second->callback(), error);
+ delete it->second, it->second = NULL;
+ }
+ pending_connects_.clear();
+ for (StalledRequestQueue::iterator it = stalled_request_queue_.begin();
+ it != stalled_request_queue_.end();
+ ++it) {
+ InvokeUserCallbackLater(it->handle, it->callback, error);
+ }
+ stalled_request_map_.clear();
+ stalled_request_queue_.clear();
+ flushing_ = false;
+}
+
+void WebSocketTransportClientSocketPool::CloseIdleSockets() {
+ // We have no idle sockets.
+}
+
+int WebSocketTransportClientSocketPool::IdleSocketCount() const {
+ return 0;
+}
+
+int WebSocketTransportClientSocketPool::IdleSocketCountInGroup(
+ const std::string& group_name) const {
+ return 0;
+}
+
+LoadState WebSocketTransportClientSocketPool::GetLoadState(
+ const std::string& group_name,
+ const ClientSocketHandle* handle) const {
+ if (stalled_request_map_.find(handle) != stalled_request_map_.end())
+ return LOAD_STATE_WAITING_FOR_AVAILABLE_SOCKET;
+ if (pending_callbacks_.count(handle))
+ return LOAD_STATE_CONNECTING;
+ return LookupConnectJob(handle)->GetLoadState();
+}
+
+base::DictionaryValue* WebSocketTransportClientSocketPool::GetInfoAsValue(
+ const std::string& name,
+ const std::string& type,
+ bool include_nested_pools) const {
+ base::DictionaryValue* dict = new base::DictionaryValue();
+ dict->SetString("name", name);
+ dict->SetString("type", type);
+ dict->SetInteger("handed_out_socket_count", handed_out_socket_count_);
+ dict->SetInteger("connecting_socket_count", pending_connects_.size());
+ dict->SetInteger("idle_socket_count", 0);
+ dict->SetInteger("max_socket_count", max_sockets_);
+ dict->SetInteger("max_sockets_per_group", max_sockets_);
+ dict->SetInteger("pool_generation_number", 0);
+ return dict;
+}
+
+TimeDelta WebSocketTransportClientSocketPool::ConnectionTimeout() const {
+ return TimeDelta::FromSeconds(kTransportConnectJobTimeoutInSeconds);
+}
+
+ClientSocketPoolHistograms* WebSocketTransportClientSocketPool::histograms()
+ const {
+ return histograms_;
+}
+
+bool WebSocketTransportClientSocketPool::IsStalled() const {
+ return !stalled_request_queue_.empty();
+}
+
+void WebSocketTransportClientSocketPool::OnConnectJobComplete(
+ int result,
+ WebSocketTransportConnectJob* job) {
+ DCHECK_NE(ERR_IO_PENDING, result);
+
+ scoped_ptr<StreamSocket> socket = job->PassSocket();
+
+ // See comment in FlushWithError.
+ if (flushing_) {
+ WebSocketEndpointLockManager::GetInstance()->UnlockSocket(socket.get());
+ return;
+ }
+
+ BoundNetLog request_net_log = job->request_net_log();
+ CompletionCallback callback = job->callback();
+ LoadTimingInfo::ConnectTiming connect_timing = job->connect_timing();
+
+ ClientSocketHandle* const handle = job->handle();
+ bool handed_out_socket = false;
+
+ if (result == OK) {
+ DCHECK(socket.get());
+ handed_out_socket = true;
+ HandOutSocket(socket.Pass(), connect_timing, handle, request_net_log);
+ request_net_log.EndEvent(NetLog::TYPE_SOCKET_POOL);
+ } else {
+ // If we got a socket, it must contain error information so pass that
+ // up so that the caller can retrieve it.
+ job->GetAdditionalErrorState(handle);
+ if (socket.get()) {
+ handed_out_socket = true;
+ HandOutSocket(socket.Pass(), connect_timing, handle, request_net_log);
+ }
+ request_net_log.EndEventWithNetErrorCode(NetLog::TYPE_SOCKET_POOL, result);
+ }
+ bool delete_succeeded = DeleteJob(handle);
+ DCHECK(delete_succeeded);
+ if (!handed_out_socket && !stalled_request_queue_.empty() &&
+ !ReachedMaxSocketsLimit())
+ ActivateStalledRequest();
+ InvokeUserCallbackLater(handle, callback, result);
+}
+
+void WebSocketTransportClientSocketPool::InvokeUserCallbackLater(
+ ClientSocketHandle* handle,
+ const CompletionCallback& callback,
+ int rv) {
+ DCHECK(!pending_callbacks_.count(handle));
+ pending_callbacks_.insert(handle);
+ base::MessageLoop::current()->PostTask(
+ FROM_HERE,
+ base::Bind(&WebSocketTransportClientSocketPool::InvokeUserCallback,
+ weak_factory_.GetWeakPtr(),
+ handle,
+ callback,
+ rv));
+}
+
+void WebSocketTransportClientSocketPool::InvokeUserCallback(
+ ClientSocketHandle* handle,
+ const CompletionCallback& callback,
+ int rv) {
+ if (pending_callbacks_.erase(handle))
+ callback.Run(rv);
+}
+
+bool WebSocketTransportClientSocketPool::ReachedMaxSocketsLimit() const {
+ return handed_out_socket_count_ >= max_sockets_ ||
+ base::checked_cast<int>(pending_connects_.size()) >=
+ max_sockets_ - handed_out_socket_count_;
+}
+
+void WebSocketTransportClientSocketPool::HandOutSocket(
+ scoped_ptr<StreamSocket> socket,
+ const LoadTimingInfo::ConnectTiming& connect_timing,
+ ClientSocketHandle* handle,
+ const BoundNetLog& net_log) {
+ DCHECK(socket);
+ handle->SetSocket(socket.Pass());
+ DCHECK_EQ(ClientSocketHandle::UNUSED, handle->reuse_type());
+ DCHECK_EQ(0, handle->idle_time().InMicroseconds());
+ handle->set_pool_id(0);
+ handle->set_connect_timing(connect_timing);
+
+ net_log.AddEvent(
+ NetLog::TYPE_SOCKET_POOL_BOUND_TO_SOCKET,
+ handle->socket()->NetLog().source().ToEventParametersCallback());
+
+ ++handed_out_socket_count_;
+}
+
+void WebSocketTransportClientSocketPool::AddJob(
+ ClientSocketHandle* handle,
+ scoped_ptr<WebSocketTransportConnectJob> connect_job) {
+ bool inserted =
+ pending_connects_.insert(PendingConnectsMap::value_type(
+ handle, connect_job.release())).second;
+ DCHECK(inserted);
+}
+
+bool WebSocketTransportClientSocketPool::DeleteJob(ClientSocketHandle* handle) {
+ PendingConnectsMap::iterator it = pending_connects_.find(handle);
+ if (it == pending_connects_.end())
+ return false;
+ // Deleting a ConnectJob which holds an endpoint lock can lead to a different
+ // ConnectJob proceeding to connect. If the connect proceeds synchronously
+ // (usually because of a failure) then it can trigger that job to be
+ // deleted. |it| remains valid because std::map guarantees that erase() does
+ // not invalid iterators to other entries.
+ delete it->second, it->second = NULL;
+ DCHECK(pending_connects_.find(handle) == it);
+ pending_connects_.erase(it);
+ return true;
+}
+
+const WebSocketTransportConnectJob*
+WebSocketTransportClientSocketPool::LookupConnectJob(
+ const ClientSocketHandle* handle) const {
+ PendingConnectsMap::const_iterator it = pending_connects_.find(handle);
+ CHECK(it != pending_connects_.end());
+ return it->second;
+}
+
+void WebSocketTransportClientSocketPool::ActivateStalledRequest() {
+ DCHECK(!stalled_request_queue_.empty());
+ DCHECK(!ReachedMaxSocketsLimit());
+ // Usually we will only be able to activate one stalled request at a time,
+ // however if all the connects fail synchronously for some reason, we may be
+ // able to clear the whole queue at once.
+ while (!stalled_request_queue_.empty() && !ReachedMaxSocketsLimit()) {
+ StalledRequest request(stalled_request_queue_.front());
+ stalled_request_queue_.pop_front();
+ stalled_request_map_.erase(request.handle);
+ int rv = RequestSocket("ignored",
+ &request.params,
+ request.priority,
+ request.handle,
+ request.callback,
+ request.net_log);
+ // ActivateStalledRequest() never returns synchronously, so it is never
+ // called re-entrantly.
+ if (rv != ERR_IO_PENDING)
+ InvokeUserCallbackLater(request.handle, request.callback, rv);
+ }
+}
+
+bool WebSocketTransportClientSocketPool::DeleteStalledRequest(
+ ClientSocketHandle* handle) {
+ StalledRequestMap::iterator it = stalled_request_map_.find(handle);
+ if (it == stalled_request_map_.end())
+ return false;
+ stalled_request_queue_.erase(it->second);
+ stalled_request_map_.erase(it);
+ return true;
+}
+
+WebSocketTransportClientSocketPool::ConnectJobDelegate::ConnectJobDelegate(
+ WebSocketTransportClientSocketPool* owner)
+ : owner_(owner) {}
+
+WebSocketTransportClientSocketPool::ConnectJobDelegate::~ConnectJobDelegate() {}
+
+void
+WebSocketTransportClientSocketPool::ConnectJobDelegate::OnConnectJobComplete(
+ int result,
+ ConnectJob* job) {
+ owner_->OnConnectJobComplete(result,
+ static_cast<WebSocketTransportConnectJob*>(job));
+}
+
+WebSocketTransportClientSocketPool::StalledRequest::StalledRequest(
+ const scoped_refptr<TransportSocketParams>& params,
+ RequestPriority priority,
+ ClientSocketHandle* handle,
+ const CompletionCallback& callback,
+ const BoundNetLog& net_log)
+ : params(params),
+ priority(priority),
+ handle(handle),
+ callback(callback),
+ net_log(net_log) {}
+
+WebSocketTransportClientSocketPool::StalledRequest::~StalledRequest() {}
+
+} // namespace net
diff --git a/chromium/net/socket/websocket_transport_client_socket_pool.h b/chromium/net/socket/websocket_transport_client_socket_pool.h
new file mode 100644
index 00000000000..f0a94be417f
--- /dev/null
+++ b/chromium/net/socket/websocket_transport_client_socket_pool.h
@@ -0,0 +1,246 @@
+// Copyright 2014 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_WEBSOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_
+#define NET_SOCKET_WEBSOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_
+
+#include <list>
+#include <map>
+#include <set>
+#include <string>
+
+#include "base/basictypes.h"
+#include "base/memory/ref_counted.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/memory/weak_ptr.h"
+#include "base/time/time.h"
+#include "base/timer/timer.h"
+#include "net/base/net_export.h"
+#include "net/base/net_log.h"
+#include "net/socket/client_socket_pool.h"
+#include "net/socket/client_socket_pool_base.h"
+#include "net/socket/transport_client_socket_pool.h"
+
+namespace net {
+
+class ClientSocketFactory;
+class ClientSocketPoolHistograms;
+class HostResolver;
+class NetLog;
+class WebSocketEndpointLockManager;
+class WebSocketTransportConnectSubJob;
+
+// WebSocketTransportConnectJob handles the host resolution necessary for socket
+// creation and the TCP connect. WebSocketTransportConnectJob also has fallback
+// logic for IPv6 connect() timeouts (which may happen due to networks / routers
+// with broken IPv6 support). Those timeouts take 20s, so rather than make the
+// user wait 20s for the timeout to fire, we use a fallback timer
+// (kIPv6FallbackTimerInMs) and start a connect() to an IPv4 address if the
+// timer fires. Then we race the IPv4 connect(s) against the IPv6 connect(s) and
+// use the socket that completes successfully first or fails last.
+class NET_EXPORT_PRIVATE WebSocketTransportConnectJob : public ConnectJob {
+ public:
+ WebSocketTransportConnectJob(
+ const std::string& group_name,
+ RequestPriority priority,
+ const scoped_refptr<TransportSocketParams>& params,
+ base::TimeDelta timeout_duration,
+ const CompletionCallback& callback,
+ ClientSocketFactory* client_socket_factory,
+ HostResolver* host_resolver,
+ ClientSocketHandle* handle,
+ Delegate* delegate,
+ NetLog* pool_net_log,
+ const BoundNetLog& request_net_log);
+ ~WebSocketTransportConnectJob() override;
+
+ // Unlike normal socket pools, the WebSocketTransportClientPool uses
+ // early-binding of sockets.
+ ClientSocketHandle* handle() const { return handle_; }
+
+ // Stash the callback from RequestSocket() here for convenience.
+ const CompletionCallback& callback() const { return callback_; }
+
+ const BoundNetLog& request_net_log() const { return request_net_log_; }
+
+ // ConnectJob methods.
+ LoadState GetLoadState() const override;
+
+ private:
+ friend class WebSocketTransportConnectSubJob;
+ friend class TransportConnectJobHelper;
+ friend class WebSocketEndpointLockManager;
+
+ // Although it is not strictly necessary, it makes the code simpler if each
+ // subjob knows what type it is.
+ enum SubJobType { SUB_JOB_IPV4, SUB_JOB_IPV6 };
+
+ int DoResolveHost();
+ int DoResolveHostComplete(int result);
+ int DoTransportConnect();
+ int DoTransportConnectComplete(int result);
+
+ // Called back from a SubJob when it completes.
+ void OnSubJobComplete(int result, WebSocketTransportConnectSubJob* job);
+
+ // Called from |fallback_timer_|.
+ void StartIPv4JobAsync();
+
+ // Begins the host resolution and the TCP connect. Returns OK on success
+ // and ERR_IO_PENDING if it cannot immediately service the request.
+ // Otherwise, it returns a net error code.
+ int ConnectInternal() override;
+
+ TransportConnectJobHelper helper_;
+
+ // The addresses are divided into IPv4 and IPv6, which are performed partially
+ // in parallel. If the list of IPv6 addresses is non-empty, then the IPv6 jobs
+ // go first, followed after |kIPv6FallbackTimerInMs| by the IPv4
+ // addresses. First sub-job to establish a connection wins.
+ scoped_ptr<WebSocketTransportConnectSubJob> ipv4_job_;
+ scoped_ptr<WebSocketTransportConnectSubJob> ipv6_job_;
+
+ base::OneShotTimer<WebSocketTransportConnectJob> fallback_timer_;
+ TransportConnectJobHelper::ConnectionLatencyHistogram race_result_;
+ ClientSocketHandle* const handle_;
+ CompletionCallback callback_;
+ BoundNetLog request_net_log_;
+
+ bool had_ipv4_;
+ bool had_ipv6_;
+
+ DISALLOW_COPY_AND_ASSIGN(WebSocketTransportConnectJob);
+};
+
+class NET_EXPORT_PRIVATE WebSocketTransportClientSocketPool
+ : public TransportClientSocketPool {
+ public:
+ WebSocketTransportClientSocketPool(int max_sockets,
+ int max_sockets_per_group,
+ ClientSocketPoolHistograms* histograms,
+ HostResolver* host_resolver,
+ ClientSocketFactory* client_socket_factory,
+ NetLog* net_log);
+
+ ~WebSocketTransportClientSocketPool() override;
+
+ // Allow another connection to be started to the IPEndPoint that this |handle|
+ // is connected to. Used when the WebSocket handshake completes successfully.
+ // This only works if the socket is connected, however the caller does not
+ // need to explicitly check for this. Instead, ensure that dead sockets are
+ // returned to ReleaseSocket() in a timely fashion.
+ static void UnlockEndpoint(ClientSocketHandle* handle);
+
+ // ClientSocketPool implementation.
+ int RequestSocket(const std::string& group_name,
+ const void* resolve_info,
+ RequestPriority priority,
+ ClientSocketHandle* handle,
+ const CompletionCallback& callback,
+ const BoundNetLog& net_log) override;
+ void RequestSockets(const std::string& group_name,
+ const void* params,
+ int num_sockets,
+ const BoundNetLog& net_log) override;
+ void CancelRequest(const std::string& group_name,
+ ClientSocketHandle* handle) override;
+ void ReleaseSocket(const std::string& group_name,
+ scoped_ptr<StreamSocket> socket,
+ int id) override;
+ void FlushWithError(int error) override;
+ void CloseIdleSockets() override;
+ int IdleSocketCount() const override;
+ int IdleSocketCountInGroup(const std::string& group_name) const override;
+ LoadState GetLoadState(const std::string& group_name,
+ const ClientSocketHandle* handle) const override;
+ base::DictionaryValue* GetInfoAsValue(
+ const std::string& name,
+ const std::string& type,
+ bool include_nested_pools) const override;
+ base::TimeDelta ConnectionTimeout() const override;
+ ClientSocketPoolHistograms* histograms() const override;
+
+ // HigherLayeredPool implementation.
+ bool IsStalled() const override;
+
+ private:
+ class ConnectJobDelegate : public ConnectJob::Delegate {
+ public:
+ explicit ConnectJobDelegate(WebSocketTransportClientSocketPool* owner);
+ ~ConnectJobDelegate() override;
+
+ void OnConnectJobComplete(int result, ConnectJob* job) override;
+
+ private:
+ WebSocketTransportClientSocketPool* owner_;
+
+ DISALLOW_COPY_AND_ASSIGN(ConnectJobDelegate);
+ };
+
+ // Store the arguments from a call to RequestSocket() that has stalled so we
+ // can replay it when there are available socket slots.
+ struct StalledRequest {
+ StalledRequest(const scoped_refptr<TransportSocketParams>& params,
+ RequestPriority priority,
+ ClientSocketHandle* handle,
+ const CompletionCallback& callback,
+ const BoundNetLog& net_log);
+ ~StalledRequest();
+ const scoped_refptr<TransportSocketParams> params;
+ const RequestPriority priority;
+ ClientSocketHandle* const handle;
+ const CompletionCallback callback;
+ const BoundNetLog net_log;
+ };
+ friend class ConnectJobDelegate;
+ typedef std::map<const ClientSocketHandle*, WebSocketTransportConnectJob*>
+ PendingConnectsMap;
+ // This is a list so that we can remove requests from the middle, and also
+ // so that iterators are not invalidated unless the corresponding request is
+ // removed.
+ typedef std::list<StalledRequest> StalledRequestQueue;
+ typedef std::map<const ClientSocketHandle*, StalledRequestQueue::iterator>
+ StalledRequestMap;
+
+ void OnConnectJobComplete(int result, WebSocketTransportConnectJob* job);
+ void InvokeUserCallbackLater(ClientSocketHandle* handle,
+ const CompletionCallback& callback,
+ int rv);
+ void InvokeUserCallback(ClientSocketHandle* handle,
+ const CompletionCallback& callback,
+ int rv);
+ bool ReachedMaxSocketsLimit() const;
+ void HandOutSocket(scoped_ptr<StreamSocket> socket,
+ const LoadTimingInfo::ConnectTiming& connect_timing,
+ ClientSocketHandle* handle,
+ const BoundNetLog& net_log);
+ void AddJob(ClientSocketHandle* handle,
+ scoped_ptr<WebSocketTransportConnectJob> connect_job);
+ bool DeleteJob(ClientSocketHandle* handle);
+ const WebSocketTransportConnectJob* LookupConnectJob(
+ const ClientSocketHandle* handle) const;
+ void ActivateStalledRequest();
+ bool DeleteStalledRequest(ClientSocketHandle* handle);
+
+ ConnectJobDelegate connect_job_delegate_;
+ std::set<const ClientSocketHandle*> pending_callbacks_;
+ PendingConnectsMap pending_connects_;
+ StalledRequestQueue stalled_request_queue_;
+ StalledRequestMap stalled_request_map_;
+ ClientSocketPoolHistograms* const histograms_;
+ NetLog* const pool_net_log_;
+ ClientSocketFactory* const client_socket_factory_;
+ HostResolver* const host_resolver_;
+ const int max_sockets_;
+ int handed_out_socket_count_;
+ bool flushing_;
+
+ base::WeakPtrFactory<WebSocketTransportClientSocketPool> weak_factory_;
+
+ DISALLOW_COPY_AND_ASSIGN(WebSocketTransportClientSocketPool);
+};
+
+} // namespace net
+
+#endif // NET_SOCKET_WEBSOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_
diff --git a/chromium/net/socket/websocket_transport_client_socket_pool_unittest.cc b/chromium/net/socket/websocket_transport_client_socket_pool_unittest.cc
new file mode 100644
index 00000000000..2189181b9fc
--- /dev/null
+++ b/chromium/net/socket/websocket_transport_client_socket_pool_unittest.cc
@@ -0,0 +1,1143 @@
+// Copyright 2014 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/websocket_transport_client_socket_pool.h"
+
+#include <queue>
+#include <vector>
+
+#include "base/bind.h"
+#include "base/bind_helpers.h"
+#include "base/callback.h"
+#include "base/macros.h"
+#include "base/message_loop/message_loop.h"
+#include "base/run_loop.h"
+#include "base/strings/stringprintf.h"
+#include "base/time/time.h"
+#include "net/base/capturing_net_log.h"
+#include "net/base/ip_endpoint.h"
+#include "net/base/load_timing_info.h"
+#include "net/base/load_timing_info_test_util.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_util.h"
+#include "net/base/test_completion_callback.h"
+#include "net/dns/mock_host_resolver.h"
+#include "net/socket/client_socket_handle.h"
+#include "net/socket/client_socket_pool_histograms.h"
+#include "net/socket/socket_test_util.h"
+#include "net/socket/stream_socket.h"
+#include "net/socket/transport_client_socket_pool_test_util.h"
+#include "net/socket/websocket_endpoint_lock_manager.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace net {
+
+namespace {
+
+const int kMaxSockets = 32;
+const int kMaxSocketsPerGroup = 6;
+const RequestPriority kDefaultPriority = LOW;
+
+// RunLoop doesn't support this natively but it is easy to emulate.
+void RunLoopForTimePeriod(base::TimeDelta period) {
+ base::RunLoop run_loop;
+ base::Closure quit_closure(run_loop.QuitClosure());
+ base::MessageLoop::current()->PostDelayedTask(
+ FROM_HERE, quit_closure, period);
+ run_loop.Run();
+}
+
+class WebSocketTransportClientSocketPoolTest : public testing::Test {
+ protected:
+ WebSocketTransportClientSocketPoolTest()
+ : params_(new TransportSocketParams(
+ HostPortPair("www.google.com", 80),
+ false,
+ false,
+ OnHostResolutionCallback(),
+ TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)),
+ histograms_(new ClientSocketPoolHistograms("TCPUnitTest")),
+ host_resolver_(new MockHostResolver),
+ client_socket_factory_(&net_log_),
+ pool_(kMaxSockets,
+ kMaxSocketsPerGroup,
+ histograms_.get(),
+ host_resolver_.get(),
+ &client_socket_factory_,
+ NULL) {}
+
+ ~WebSocketTransportClientSocketPoolTest() override {
+ ReleaseAllConnections(ClientSocketPoolTest::NO_KEEP_ALIVE);
+ EXPECT_TRUE(WebSocketEndpointLockManager::GetInstance()->IsEmpty());
+ }
+
+ int StartRequest(const std::string& group_name, RequestPriority priority) {
+ scoped_refptr<TransportSocketParams> params(
+ new TransportSocketParams(
+ HostPortPair("www.google.com", 80),
+ false,
+ false,
+ OnHostResolutionCallback(),
+ TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT));
+ return test_base_.StartRequestUsingPool(
+ &pool_, group_name, priority, params);
+ }
+
+ int GetOrderOfRequest(size_t index) {
+ return test_base_.GetOrderOfRequest(index);
+ }
+
+ bool ReleaseOneConnection(ClientSocketPoolTest::KeepAlive keep_alive) {
+ return test_base_.ReleaseOneConnection(keep_alive);
+ }
+
+ void ReleaseAllConnections(ClientSocketPoolTest::KeepAlive keep_alive) {
+ test_base_.ReleaseAllConnections(keep_alive);
+ }
+
+ TestSocketRequest* request(int i) { return test_base_.request(i); }
+
+ ScopedVector<TestSocketRequest>* requests() { return test_base_.requests(); }
+ size_t completion_count() const { return test_base_.completion_count(); }
+
+ CapturingNetLog net_log_;
+ scoped_refptr<TransportSocketParams> params_;
+ scoped_ptr<ClientSocketPoolHistograms> histograms_;
+ scoped_ptr<MockHostResolver> host_resolver_;
+ MockTransportClientSocketFactory client_socket_factory_;
+ WebSocketTransportClientSocketPool pool_;
+ ClientSocketPoolTest test_base_;
+
+ private:
+ DISALLOW_COPY_AND_ASSIGN(WebSocketTransportClientSocketPoolTest);
+};
+
+TEST_F(WebSocketTransportClientSocketPoolTest, Basic) {
+ TestCompletionCallback callback;
+ ClientSocketHandle handle;
+ int rv = handle.Init(
+ "a", params_, LOW, callback.callback(), &pool_, BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+
+ EXPECT_EQ(OK, callback.WaitForResult());
+ EXPECT_TRUE(handle.is_initialized());
+ EXPECT_TRUE(handle.socket());
+ TestLoadTimingInfoConnectedNotReused(handle);
+}
+
+// Make sure that WebSocketTransportConnectJob passes on its priority to its
+// HostResolver request on Init.
+TEST_F(WebSocketTransportClientSocketPoolTest, SetResolvePriorityOnInit) {
+ for (int i = MINIMUM_PRIORITY; i <= MAXIMUM_PRIORITY; ++i) {
+ RequestPriority priority = static_cast<RequestPriority>(i);
+ TestCompletionCallback callback;
+ ClientSocketHandle handle;
+ EXPECT_EQ(ERR_IO_PENDING,
+ handle.Init("a",
+ params_,
+ priority,
+ callback.callback(),
+ &pool_,
+ BoundNetLog()));
+ EXPECT_EQ(priority, host_resolver_->last_request_priority());
+ }
+}
+
+TEST_F(WebSocketTransportClientSocketPoolTest, InitHostResolutionFailure) {
+ host_resolver_->rules()->AddSimulatedFailure("unresolvable.host.name");
+ TestCompletionCallback callback;
+ ClientSocketHandle handle;
+ HostPortPair host_port_pair("unresolvable.host.name", 80);
+ scoped_refptr<TransportSocketParams> dest(new TransportSocketParams(
+ host_port_pair, false, false, OnHostResolutionCallback(),
+ TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT));
+ EXPECT_EQ(ERR_IO_PENDING,
+ handle.Init("a",
+ dest,
+ kDefaultPriority,
+ callback.callback(),
+ &pool_,
+ BoundNetLog()));
+ EXPECT_EQ(ERR_NAME_NOT_RESOLVED, callback.WaitForResult());
+}
+
+TEST_F(WebSocketTransportClientSocketPoolTest, InitConnectionFailure) {
+ client_socket_factory_.set_default_client_socket_type(
+ MockTransportClientSocketFactory::MOCK_FAILING_CLIENT_SOCKET);
+ TestCompletionCallback callback;
+ ClientSocketHandle handle;
+ EXPECT_EQ(ERR_IO_PENDING,
+ handle.Init("a",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ &pool_,
+ BoundNetLog()));
+ EXPECT_EQ(ERR_CONNECTION_FAILED, callback.WaitForResult());
+
+ // Make the host resolutions complete synchronously this time.
+ host_resolver_->set_synchronous_mode(true);
+ EXPECT_EQ(ERR_CONNECTION_FAILED,
+ handle.Init("a",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ &pool_,
+ BoundNetLog()));
+}
+
+TEST_F(WebSocketTransportClientSocketPoolTest, PendingRequestsFinishFifo) {
+ // First request finishes asynchronously.
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(OK, request(0)->WaitForResult());
+
+ // Make all subsequent host resolutions complete synchronously.
+ host_resolver_->set_synchronous_mode(true);
+
+ // Rest of them wait for the first socket to be released.
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+
+ ReleaseAllConnections(ClientSocketPoolTest::KEEP_ALIVE);
+
+ EXPECT_EQ(6, client_socket_factory_.allocation_count());
+
+ // One initial asynchronous request and then 5 pending requests.
+ EXPECT_EQ(6U, completion_count());
+
+ // The requests finish in FIFO order.
+ EXPECT_EQ(1, GetOrderOfRequest(1));
+ EXPECT_EQ(2, GetOrderOfRequest(2));
+ EXPECT_EQ(3, GetOrderOfRequest(3));
+ EXPECT_EQ(4, GetOrderOfRequest(4));
+ EXPECT_EQ(5, GetOrderOfRequest(5));
+ EXPECT_EQ(6, GetOrderOfRequest(6));
+
+ // Make sure we test order of all requests made.
+ EXPECT_EQ(ClientSocketPoolTest::kIndexOutOfBounds, GetOrderOfRequest(7));
+}
+
+TEST_F(WebSocketTransportClientSocketPoolTest, PendingRequests_NoKeepAlive) {
+ // First request finishes asynchronously.
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(OK, request(0)->WaitForResult());
+
+ // Make all subsequent host resolutions complete synchronously.
+ host_resolver_->set_synchronous_mode(true);
+
+ // Rest of them wait for the first socket to be released.
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+
+ ReleaseAllConnections(ClientSocketPoolTest::NO_KEEP_ALIVE);
+
+ // The pending requests should finish successfully.
+ EXPECT_EQ(OK, request(1)->WaitForResult());
+ EXPECT_EQ(OK, request(2)->WaitForResult());
+ EXPECT_EQ(OK, request(3)->WaitForResult());
+ EXPECT_EQ(OK, request(4)->WaitForResult());
+ EXPECT_EQ(OK, request(5)->WaitForResult());
+
+ EXPECT_EQ(static_cast<int>(requests()->size()),
+ client_socket_factory_.allocation_count());
+
+ // First asynchronous request, and then last 5 pending requests.
+ EXPECT_EQ(6U, completion_count());
+}
+
+// This test will start up a RequestSocket() and then immediately Cancel() it.
+// The pending host resolution will eventually complete, and destroy the
+// ClientSocketPool which will crash if the group was not cleared properly.
+TEST_F(WebSocketTransportClientSocketPoolTest, CancelRequestClearGroup) {
+ TestCompletionCallback callback;
+ ClientSocketHandle handle;
+ EXPECT_EQ(ERR_IO_PENDING,
+ handle.Init("a",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ &pool_,
+ BoundNetLog()));
+ handle.Reset();
+}
+
+TEST_F(WebSocketTransportClientSocketPoolTest, TwoRequestsCancelOne) {
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ ClientSocketHandle handle2;
+ TestCompletionCallback callback2;
+
+ EXPECT_EQ(ERR_IO_PENDING,
+ handle.Init("a",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ &pool_,
+ BoundNetLog()));
+ EXPECT_EQ(ERR_IO_PENDING,
+ handle2.Init("a",
+ params_,
+ kDefaultPriority,
+ callback2.callback(),
+ &pool_,
+ BoundNetLog()));
+
+ handle.Reset();
+
+ EXPECT_EQ(OK, callback2.WaitForResult());
+ handle2.Reset();
+}
+
+TEST_F(WebSocketTransportClientSocketPoolTest, ConnectCancelConnect) {
+ client_socket_factory_.set_default_client_socket_type(
+ MockTransportClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET);
+ ClientSocketHandle handle;
+ TestCompletionCallback callback;
+ EXPECT_EQ(ERR_IO_PENDING,
+ handle.Init("a",
+ params_,
+ kDefaultPriority,
+ callback.callback(),
+ &pool_,
+ BoundNetLog()));
+
+ handle.Reset();
+
+ TestCompletionCallback callback2;
+ EXPECT_EQ(ERR_IO_PENDING,
+ handle.Init("a",
+ params_,
+ kDefaultPriority,
+ callback2.callback(),
+ &pool_,
+ BoundNetLog()));
+
+ host_resolver_->set_synchronous_mode(true);
+ // At this point, handle has two ConnectingSockets out for it. Due to the
+ // setting the mock resolver into synchronous mode, the host resolution for
+ // both will return in the same loop of the MessageLoop. The client socket
+ // is a pending socket, so the Connect() will asynchronously complete on the
+ // next loop of the MessageLoop. That means that the first
+ // ConnectingSocket will enter OnIOComplete, and then the second one will.
+ // If the first one is not cancelled, it will advance the load state, and
+ // then the second one will crash.
+
+ EXPECT_EQ(OK, callback2.WaitForResult());
+ EXPECT_FALSE(callback.have_result());
+
+ handle.Reset();
+}
+
+TEST_F(WebSocketTransportClientSocketPoolTest, CancelRequest) {
+ // First request finishes asynchronously.
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(OK, request(0)->WaitForResult());
+
+ // Make all subsequent host resolutions complete synchronously.
+ host_resolver_->set_synchronous_mode(true);
+
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+
+ // Cancel a request.
+ const size_t index_to_cancel = 2;
+ EXPECT_FALSE(request(index_to_cancel)->handle()->is_initialized());
+ request(index_to_cancel)->handle()->Reset();
+
+ ReleaseAllConnections(ClientSocketPoolTest::KEEP_ALIVE);
+
+ EXPECT_EQ(5, client_socket_factory_.allocation_count());
+
+ EXPECT_EQ(1, GetOrderOfRequest(1));
+ EXPECT_EQ(2, GetOrderOfRequest(2));
+ EXPECT_EQ(ClientSocketPoolTest::kRequestNotFound,
+ GetOrderOfRequest(3)); // Canceled request.
+ EXPECT_EQ(3, GetOrderOfRequest(4));
+ EXPECT_EQ(4, GetOrderOfRequest(5));
+ EXPECT_EQ(5, GetOrderOfRequest(6));
+
+ // Make sure we test order of all requests made.
+ EXPECT_EQ(ClientSocketPoolTest::kIndexOutOfBounds, GetOrderOfRequest(7));
+}
+
+class RequestSocketCallback : public TestCompletionCallbackBase {
+ public:
+ RequestSocketCallback(ClientSocketHandle* handle,
+ WebSocketTransportClientSocketPool* pool)
+ : handle_(handle),
+ pool_(pool),
+ within_callback_(false),
+ callback_(base::Bind(&RequestSocketCallback::OnComplete,
+ base::Unretained(this))) {}
+
+ ~RequestSocketCallback() override {}
+
+ const CompletionCallback& callback() const { return callback_; }
+
+ private:
+ void OnComplete(int result) {
+ SetResult(result);
+ ASSERT_EQ(OK, result);
+
+ if (!within_callback_) {
+ // Don't allow reuse of the socket. Disconnect it and then release it and
+ // run through the MessageLoop once to get it completely released.
+ handle_->socket()->Disconnect();
+ handle_->Reset();
+ {
+ base::MessageLoop::ScopedNestableTaskAllower allow(
+ base::MessageLoop::current());
+ base::MessageLoop::current()->RunUntilIdle();
+ }
+ within_callback_ = true;
+ scoped_refptr<TransportSocketParams> dest(
+ new TransportSocketParams(
+ HostPortPair("www.google.com", 80),
+ false,
+ false,
+ OnHostResolutionCallback(),
+ TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT));
+ int rv =
+ handle_->Init("a", dest, LOWEST, callback(), pool_, BoundNetLog());
+ EXPECT_EQ(OK, rv);
+ }
+ }
+
+ ClientSocketHandle* const handle_;
+ WebSocketTransportClientSocketPool* const pool_;
+ bool within_callback_;
+ CompletionCallback callback_;
+
+ DISALLOW_COPY_AND_ASSIGN(RequestSocketCallback);
+};
+
+TEST_F(WebSocketTransportClientSocketPoolTest, RequestTwice) {
+ ClientSocketHandle handle;
+ RequestSocketCallback callback(&handle, &pool_);
+ scoped_refptr<TransportSocketParams> dest(
+ new TransportSocketParams(
+ HostPortPair("www.google.com", 80),
+ false,
+ false,
+ OnHostResolutionCallback(),
+ TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT));
+ int rv = handle.Init(
+ "a", dest, LOWEST, callback.callback(), &pool_, BoundNetLog());
+ ASSERT_EQ(ERR_IO_PENDING, rv);
+
+ // The callback is going to request "www.google.com". We want it to complete
+ // synchronously this time.
+ host_resolver_->set_synchronous_mode(true);
+
+ EXPECT_EQ(OK, callback.WaitForResult());
+
+ handle.Reset();
+}
+
+// Make sure that pending requests get serviced after active requests get
+// cancelled.
+TEST_F(WebSocketTransportClientSocketPoolTest,
+ CancelActiveRequestWithPendingRequests) {
+ client_socket_factory_.set_default_client_socket_type(
+ MockTransportClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET);
+
+ // Queue up all the requests
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+
+ // Now, kMaxSocketsPerGroup requests should be active. Let's cancel them.
+ ASSERT_LE(kMaxSocketsPerGroup, static_cast<int>(requests()->size()));
+ for (int i = 0; i < kMaxSocketsPerGroup; i++)
+ request(i)->handle()->Reset();
+
+ // Let's wait for the rest to complete now.
+ for (size_t i = kMaxSocketsPerGroup; i < requests()->size(); ++i) {
+ EXPECT_EQ(OK, request(i)->WaitForResult());
+ request(i)->handle()->Reset();
+ }
+
+ EXPECT_EQ(requests()->size() - kMaxSocketsPerGroup, completion_count());
+}
+
+// Make sure that pending requests get serviced after active requests fail.
+TEST_F(WebSocketTransportClientSocketPoolTest,
+ FailingActiveRequestWithPendingRequests) {
+ client_socket_factory_.set_default_client_socket_type(
+ MockTransportClientSocketFactory::MOCK_PENDING_FAILING_CLIENT_SOCKET);
+
+ const int kNumRequests = 2 * kMaxSocketsPerGroup + 1;
+ ASSERT_LE(kNumRequests, kMaxSockets); // Otherwise the test will hang.
+
+ // Queue up all the requests
+ for (int i = 0; i < kNumRequests; i++)
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+
+ for (int i = 0; i < kNumRequests; i++)
+ EXPECT_EQ(ERR_CONNECTION_FAILED, request(i)->WaitForResult());
+}
+
+// The lock on the endpoint is released when a ClientSocketHandle is reset.
+TEST_F(WebSocketTransportClientSocketPoolTest, LockReleasedOnHandleReset) {
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(OK, request(0)->WaitForResult());
+ EXPECT_FALSE(request(1)->handle()->is_initialized());
+ request(0)->handle()->Reset();
+ base::RunLoop().RunUntilIdle();
+ EXPECT_TRUE(request(1)->handle()->is_initialized());
+}
+
+// The lock on the endpoint is released when a ClientSocketHandle is deleted.
+TEST_F(WebSocketTransportClientSocketPoolTest, LockReleasedOnHandleDelete) {
+ TestCompletionCallback callback;
+ scoped_ptr<ClientSocketHandle> handle(new ClientSocketHandle);
+ int rv = handle->Init(
+ "a", params_, LOW, callback.callback(), &pool_, BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(OK, callback.WaitForResult());
+ EXPECT_FALSE(request(0)->handle()->is_initialized());
+ handle.reset();
+ base::RunLoop().RunUntilIdle();
+ EXPECT_TRUE(request(0)->handle()->is_initialized());
+}
+
+// A new connection is performed when the lock on the previous connection is
+// explicitly released.
+TEST_F(WebSocketTransportClientSocketPoolTest,
+ ConnectionProceedsOnExplicitRelease) {
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(OK, request(0)->WaitForResult());
+ EXPECT_FALSE(request(1)->handle()->is_initialized());
+ WebSocketTransportClientSocketPool::UnlockEndpoint(request(0)->handle());
+ base::RunLoop().RunUntilIdle();
+ EXPECT_TRUE(request(1)->handle()->is_initialized());
+}
+
+// A connection which is cancelled before completion does not block subsequent
+// connections.
+TEST_F(WebSocketTransportClientSocketPoolTest,
+ CancelDuringConnectionReleasesLock) {
+ MockTransportClientSocketFactory::ClientSocketType case_types[] = {
+ MockTransportClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET,
+ MockTransportClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET};
+
+ client_socket_factory_.set_client_socket_types(case_types,
+ arraysize(case_types));
+
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ base::RunLoop().RunUntilIdle();
+ pool_.CancelRequest("a", request(0)->handle());
+ EXPECT_EQ(OK, request(1)->WaitForResult());
+}
+
+// Test the case of the IPv6 address stalling, and falling back to the IPv4
+// socket which finishes first.
+TEST_F(WebSocketTransportClientSocketPoolTest,
+ IPv6FallbackSocketIPv4FinishesFirst) {
+ WebSocketTransportClientSocketPool pool(kMaxSockets,
+ kMaxSocketsPerGroup,
+ histograms_.get(),
+ host_resolver_.get(),
+ &client_socket_factory_,
+ NULL);
+
+ MockTransportClientSocketFactory::ClientSocketType case_types[] = {
+ // This is the IPv6 socket.
+ MockTransportClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET,
+ // This is the IPv4 socket.
+ MockTransportClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET};
+
+ client_socket_factory_.set_client_socket_types(case_types, 2);
+
+ // Resolve an AddressList with an IPv6 address first and then an IPv4 address.
+ host_resolver_->rules()->AddIPLiteralRule(
+ "*", "2:abcd::3:4:ff,2.2.2.2", std::string());
+
+ TestCompletionCallback callback;
+ ClientSocketHandle handle;
+ int rv =
+ handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+
+ EXPECT_EQ(OK, callback.WaitForResult());
+ EXPECT_TRUE(handle.is_initialized());
+ EXPECT_TRUE(handle.socket());
+ IPEndPoint endpoint;
+ handle.socket()->GetLocalAddress(&endpoint);
+ EXPECT_EQ(kIPv4AddressSize, endpoint.address().size());
+ EXPECT_EQ(2, client_socket_factory_.allocation_count());
+}
+
+// Test the case of the IPv6 address being slow, thus falling back to trying to
+// connect to the IPv4 address, but having the connect to the IPv6 address
+// finish first.
+TEST_F(WebSocketTransportClientSocketPoolTest,
+ IPv6FallbackSocketIPv6FinishesFirst) {
+ WebSocketTransportClientSocketPool pool(kMaxSockets,
+ kMaxSocketsPerGroup,
+ histograms_.get(),
+ host_resolver_.get(),
+ &client_socket_factory_,
+ NULL);
+
+ MockTransportClientSocketFactory::ClientSocketType case_types[] = {
+ // This is the IPv6 socket.
+ MockTransportClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET,
+ // This is the IPv4 socket.
+ MockTransportClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET};
+
+ client_socket_factory_.set_client_socket_types(case_types, 2);
+ client_socket_factory_.set_delay(base::TimeDelta::FromMilliseconds(
+ TransportConnectJobHelper::kIPv6FallbackTimerInMs + 50));
+
+ // Resolve an AddressList with an IPv6 address first and then an IPv4 address.
+ host_resolver_->rules()->AddIPLiteralRule(
+ "*", "2:abcd::3:4:ff,2.2.2.2", std::string());
+
+ TestCompletionCallback callback;
+ ClientSocketHandle handle;
+ int rv =
+ handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+
+ EXPECT_EQ(OK, callback.WaitForResult());
+ EXPECT_TRUE(handle.is_initialized());
+ EXPECT_TRUE(handle.socket());
+ IPEndPoint endpoint;
+ handle.socket()->GetLocalAddress(&endpoint);
+ EXPECT_EQ(kIPv6AddressSize, endpoint.address().size());
+ EXPECT_EQ(2, client_socket_factory_.allocation_count());
+}
+
+TEST_F(WebSocketTransportClientSocketPoolTest,
+ IPv6NoIPv4AddressesToFallbackTo) {
+ WebSocketTransportClientSocketPool pool(kMaxSockets,
+ kMaxSocketsPerGroup,
+ histograms_.get(),
+ host_resolver_.get(),
+ &client_socket_factory_,
+ NULL);
+
+ client_socket_factory_.set_default_client_socket_type(
+ MockTransportClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET);
+
+ // Resolve an AddressList with only IPv6 addresses.
+ host_resolver_->rules()->AddIPLiteralRule(
+ "*", "2:abcd::3:4:ff,3:abcd::3:4:ff", std::string());
+
+ TestCompletionCallback callback;
+ ClientSocketHandle handle;
+ int rv =
+ handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+
+ EXPECT_EQ(OK, callback.WaitForResult());
+ EXPECT_TRUE(handle.is_initialized());
+ EXPECT_TRUE(handle.socket());
+ IPEndPoint endpoint;
+ handle.socket()->GetLocalAddress(&endpoint);
+ EXPECT_EQ(kIPv6AddressSize, endpoint.address().size());
+ EXPECT_EQ(1, client_socket_factory_.allocation_count());
+}
+
+TEST_F(WebSocketTransportClientSocketPoolTest, IPv4HasNoFallback) {
+ WebSocketTransportClientSocketPool pool(kMaxSockets,
+ kMaxSocketsPerGroup,
+ histograms_.get(),
+ host_resolver_.get(),
+ &client_socket_factory_,
+ NULL);
+
+ client_socket_factory_.set_default_client_socket_type(
+ MockTransportClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET);
+
+ // Resolve an AddressList with only IPv4 addresses.
+ host_resolver_->rules()->AddIPLiteralRule("*", "1.1.1.1", std::string());
+
+ TestCompletionCallback callback;
+ ClientSocketHandle handle;
+ int rv =
+ handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_FALSE(handle.is_initialized());
+ EXPECT_FALSE(handle.socket());
+
+ EXPECT_EQ(OK, callback.WaitForResult());
+ EXPECT_TRUE(handle.is_initialized());
+ EXPECT_TRUE(handle.socket());
+ IPEndPoint endpoint;
+ handle.socket()->GetLocalAddress(&endpoint);
+ EXPECT_EQ(kIPv4AddressSize, endpoint.address().size());
+ EXPECT_EQ(1, client_socket_factory_.allocation_count());
+}
+
+// If all IPv6 addresses fail to connect synchronously, then IPv4 connections
+// proceeed immediately.
+TEST_F(WebSocketTransportClientSocketPoolTest, IPv6InstantFail) {
+ WebSocketTransportClientSocketPool pool(kMaxSockets,
+ kMaxSocketsPerGroup,
+ histograms_.get(),
+ host_resolver_.get(),
+ &client_socket_factory_,
+ NULL);
+
+ MockTransportClientSocketFactory::ClientSocketType case_types[] = {
+ // First IPv6 socket.
+ MockTransportClientSocketFactory::MOCK_FAILING_CLIENT_SOCKET,
+ // Second IPv6 socket.
+ MockTransportClientSocketFactory::MOCK_FAILING_CLIENT_SOCKET,
+ // This is the IPv4 socket.
+ MockTransportClientSocketFactory::MOCK_CLIENT_SOCKET};
+
+ client_socket_factory_.set_client_socket_types(case_types,
+ arraysize(case_types));
+
+ // Resolve an AddressList with two IPv6 addresses and then an IPv4 address.
+ host_resolver_->rules()->AddIPLiteralRule(
+ "*", "2:abcd::3:4:ff,2:abcd::3:5:ff,2.2.2.2", std::string());
+ host_resolver_->set_synchronous_mode(true);
+ TestCompletionCallback callback;
+ ClientSocketHandle handle;
+ int rv =
+ handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog());
+ EXPECT_EQ(OK, rv);
+ ASSERT_TRUE(handle.socket());
+
+ IPEndPoint endpoint;
+ handle.socket()->GetPeerAddress(&endpoint);
+ EXPECT_EQ("2.2.2.2", endpoint.ToStringWithoutPort());
+}
+
+// If all IPv6 addresses fail before the IPv4 fallback timeout, then the IPv4
+// connections proceed immediately.
+TEST_F(WebSocketTransportClientSocketPoolTest, IPv6RapidFail) {
+ WebSocketTransportClientSocketPool pool(kMaxSockets,
+ kMaxSocketsPerGroup,
+ histograms_.get(),
+ host_resolver_.get(),
+ &client_socket_factory_,
+ NULL);
+
+ MockTransportClientSocketFactory::ClientSocketType case_types[] = {
+ // First IPv6 socket.
+ MockTransportClientSocketFactory::MOCK_PENDING_FAILING_CLIENT_SOCKET,
+ // Second IPv6 socket.
+ MockTransportClientSocketFactory::MOCK_PENDING_FAILING_CLIENT_SOCKET,
+ // This is the IPv4 socket.
+ MockTransportClientSocketFactory::MOCK_CLIENT_SOCKET};
+
+ client_socket_factory_.set_client_socket_types(case_types,
+ arraysize(case_types));
+
+ // Resolve an AddressList with two IPv6 addresses and then an IPv4 address.
+ host_resolver_->rules()->AddIPLiteralRule(
+ "*", "2:abcd::3:4:ff,2:abcd::3:5:ff,2.2.2.2", std::string());
+
+ TestCompletionCallback callback;
+ ClientSocketHandle handle;
+ int rv =
+ handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ EXPECT_FALSE(handle.socket());
+
+ base::Time start(base::Time::NowFromSystemTime());
+ EXPECT_EQ(OK, callback.WaitForResult());
+ EXPECT_LT(base::Time::NowFromSystemTime() - start,
+ base::TimeDelta::FromMilliseconds(
+ TransportConnectJobHelper::kIPv6FallbackTimerInMs));
+ ASSERT_TRUE(handle.socket());
+
+ IPEndPoint endpoint;
+ handle.socket()->GetPeerAddress(&endpoint);
+ EXPECT_EQ("2.2.2.2", endpoint.ToStringWithoutPort());
+}
+
+// If two sockets connect successfully, the one which connected first wins (this
+// can only happen if the sockets are different types, since sockets of the same
+// type do not race).
+TEST_F(WebSocketTransportClientSocketPoolTest, FirstSuccessWins) {
+ WebSocketTransportClientSocketPool pool(kMaxSockets,
+ kMaxSocketsPerGroup,
+ histograms_.get(),
+ host_resolver_.get(),
+ &client_socket_factory_,
+ NULL);
+
+ client_socket_factory_.set_default_client_socket_type(
+ MockTransportClientSocketFactory::MOCK_TRIGGERABLE_CLIENT_SOCKET);
+
+ // Resolve an AddressList with an IPv6 addresses and an IPv4 address.
+ host_resolver_->rules()->AddIPLiteralRule(
+ "*", "2:abcd::3:4:ff,2.2.2.2", std::string());
+
+ TestCompletionCallback callback;
+ ClientSocketHandle handle;
+ int rv =
+ handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+ ASSERT_FALSE(handle.socket());
+
+ base::Closure ipv6_connect_trigger =
+ client_socket_factory_.WaitForTriggerableSocketCreation();
+ base::Closure ipv4_connect_trigger =
+ client_socket_factory_.WaitForTriggerableSocketCreation();
+
+ ipv4_connect_trigger.Run();
+ ipv6_connect_trigger.Run();
+
+ EXPECT_EQ(OK, callback.WaitForResult());
+ ASSERT_TRUE(handle.socket());
+
+ IPEndPoint endpoint;
+ handle.socket()->GetPeerAddress(&endpoint);
+ EXPECT_EQ("2.2.2.2", endpoint.ToStringWithoutPort());
+}
+
+// We should not report failure until all connections have failed.
+TEST_F(WebSocketTransportClientSocketPoolTest, LastFailureWins) {
+ WebSocketTransportClientSocketPool pool(kMaxSockets,
+ kMaxSocketsPerGroup,
+ histograms_.get(),
+ host_resolver_.get(),
+ &client_socket_factory_,
+ NULL);
+
+ client_socket_factory_.set_default_client_socket_type(
+ MockTransportClientSocketFactory::MOCK_DELAYED_FAILING_CLIENT_SOCKET);
+ base::TimeDelta delay = base::TimeDelta::FromMilliseconds(
+ TransportConnectJobHelper::kIPv6FallbackTimerInMs / 3);
+ client_socket_factory_.set_delay(delay);
+
+ // Resolve an AddressList with 4 IPv6 addresses and 2 IPv4 addresses.
+ host_resolver_->rules()->AddIPLiteralRule("*",
+ "1:abcd::3:4:ff,2:abcd::3:4:ff,"
+ "3:abcd::3:4:ff,4:abcd::3:4:ff,"
+ "1.1.1.1,2.2.2.2",
+ std::string());
+
+ // Expected order of events:
+ // After 100ms: Connect to 1:abcd::3:4:ff times out
+ // After 200ms: Connect to 2:abcd::3:4:ff times out
+ // After 300ms: Connect to 3:abcd::3:4:ff times out, IPv4 fallback starts
+ // After 400ms: Connect to 4:abcd::3:4:ff and 1.1.1.1 time out
+ // After 500ms: Connect to 2.2.2.2 times out
+
+ TestCompletionCallback callback;
+ ClientSocketHandle handle;
+ base::Time start(base::Time::NowFromSystemTime());
+ int rv =
+ handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+
+ EXPECT_EQ(ERR_CONNECTION_FAILED, callback.WaitForResult());
+
+ EXPECT_GE(base::Time::NowFromSystemTime() - start, delay * 5);
+}
+
+// Global timeout for all connects applies. This test is disabled by default
+// because it takes 4 minutes. Run with --gtest_also_run_disabled_tests if you
+// want to run it.
+TEST_F(WebSocketTransportClientSocketPoolTest, DISABLED_OverallTimeoutApplies) {
+ WebSocketTransportClientSocketPool pool(kMaxSockets,
+ kMaxSocketsPerGroup,
+ histograms_.get(),
+ host_resolver_.get(),
+ &client_socket_factory_,
+ NULL);
+ const base::TimeDelta connect_job_timeout = pool.ConnectionTimeout();
+
+ client_socket_factory_.set_default_client_socket_type(
+ MockTransportClientSocketFactory::MOCK_DELAYED_FAILING_CLIENT_SOCKET);
+ client_socket_factory_.set_delay(base::TimeDelta::FromSeconds(1) +
+ connect_job_timeout / 6);
+
+ // Resolve an AddressList with 6 IPv6 addresses and 6 IPv4 addresses.
+ host_resolver_->rules()->AddIPLiteralRule("*",
+ "1:abcd::3:4:ff,2:abcd::3:4:ff,"
+ "3:abcd::3:4:ff,4:abcd::3:4:ff,"
+ "5:abcd::3:4:ff,6:abcd::3:4:ff,"
+ "1.1.1.1,2.2.2.2,3.3.3.3,"
+ "4.4.4.4,5.5.5.5,6.6.6.6",
+ std::string());
+
+ TestCompletionCallback callback;
+ ClientSocketHandle handle;
+
+ int rv =
+ handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog());
+ EXPECT_EQ(ERR_IO_PENDING, rv);
+
+ EXPECT_EQ(ERR_TIMED_OUT, callback.WaitForResult());
+}
+
+TEST_F(WebSocketTransportClientSocketPoolTest, MaxSocketsEnforced) {
+ host_resolver_->set_synchronous_mode(true);
+ for (int i = 0; i < kMaxSockets; ++i) {
+ EXPECT_EQ(OK, StartRequest("a", kDefaultPriority));
+ WebSocketTransportClientSocketPool::UnlockEndpoint(request(i)->handle());
+ }
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+}
+
+TEST_F(WebSocketTransportClientSocketPoolTest, MaxSocketsEnforcedWhenPending) {
+ for (int i = 0; i < kMaxSockets + 1; ++i) {
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ }
+ // Now there are 32 sockets waiting to connect, and one stalled.
+ for (int i = 0; i < kMaxSockets; ++i) {
+ base::RunLoop().RunUntilIdle();
+ EXPECT_TRUE(request(i)->handle()->is_initialized());
+ EXPECT_TRUE(request(i)->handle()->socket());
+ WebSocketTransportClientSocketPool::UnlockEndpoint(request(i)->handle());
+ }
+ // Now there are 32 sockets connected, and one stalled.
+ base::RunLoop().RunUntilIdle();
+ EXPECT_FALSE(request(kMaxSockets)->handle()->is_initialized());
+ EXPECT_FALSE(request(kMaxSockets)->handle()->socket());
+}
+
+TEST_F(WebSocketTransportClientSocketPoolTest, StalledSocketReleased) {
+ host_resolver_->set_synchronous_mode(true);
+ for (int i = 0; i < kMaxSockets; ++i) {
+ EXPECT_EQ(OK, StartRequest("a", kDefaultPriority));
+ WebSocketTransportClientSocketPool::UnlockEndpoint(request(i)->handle());
+ }
+
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ ReleaseOneConnection(ClientSocketPoolTest::NO_KEEP_ALIVE);
+ EXPECT_TRUE(request(kMaxSockets)->handle()->is_initialized());
+ EXPECT_TRUE(request(kMaxSockets)->handle()->socket());
+}
+
+TEST_F(WebSocketTransportClientSocketPoolTest, IsStalledTrueWhenStalled) {
+ for (int i = 0; i < kMaxSockets + 1; ++i) {
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ }
+ EXPECT_EQ(OK, request(0)->WaitForResult());
+ EXPECT_TRUE(pool_.IsStalled());
+}
+
+TEST_F(WebSocketTransportClientSocketPoolTest,
+ CancellingPendingSocketUnstallsStalledSocket) {
+ for (int i = 0; i < kMaxSockets + 1; ++i) {
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ }
+ EXPECT_EQ(OK, request(0)->WaitForResult());
+ request(1)->handle()->Reset();
+ base::RunLoop().RunUntilIdle();
+ EXPECT_FALSE(pool_.IsStalled());
+}
+
+TEST_F(WebSocketTransportClientSocketPoolTest,
+ LoadStateOfStalledSocketIsWaitingForAvailableSocket) {
+ for (int i = 0; i < kMaxSockets + 1; ++i) {
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ }
+ EXPECT_EQ(LOAD_STATE_WAITING_FOR_AVAILABLE_SOCKET,
+ pool_.GetLoadState("a", request(kMaxSockets)->handle()));
+}
+
+TEST_F(WebSocketTransportClientSocketPoolTest,
+ CancellingStalledSocketUnstallsPool) {
+ for (int i = 0; i < kMaxSockets + 1; ++i) {
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ }
+ request(kMaxSockets)->handle()->Reset();
+ EXPECT_FALSE(pool_.IsStalled());
+}
+
+TEST_F(WebSocketTransportClientSocketPoolTest,
+ FlushWithErrorFlushesPendingConnections) {
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ pool_.FlushWithError(ERR_FAILED);
+ EXPECT_EQ(ERR_FAILED, request(0)->WaitForResult());
+}
+
+TEST_F(WebSocketTransportClientSocketPoolTest,
+ FlushWithErrorFlushesStalledConnections) {
+ for (int i = 0; i < kMaxSockets + 1; ++i) {
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ }
+ pool_.FlushWithError(ERR_FAILED);
+ EXPECT_EQ(ERR_FAILED, request(kMaxSockets)->WaitForResult());
+}
+
+TEST_F(WebSocketTransportClientSocketPoolTest,
+ AfterFlushWithErrorCanMakeNewConnections) {
+ for (int i = 0; i < kMaxSockets + 1; ++i) {
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ }
+ pool_.FlushWithError(ERR_FAILED);
+ host_resolver_->set_synchronous_mode(true);
+ EXPECT_EQ(OK, StartRequest("a", kDefaultPriority));
+}
+
+// Deleting pending connections can release the lock on the endpoint, which can
+// in principle lead to other pending connections succeeding. However, when we
+// call FlushWithError(), everything should fail.
+TEST_F(WebSocketTransportClientSocketPoolTest,
+ FlushWithErrorDoesNotCauseSuccessfulConnections) {
+ host_resolver_->set_synchronous_mode(true);
+ MockTransportClientSocketFactory::ClientSocketType first_type[] = {
+ // First socket
+ MockTransportClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET
+ };
+ client_socket_factory_.set_client_socket_types(first_type,
+ arraysize(first_type));
+ // The rest of the sockets will connect synchronously.
+ client_socket_factory_.set_default_client_socket_type(
+ MockTransportClientSocketFactory::MOCK_CLIENT_SOCKET);
+ for (int i = 0; i < kMaxSockets; ++i) {
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ }
+ // Now we have one socket in STATE_TRANSPORT_CONNECT and the rest in
+ // STATE_OBTAIN_LOCK. If any of the sockets in STATE_OBTAIN_LOCK is given the
+ // lock, they will synchronously connect.
+ pool_.FlushWithError(ERR_FAILED);
+ for (int i = 0; i < kMaxSockets; ++i) {
+ EXPECT_EQ(ERR_FAILED, request(i)->WaitForResult());
+ }
+}
+
+// This is a regression test for the first attempted fix for
+// FlushWithErrorDoesNotCauseSuccessfulConnections. Because a ConnectJob can
+// have both IPv4 and IPv6 subjobs, it can be both connecting and waiting for
+// the lock at the same time.
+TEST_F(WebSocketTransportClientSocketPoolTest,
+ FlushWithErrorDoesNotCauseSuccessfulConnectionsMultipleAddressTypes) {
+ host_resolver_->set_synchronous_mode(true);
+ // The first |kMaxSockets| sockets to connect will be IPv6. Then we will have
+ // one IPv4.
+ std::vector<MockTransportClientSocketFactory::ClientSocketType> socket_types(
+ kMaxSockets + 1,
+ MockTransportClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET);
+ client_socket_factory_.set_client_socket_types(&socket_types[0],
+ socket_types.size());
+ // The rest of the sockets will connect synchronously.
+ client_socket_factory_.set_default_client_socket_type(
+ MockTransportClientSocketFactory::MOCK_CLIENT_SOCKET);
+ for (int i = 0; i < kMaxSockets; ++i) {
+ host_resolver_->rules()->ClearRules();
+ // Each connect job has a different IPv6 address but the same IPv4 address.
+ // So the IPv6 connections happen in parallel but the IPv4 ones are
+ // serialised.
+ host_resolver_->rules()->AddIPLiteralRule("*",
+ base::StringPrintf(
+ "%x:abcd::3:4:ff,"
+ "1.1.1.1",
+ i + 1),
+ std::string());
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ }
+ // Now we have |kMaxSockets| IPv6 sockets stalled in connect. No IPv4 sockets
+ // are started yet.
+ RunLoopForTimePeriod(base::TimeDelta::FromMilliseconds(
+ TransportConnectJobHelper::kIPv6FallbackTimerInMs));
+ // Now we have |kMaxSockets| IPv6 sockets and one IPv4 socket stalled in
+ // connect, and |kMaxSockets - 1| IPv4 sockets waiting for the endpoint lock.
+ pool_.FlushWithError(ERR_FAILED);
+ for (int i = 0; i < kMaxSockets; ++i) {
+ EXPECT_EQ(ERR_FAILED, request(i)->WaitForResult());
+ }
+}
+
+// Sockets that have had ownership transferred to a ClientSocketHandle should
+// not be affected by FlushWithError.
+TEST_F(WebSocketTransportClientSocketPoolTest,
+ FlushWithErrorDoesNotAffectHandedOutSockets) {
+ host_resolver_->set_synchronous_mode(true);
+ MockTransportClientSocketFactory::ClientSocketType socket_types[] = {
+ MockTransportClientSocketFactory::MOCK_CLIENT_SOCKET,
+ MockTransportClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET};
+ client_socket_factory_.set_client_socket_types(socket_types,
+ arraysize(socket_types));
+ EXPECT_EQ(OK, StartRequest("a", kDefaultPriority));
+ // Socket has been "handed out".
+ EXPECT_TRUE(request(0)->handle()->socket());
+
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ // Now we have one socket handed out, and one pending.
+ pool_.FlushWithError(ERR_FAILED);
+ EXPECT_EQ(ERR_FAILED, request(1)->WaitForResult());
+ // Socket owned by ClientSocketHandle is unaffected:
+ EXPECT_TRUE(request(0)->handle()->socket());
+ // Return it to the pool (which deletes it).
+ request(0)->handle()->Reset();
+}
+
+// Sockets should not be leaked if CancelRequest() is called in between
+// SetSocket() being called on the ClientSocketHandle and InvokeUserCallback().
+TEST_F(WebSocketTransportClientSocketPoolTest, CancelRequestReclaimsSockets) {
+ host_resolver_->set_synchronous_mode(true);
+ MockTransportClientSocketFactory::ClientSocketType socket_types[] = {
+ MockTransportClientSocketFactory::MOCK_TRIGGERABLE_CLIENT_SOCKET,
+ MockTransportClientSocketFactory::MOCK_CLIENT_SOCKET};
+
+ client_socket_factory_.set_client_socket_types(socket_types,
+ arraysize(socket_types));
+
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+
+ base::Closure connect_trigger =
+ client_socket_factory_.WaitForTriggerableSocketCreation();
+
+ connect_trigger.Run(); // Calls InvokeUserCallbackLater()
+
+ request(0)->handle()->Reset(); // calls CancelRequest()
+
+ // We should now be able to create a new connection without blocking on the
+ // endpoint lock.
+ EXPECT_EQ(OK, StartRequest("a", kDefaultPriority));
+}
+
+// A handshake completing and then the WebSocket closing should only release one
+// Endpoint, not two.
+TEST_F(WebSocketTransportClientSocketPoolTest, EndpointLockIsOnlyReleasedOnce) {
+ host_resolver_->set_synchronous_mode(true);
+ EXPECT_EQ(OK, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority));
+ // First socket completes handshake.
+ WebSocketTransportClientSocketPool::UnlockEndpoint(request(0)->handle());
+ // First socket is closed.
+ request(0)->handle()->Reset();
+ // Second socket should have been released.
+ EXPECT_EQ(OK, request(1)->WaitForResult());
+ // Third socket should still be waiting for endpoint.
+ ASSERT_FALSE(request(2)->handle()->is_initialized());
+ EXPECT_EQ(LOAD_STATE_WAITING_FOR_AVAILABLE_SOCKET,
+ request(2)->handle()->GetLoadState());
+}
+
+} // namespace
+
+} // namespace net
diff --git a/chromium/net/socket/websocket_transport_connect_sub_job.cc b/chromium/net/socket/websocket_transport_connect_sub_job.cc
new file mode 100644
index 00000000000..fbe8bbcc82c
--- /dev/null
+++ b/chromium/net/socket/websocket_transport_connect_sub_job.cc
@@ -0,0 +1,170 @@
+// Copyright 2014 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/socket/websocket_transport_connect_sub_job.h"
+
+#include "base/logging.h"
+#include "net/base/ip_endpoint.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_log.h"
+#include "net/socket/client_socket_factory.h"
+#include "net/socket/websocket_endpoint_lock_manager.h"
+
+namespace net {
+
+WebSocketTransportConnectSubJob::WebSocketTransportConnectSubJob(
+ const AddressList& addresses,
+ WebSocketTransportConnectJob* parent_job,
+ SubJobType type)
+ : parent_job_(parent_job),
+ addresses_(addresses),
+ current_address_index_(0),
+ next_state_(STATE_NONE),
+ type_(type) {}
+
+WebSocketTransportConnectSubJob::~WebSocketTransportConnectSubJob() {
+ // We don't worry about cancelling the TCP connect, since ~StreamSocket will
+ // take care of it.
+ if (next()) {
+ DCHECK_EQ(STATE_OBTAIN_LOCK_COMPLETE, next_state_);
+ // The ~Waiter destructor will remove this object from the waiting list.
+ } else if (next_state_ == STATE_TRANSPORT_CONNECT_COMPLETE) {
+ WebSocketEndpointLockManager::GetInstance()->UnlockEndpoint(
+ CurrentAddress());
+ }
+}
+
+// Start connecting.
+int WebSocketTransportConnectSubJob::Start() {
+ DCHECK_EQ(STATE_NONE, next_state_);
+ next_state_ = STATE_OBTAIN_LOCK;
+ return DoLoop(OK);
+}
+
+// Called by WebSocketEndpointLockManager when the lock becomes available.
+void WebSocketTransportConnectSubJob::GotEndpointLock() {
+ DCHECK_EQ(STATE_OBTAIN_LOCK_COMPLETE, next_state_);
+ OnIOComplete(OK);
+}
+
+LoadState WebSocketTransportConnectSubJob::GetLoadState() const {
+ switch (next_state_) {
+ case STATE_OBTAIN_LOCK:
+ case STATE_OBTAIN_LOCK_COMPLETE:
+ // TODO(ricea): Add a WebSocket-specific LOAD_STATE ?
+ return LOAD_STATE_WAITING_FOR_AVAILABLE_SOCKET;
+ case STATE_TRANSPORT_CONNECT:
+ case STATE_TRANSPORT_CONNECT_COMPLETE:
+ case STATE_DONE:
+ return LOAD_STATE_CONNECTING;
+ case STATE_NONE:
+ return LOAD_STATE_IDLE;
+ }
+ NOTREACHED();
+ return LOAD_STATE_IDLE;
+}
+
+ClientSocketFactory* WebSocketTransportConnectSubJob::client_socket_factory()
+ const {
+ return parent_job_->helper_.client_socket_factory();
+}
+
+const BoundNetLog& WebSocketTransportConnectSubJob::net_log() const {
+ return parent_job_->net_log();
+}
+
+const IPEndPoint& WebSocketTransportConnectSubJob::CurrentAddress() const {
+ DCHECK_LT(current_address_index_, addresses_.size());
+ return addresses_[current_address_index_];
+}
+
+void WebSocketTransportConnectSubJob::OnIOComplete(int result) {
+ int rv = DoLoop(result);
+ if (rv != ERR_IO_PENDING)
+ parent_job_->OnSubJobComplete(rv, this); // |this| deleted
+}
+
+int WebSocketTransportConnectSubJob::DoLoop(int result) {
+ DCHECK_NE(next_state_, STATE_NONE);
+
+ int rv = result;
+ do {
+ State state = next_state_;
+ next_state_ = STATE_NONE;
+ switch (state) {
+ case STATE_OBTAIN_LOCK:
+ DCHECK_EQ(OK, rv);
+ rv = DoEndpointLock();
+ break;
+ case STATE_OBTAIN_LOCK_COMPLETE:
+ DCHECK_EQ(OK, rv);
+ rv = DoEndpointLockComplete();
+ break;
+ case STATE_TRANSPORT_CONNECT:
+ DCHECK_EQ(OK, rv);
+ rv = DoTransportConnect();
+ break;
+ case STATE_TRANSPORT_CONNECT_COMPLETE:
+ rv = DoTransportConnectComplete(rv);
+ break;
+ default:
+ NOTREACHED();
+ rv = ERR_FAILED;
+ break;
+ }
+ } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE &&
+ next_state_ != STATE_DONE);
+
+ return rv;
+}
+
+int WebSocketTransportConnectSubJob::DoEndpointLock() {
+ int rv = WebSocketEndpointLockManager::GetInstance()->LockEndpoint(
+ CurrentAddress(), this);
+ next_state_ = STATE_OBTAIN_LOCK_COMPLETE;
+ return rv;
+}
+
+int WebSocketTransportConnectSubJob::DoEndpointLockComplete() {
+ next_state_ = STATE_TRANSPORT_CONNECT;
+ return OK;
+}
+
+int WebSocketTransportConnectSubJob::DoTransportConnect() {
+ // TODO(ricea): Update global g_last_connect_time and report
+ // ConnectInterval.
+ next_state_ = STATE_TRANSPORT_CONNECT_COMPLETE;
+ AddressList one_address(CurrentAddress());
+ transport_socket_ = client_socket_factory()->CreateTransportClientSocket(
+ one_address, net_log().net_log(), net_log().source());
+ // This use of base::Unretained() is safe because transport_socket_ is
+ // destroyed in the destructor.
+ return transport_socket_->Connect(base::Bind(
+ &WebSocketTransportConnectSubJob::OnIOComplete, base::Unretained(this)));
+}
+
+int WebSocketTransportConnectSubJob::DoTransportConnectComplete(int result) {
+ next_state_ = STATE_DONE;
+ WebSocketEndpointLockManager* endpoint_lock_manager =
+ WebSocketEndpointLockManager::GetInstance();
+ if (result != OK) {
+ endpoint_lock_manager->UnlockEndpoint(CurrentAddress());
+
+ if (current_address_index_ + 1 < addresses_.size()) {
+ // Try falling back to the next address in the list.
+ next_state_ = STATE_OBTAIN_LOCK;
+ ++current_address_index_;
+ result = OK;
+ }
+
+ return result;
+ }
+
+ endpoint_lock_manager->RememberSocket(transport_socket_.get(),
+ CurrentAddress());
+
+ return result;
+}
+
+} // namespace net
diff --git a/chromium/net/socket/websocket_transport_connect_sub_job.h b/chromium/net/socket/websocket_transport_connect_sub_job.h
new file mode 100644
index 00000000000..5709a461caf
--- /dev/null
+++ b/chromium/net/socket/websocket_transport_connect_sub_job.h
@@ -0,0 +1,90 @@
+// Copyright 2014 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef NET_SOCKET_WEBSOCKET_TRANSPORT_CONNECT_SUB_JOB_H_
+#define NET_SOCKET_WEBSOCKET_TRANSPORT_CONNECT_SUB_JOB_H_
+
+#include "base/compiler_specific.h"
+#include "base/macros.h"
+#include "base/memory/scoped_ptr.h"
+#include "net/base/address_list.h"
+#include "net/base/load_states.h"
+#include "net/socket/websocket_endpoint_lock_manager.h"
+#include "net/socket/websocket_transport_client_socket_pool.h"
+
+namespace net {
+
+class BoundNetLog;
+class ClientSocketFactory;
+class IPEndPoint;
+class StreamSocket;
+
+// Attempts to connect to a subset of the addresses required by a
+// WebSocketTransportConnectJob, specifically either the IPv4 or IPv6
+// addresses. Each address is tried in turn, and parent_job->OnSubJobComplete()
+// is called when the first address succeeds or the last address fails.
+class WebSocketTransportConnectSubJob
+ : public WebSocketEndpointLockManager::Waiter {
+ public:
+ typedef WebSocketTransportConnectJob::SubJobType SubJobType;
+
+ WebSocketTransportConnectSubJob(const AddressList& addresses,
+ WebSocketTransportConnectJob* parent_job,
+ SubJobType type);
+
+ ~WebSocketTransportConnectSubJob() override;
+
+ // Start connecting.
+ int Start();
+
+ bool started() { return next_state_ != STATE_NONE; }
+
+ LoadState GetLoadState() const;
+
+ SubJobType type() const { return type_; }
+
+ scoped_ptr<StreamSocket> PassSocket() { return transport_socket_.Pass(); }
+
+ // Implementation of WebSocketEndpointLockManager::EndpointWaiter.
+ void GotEndpointLock() override;
+
+ private:
+ enum State {
+ STATE_NONE,
+ STATE_OBTAIN_LOCK,
+ STATE_OBTAIN_LOCK_COMPLETE,
+ STATE_TRANSPORT_CONNECT,
+ STATE_TRANSPORT_CONNECT_COMPLETE,
+ STATE_DONE,
+ };
+
+ ClientSocketFactory* client_socket_factory() const;
+
+ const BoundNetLog& net_log() const;
+
+ const IPEndPoint& CurrentAddress() const;
+
+ void OnIOComplete(int result);
+ int DoLoop(int result);
+ int DoEndpointLock();
+ int DoEndpointLockComplete();
+ int DoTransportConnect();
+ int DoTransportConnectComplete(int result);
+
+ WebSocketTransportConnectJob* const parent_job_;
+
+ const AddressList addresses_;
+ size_t current_address_index_;
+
+ State next_state_;
+ const SubJobType type_;
+
+ scoped_ptr<StreamSocket> transport_socket_;
+
+ DISALLOW_COPY_AND_ASSIGN(WebSocketTransportConnectSubJob);
+};
+
+} // namespace net
+
+#endif // NET_SOCKET_WEBSOCKET_TRANSPORT_CONNECT_SUB_JOB_H_