diff options
author | Zeno Albisser <zeno.albisser@theqtcompany.com> | 2014-12-05 15:04:29 +0100 |
---|---|---|
committer | Andras Becsi <andras.becsi@theqtcompany.com> | 2014-12-09 10:49:28 +0100 |
commit | af6588f8d723931a298c995fa97259bb7f7deb55 (patch) | |
tree | 060ca707847ba1735f01af2372e0d5e494dc0366 /chromium/net/socket | |
parent | 2fff84d821cc7b1c785f6404e0f8091333283e74 (diff) | |
download | qtwebengine-chromium-af6588f8d723931a298c995fa97259bb7f7deb55.tar.gz |
BASELINE: Update chromium to 40.0.2214.28 and ninja to 1.5.3.
Change-Id: I759465284fd64d59ad120219cbe257f7402c4181
Reviewed-by: Andras Becsi <andras.becsi@theqtcompany.com>
Diffstat (limited to 'chromium/net/socket')
93 files changed, 10318 insertions, 3962 deletions
diff --git a/chromium/net/socket/client_socket_factory.cc b/chromium/net/socket/client_socket_factory.cc index 953914581fc..51aea715f4d 100644 --- a/chromium/net/socket/client_socket_factory.cc +++ b/chromium/net/socket/client_socket_factory.cc @@ -50,45 +50,45 @@ class DefaultClientSocketFactory : public ClientSocketFactory, CertDatabase::GetInstance()->AddObserver(this); } - virtual ~DefaultClientSocketFactory() { + ~DefaultClientSocketFactory() override { // Note: This code never runs, as the factory is defined as a Leaky // singleton. CertDatabase::GetInstance()->RemoveObserver(this); } - virtual void OnCertAdded(const X509Certificate* cert) OVERRIDE { + void OnCertAdded(const X509Certificate* cert) override { ClearSSLSessionCache(); } - virtual void OnCACertChanged(const X509Certificate* cert) OVERRIDE { + void OnCACertChanged(const X509Certificate* cert) override { // Per wtc, we actually only need to flush when trust is reduced. // Always flush now because OnCACertChanged does not tell us this. // See comments in ClientSocketPoolManager::OnCACertChanged. ClearSSLSessionCache(); } - virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( + scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( DatagramSocket::BindType bind_type, const RandIntCallback& rand_int_cb, NetLog* net_log, - const NetLog::Source& source) OVERRIDE { + const NetLog::Source& source) override { return scoped_ptr<DatagramClientSocket>( new UDPClientSocket(bind_type, rand_int_cb, net_log, source)); } - virtual scoped_ptr<StreamSocket> CreateTransportClientSocket( + scoped_ptr<StreamSocket> CreateTransportClientSocket( const AddressList& addresses, NetLog* net_log, - const NetLog::Source& source) OVERRIDE { + const NetLog::Source& source) override { return scoped_ptr<StreamSocket>( new TCPClientSocket(addresses, net_log, source)); } - virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<SSLClientSocket> CreateSSLClientSocket( scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, - const SSLClientSocketContext& context) OVERRIDE { + const SSLClientSocketContext& context) override { // nss_thread_task_runner_ may be NULL if g_use_dedicated_nss_thread is // false or if the dedicated NSS thread failed to start. If so, cause NSS // functions to execute on the current task runner. @@ -120,9 +120,7 @@ class DefaultClientSocketFactory : public ClientSocketFactory, #endif } - virtual void ClearSSLSessionCache() OVERRIDE { - SSLClientSocket::ClearSessionCache(); - } + void ClearSSLSessionCache() override { SSLClientSocket::ClearSessionCache(); } private: scoped_refptr<base::SequencedWorkerPool> worker_pool_; diff --git a/chromium/net/socket/client_socket_pool.cc b/chromium/net/socket/client_socket_pool.cc index 0eebd11b9fb..261e87fc7fe 100644 --- a/chromium/net/socket/client_socket_pool.cc +++ b/chromium/net/socket/client_socket_pool.cc @@ -12,10 +12,10 @@ namespace { // alive. // TODO(ziadh): Change this timeout after getting histogram data on how long it // should be. -int g_unused_idle_socket_timeout_s = 10; +int64 g_unused_idle_socket_timeout_s = 10; // The maximum duration, in seconds, to keep used idle persistent sockets alive. -int g_used_idle_socket_timeout_s = 300; // 5 minutes +int64 g_used_idle_socket_timeout_s = 300; // 5 minutes } // namespace diff --git a/chromium/net/socket/client_socket_pool.h b/chromium/net/socket/client_socket_pool.h index 715cddb94d4..2a2be36c8cd 100644 --- a/chromium/net/socket/client_socket_pool.h +++ b/chromium/net/socket/client_socket_pool.h @@ -182,7 +182,7 @@ class NET_EXPORT ClientSocketPool : public LowerLayeredPool { protected: ClientSocketPool(); - virtual ~ClientSocketPool(); + ~ClientSocketPool() override; // Return the connection timeout for this pool. virtual base::TimeDelta ConnectionTimeout() const = 0; diff --git a/chromium/net/socket/client_socket_pool_base.h b/chromium/net/socket/client_socket_pool_base.h index 8079cd4e5c7..ec4e33cc0ed 100644 --- a/chromium/net/socket/client_socket_pool_base.h +++ b/chromium/net/socket/client_socket_pool_base.h @@ -219,7 +219,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper base::TimeDelta used_idle_socket_timeout, ConnectJobFactory* connect_job_factory); - virtual ~ClientSocketPoolBaseHelper(); + ~ClientSocketPoolBaseHelper() override; // Adds a lower layered pool to |this|, and adds |this| as a higher layered // pool on top of |lower_pool|. @@ -327,10 +327,10 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper void EnableConnectBackupJobs(); // ConnectJob::Delegate methods: - virtual void OnConnectJobComplete(int result, ConnectJob* job) OVERRIDE; + void OnConnectJobComplete(int result, ConnectJob* job) override; // NetworkChangeNotifier::IPAddressObserver methods: - virtual void OnIPAddressChanged() OVERRIDE; + void OnIPAddressChanged() override; private: friend class base::RefCounted<ClientSocketPoolBaseHelper>; @@ -741,10 +741,7 @@ class ClientSocketPoolBase { internal::ClientSocketPoolBaseHelper::NORMAL, params->ignore_limits(), params, net_log)); - return helper_.RequestSocket( - group_name, - request.template PassAs< - const internal::ClientSocketPoolBaseHelper::Request>()); + return helper_.RequestSocket(group_name, request.Pass()); } // RequestSockets bundles up the parameters into a Request and then forwards @@ -856,7 +853,7 @@ class ClientSocketPoolBase { virtual scoped_ptr<ConnectJob> NewConnectJob( const std::string& group_name, const internal::ClientSocketPoolBaseHelper::Request& request, - ConnectJob::Delegate* delegate) const OVERRIDE { + ConnectJob::Delegate* delegate) const override { const Request& casted_request = static_cast<const Request&>(request); return connect_job_factory_->NewConnectJob( group_name, casted_request, delegate); diff --git a/chromium/net/socket/client_socket_pool_base_unittest.cc b/chromium/net/socket/client_socket_pool_base_unittest.cc index 5a672b44a4e..c4a28459a1e 100644 --- a/chromium/net/socket/client_socket_pool_base_unittest.cc +++ b/chromium/net/socket/client_socket_pool_base_unittest.cc @@ -128,9 +128,9 @@ class MockClientSocket : public StreamSocket { } // Socket implementation. - virtual int Read( - IOBuffer* /* buf */, int len, - const CompletionCallback& /* callback */) OVERRIDE { + int Read(IOBuffer* /* buf */, + int len, + const CompletionCallback& /* callback */) override { if (has_unread_data_ && len > 0) { has_unread_data_ = false; was_used_to_convey_data_ = true; @@ -139,54 +139,44 @@ class MockClientSocket : public StreamSocket { return ERR_UNEXPECTED; } - virtual int Write( - IOBuffer* /* buf */, int len, - const CompletionCallback& /* callback */) OVERRIDE { + int Write(IOBuffer* /* buf */, + int len, + const CompletionCallback& /* callback */) override { was_used_to_convey_data_ = true; return len; } - virtual int SetReceiveBufferSize(int32 size) OVERRIDE { return OK; } - virtual int SetSendBufferSize(int32 size) OVERRIDE { return OK; } + int SetReceiveBufferSize(int32 size) override { return OK; } + int SetSendBufferSize(int32 size) override { return OK; } // StreamSocket implementation. - virtual int Connect(const CompletionCallback& callback) OVERRIDE { + int Connect(const CompletionCallback& callback) override { connected_ = true; return OK; } - virtual void Disconnect() OVERRIDE { connected_ = false; } - virtual bool IsConnected() const OVERRIDE { return connected_; } - virtual bool IsConnectedAndIdle() const OVERRIDE { + void Disconnect() override { connected_ = false; } + bool IsConnected() const override { return connected_; } + bool IsConnectedAndIdle() const override { return connected_ && !has_unread_data_; } - virtual int GetPeerAddress(IPEndPoint* /* address */) const OVERRIDE { + int GetPeerAddress(IPEndPoint* /* address */) const override { return ERR_UNEXPECTED; } - virtual int GetLocalAddress(IPEndPoint* /* address */) const OVERRIDE { + int GetLocalAddress(IPEndPoint* /* address */) const override { return ERR_UNEXPECTED; } - virtual const BoundNetLog& NetLog() const OVERRIDE { - return net_log_; - } + const BoundNetLog& NetLog() const override { return net_log_; } - virtual void SetSubresourceSpeculation() OVERRIDE {} - virtual void SetOmniboxSpeculation() OVERRIDE {} - virtual bool WasEverUsed() const OVERRIDE { - return was_used_to_convey_data_; - } - virtual bool UsingTCPFastOpen() const OVERRIDE { return false; } - virtual bool WasNpnNegotiated() const OVERRIDE { - return false; - } - virtual NextProto GetNegotiatedProtocol() const OVERRIDE { - return kProtoUnknown; - } - virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { - return false; - } + void SetSubresourceSpeculation() override {} + void SetOmniboxSpeculation() override {} + bool WasEverUsed() const override { return was_used_to_convey_data_; } + bool UsingTCPFastOpen() const override { return false; } + bool WasNpnNegotiated() const override { return false; } + NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; } + bool GetSSLInfo(SSLInfo* ssl_info) override { return false; } private: bool connected_; @@ -203,35 +193,33 @@ class MockClientSocketFactory : public ClientSocketFactory { public: MockClientSocketFactory() : allocation_count_(0) {} - virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( + scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( DatagramSocket::BindType bind_type, const RandIntCallback& rand_int_cb, NetLog* net_log, - const NetLog::Source& source) OVERRIDE { + const NetLog::Source& source) override { NOTREACHED(); return scoped_ptr<DatagramClientSocket>(); } - virtual scoped_ptr<StreamSocket> CreateTransportClientSocket( + scoped_ptr<StreamSocket> CreateTransportClientSocket( const AddressList& addresses, NetLog* /* net_log */, - const NetLog::Source& /*source*/) OVERRIDE { + const NetLog::Source& /*source*/) override { allocation_count_++; return scoped_ptr<StreamSocket>(); } - virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<SSLClientSocket> CreateSSLClientSocket( scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, - const SSLClientSocketContext& context) OVERRIDE { + const SSLClientSocketContext& context) override { NOTIMPLEMENTED(); return scoped_ptr<SSLClientSocket>(); } - virtual void ClearSSLSessionCache() OVERRIDE { - NOTIMPLEMENTED(); - } + void ClearSSLSessionCache() override { NOTIMPLEMENTED(); } void WaitForSignal(TestConnectJob* job) { waiting_jobs_.push_back(job); } @@ -291,9 +279,9 @@ class TestConnectJob : public ConnectJob { // From ConnectJob: - virtual LoadState GetLoadState() const OVERRIDE { return load_state_; } + LoadState GetLoadState() const override { return load_state_; } - virtual void GetAdditionalErrorState(ClientSocketHandle* handle) OVERRIDE { + void GetAdditionalErrorState(ClientSocketHandle* handle) override { if (store_additional_error_state_) { // Set all of the additional error state fields in some way. handle->set_is_ssl_error(true); @@ -306,7 +294,7 @@ class TestConnectJob : public ConnectJob { private: // From ConnectJob: - virtual int ConnectInternal() OVERRIDE { + int ConnectInternal() override { AddressList ignored; client_socket_factory_->CreateTransportClientSocket( ignored, NULL, net::NetLog::Source()); @@ -439,7 +427,7 @@ class TestConnectJobFactory net_log_(net_log) { } - virtual ~TestConnectJobFactory() {} + ~TestConnectJobFactory() override {} void set_job_type(TestConnectJob::JobType job_type) { job_type_ = job_type; } @@ -454,10 +442,10 @@ class TestConnectJobFactory // ConnectJobFactory implementation. - virtual scoped_ptr<ConnectJob> NewConnectJob( + scoped_ptr<ConnectJob> NewConnectJob( const std::string& group_name, const TestClientSocketPoolBase::Request& request, - ConnectJob::Delegate* delegate) const OVERRIDE { + ConnectJob::Delegate* delegate) const override { EXPECT_TRUE(!job_types_ || !job_types_->empty()); TestConnectJob::JobType job_type = job_type_; if (job_types_ && !job_types_->empty()) { @@ -473,7 +461,7 @@ class TestConnectJobFactory net_log_)); } - virtual base::TimeDelta ConnectionTimeout() const OVERRIDE { + base::TimeDelta ConnectionTimeout() const override { return timeout_duration_; } @@ -502,92 +490,78 @@ class TestClientSocketPool : public ClientSocketPool { unused_idle_socket_timeout, used_idle_socket_timeout, connect_job_factory) {} - virtual ~TestClientSocketPool() {} + ~TestClientSocketPool() override {} - virtual int RequestSocket( - const std::string& group_name, - const void* params, - net::RequestPriority priority, - ClientSocketHandle* handle, - const CompletionCallback& callback, - const BoundNetLog& net_log) OVERRIDE { + int RequestSocket(const std::string& group_name, + const void* params, + net::RequestPriority priority, + ClientSocketHandle* handle, + const CompletionCallback& callback, + const BoundNetLog& net_log) override { const scoped_refptr<TestSocketParams>* casted_socket_params = static_cast<const scoped_refptr<TestSocketParams>*>(params); return base_.RequestSocket(group_name, *casted_socket_params, priority, handle, callback, net_log); } - virtual void RequestSockets(const std::string& group_name, - const void* params, - int num_sockets, - const BoundNetLog& net_log) OVERRIDE { + void RequestSockets(const std::string& group_name, + const void* params, + int num_sockets, + const BoundNetLog& net_log) override { const scoped_refptr<TestSocketParams>* casted_params = static_cast<const scoped_refptr<TestSocketParams>*>(params); base_.RequestSockets(group_name, *casted_params, num_sockets, net_log); } - virtual void CancelRequest( - const std::string& group_name, - ClientSocketHandle* handle) OVERRIDE { + void CancelRequest(const std::string& group_name, + ClientSocketHandle* handle) override { base_.CancelRequest(group_name, handle); } - virtual void ReleaseSocket( - const std::string& group_name, - scoped_ptr<StreamSocket> socket, - int id) OVERRIDE { + void ReleaseSocket(const std::string& group_name, + scoped_ptr<StreamSocket> socket, + int id) override { base_.ReleaseSocket(group_name, socket.Pass(), id); } - virtual void FlushWithError(int error) OVERRIDE { - base_.FlushWithError(error); - } + void FlushWithError(int error) override { base_.FlushWithError(error); } - virtual bool IsStalled() const OVERRIDE { - return base_.IsStalled(); - } + bool IsStalled() const override { return base_.IsStalled(); } - virtual void CloseIdleSockets() OVERRIDE { - base_.CloseIdleSockets(); - } + void CloseIdleSockets() override { base_.CloseIdleSockets(); } - virtual int IdleSocketCount() const OVERRIDE { - return base_.idle_socket_count(); - } + int IdleSocketCount() const override { return base_.idle_socket_count(); } - virtual int IdleSocketCountInGroup( - const std::string& group_name) const OVERRIDE { + int IdleSocketCountInGroup(const std::string& group_name) const override { return base_.IdleSocketCountInGroup(group_name); } - virtual LoadState GetLoadState( - const std::string& group_name, - const ClientSocketHandle* handle) const OVERRIDE { + LoadState GetLoadState(const std::string& group_name, + const ClientSocketHandle* handle) const override { return base_.GetLoadState(group_name, handle); } - virtual void AddHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE { + void AddHigherLayeredPool(HigherLayeredPool* higher_pool) override { base_.AddHigherLayeredPool(higher_pool); } - virtual void RemoveHigherLayeredPool( - HigherLayeredPool* higher_pool) OVERRIDE { + void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool) override { base_.RemoveHigherLayeredPool(higher_pool); } - virtual base::DictionaryValue* GetInfoAsValue( + base::DictionaryValue* GetInfoAsValue( const std::string& name, const std::string& type, - bool include_nested_pools) const OVERRIDE { + bool include_nested_pools) const override { return base_.GetInfoAsValue(name, type); } - virtual base::TimeDelta ConnectionTimeout() const OVERRIDE { + base::TimeDelta ConnectionTimeout() const override { return base_.ConnectionTimeout(); } - virtual ClientSocketPoolHistograms* histograms() const OVERRIDE { + ClientSocketPoolHistograms* histograms() const override { return base_.histograms(); } @@ -651,9 +625,9 @@ class TestConnectJobDelegate : public ConnectJob::Delegate { public: TestConnectJobDelegate() : have_result_(false), waiting_for_result_(false), result_(OK) {} - virtual ~TestConnectJobDelegate() {} + ~TestConnectJobDelegate() override {} - virtual void OnConnectJobComplete(int result, ConnectJob* job) OVERRIDE { + void OnConnectJobComplete(int result, ConnectJob* job) override { result_ = result; scoped_ptr<ConnectJob> owned_job(job); scoped_ptr<StreamSocket> socket = owned_job->PassSocket(); @@ -693,7 +667,7 @@ class ClientSocketPoolBaseTest : public testing::Test { internal::ClientSocketPoolBaseHelper::cleanup_timer_enabled(); } - virtual ~ClientSocketPoolBaseTest() { + ~ClientSocketPoolBaseTest() override { internal::ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled( connect_backup_jobs_enabled_); internal::ClientSocketPoolBaseHelper::set_cleanup_timer_enabled( @@ -1473,7 +1447,7 @@ class RequestSocketCallback : public TestCompletionCallbackBase { base::Unretained(this))) { } - virtual ~RequestSocketCallback() {} + ~RequestSocketCallback() override {} const CompletionCallback& callback() const { return callback_; } @@ -2591,7 +2565,7 @@ class TestReleasingSocketRequest : public TestCompletionCallbackBase { base::Unretained(this))) { } - virtual ~TestReleasingSocketRequest() {} + ~TestReleasingSocketRequest() override {} ClientSocketHandle* handle() { return &handle_; } @@ -2716,7 +2690,7 @@ class ConnectWithinCallback : public TestCompletionCallbackBase { base::Unretained(this))) { } - virtual ~ConnectWithinCallback() {} + ~ConnectWithinCallback() override {} int WaitForNestedResult() { return nested_callback_.WaitForResult(); diff --git a/chromium/net/socket/client_socket_pool_manager.cc b/chromium/net/socket/client_socket_pool_manager.cc index f81b4bc1db5..b99612718e4 100644 --- a/chromium/net/socket/client_socket_pool_manager.cc +++ b/chromium/net/socket/client_socket_pool_manager.cc @@ -84,7 +84,6 @@ int InitSocketPoolHelper(const GURL& request_url, HttpNetworkSession::SocketPoolType socket_pool_type, const OnHostResolutionCallback& resolution_callback, const CompletionCallback& callback) { - scoped_refptr<TransportSocketParams> tcp_params; scoped_refptr<HttpProxySocketParams> http_proxy_params; scoped_refptr<SOCKSSocketParams> socks_params; scoped_ptr<HostPortPair> proxy_host_port; @@ -132,7 +131,7 @@ int InitSocketPoolHelper(const GURL& request_url, // should be the same for all connections, whereas version_max may // change for version fallbacks. std::string prefix = "ssl/"; - if (ssl_config_for_origin.version_max != net::kDefaultSSLVersionMax) { + if (ssl_config_for_origin.version_max != kDefaultSSLVersionMax) { switch (ssl_config_for_origin.version_max) { case SSL_PROTOCOL_VERSION_TLS1_2: prefix = "ssl(max:3.3)/"; @@ -155,19 +154,16 @@ int InitSocketPoolHelper(const GURL& request_url, } bool ignore_limits = (request_load_flags & LOAD_IGNORE_LIMITS) != 0; - if (proxy_info.is_direct()) { - tcp_params = new TransportSocketParams(origin_host_port, - disable_resolver_cache, - ignore_limits, - resolution_callback); - } else { + if (!proxy_info.is_direct()) { ProxyServer proxy_server = proxy_info.proxy_server(); proxy_host_port.reset(new HostPortPair(proxy_server.host_port_pair())); scoped_refptr<TransportSocketParams> proxy_tcp_params( - new TransportSocketParams(*proxy_host_port, - disable_resolver_cache, - ignore_limits, - resolution_callback)); + new TransportSocketParams( + *proxy_host_port, + disable_resolver_cache, + ignore_limits, + resolution_callback, + TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)); if (proxy_info.is_http() || proxy_info.is_https()) { std::string user_agent; @@ -175,6 +171,18 @@ int InitSocketPoolHelper(const GURL& request_url, &user_agent); scoped_refptr<SSLSocketParams> ssl_params; if (proxy_info.is_https()) { + // Combine connect and write for SSL sockets in TCP FastOpen + // field trial. + TransportSocketParams::CombineConnectAndWritePolicy + combine_connect_and_write = + session->params().enable_tcp_fast_open_for_ssl ? + TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DESIRED : + TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT; + proxy_tcp_params = new TransportSocketParams(*proxy_host_port, + disable_resolver_cache, + ignore_limits, + resolution_callback, + combine_connect_and_write); // Set ssl_params, and unset proxy_tcp_params ssl_params = new SSLSocketParams(proxy_tcp_params, NULL, @@ -197,7 +205,8 @@ int InitSocketPoolHelper(const GURL& request_url, session->http_auth_cache(), session->http_auth_handler_factory(), session->spdy_session_pool(), - force_tunnel || using_ssl); + force_tunnel || using_ssl, + session->params().proxy_delegate); } else { DCHECK(proxy_info.is_socks()); char socks_version; @@ -220,8 +229,23 @@ int InitSocketPoolHelper(const GURL& request_url, // Deal with SSL - which layers on top of any given proxy. if (using_ssl) { + scoped_refptr<TransportSocketParams> ssl_tcp_params; + if (proxy_info.is_direct()) { + // Setup TCP params if non-proxied SSL connection. + // Combine connect and write for SSL sockets in TCP FastOpen field trial. + TransportSocketParams::CombineConnectAndWritePolicy + combine_connect_and_write = + session->params().enable_tcp_fast_open_for_ssl ? + TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DESIRED : + TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT; + ssl_tcp_params = new TransportSocketParams(origin_host_port, + disable_resolver_cache, + ignore_limits, + resolution_callback, + combine_connect_and_write); + } scoped_refptr<SSLSocketParams> ssl_params = - new SSLSocketParams(tcp_params, + new SSLSocketParams(ssl_tcp_params, socks_params, http_proxy_params, origin_host_port, @@ -280,7 +304,13 @@ int InitSocketPoolHelper(const GURL& request_url, } DCHECK(proxy_info.is_direct()); - + scoped_refptr<TransportSocketParams> tcp_params = + new TransportSocketParams( + origin_host_port, + disable_resolver_cache, + ignore_limits, + resolution_callback, + TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT); TransportClientSocketPool* pool = session->GetTransportSocketPool(socket_pool_type); if (num_preconnect_streams) { diff --git a/chromium/net/socket/client_socket_pool_manager_impl.cc b/chromium/net/socket/client_socket_pool_manager_impl.cc index 991278d7341..5ed31fc9a25 100644 --- a/chromium/net/socket/client_socket_pool_manager_impl.cc +++ b/chromium/net/socket/client_socket_pool_manager_impl.cc @@ -11,6 +11,7 @@ #include "net/socket/socks_client_socket_pool.h" #include "net/socket/ssl_client_socket_pool.h" #include "net/socket/transport_client_socket_pool.h" +#include "net/socket/websocket_transport_client_socket_pool.h" #include "net/ssl/ssl_config_service.h" namespace net { @@ -38,54 +39,68 @@ ClientSocketPoolManagerImpl::ClientSocketPoolManagerImpl( ClientSocketFactory* socket_factory, HostResolver* host_resolver, CertVerifier* cert_verifier, - ServerBoundCertService* server_bound_cert_service, + ChannelIDService* channel_id_service, TransportSecurityState* transport_security_state, CTVerifier* cert_transparency_verifier, const std::string& ssl_session_cache_shard, ProxyService* proxy_service, SSLConfigService* ssl_config_service, + bool enable_ssl_connect_job_waiting, + ProxyDelegate* proxy_delegate, HttpNetworkSession::SocketPoolType pool_type) : net_log_(net_log), socket_factory_(socket_factory), host_resolver_(host_resolver), cert_verifier_(cert_verifier), - server_bound_cert_service_(server_bound_cert_service), + channel_id_service_(channel_id_service), transport_security_state_(transport_security_state), cert_transparency_verifier_(cert_transparency_verifier), ssl_session_cache_shard_(ssl_session_cache_shard), proxy_service_(proxy_service), ssl_config_service_(ssl_config_service), + enable_ssl_connect_job_waiting_(enable_ssl_connect_job_waiting), pool_type_(pool_type), transport_pool_histograms_("TCP"), - transport_socket_pool_(new TransportClientSocketPool( - max_sockets_per_pool(pool_type), max_sockets_per_group(pool_type), - &transport_pool_histograms_, - host_resolver, - socket_factory_, - net_log)), + transport_socket_pool_( + pool_type == HttpNetworkSession::WEBSOCKET_SOCKET_POOL + ? new WebSocketTransportClientSocketPool( + max_sockets_per_pool(pool_type), + max_sockets_per_group(pool_type), + &transport_pool_histograms_, + host_resolver, + socket_factory_, + net_log) + : new TransportClientSocketPool(max_sockets_per_pool(pool_type), + max_sockets_per_group(pool_type), + &transport_pool_histograms_, + host_resolver, + socket_factory_, + net_log)), ssl_pool_histograms_("SSL2"), - ssl_socket_pool_(new SSLClientSocketPool( - max_sockets_per_pool(pool_type), max_sockets_per_group(pool_type), - &ssl_pool_histograms_, - host_resolver, - cert_verifier, - server_bound_cert_service, - transport_security_state, - cert_transparency_verifier, - ssl_session_cache_shard, - socket_factory, - transport_socket_pool_.get(), - NULL /* no socks proxy */, - NULL /* no http proxy */, - ssl_config_service, - net_log)), + ssl_socket_pool_(new SSLClientSocketPool(max_sockets_per_pool(pool_type), + max_sockets_per_group(pool_type), + &ssl_pool_histograms_, + host_resolver, + cert_verifier, + channel_id_service, + transport_security_state, + cert_transparency_verifier, + ssl_session_cache_shard, + socket_factory, + transport_socket_pool_.get(), + NULL /* no socks proxy */, + NULL /* no http proxy */, + ssl_config_service, + enable_ssl_connect_job_waiting, + net_log)), transport_for_socks_pool_histograms_("TCPforSOCKS"), socks_pool_histograms_("SOCK"), transport_for_http_proxy_pool_histograms_("TCPforHTTPProxy"), transport_for_https_proxy_pool_histograms_("TCPforHTTPSProxy"), ssl_for_https_proxy_pool_histograms_("SSLforHTTPSProxy"), http_proxy_pool_histograms_("HTTPProxy"), - ssl_socket_pool_for_proxies_histograms_("SSLForProxies") { + ssl_socket_pool_for_proxies_histograms_("SSLForProxies"), + proxy_delegate_(proxy_delegate) { CertDatabase::GetInstance()->AddObserver(this); } @@ -287,7 +302,7 @@ ClientSocketPoolManagerImpl::GetSocketPoolForHTTPProxy( &ssl_for_https_proxy_pool_histograms_, host_resolver_, cert_verifier_, - server_bound_cert_service_, + channel_id_service_, transport_security_state_, cert_transparency_verifier_, ssl_session_cache_shard_, @@ -296,6 +311,7 @@ ClientSocketPoolManagerImpl::GetSocketPoolForHTTPProxy( NULL /* no socks proxy */, NULL /* no http proxy */, ssl_config_service_.get(), + enable_ssl_connect_job_waiting_, net_log_))); DCHECK(tcp_https_ret.second); @@ -310,6 +326,7 @@ ClientSocketPoolManagerImpl::GetSocketPoolForHTTPProxy( host_resolver_, tcp_http_ret.first->second, ssl_https_ret.first->second, + proxy_delegate_, net_log_))); return ret.first->second; @@ -328,7 +345,7 @@ SSLClientSocketPool* ClientSocketPoolManagerImpl::GetSocketPoolForSSLWithProxy( &ssl_pool_histograms_, host_resolver_, cert_verifier_, - server_bound_cert_service_, + channel_id_service_, transport_security_state_, cert_transparency_verifier_, ssl_session_cache_shard_, @@ -337,6 +354,7 @@ SSLClientSocketPool* ClientSocketPoolManagerImpl::GetSocketPoolForSSLWithProxy( GetSocketPoolForSOCKSProxy(proxy_server), GetSocketPoolForHTTPProxy(proxy_server), ssl_config_service_.get(), + enable_ssl_connect_job_waiting_, net_log_); std::pair<SSLSocketPoolMap::iterator, bool> ret = diff --git a/chromium/net/socket/client_socket_pool_manager_impl.h b/chromium/net/socket/client_socket_pool_manager_impl.h index 06d4d244a36..f9f8d3b8f15 100644 --- a/chromium/net/socket/client_socket_pool_manager_impl.h +++ b/chromium/net/socket/client_socket_pool_manager_impl.h @@ -21,13 +21,14 @@ namespace net { class CertVerifier; +class ChannelIDService; class ClientSocketFactory; class ClientSocketPoolHistograms; class CTVerifier; class HttpProxyClientSocketPool; class HostResolver; class NetLog; -class ServerBoundCertService; +class ProxyDelegate; class ProxyService; class SOCKSClientSocketPool; class SSLClientSocketPool; @@ -61,38 +62,40 @@ class ClientSocketPoolManagerImpl : public base::NonThreadSafe, ClientSocketFactory* socket_factory, HostResolver* host_resolver, CertVerifier* cert_verifier, - ServerBoundCertService* server_bound_cert_service, + ChannelIDService* channel_id_service, TransportSecurityState* transport_security_state, CTVerifier* cert_transparency_verifier, const std::string& ssl_session_cache_shard, ProxyService* proxy_service, SSLConfigService* ssl_config_service, + bool enable_ssl_connect_job_waiting, + ProxyDelegate* proxy_delegate, HttpNetworkSession::SocketPoolType pool_type); - virtual ~ClientSocketPoolManagerImpl(); + ~ClientSocketPoolManagerImpl() override; - virtual void FlushSocketPoolsWithError(int error) OVERRIDE; - virtual void CloseIdleSockets() OVERRIDE; + void FlushSocketPoolsWithError(int error) override; + void CloseIdleSockets() override; - virtual TransportClientSocketPool* GetTransportSocketPool() OVERRIDE; + TransportClientSocketPool* GetTransportSocketPool() override; - virtual SSLClientSocketPool* GetSSLSocketPool() OVERRIDE; + SSLClientSocketPool* GetSSLSocketPool() override; - virtual SOCKSClientSocketPool* GetSocketPoolForSOCKSProxy( - const HostPortPair& socks_proxy) OVERRIDE; + SOCKSClientSocketPool* GetSocketPoolForSOCKSProxy( + const HostPortPair& socks_proxy) override; - virtual HttpProxyClientSocketPool* GetSocketPoolForHTTPProxy( - const HostPortPair& http_proxy) OVERRIDE; + HttpProxyClientSocketPool* GetSocketPoolForHTTPProxy( + const HostPortPair& http_proxy) override; - virtual SSLClientSocketPool* GetSocketPoolForSSLWithProxy( - const HostPortPair& proxy_server) OVERRIDE; + SSLClientSocketPool* GetSocketPoolForSSLWithProxy( + const HostPortPair& proxy_server) override; // Creates a Value summary of the state of the socket pools. The caller is // responsible for deleting the returned value. - virtual base::Value* SocketPoolInfoToValue() const OVERRIDE; + base::Value* SocketPoolInfoToValue() const override; // CertDatabase::Observer methods: - virtual void OnCertAdded(const X509Certificate* cert) OVERRIDE; - virtual void OnCACertChanged(const X509Certificate* cert) OVERRIDE; + void OnCertAdded(const X509Certificate* cert) override; + void OnCACertChanged(const X509Certificate* cert) override; private: typedef internal::OwnedPoolMap<HostPortPair, TransportClientSocketPool*> @@ -108,12 +111,13 @@ class ClientSocketPoolManagerImpl : public base::NonThreadSafe, ClientSocketFactory* const socket_factory_; HostResolver* const host_resolver_; CertVerifier* const cert_verifier_; - ServerBoundCertService* const server_bound_cert_service_; + ChannelIDService* const channel_id_service_; TransportSecurityState* const transport_security_state_; CTVerifier* const cert_transparency_verifier_; const std::string ssl_session_cache_shard_; ProxyService* const proxy_service_; const scoped_refptr<SSLConfigService> ssl_config_service_; + bool enable_ssl_connect_job_waiting_; const HttpNetworkSession::SocketPoolType pool_type_; // Note: this ordering is important. @@ -145,6 +149,8 @@ class ClientSocketPoolManagerImpl : public base::NonThreadSafe, ClientSocketPoolHistograms ssl_socket_pool_for_proxies_histograms_; SSLSocketPoolMap ssl_socket_pools_for_proxies_; + const ProxyDelegate* proxy_delegate_; + DISALLOW_COPY_AND_ASSIGN(ClientSocketPoolManagerImpl); }; diff --git a/chromium/net/socket/deterministic_socket_data_unittest.cc b/chromium/net/socket/deterministic_socket_data_unittest.cc index c51427e25a7..bdeba2bef55 100644 --- a/chromium/net/socket/deterministic_socket_data_unittest.cc +++ b/chromium/net/socket/deterministic_socket_data_unittest.cc @@ -16,7 +16,7 @@ namespace { static const char kMsg1[] = "\0hello!\xff"; static const int kLen1 = arraysize(kMsg1); -static const char kMsg2[] = "\012345678\0"; +static const char kMsg2[] = "\0a2345678\0"; static const int kLen2 = arraysize(kMsg2); static const char kMsg3[] = "bye!"; static const int kLen3 = arraysize(kMsg3); @@ -29,7 +29,7 @@ class DeterministicSocketDataTest : public PlatformTest { public: DeterministicSocketDataTest(); - virtual void TearDown(); + void TearDown() override; void ReentrantReadCallback(int len, int rv); void ReentrantWriteCallback(const char* data, int len, int rv); @@ -71,10 +71,12 @@ DeterministicSocketDataTest::DeterministicSocketDataTest() read_buf_(NULL), connect_data_(SYNCHRONOUS, OK), endpoint_("www.google.com", 443), - tcp_params_(new TransportSocketParams(endpoint_, - false, - false, - OnHostResolutionCallback())), + tcp_params_(new TransportSocketParams( + endpoint_, + false, + false, + OnHostResolutionCallback(), + TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)), histograms_(std::string()), socket_pool_(10, 10, &histograms_, &socket_factory_) {} diff --git a/chromium/net/socket/mock_client_socket_pool_manager.h b/chromium/net/socket/mock_client_socket_pool_manager.h index c2c3792a4f6..76d7704316a 100644 --- a/chromium/net/socket/mock_client_socket_pool_manager.h +++ b/chromium/net/socket/mock_client_socket_pool_manager.h @@ -14,7 +14,7 @@ namespace net { class MockClientSocketPoolManager : public ClientSocketPoolManager { public: MockClientSocketPoolManager(); - virtual ~MockClientSocketPoolManager(); + ~MockClientSocketPoolManager() override; // Sets "override" socket pools that get used instead. void SetTransportSocketPool(TransportClientSocketPool* pool); @@ -27,17 +27,17 @@ class MockClientSocketPoolManager : public ClientSocketPoolManager { SSLClientSocketPool* pool); // ClientSocketPoolManager methods: - virtual void FlushSocketPoolsWithError(int error) OVERRIDE; - virtual void CloseIdleSockets() OVERRIDE; - virtual TransportClientSocketPool* GetTransportSocketPool() OVERRIDE; - virtual SSLClientSocketPool* GetSSLSocketPool() OVERRIDE; - virtual SOCKSClientSocketPool* GetSocketPoolForSOCKSProxy( - const HostPortPair& socks_proxy) OVERRIDE; - virtual HttpProxyClientSocketPool* GetSocketPoolForHTTPProxy( - const HostPortPair& http_proxy) OVERRIDE; - virtual SSLClientSocketPool* GetSocketPoolForSSLWithProxy( - const HostPortPair& proxy_server) OVERRIDE; - virtual base::Value* SocketPoolInfoToValue() const OVERRIDE; + void FlushSocketPoolsWithError(int error) override; + void CloseIdleSockets() override; + TransportClientSocketPool* GetTransportSocketPool() override; + SSLClientSocketPool* GetSSLSocketPool() override; + SOCKSClientSocketPool* GetSocketPoolForSOCKSProxy( + const HostPortPair& socks_proxy) override; + HttpProxyClientSocketPool* GetSocketPoolForHTTPProxy( + const HostPortPair& http_proxy) override; + SSLClientSocketPool* GetSocketPoolForSSLWithProxy( + const HostPortPair& proxy_server) override; + base::Value* SocketPoolInfoToValue() const override; private: typedef internal::OwnedPoolMap<HostPortPair, TransportClientSocketPool*> diff --git a/chromium/net/socket/next_proto.cc b/chromium/net/socket/next_proto.cc index c9172365c38..1dcfb5d58f3 100644 --- a/chromium/net/socket/next_proto.cc +++ b/chromium/net/socket/next_proto.cc @@ -15,7 +15,6 @@ NextProtoVector NextProtosHttpOnly() { NextProtoVector NextProtosDefaults() { NextProtoVector next_protos; next_protos.push_back(kProtoHTTP11); - next_protos.push_back(kProtoSPDY3); next_protos.push_back(kProtoSPDY31); return next_protos; } @@ -27,35 +26,15 @@ NextProtoVector NextProtosWithSpdyAndQuic(bool spdy_enabled, if (quic_enabled) next_protos.push_back(kProtoQUIC1SPDY3); if (spdy_enabled) { - next_protos.push_back(kProtoSPDY3); next_protos.push_back(kProtoSPDY31); } return next_protos; } -NextProtoVector NextProtosSpdy3() { - NextProtoVector next_protos; - next_protos.push_back(kProtoHTTP11); - next_protos.push_back(kProtoQUIC1SPDY3); - next_protos.push_back(kProtoSPDY3); - return next_protos; -} - NextProtoVector NextProtosSpdy31() { NextProtoVector next_protos; next_protos.push_back(kProtoHTTP11); next_protos.push_back(kProtoQUIC1SPDY3); - next_protos.push_back(kProtoSPDY3); - next_protos.push_back(kProtoSPDY31); - return next_protos; -} - -NextProtoVector NextProtosSpdy31WithSpdy2() { - NextProtoVector next_protos; - next_protos.push_back(kProtoHTTP11); - next_protos.push_back(kProtoQUIC1SPDY3); - next_protos.push_back(kProtoDeprecatedSPDY2); - next_protos.push_back(kProtoSPDY3); next_protos.push_back(kProtoSPDY31); return next_protos; } @@ -64,7 +43,6 @@ NextProtoVector NextProtosSpdy4Http2() { NextProtoVector next_protos; next_protos.push_back(kProtoHTTP11); next_protos.push_back(kProtoQUIC1SPDY3); - next_protos.push_back(kProtoSPDY3); next_protos.push_back(kProtoSPDY31); next_protos.push_back(kProtoSPDY4); return next_protos; diff --git a/chromium/net/socket/next_proto.h b/chromium/net/socket/next_proto.h index 19ff55e0bc0..4df6e9b9cd5 100644 --- a/chromium/net/socket/next_proto.h +++ b/chromium/net/socket/next_proto.h @@ -14,20 +14,23 @@ namespace net { // Next Protocol Negotiation (NPN), if successful, results in agreement on an // application-level string that specifies the application level protocol to // use over the TLS connection. NextProto enumerates the application level -// protocols that we recognise. +// protocols that we recognize. Do not change or reuse values, because they +// are used to collect statistics on UMA. Also, values must be in [0,499), +// because of the way TLS protocol negotiation extension information is added to +// UMA histogram. enum NextProto { kProtoUnknown = 0, - kProtoHTTP11, + kProtoHTTP11 = 1, kProtoMinimumVersion = kProtoHTTP11, - kProtoDeprecatedSPDY2, + kProtoDeprecatedSPDY2 = 100, kProtoSPDYMinimumVersion = kProtoDeprecatedSPDY2, - kProtoSPDY3, - kProtoSPDY31, - kProtoSPDY4, // SPDY4 is HTTP/2. + kProtoSPDY3 = 101, + kProtoSPDY31 = 102, + kProtoSPDY4 = 103, // SPDY4 is HTTP/2. kProtoSPDYMaximumVersion = kProtoSPDY4, - kProtoQUIC1SPDY3, + kProtoQUIC1SPDY3 = 200, kProtoMaximumVersion = kProtoQUIC1SPDY3, }; @@ -47,9 +50,7 @@ NET_EXPORT NextProtoVector NextProtosWithSpdyAndQuic(bool spdy_enabled, bool quic_enabled); // All of these also enable QUIC. -NET_EXPORT NextProtoVector NextProtosSpdy3(); NET_EXPORT NextProtoVector NextProtosSpdy31(); -NET_EXPORT NextProtoVector NextProtosSpdy31WithSpdy2(); NET_EXPORT NextProtoVector NextProtosSpdy4Http2(); } // namespace net diff --git a/chromium/net/socket/nss_ssl_util.cc b/chromium/net/socket/nss_ssl_util.cc index 7b068545a55..a238a25d2d4 100644 --- a/chromium/net/socket/nss_ssl_util.cc +++ b/chromium/net/socket/nss_ssl_util.cc @@ -29,6 +29,8 @@ #include "base/win/windows_version.h" #endif +namespace net { + namespace { // CiphersRemove takes a zero-terminated array of cipher suite ids in @@ -77,9 +79,15 @@ size_t CiphersCopy(const uint16* in, uint16* out) { } } -} // anonymous namespace - -namespace net { +base::Value* NetLogSSLErrorCallback(int net_error, + int ssl_lib_error, + NetLog::LogLevel /* log_level */) { + base::DictionaryValue* dict = new base::DictionaryValue(); + dict->SetInteger("net_error", net_error); + if (ssl_lib_error) + dict->SetInteger("ssl_lib_error", ssl_lib_error); + return dict; +} class NSSSSLInitSingleton { public: @@ -201,9 +209,11 @@ class NSSSSLInitSingleton { PRFileDesc* model_fd_; }; -static base::LazyInstance<NSSSSLInitSingleton>::Leaky g_nss_ssl_init_singleton = +base::LazyInstance<NSSSSLInitSingleton>::Leaky g_nss_ssl_init_singleton = LAZY_INSTANCE_INITIALIZER; +} // anonymous namespace + // Initialize the NSS SSL library if it isn't already initialized. This must // be called before any other NSS SSL functions. This function is // thread-safe, and the NSS SSL library will only ever be initialized once. @@ -399,4 +409,9 @@ void LogFailedNSSFunction(const BoundNetLog& net_log, function, param, PR_GetError())); } +NetLog::ParametersCallback CreateNetLogSSLErrorCallback(int net_error, + int ssl_lib_error) { + return base::Bind(&NetLogSSLErrorCallback, net_error, ssl_lib_error); +} + } // namespace net diff --git a/chromium/net/socket/nss_ssl_util.h b/chromium/net/socket/nss_ssl_util.h index 3aed7bf6b4a..7b046ffd282 100644 --- a/chromium/net/socket/nss_ssl_util.h +++ b/chromium/net/socket/nss_ssl_util.h @@ -12,6 +12,7 @@ #include <prio.h> #include "net/base/net_export.h" +#include "net/base/net_log.h" namespace net { @@ -35,6 +36,11 @@ PRFileDesc* GetNSSModelSocket(); // Map NSS error code to network error code. int MapNSSError(PRErrorCode err); +// Creates a NetLog callback for an SSL error. +NetLog::ParametersCallback CreateNetLogSSLErrorCallback(int net_error, + int ssl_lib_error); + + } // namespace net #endif // NET_SOCKET_NSS_SSL_UTIL_H_ diff --git a/chromium/net/socket/openssl_ssl_util.cc b/chromium/net/socket/openssl_ssl_util.cc deleted file mode 100644 index 36b8e6ca0e3..00000000000 --- a/chromium/net/socket/openssl_ssl_util.cc +++ /dev/null @@ -1,156 +0,0 @@ -// Copyright 2014 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "net/socket/openssl_ssl_util.h" - -#include <openssl/err.h> -#include <openssl/ssl.h> - -#include "base/logging.h" -#include "crypto/openssl_util.h" -#include "net/base/net_errors.h" - -namespace net { - -SslSetClearMask::SslSetClearMask() - : set_mask(0), - clear_mask(0) { -} - -void SslSetClearMask::ConfigureFlag(long flag, bool state) { - (state ? set_mask : clear_mask) |= flag; - // Make sure we haven't got any intersection in the set & clear options. - DCHECK_EQ(0, set_mask & clear_mask) << flag << ":" << state; -} - -namespace { - -int MapOpenSSLErrorSSL() { - // Walk down the error stack to find the SSLerr generated reason. - unsigned long error_code; - do { - error_code = ERR_get_error(); - if (error_code == 0) - return ERR_SSL_PROTOCOL_ERROR; - } while (ERR_GET_LIB(error_code) != ERR_LIB_SSL); - - DVLOG(1) << "OpenSSL SSL error, reason: " << ERR_GET_REASON(error_code) - << ", name: " << ERR_error_string(error_code, NULL); - switch (ERR_GET_REASON(error_code)) { - case SSL_R_READ_TIMEOUT_EXPIRED: - return ERR_TIMED_OUT; - case SSL_R_BAD_RESPONSE_ARGUMENT: - return ERR_INVALID_ARGUMENT; - case SSL_R_UNKNOWN_CERTIFICATE_TYPE: - case SSL_R_UNKNOWN_CIPHER_TYPE: - case SSL_R_UNKNOWN_KEY_EXCHANGE_TYPE: - case SSL_R_UNKNOWN_PKEY_TYPE: - case SSL_R_UNKNOWN_REMOTE_ERROR_TYPE: - case SSL_R_UNKNOWN_SSL_VERSION: - return ERR_NOT_IMPLEMENTED; - case SSL_R_UNSUPPORTED_SSL_VERSION: - case SSL_R_NO_CIPHER_MATCH: - case SSL_R_NO_SHARED_CIPHER: - case SSL_R_TLSV1_ALERT_INSUFFICIENT_SECURITY: - case SSL_R_TLSV1_ALERT_PROTOCOL_VERSION: - case SSL_R_UNSUPPORTED_PROTOCOL: - return ERR_SSL_VERSION_OR_CIPHER_MISMATCH; - case SSL_R_SSLV3_ALERT_BAD_CERTIFICATE: - case SSL_R_SSLV3_ALERT_UNSUPPORTED_CERTIFICATE: - case SSL_R_SSLV3_ALERT_CERTIFICATE_REVOKED: - case SSL_R_SSLV3_ALERT_CERTIFICATE_EXPIRED: - case SSL_R_SSLV3_ALERT_CERTIFICATE_UNKNOWN: - case SSL_R_TLSV1_ALERT_ACCESS_DENIED: - case SSL_R_TLSV1_ALERT_UNKNOWN_CA: - return ERR_BAD_SSL_CLIENT_AUTH_CERT; - case SSL_R_BAD_DECOMPRESSION: - case SSL_R_SSLV3_ALERT_DECOMPRESSION_FAILURE: - return ERR_SSL_DECOMPRESSION_FAILURE_ALERT; - case SSL_R_SSLV3_ALERT_BAD_RECORD_MAC: - return ERR_SSL_BAD_RECORD_MAC_ALERT; - case SSL_R_TLSV1_ALERT_DECRYPT_ERROR: - return ERR_SSL_DECRYPT_ERROR_ALERT; - case SSL_R_TLSV1_UNRECOGNIZED_NAME: - return ERR_SSL_UNRECOGNIZED_NAME_ALERT; - case SSL_R_UNSAFE_LEGACY_RENEGOTIATION_DISABLED: - return ERR_SSL_UNSAFE_NEGOTIATION; - case SSL_R_WRONG_NUMBER_OF_KEY_BITS: - return ERR_SSL_WEAK_SERVER_EPHEMERAL_DH_KEY; - // SSL_R_UNKNOWN_PROTOCOL is reported if premature application data is - // received (see http://crbug.com/42538), and also if all the protocol - // versions supported by the server were disabled in this socket instance. - // Mapped to ERR_SSL_PROTOCOL_ERROR for compatibility with other SSL sockets - // in the former scenario. - case SSL_R_UNKNOWN_PROTOCOL: - case SSL_R_SSL_HANDSHAKE_FAILURE: - case SSL_R_DECRYPTION_FAILED: - case SSL_R_DECRYPTION_FAILED_OR_BAD_RECORD_MAC: - case SSL_R_DH_PUBLIC_VALUE_LENGTH_IS_WRONG: - case SSL_R_DIGEST_CHECK_FAILED: - case SSL_R_DUPLICATE_COMPRESSION_ID: - case SSL_R_ECGROUP_TOO_LARGE_FOR_CIPHER: - case SSL_R_ENCRYPTED_LENGTH_TOO_LONG: - case SSL_R_ERROR_IN_RECEIVED_CIPHER_LIST: - case SSL_R_EXCESSIVE_MESSAGE_SIZE: - case SSL_R_EXTRA_DATA_IN_MESSAGE: - case SSL_R_GOT_A_FIN_BEFORE_A_CCS: - case SSL_R_ILLEGAL_PADDING: - case SSL_R_INVALID_CHALLENGE_LENGTH: - case SSL_R_INVALID_COMMAND: - case SSL_R_INVALID_PURPOSE: - case SSL_R_INVALID_STATUS_RESPONSE: - case SSL_R_INVALID_TICKET_KEYS_LENGTH: - case SSL_R_KEY_ARG_TOO_LONG: - case SSL_R_READ_WRONG_PACKET_TYPE: - // SSL_do_handshake reports this error when the server responds to a - // ClientHello with a fatal close_notify alert. - case SSL_AD_REASON_OFFSET + SSL_AD_CLOSE_NOTIFY: - case SSL_R_SSLV3_ALERT_UNEXPECTED_MESSAGE: - // TODO(joth): SSL_R_SSLV3_ALERT_HANDSHAKE_FAILURE may be returned from the - // server after receiving ClientHello if there's no common supported cipher. - // Ideally we'd map that specific case to ERR_SSL_VERSION_OR_CIPHER_MISMATCH - // to match the NSS implementation. See also http://goo.gl/oMtZW - case SSL_R_SSLV3_ALERT_HANDSHAKE_FAILURE: - case SSL_R_SSLV3_ALERT_NO_CERTIFICATE: - case SSL_R_SSLV3_ALERT_ILLEGAL_PARAMETER: - case SSL_R_TLSV1_ALERT_DECODE_ERROR: - case SSL_R_TLSV1_ALERT_DECRYPTION_FAILED: - case SSL_R_TLSV1_ALERT_EXPORT_RESTRICTION: - case SSL_R_TLSV1_ALERT_INTERNAL_ERROR: - case SSL_R_TLSV1_ALERT_NO_RENEGOTIATION: - case SSL_R_TLSV1_ALERT_RECORD_OVERFLOW: - case SSL_R_TLSV1_ALERT_USER_CANCELLED: - return ERR_SSL_PROTOCOL_ERROR; - case SSL_R_CERTIFICATE_VERIFY_FAILED: - // The only way that the certificate verify callback can fail is if - // the leaf certificate changed during a renegotiation. - return ERR_SSL_SERVER_CERT_CHANGED; - default: - LOG(WARNING) << "Unmapped error reason: " << ERR_GET_REASON(error_code); - return ERR_FAILED; - } -} - -} // namespace - -int MapOpenSSLError(int err, const crypto::OpenSSLErrStackTracer& tracer) { - switch (err) { - case SSL_ERROR_WANT_READ: - case SSL_ERROR_WANT_WRITE: - return ERR_IO_PENDING; - case SSL_ERROR_SYSCALL: - LOG(ERROR) << "OpenSSL SYSCALL error, earliest error code in " - "error queue: " << ERR_peek_error() << ", errno: " - << errno; - return ERR_SSL_PROTOCOL_ERROR; - case SSL_ERROR_SSL: - return MapOpenSSLErrorSSL(); - default: - // TODO(joth): Implement full mapping. - LOG(WARNING) << "Unknown OpenSSL error " << err; - return ERR_SSL_PROTOCOL_ERROR; - } -} - -} // namespace net diff --git a/chromium/net/socket/openssl_ssl_util.h b/chromium/net/socket/openssl_ssl_util.h deleted file mode 100644 index e459a445ef3..00000000000 --- a/chromium/net/socket/openssl_ssl_util.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2014 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef NET_SOCKET_OPENSSL_SSL_UTIL_H_ -#define NET_SOCKET_OPENSSL_SSL_UTIL_H_ - -namespace crypto { -class OpenSSLErrStackTracer; -} - -namespace net { - -// Utility to construct the appropriate set & clear masks for use the OpenSSL -// options and mode configuration functions. (SSL_set_options etc) -struct SslSetClearMask { - SslSetClearMask(); - void ConfigureFlag(long flag, bool state); - - long set_mask; - long clear_mask; -}; - -// Converts an OpenSSL error code into a net error code, walking the OpenSSL -// error stack if needed. Note that |tracer| is not currently used in the -// implementation, but is passed in anyway as this ensures the caller will clear -// any residual codes left on the error stack. -int MapOpenSSLError(int err, const crypto::OpenSSLErrStackTracer& tracer); - -} // namespace net - -#endif // NET_SOCKET_OPENSSL_SSL_UTIL_H_ diff --git a/chromium/net/socket/server_socket.cc b/chromium/net/socket/server_socket.cc new file mode 100644 index 00000000000..da89b4645f8 --- /dev/null +++ b/chromium/net/socket/server_socket.cc @@ -0,0 +1,30 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/server_socket.h" + +#include "net/base/ip_endpoint.h" +#include "net/base/net_errors.h" +#include "net/base/net_util.h" + +namespace net { + +ServerSocket::ServerSocket() { +} + +ServerSocket::~ServerSocket() { +} + +int ServerSocket::ListenWithAddressAndPort(const std::string& address_string, + int port, + int backlog) { + IPAddressNumber address_number; + if (!ParseIPLiteralToNumber(address_string, &address_number)) { + return ERR_ADDRESS_INVALID; + } + + return Listen(IPEndPoint(address_number, port), backlog); +} + +} // namespace net diff --git a/chromium/net/socket/server_socket.h b/chromium/net/socket/server_socket.h index 11151eea153..4b9ca8e39cf 100644 --- a/chromium/net/socket/server_socket.h +++ b/chromium/net/socket/server_socket.h @@ -5,6 +5,8 @@ #ifndef NET_SOCKET_SERVER_SOCKET_H_ #define NET_SOCKET_SERVER_SOCKET_H_ +#include <string> + #include "base/memory/scoped_ptr.h" #include "net/base/completion_callback.h" #include "net/base/net_export.h" @@ -16,17 +18,25 @@ class StreamSocket; class NET_EXPORT ServerSocket { public: - ServerSocket() { } - virtual ~ServerSocket() { } + ServerSocket(); + virtual ~ServerSocket(); - // Bind the socket and start listening. Destroy the socket to stop + // Binds the socket and starts listening. Destroys the socket to stop // listening. - virtual int Listen(const net::IPEndPoint& address, int backlog) = 0; + virtual int Listen(const IPEndPoint& address, int backlog) = 0; + + // Binds the socket with address and port, and starts listening. It expects + // a valid IPv4 or IPv6 address. Otherwise, it returns ERR_ADDRESS_INVALID. + // Subclasses may override this function if |address_string| is in a different + // format, for example, unix domain socket path. + virtual int ListenWithAddressAndPort(const std::string& address_string, + int port, + int backlog); // Gets current address the socket is bound to. virtual int GetLocalAddress(IPEndPoint* address) const = 0; - // Accept connection. Callback is called when new connection is + // Accepts connection. Callback is called when new connection is // accepted. virtual int Accept(scoped_ptr<StreamSocket>* socket, const CompletionCallback& callback) = 0; diff --git a/chromium/net/socket/socket_libevent.cc b/chromium/net/socket/socket_libevent.cc new file mode 100644 index 00000000000..5f16c929ab2 --- /dev/null +++ b/chromium/net/socket/socket_libevent.cc @@ -0,0 +1,482 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/socket_libevent.h" + +#include <errno.h> +#include <netinet/in.h> +#include <sys/socket.h> + +#include "base/callback_helpers.h" +#include "base/logging.h" +#include "base/posix/eintr_wrapper.h" +#include "net/base/io_buffer.h" +#include "net/base/ip_endpoint.h" +#include "net/base/net_errors.h" +#include "net/base/net_util.h" + +namespace net { + +namespace { + +int MapAcceptError(int os_error) { + switch (os_error) { + // If the client aborts the connection before the server calls accept, + // POSIX specifies accept should fail with ECONNABORTED. The server can + // ignore the error and just call accept again, so we map the error to + // ERR_IO_PENDING. See UNIX Network Programming, Vol. 1, 3rd Ed., Sec. + // 5.11, "Connection Abort before accept Returns". + case ECONNABORTED: + return ERR_IO_PENDING; + default: + return MapSystemError(os_error); + } +} + +int MapConnectError(int os_error) { + switch (os_error) { + case EINPROGRESS: + return ERR_IO_PENDING; + case EACCES: + return ERR_NETWORK_ACCESS_DENIED; + case ETIMEDOUT: + return ERR_CONNECTION_TIMED_OUT; + default: { + int net_error = MapSystemError(os_error); + if (net_error == ERR_FAILED) + return ERR_CONNECTION_FAILED; // More specific than ERR_FAILED. + return net_error; + } + } +} + +} // namespace + +SocketLibevent::SocketLibevent() + : socket_fd_(kInvalidSocket), + read_buf_len_(0), + write_buf_len_(0), + waiting_connect_(false) { +} + +SocketLibevent::~SocketLibevent() { + Close(); +} + +int SocketLibevent::Open(int address_family) { + DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_EQ(kInvalidSocket, socket_fd_); + DCHECK(address_family == AF_INET || + address_family == AF_INET6 || + address_family == AF_UNIX); + + socket_fd_ = CreatePlatformSocket( + address_family, + SOCK_STREAM, + address_family == AF_UNIX ? 0 : IPPROTO_TCP); + if (socket_fd_ < 0) { + PLOG(ERROR) << "CreatePlatformSocket() returned an error, errno=" << errno; + return MapSystemError(errno); + } + + if (SetNonBlocking(socket_fd_)) { + int rv = MapSystemError(errno); + Close(); + return rv; + } + + return OK; +} + +int SocketLibevent::AdoptConnectedSocket(SocketDescriptor socket, + const SockaddrStorage& address) { + DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_EQ(kInvalidSocket, socket_fd_); + + socket_fd_ = socket; + + if (SetNonBlocking(socket_fd_)) { + int rv = MapSystemError(errno); + Close(); + return rv; + } + + SetPeerAddress(address); + return OK; +} + +SocketDescriptor SocketLibevent::ReleaseConnectedSocket() { + StopWatchingAndCleanUp(); + SocketDescriptor socket_fd = socket_fd_; + socket_fd_ = kInvalidSocket; + return socket_fd; +} + +int SocketLibevent::Bind(const SockaddrStorage& address) { + DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_NE(kInvalidSocket, socket_fd_); + + int rv = bind(socket_fd_, address.addr, address.addr_len); + if (rv < 0) { + PLOG(ERROR) << "bind() returned an error, errno=" << errno; + return MapSystemError(errno); + } + + return OK; +} + +int SocketLibevent::Listen(int backlog) { + DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_NE(kInvalidSocket, socket_fd_); + DCHECK_LT(0, backlog); + + int rv = listen(socket_fd_, backlog); + if (rv < 0) { + PLOG(ERROR) << "listen() returned an error, errno=" << errno; + return MapSystemError(errno); + } + + return OK; +} + +int SocketLibevent::Accept(scoped_ptr<SocketLibevent>* socket, + const CompletionCallback& callback) { + DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_NE(kInvalidSocket, socket_fd_); + DCHECK(accept_callback_.is_null()); + DCHECK(socket); + DCHECK(!callback.is_null()); + + int rv = DoAccept(socket); + if (rv != ERR_IO_PENDING) + return rv; + + if (!base::MessageLoopForIO::current()->WatchFileDescriptor( + socket_fd_, true, base::MessageLoopForIO::WATCH_READ, + &accept_socket_watcher_, this)) { + PLOG(ERROR) << "WatchFileDescriptor failed on accept, errno " << errno; + return MapSystemError(errno); + } + + accept_socket_ = socket; + accept_callback_ = callback; + return ERR_IO_PENDING; +} + +int SocketLibevent::Connect(const SockaddrStorage& address, + const CompletionCallback& callback) { + DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_NE(kInvalidSocket, socket_fd_); + DCHECK(!waiting_connect_); + DCHECK(!callback.is_null()); + + SetPeerAddress(address); + + int rv = DoConnect(); + if (rv != ERR_IO_PENDING) + return rv; + + if (!base::MessageLoopForIO::current()->WatchFileDescriptor( + socket_fd_, true, base::MessageLoopForIO::WATCH_WRITE, + &write_socket_watcher_, this)) { + PLOG(ERROR) << "WatchFileDescriptor failed on connect, errno " << errno; + return MapSystemError(errno); + } + + write_callback_ = callback; + waiting_connect_ = true; + return ERR_IO_PENDING; +} + +bool SocketLibevent::IsConnected() const { + DCHECK(thread_checker_.CalledOnValidThread()); + + if (socket_fd_ == kInvalidSocket || waiting_connect_) + return false; + + // Checks if connection is alive. + char c; + int rv = HANDLE_EINTR(recv(socket_fd_, &c, 1, MSG_PEEK)); + if (rv == 0) + return false; + if (rv == -1 && errno != EAGAIN && errno != EWOULDBLOCK) + return false; + + return true; +} + +bool SocketLibevent::IsConnectedAndIdle() const { + DCHECK(thread_checker_.CalledOnValidThread()); + + if (socket_fd_ == kInvalidSocket || waiting_connect_) + return false; + + // Check if connection is alive and we haven't received any data + // unexpectedly. + char c; + int rv = HANDLE_EINTR(recv(socket_fd_, &c, 1, MSG_PEEK)); + if (rv >= 0) + return false; + if (errno != EAGAIN && errno != EWOULDBLOCK) + return false; + + return true; +} + +int SocketLibevent::Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_NE(kInvalidSocket, socket_fd_); + DCHECK(!waiting_connect_); + CHECK(read_callback_.is_null()); + // Synchronous operation not supported + DCHECK(!callback.is_null()); + DCHECK_LT(0, buf_len); + + int rv = DoRead(buf, buf_len); + if (rv != ERR_IO_PENDING) + return rv; + + if (!base::MessageLoopForIO::current()->WatchFileDescriptor( + socket_fd_, true, base::MessageLoopForIO::WATCH_READ, + &read_socket_watcher_, this)) { + PLOG(ERROR) << "WatchFileDescriptor failed on read, errno " << errno; + return MapSystemError(errno); + } + + read_buf_ = buf; + read_buf_len_ = buf_len; + read_callback_ = callback; + return ERR_IO_PENDING; +} + +int SocketLibevent::Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_NE(kInvalidSocket, socket_fd_); + DCHECK(!waiting_connect_); + CHECK(write_callback_.is_null()); + // Synchronous operation not supported + DCHECK(!callback.is_null()); + DCHECK_LT(0, buf_len); + + int rv = DoWrite(buf, buf_len); + if (rv == ERR_IO_PENDING) + rv = WaitForWrite(buf, buf_len, callback); + return rv; +} + +int SocketLibevent::WaitForWrite(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_NE(kInvalidSocket, socket_fd_); + DCHECK(write_callback_.is_null()); + // Synchronous operation not supported + DCHECK(!callback.is_null()); + DCHECK_LT(0, buf_len); + + if (!base::MessageLoopForIO::current()->WatchFileDescriptor( + socket_fd_, true, base::MessageLoopForIO::WATCH_WRITE, + &write_socket_watcher_, this)) { + PLOG(ERROR) << "WatchFileDescriptor failed on write, errno " << errno; + return MapSystemError(errno); + } + + write_buf_ = buf; + write_buf_len_ = buf_len; + write_callback_ = callback; + return ERR_IO_PENDING; +} + +int SocketLibevent::GetLocalAddress(SockaddrStorage* address) const { + DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK(address); + + if (getsockname(socket_fd_, address->addr, &address->addr_len) < 0) + return MapSystemError(errno); + return OK; +} + +int SocketLibevent::GetPeerAddress(SockaddrStorage* address) const { + DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK(address); + + if (!HasPeerAddress()) + return ERR_SOCKET_NOT_CONNECTED; + + *address = *peer_address_; + return OK; +} + +void SocketLibevent::SetPeerAddress(const SockaddrStorage& address) { + DCHECK(thread_checker_.CalledOnValidThread()); + // |peer_address_| will be non-NULL if Connect() has been called. Unless + // Close() is called to reset the internal state, a second call to Connect() + // is not allowed. + // Please note that we don't allow a second Connect() even if the previous + // Connect() has failed. Connecting the same |socket_| again after a + // connection attempt failed results in unspecified behavior according to + // POSIX. + DCHECK(!peer_address_); + peer_address_.reset(new SockaddrStorage(address)); +} + +bool SocketLibevent::HasPeerAddress() const { + DCHECK(thread_checker_.CalledOnValidThread()); + return peer_address_ != NULL; +} + +void SocketLibevent::Close() { + DCHECK(thread_checker_.CalledOnValidThread()); + + StopWatchingAndCleanUp(); + + if (socket_fd_ != kInvalidSocket) { + if (IGNORE_EINTR(close(socket_fd_)) < 0) + PLOG(ERROR) << "close() returned an error, errno=" << errno; + socket_fd_ = kInvalidSocket; + } +} + +void SocketLibevent::OnFileCanReadWithoutBlocking(int fd) { + DCHECK(!accept_callback_.is_null() || !read_callback_.is_null()); + if (!accept_callback_.is_null()) { + AcceptCompleted(); + } else { // !read_callback_.is_null() + ReadCompleted(); + } +} + +void SocketLibevent::OnFileCanWriteWithoutBlocking(int fd) { + DCHECK(!write_callback_.is_null()); + if (waiting_connect_) { + ConnectCompleted(); + } else { + WriteCompleted(); + } +} + +int SocketLibevent::DoAccept(scoped_ptr<SocketLibevent>* socket) { + SockaddrStorage new_peer_address; + int new_socket = HANDLE_EINTR(accept(socket_fd_, + new_peer_address.addr, + &new_peer_address.addr_len)); + if (new_socket < 0) + return MapAcceptError(errno); + + scoped_ptr<SocketLibevent> accepted_socket(new SocketLibevent); + int rv = accepted_socket->AdoptConnectedSocket(new_socket, new_peer_address); + if (rv != OK) + return rv; + + *socket = accepted_socket.Pass(); + return OK; +} + +void SocketLibevent::AcceptCompleted() { + DCHECK(accept_socket_); + int rv = DoAccept(accept_socket_); + if (rv == ERR_IO_PENDING) + return; + + bool ok = accept_socket_watcher_.StopWatchingFileDescriptor(); + DCHECK(ok); + accept_socket_ = NULL; + base::ResetAndReturn(&accept_callback_).Run(rv); +} + +int SocketLibevent::DoConnect() { + int rv = HANDLE_EINTR(connect(socket_fd_, + peer_address_->addr, + peer_address_->addr_len)); + DCHECK_GE(0, rv); + return rv == 0 ? OK : MapConnectError(errno); +} + +void SocketLibevent::ConnectCompleted() { + // Get the error that connect() completed with. + int os_error = 0; + socklen_t len = sizeof(os_error); + if (getsockopt(socket_fd_, SOL_SOCKET, SO_ERROR, &os_error, &len) == 0) { + // TCPSocketLibevent expects errno to be set. + errno = os_error; + } + + int rv = MapConnectError(errno); + if (rv == ERR_IO_PENDING) + return; + + bool ok = write_socket_watcher_.StopWatchingFileDescriptor(); + DCHECK(ok); + waiting_connect_ = false; + base::ResetAndReturn(&write_callback_).Run(rv); +} + +int SocketLibevent::DoRead(IOBuffer* buf, int buf_len) { + int rv = HANDLE_EINTR(read(socket_fd_, buf->data(), buf_len)); + return rv >= 0 ? rv : MapSystemError(errno); +} + +void SocketLibevent::ReadCompleted() { + int rv = DoRead(read_buf_.get(), read_buf_len_); + if (rv == ERR_IO_PENDING) + return; + + bool ok = read_socket_watcher_.StopWatchingFileDescriptor(); + DCHECK(ok); + read_buf_ = NULL; + read_buf_len_ = 0; + base::ResetAndReturn(&read_callback_).Run(rv); +} + +int SocketLibevent::DoWrite(IOBuffer* buf, int buf_len) { + int rv = HANDLE_EINTR(write(socket_fd_, buf->data(), buf_len)); + return rv >= 0 ? rv : MapSystemError(errno); +} + +void SocketLibevent::WriteCompleted() { + int rv = DoWrite(write_buf_.get(), write_buf_len_); + if (rv == ERR_IO_PENDING) + return; + + bool ok = write_socket_watcher_.StopWatchingFileDescriptor(); + DCHECK(ok); + write_buf_ = NULL; + write_buf_len_ = 0; + base::ResetAndReturn(&write_callback_).Run(rv); +} + +void SocketLibevent::StopWatchingAndCleanUp() { + bool ok = accept_socket_watcher_.StopWatchingFileDescriptor(); + DCHECK(ok); + ok = read_socket_watcher_.StopWatchingFileDescriptor(); + DCHECK(ok); + ok = write_socket_watcher_.StopWatchingFileDescriptor(); + DCHECK(ok); + + if (!accept_callback_.is_null()) { + accept_socket_ = NULL; + accept_callback_.Reset(); + } + + if (!read_callback_.is_null()) { + read_buf_ = NULL; + read_buf_len_ = 0; + read_callback_.Reset(); + } + + if (!write_callback_.is_null()) { + write_buf_ = NULL; + write_buf_len_ = 0; + write_callback_.Reset(); + } + + waiting_connect_ = false; + peer_address_.reset(); +} + +} // namespace net diff --git a/chromium/net/socket/socket_libevent.h b/chromium/net/socket/socket_libevent.h new file mode 100644 index 00000000000..00a0ca6fa7c --- /dev/null +++ b/chromium/net/socket/socket_libevent.h @@ -0,0 +1,132 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_SOCKET_LIBEVENT_H_ +#define NET_SOCKET_SOCKET_LIBEVENT_H_ + +#include "base/basictypes.h" +#include "base/compiler_specific.h" +#include "base/macros.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "base/message_loop/message_loop.h" +#include "base/threading/thread_checker.h" +#include "net/base/completion_callback.h" +#include "net/base/net_util.h" +#include "net/socket/socket_descriptor.h" + +namespace net { + +class IOBuffer; +class IPEndPoint; + +// Socket class to provide asynchronous read/write operations on top of the +// posix socket api. It supports AF_INET, AF_INET6, and AF_UNIX addresses. +class NET_EXPORT_PRIVATE SocketLibevent + : public base::MessageLoopForIO::Watcher { + public: + SocketLibevent(); + ~SocketLibevent() override; + + // Opens a socket and returns net::OK if |address_family| is AF_INET, AF_INET6 + // or AF_UNIX. Otherwise, it does DCHECK() and returns a net error. + int Open(int address_family); + // Takes ownership of |socket|. + int AdoptConnectedSocket(SocketDescriptor socket, + const SockaddrStorage& peer_address); + // Releases ownership of |socket_fd_| to caller. + SocketDescriptor ReleaseConnectedSocket(); + + int Bind(const SockaddrStorage& address); + + int Listen(int backlog); + int Accept(scoped_ptr<SocketLibevent>* socket, + const CompletionCallback& callback); + + // Connects socket. On non-ERR_IO_PENDING error, sets errno and returns a net + // error code. On ERR_IO_PENDING, |callback| is called with a net error code, + // not errno, though errno is set if connect event happens with error. + // TODO(byungchul): Need more robust way to pass system errno. + int Connect(const SockaddrStorage& address, + const CompletionCallback& callback); + bool IsConnected() const; + bool IsConnectedAndIdle() const; + + // Multiple outstanding requests of the same type are not supported. + // Full duplex mode (reading and writing at the same time) is supported. + // On error which is not ERR_IO_PENDING, sets errno and returns a net error + // code. On ERR_IO_PENDING, |callback| is called with a net error code, not + // errno, though errno is set if read or write events happen with error. + // TODO(byungchul): Need more robust way to pass system errno. + int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback); + int Write(IOBuffer* buf, int buf_len, const CompletionCallback& callback); + + // Waits for next write event. This is called by TCPsocketLibevent for TCP + // fastopen after sending first data. Returns ERR_IO_PENDING if it starts + // waiting for write event successfully. Otherwise, returns a net error code. + // It must not be called after Write() because Write() calls it internally. + int WaitForWrite(IOBuffer* buf, int buf_len, + const CompletionCallback& callback); + + int GetLocalAddress(SockaddrStorage* address) const; + int GetPeerAddress(SockaddrStorage* address) const; + void SetPeerAddress(const SockaddrStorage& address); + // Returns true if peer address has been set regardless of socket state. + bool HasPeerAddress() const; + + void Close(); + + SocketDescriptor socket_fd() const { return socket_fd_; } + + private: + // base::MessageLoopForIO::Watcher methods. + void OnFileCanReadWithoutBlocking(int fd) override; + void OnFileCanWriteWithoutBlocking(int fd) override; + + int DoAccept(scoped_ptr<SocketLibevent>* socket); + void AcceptCompleted(); + + int DoConnect(); + void ConnectCompleted(); + + int DoRead(IOBuffer* buf, int buf_len); + void ReadCompleted(); + + int DoWrite(IOBuffer* buf, int buf_len); + void WriteCompleted(); + + void StopWatchingAndCleanUp(); + + SocketDescriptor socket_fd_; + + base::MessageLoopForIO::FileDescriptorWatcher accept_socket_watcher_; + scoped_ptr<SocketLibevent>* accept_socket_; + CompletionCallback accept_callback_; + + base::MessageLoopForIO::FileDescriptorWatcher read_socket_watcher_; + scoped_refptr<IOBuffer> read_buf_; + int read_buf_len_; + // External callback; called when read is complete. + CompletionCallback read_callback_; + + base::MessageLoopForIO::FileDescriptorWatcher write_socket_watcher_; + scoped_refptr<IOBuffer> write_buf_; + int write_buf_len_; + // External callback; called when write or connect is complete. + CompletionCallback write_callback_; + + // A connect operation is pending. In this case, |write_callback_| needs to be + // called when connect is complete. + bool waiting_connect_; + + scoped_ptr<SockaddrStorage> peer_address_; + + base::ThreadChecker thread_checker_; + + DISALLOW_COPY_AND_ASSIGN(SocketLibevent); +}; + +} // namespace net + +#endif // NET_SOCKET_SOCKET_LIBEVENT_H_ diff --git a/chromium/net/socket/socket_test_util.cc b/chromium/net/socket/socket_test_util.cc index f993801adeb..5ae4eeeedb8 100644 --- a/chromium/net/socket/socket_test_util.cc +++ b/chromium/net/socket/socket_test_util.cc @@ -10,6 +10,7 @@ #include "base/basictypes.h" #include "base/bind.h" #include "base/bind_helpers.h" +#include "base/callback_helpers.h" #include "base/compiler_specific.h" #include "base/message_loop/message_loop.h" #include "base/run_loop.h" @@ -277,7 +278,9 @@ SSLSocketDataProvider::SSLSocketDataProvider(IoMode mode, int result) client_cert_sent(false), cert_request_info(NULL), channel_id_sent(false), - connection_status(0) { + connection_status(0), + should_pause_on_connect(false), + is_in_session_cache(false) { SSLConnectionStatusSetVersion(SSL_CONNECTION_VERSION_TLS1_2, &connection_status); // Set to TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305 @@ -680,7 +683,7 @@ MockClientSocketFactory::CreateDatagramClientSocket( data_provider->set_socket(socket.get()); if (bind_type == DatagramSocket::RANDOM_BIND) socket->set_source_port(rand_int_cb.Run(1025, 65535)); - return socket.PassAs<DatagramClientSocket>(); + return socket.Pass(); } scoped_ptr<StreamSocket> MockClientSocketFactory::CreateTransportClientSocket( @@ -691,7 +694,7 @@ scoped_ptr<StreamSocket> MockClientSocketFactory::CreateTransportClientSocket( scoped_ptr<MockTCPClientSocket> socket( new MockTCPClientSocket(addresses, net_log, data_provider)); data_provider->set_socket(socket.get()); - return socket.PassAs<StreamSocket>(); + return socket.Pass(); } scoped_ptr<SSLClientSocket> MockClientSocketFactory::CreateSSLClientSocket( @@ -699,10 +702,13 @@ scoped_ptr<SSLClientSocket> MockClientSocketFactory::CreateSSLClientSocket( const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context) { - return scoped_ptr<SSLClientSocket>( + scoped_ptr<MockSSLClientSocket> socket( new MockSSLClientSocket(transport_socket.Pass(), - host_and_port, ssl_config, + host_and_port, + ssl_config, mock_ssl_data_.GetNext())); + ssl_client_sockets_.push_back(socket.get()); + return socket.Pass(); } void MockClientSocketFactory::ClearSSLSessionCache() { @@ -758,6 +764,20 @@ const BoundNetLog& MockClientSocket::NetLog() const { return net_log_; } +std::string MockClientSocket::GetSessionCacheKey() const { + NOTIMPLEMENTED(); + return std::string(); +} + +bool MockClientSocket::InSessionCache() const { + NOTIMPLEMENTED(); + return false; +} + +void MockClientSocket::SetHandshakeCompletionCallback(const base::Closure& cb) { + NOTIMPLEMENTED(); +} + void MockClientSocket::GetSSLCertRequestInfo( SSLCertRequestInfo* cert_request_info) { } @@ -776,15 +796,14 @@ int MockClientSocket::GetTLSUniqueChannelBinding(std::string* out) { return OK; } -ServerBoundCertService* MockClientSocket::GetServerBoundCertService() const { +ChannelIDService* MockClientSocket::GetChannelIDService() const { NOTREACHED(); return NULL; } SSLClientSocket::NextProtoStatus -MockClientSocket::GetNextProto(std::string* proto, std::string* server_protos) { +MockClientSocket::GetNextProto(std::string* proto) { proto->clear(); - server_protos->clear(); return SSLClientSocket::kNextProtoUnsupported; } @@ -838,7 +857,7 @@ int MockTCPClientSocket::Read(IOBuffer* buf, int buf_len, return ERR_UNEXPECTED; // If the buffer is already in use, a read is already in progress! - DCHECK(pending_buf_ == NULL); + DCHECK(pending_buf_.get() == NULL); // Store our async IO data. pending_buf_ = buf; @@ -946,7 +965,7 @@ bool MockTCPClientSocket::GetSSLInfo(SSLInfo* ssl_info) { void MockTCPClientSocket::OnReadComplete(const MockRead& data) { // There must be a read pending. - DCHECK(pending_buf_); + DCHECK(pending_buf_.get()); // You can't complete a read with another ERR_IO_PENDING status code. DCHECK_NE(ERR_IO_PENDING, data.result); // Since we've been waiting for data, need_read_data_ should be true. @@ -970,7 +989,7 @@ void MockTCPClientSocket::OnConnectComplete(const MockConnect& data) { } int MockTCPClientSocket::CompleteRead() { - DCHECK(pending_buf_); + DCHECK(pending_buf_.get()); DCHECK(pending_buf_len_ > 0); was_used_to_convey_data_ = true; @@ -1298,31 +1317,25 @@ void DeterministicMockTCPClientSocket::OnReadComplete(const MockRead& data) {} void DeterministicMockTCPClientSocket::OnConnectComplete( const MockConnect& data) {} -// static -void MockSSLClientSocket::ConnectCallback( - MockSSLClientSocket* ssl_client_socket, - const CompletionCallback& callback, - int rv) { - if (rv == OK) - ssl_client_socket->connected_ = true; - callback.Run(rv); -} - MockSSLClientSocket::MockSSLClientSocket( scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_port_pair, const SSLConfig& ssl_config, SSLSocketDataProvider* data) : MockClientSocket( - // Have to use the right BoundNetLog for LoadTimingInfo regression - // tests. - transport_socket->socket()->NetLog()), + // Have to use the right BoundNetLog for LoadTimingInfo regression + // tests. + transport_socket->socket()->NetLog()), transport_(transport_socket.Pass()), + host_port_pair_(host_port_pair), data_(data), is_npn_state_set_(false), new_npn_value_(false), is_protocol_negotiated_set_(false), - protocol_negotiated_(kProtoUnknown) { + protocol_negotiated_(kProtoUnknown), + next_connect_state_(STATE_NONE), + reached_connect_(false), + weak_factory_(this) { DCHECK(data_); peer_addr_ = data->connect.peer_addr; } @@ -1342,28 +1355,23 @@ int MockSSLClientSocket::Write(IOBuffer* buf, int buf_len, } int MockSSLClientSocket::Connect(const CompletionCallback& callback) { - int rv = transport_->socket()->Connect( - base::Bind(&ConnectCallback, base::Unretained(this), callback)); - if (rv == OK) { - if (data_->connect.result == OK) - connected_ = true; - if (data_->connect.mode == ASYNC) { - RunCallbackAsync(callback, data_->connect.result); - return ERR_IO_PENDING; - } - return data_->connect.result; - } + next_connect_state_ = STATE_SSL_CONNECT; + reached_connect_ = true; + int rv = DoConnectLoop(OK); + if (rv == ERR_IO_PENDING) + connect_callback_ = callback; return rv; } void MockSSLClientSocket::Disconnect() { + weak_factory_.InvalidateWeakPtrs(); MockClientSocket::Disconnect(); if (transport_->socket() != NULL) transport_->socket()->Disconnect(); } bool MockSSLClientSocket::IsConnected() const { - return transport_->socket()->IsConnected(); + return transport_->socket()->IsConnected() && connected_; } bool MockSSLClientSocket::WasEverUsed() const { @@ -1387,6 +1395,21 @@ bool MockSSLClientSocket::GetSSLInfo(SSLInfo* ssl_info) { return true; } +std::string MockSSLClientSocket::GetSessionCacheKey() const { + // For the purposes of these tests, |host_and_port| will serve as the + // cache key. + return host_port_pair_.ToString(); +} + +bool MockSSLClientSocket::InSessionCache() const { + return data_->is_in_session_cache; +} + +void MockSSLClientSocket::SetHandshakeCompletionCallback( + const base::Closure& cb) { + handshake_completion_callback_ = cb; +} + void MockSSLClientSocket::GetSSLCertRequestInfo( SSLCertRequestInfo* cert_request_info) { DCHECK(cert_request_info); @@ -1400,9 +1423,8 @@ void MockSSLClientSocket::GetSSLCertRequestInfo( } SSLClientSocket::NextProtoStatus MockSSLClientSocket::GetNextProto( - std::string* proto, std::string* server_protos) { + std::string* proto) { *proto = data_->next_proto; - *server_protos = data_->server_protos; return data_->next_proto_status; } @@ -1437,8 +1459,8 @@ void MockSSLClientSocket::set_channel_id_sent(bool channel_id_sent) { data_->channel_id_sent = channel_id_sent; } -ServerBoundCertService* MockSSLClientSocket::GetServerBoundCertService() const { - return data_->server_bound_cert_service; +ChannelIDService* MockSSLClientSocket::GetChannelIDService() const { + return data_->channel_id_service; } void MockSSLClientSocket::OnReadComplete(const MockRead& data) { @@ -1449,6 +1471,69 @@ void MockSSLClientSocket::OnConnectComplete(const MockConnect& data) { NOTIMPLEMENTED(); } +void MockSSLClientSocket::RestartPausedConnect() { + DCHECK(data_->should_pause_on_connect); + DCHECK_EQ(next_connect_state_, STATE_SSL_CONNECT_COMPLETE); + OnIOComplete(data_->connect.result); +} + +void MockSSLClientSocket::OnIOComplete(int result) { + int rv = DoConnectLoop(result); + if (rv != ERR_IO_PENDING) + base::ResetAndReturn(&connect_callback_).Run(rv); +} + +int MockSSLClientSocket::DoConnectLoop(int result) { + DCHECK_NE(next_connect_state_, STATE_NONE); + + int rv = result; + do { + ConnectState state = next_connect_state_; + next_connect_state_ = STATE_NONE; + switch (state) { + case STATE_SSL_CONNECT: + rv = DoSSLConnect(); + break; + case STATE_SSL_CONNECT_COMPLETE: + rv = DoSSLConnectComplete(rv); + break; + default: + NOTREACHED() << "bad state"; + rv = ERR_UNEXPECTED; + break; + } + } while (rv != ERR_IO_PENDING && next_connect_state_ != STATE_NONE); + + return rv; +} + +int MockSSLClientSocket::DoSSLConnect() { + next_connect_state_ = STATE_SSL_CONNECT_COMPLETE; + + if (data_->should_pause_on_connect) + return ERR_IO_PENDING; + + if (data_->connect.mode == ASYNC) { + base::MessageLoop::current()->PostTask( + FROM_HERE, + base::Bind(&MockSSLClientSocket::OnIOComplete, + weak_factory_.GetWeakPtr(), + data_->connect.result)); + return ERR_IO_PENDING; + } + + return data_->connect.result; +} + +int MockSSLClientSocket::DoSSLConnectComplete(int result) { + if (result == OK) + connected_ = true; + + if (!handshake_completion_callback_.is_null()) + base::ResetAndReturn(&handshake_completion_callback_).Run(); + return result; +} + MockUDPClientSocket::MockUDPClientSocket(SocketDataProvider* data, net::NetLog* net_log) : connected_(false), @@ -1475,7 +1560,7 @@ int MockUDPClientSocket::Read(IOBuffer* buf, return ERR_UNEXPECTED; // If the buffer is already in use, a read is already in progress! - DCHECK(pending_buf_ == NULL); + DCHECK(pending_buf_.get() == NULL); // Store our async IO data. pending_buf_ = buf; @@ -1552,7 +1637,7 @@ int MockUDPClientSocket::Connect(const IPEndPoint& address) { void MockUDPClientSocket::OnReadComplete(const MockRead& data) { // There must be a read pending. - DCHECK(pending_buf_); + DCHECK(pending_buf_.get()); // You can't complete a read with another ERR_IO_PENDING status code. DCHECK_NE(ERR_IO_PENDING, data.result); // Since we've been waiting for data, need_read_data_ should be true. @@ -1575,7 +1660,7 @@ void MockUDPClientSocket::OnConnectComplete(const MockConnect& data) { } int MockUDPClientSocket::CompleteRead() { - DCHECK(pending_buf_); + DCHECK(pending_buf_.get()); DCHECK(pending_buf_len_ > 0); // Save the pending async IO data and reset our |pending_| state. @@ -1833,7 +1918,7 @@ DeterministicMockClientSocketFactory::CreateDatagramClientSocket( udp_client_sockets().push_back(socket.get()); if (bind_type == DatagramSocket::RANDOM_BIND) socket->set_source_port(rand_int_cb.Run(1025, 65535)); - return socket.PassAs<DatagramClientSocket>(); + return socket.Pass(); } scoped_ptr<StreamSocket> @@ -1846,7 +1931,7 @@ DeterministicMockClientSocketFactory::CreateTransportClientSocket( new DeterministicMockTCPClientSocket(net_log, data_provider)); data_provider->set_delegate(socket->AsWeakPtr()); tcp_client_sockets().push_back(socket.get()); - return socket.PassAs<StreamSocket>(); + return socket.Pass(); } scoped_ptr<SSLClientSocket> @@ -1860,7 +1945,7 @@ DeterministicMockClientSocketFactory::CreateSSLClientSocket( host_and_port, ssl_config, mock_ssl_data_.GetNext())); ssl_client_sockets_.push_back(socket.get()); - return socket.PassAs<SSLClientSocket>(); + return socket.Pass(); } void DeterministicMockClientSocketFactory::ClearSSLSessionCache() { diff --git a/chromium/net/socket/socket_test_util.h b/chromium/net/socket/socket_test_util.h index 2918aad2dc5..7bccdaed727 100644 --- a/chromium/net/socket/socket_test_util.h +++ b/chromium/net/socket/socket_test_util.h @@ -47,8 +47,8 @@ enum { }; class AsyncSocket; +class ChannelIDService; class MockClientSocket; -class ServerBoundCertService; class SSLClientSocket; class StreamSocket; @@ -243,7 +243,7 @@ class StaticSocketDataProvider : public SocketDataProvider { size_t reads_count, MockWrite* writes, size_t writes_count); - virtual ~StaticSocketDataProvider(); + ~StaticSocketDataProvider() override; // These functions get access to the next available read and write data. const MockRead& PeekRead() const; @@ -262,9 +262,9 @@ class StaticSocketDataProvider : public SocketDataProvider { virtual void CompleteRead() {} // SocketDataProvider implementation. - virtual MockRead GetNextRead() OVERRIDE; - virtual MockWriteResult OnWrite(const std::string& data) OVERRIDE; - virtual void Reset() OVERRIDE; + MockRead GetNextRead() override; + MockWriteResult OnWrite(const std::string& data) override; + void Reset() override; private: MockRead* reads_; @@ -284,7 +284,7 @@ class StaticSocketDataProvider : public SocketDataProvider { class DynamicSocketDataProvider : public SocketDataProvider { public: DynamicSocketDataProvider(); - virtual ~DynamicSocketDataProvider(); + ~DynamicSocketDataProvider() override; int short_read_limit() const { return short_read_limit_; } void set_short_read_limit(int limit) { short_read_limit_ = limit; } @@ -292,9 +292,9 @@ class DynamicSocketDataProvider : public SocketDataProvider { void allow_unconsumed_reads(bool allow) { allow_unconsumed_reads_ = allow; } // SocketDataProvider implementation. - virtual MockRead GetNextRead() OVERRIDE; + MockRead GetNextRead() override; virtual MockWriteResult OnWrite(const std::string& data) = 0; - virtual void Reset() OVERRIDE; + void Reset() override; protected: // The next time there is a read from this socket, it will return |data|. @@ -326,15 +326,20 @@ struct SSLSocketDataProvider { MockConnect connect; SSLClientSocket::NextProtoStatus next_proto_status; std::string next_proto; - std::string server_protos; bool was_npn_negotiated; NextProto protocol_negotiated; bool client_cert_sent; SSLCertRequestInfo* cert_request_info; scoped_refptr<X509Certificate> cert; bool channel_id_sent; - ServerBoundCertService* server_bound_cert_service; + ChannelIDService* channel_id_service; int connection_status; + // Indicates that the socket should pause in the Connect method. + bool should_pause_on_connect; + // Whether or not the Socket should behave like there is a pre-existing + // session to resume. Whether or not such a session is reported as + // resumed is controlled by |connection_status|. + bool is_in_session_cache; }; // A DataProvider where the client must write a request before the reads (e.g. @@ -366,15 +371,15 @@ class DelayedSocketData : public StaticSocketDataProvider { size_t reads_count, MockWrite* writes, size_t writes_count); - virtual ~DelayedSocketData(); + ~DelayedSocketData() override; void ForceNextRead(); // StaticSocketDataProvider: - virtual MockRead GetNextRead() OVERRIDE; - virtual MockWriteResult OnWrite(const std::string& data) OVERRIDE; - virtual void Reset() OVERRIDE; - virtual void CompleteRead() OVERRIDE; + MockRead GetNextRead() override; + MockWriteResult OnWrite(const std::string& data) override; + void Reset() override; + void CompleteRead() override; private: int write_delay_; @@ -407,7 +412,7 @@ class OrderedSocketData : public StaticSocketDataProvider { size_t reads_count, MockWrite* writes, size_t writes_count); - virtual ~OrderedSocketData(); + ~OrderedSocketData() override; // |connect| the result for the connect phase. // |reads| the list of MockRead completions. @@ -425,10 +430,10 @@ class OrderedSocketData : public StaticSocketDataProvider { void EndLoop(); // StaticSocketDataProvider: - virtual MockRead GetNextRead() OVERRIDE; - virtual MockWriteResult OnWrite(const std::string& data) OVERRIDE; - virtual void Reset() OVERRIDE; - virtual void CompleteRead() OVERRIDE; + MockRead GetNextRead() override; + MockWriteResult OnWrite(const std::string& data) override; + void Reset() override; + void CompleteRead() override; private: int sequence_number_; @@ -531,7 +536,7 @@ class DeterministicSocketData : public StaticSocketDataProvider { size_t reads_count, MockWrite* writes, size_t writes_count); - virtual ~DeterministicSocketData(); + ~DeterministicSocketData() override; // Consume all the data up to the give stop point (via SetStop()). void Run(); @@ -555,14 +560,14 @@ class DeterministicSocketData : public StaticSocketDataProvider { // When the socket calls Read(), that calls GetNextRead(), and expects either // ERR_IO_PENDING or data. - virtual MockRead GetNextRead() OVERRIDE; + MockRead GetNextRead() override; // When the socket calls Write(), it always completes synchronously. OnWrite() // checks to make sure the written data matches the expected data. The // callback will not be invoked until its sequence number is reached. - virtual MockWriteResult OnWrite(const std::string& data) OVERRIDE; - virtual void Reset() OVERRIDE; - virtual void CompleteRead() OVERRIDE {} + MockWriteResult OnWrite(const std::string& data) override; + void Reset() override; + void CompleteRead() override {} private: // Invoke the read and write callbacks, if the timing is appropriate. @@ -628,7 +633,7 @@ class MockSSLClientSocket; class MockClientSocketFactory : public ClientSocketFactory { public: MockClientSocketFactory(); - virtual ~MockClientSocketFactory(); + ~MockClientSocketFactory() override; void AddSocketDataProvider(SocketDataProvider* socket); void AddSSLSocketDataProvider(SSLSocketDataProvider* socket); @@ -638,26 +643,33 @@ class MockClientSocketFactory : public ClientSocketFactory { return mock_data_; } + // Note: this method is unsafe; the elements of the returned vector + // are not necessarily valid. + const std::vector<MockSSLClientSocket*>& ssl_client_sockets() const { + return ssl_client_sockets_; + } + // ClientSocketFactory - virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( + scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( DatagramSocket::BindType bind_type, const RandIntCallback& rand_int_cb, NetLog* net_log, - const NetLog::Source& source) OVERRIDE; - virtual scoped_ptr<StreamSocket> CreateTransportClientSocket( + const NetLog::Source& source) override; + scoped_ptr<StreamSocket> CreateTransportClientSocket( const AddressList& addresses, NetLog* net_log, - const NetLog::Source& source) OVERRIDE; - virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + const NetLog::Source& source) override; + scoped_ptr<SSLClientSocket> CreateSSLClientSocket( scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, - const SSLClientSocketContext& context) OVERRIDE; - virtual void ClearSSLSessionCache() OVERRIDE; + const SSLClientSocketContext& context) override; + void ClearSSLSessionCache() override; private: SocketDataProviderArray<SocketDataProvider> mock_data_; SocketDataProviderArray<SSLSocketDataProvider> mock_ssl_data_; + std::vector<MockSSLClientSocket*> ssl_client_sockets_; }; class MockClientSocket : public SSLClientSocket { @@ -676,41 +688,42 @@ class MockClientSocket : public SSLClientSocket { virtual int Write(IOBuffer* buf, int buf_len, const CompletionCallback& callback) = 0; - virtual int SetReceiveBufferSize(int32 size) OVERRIDE; - virtual int SetSendBufferSize(int32 size) OVERRIDE; + int SetReceiveBufferSize(int32 size) override; + int SetSendBufferSize(int32 size) override; // StreamSocket implementation. virtual int Connect(const CompletionCallback& callback) = 0; - virtual void Disconnect() OVERRIDE; - virtual bool IsConnected() const OVERRIDE; - virtual bool IsConnectedAndIdle() const OVERRIDE; - virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE; - virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE; - virtual const BoundNetLog& NetLog() const OVERRIDE; - virtual void SetSubresourceSpeculation() OVERRIDE {} - virtual void SetOmniboxSpeculation() OVERRIDE {} + void Disconnect() override; + bool IsConnected() const override; + bool IsConnectedAndIdle() const override; + int GetPeerAddress(IPEndPoint* address) const override; + int GetLocalAddress(IPEndPoint* address) const override; + const BoundNetLog& NetLog() const override; + void SetSubresourceSpeculation() override {} + void SetOmniboxSpeculation() override {} // SSLClientSocket implementation. - virtual void GetSSLCertRequestInfo(SSLCertRequestInfo* cert_request_info) - OVERRIDE; - virtual int ExportKeyingMaterial(const base::StringPiece& label, - bool has_context, - const base::StringPiece& context, - unsigned char* out, - unsigned int outlen) OVERRIDE; - virtual int GetTLSUniqueChannelBinding(std::string* out) OVERRIDE; - virtual NextProtoStatus GetNextProto(std::string* proto, - std::string* server_protos) OVERRIDE; - virtual ServerBoundCertService* GetServerBoundCertService() const OVERRIDE; + std::string GetSessionCacheKey() const override; + bool InSessionCache() const override; + void SetHandshakeCompletionCallback(const base::Closure& cb) override; + void GetSSLCertRequestInfo(SSLCertRequestInfo* cert_request_info) override; + int ExportKeyingMaterial(const base::StringPiece& label, + bool has_context, + const base::StringPiece& context, + unsigned char* out, + unsigned int outlen) override; + int GetTLSUniqueChannelBinding(std::string* out) override; + NextProtoStatus GetNextProto(std::string* proto) override; + ChannelIDService* GetChannelIDService() const override; protected: - virtual ~MockClientSocket(); + ~MockClientSocket() override; void RunCallbackAsync(const CompletionCallback& callback, int result); void RunCallback(const CompletionCallback& callback, int result); // SSLClientSocket implementation. - virtual scoped_refptr<X509Certificate> GetUnverifiedServerCertificateChain() - const OVERRIDE; + scoped_refptr<X509Certificate> GetUnverifiedServerCertificateChain() + const override; // True if Connect completed successfully and Disconnect hasn't been called. bool connected_; @@ -720,6 +733,7 @@ class MockClientSocket : public SSLClientSocket { BoundNetLog net_log_; + private: base::WeakPtrFactory<MockClientSocket> weak_factory_; DISALLOW_COPY_AND_ASSIGN(MockClientSocket); @@ -730,32 +744,32 @@ class MockTCPClientSocket : public MockClientSocket, public AsyncSocket { MockTCPClientSocket(const AddressList& addresses, net::NetLog* net_log, SocketDataProvider* socket); - virtual ~MockTCPClientSocket(); + ~MockTCPClientSocket() override; const AddressList& addresses() const { return addresses_; } // Socket implementation. - virtual int Read(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) OVERRIDE; - virtual int Write(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) OVERRIDE; + int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override; + int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override; // StreamSocket implementation. - virtual int Connect(const CompletionCallback& callback) OVERRIDE; - virtual void Disconnect() OVERRIDE; - virtual bool IsConnected() const OVERRIDE; - virtual bool IsConnectedAndIdle() const OVERRIDE; - virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE; - virtual bool WasEverUsed() const OVERRIDE; - virtual bool UsingTCPFastOpen() const OVERRIDE; - virtual bool WasNpnNegotiated() const OVERRIDE; - virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; + int Connect(const CompletionCallback& callback) override; + void Disconnect() override; + bool IsConnected() const override; + bool IsConnectedAndIdle() const override; + int GetPeerAddress(IPEndPoint* address) const override; + bool WasEverUsed() const override; + bool UsingTCPFastOpen() const override; + bool WasNpnNegotiated() const override; + bool GetSSLInfo(SSLInfo* ssl_info) override; // AsyncSocket: - virtual void OnReadComplete(const MockRead& data) OVERRIDE; - virtual void OnConnectComplete(const MockConnect& data) OVERRIDE; + void OnReadComplete(const MockRead& data) override; + void OnConnectComplete(const MockConnect& data) override; private: int CompleteRead(); @@ -836,36 +850,36 @@ class DeterministicMockUDPClientSocket public: DeterministicMockUDPClientSocket(net::NetLog* net_log, DeterministicSocketData* data); - virtual ~DeterministicMockUDPClientSocket(); + ~DeterministicMockUDPClientSocket() override; // DeterministicSocketData::Delegate: - virtual bool WritePending() const OVERRIDE; - virtual bool ReadPending() const OVERRIDE; - virtual void CompleteWrite() OVERRIDE; - virtual int CompleteRead() OVERRIDE; + bool WritePending() const override; + bool ReadPending() const override; + void CompleteWrite() override; + int CompleteRead() override; // Socket implementation. - virtual int Read(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) OVERRIDE; - virtual int Write(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) OVERRIDE; - virtual int SetReceiveBufferSize(int32 size) OVERRIDE; - virtual int SetSendBufferSize(int32 size) OVERRIDE; + int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override; + int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override; + int SetReceiveBufferSize(int32 size) override; + int SetSendBufferSize(int32 size) override; // DatagramSocket implementation. - virtual void Close() OVERRIDE; - virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE; - virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE; - virtual const BoundNetLog& NetLog() const OVERRIDE; + void Close() override; + int GetPeerAddress(IPEndPoint* address) const override; + int GetLocalAddress(IPEndPoint* address) const override; + const BoundNetLog& NetLog() const override; // DatagramClientSocket implementation. - virtual int Connect(const IPEndPoint& address) OVERRIDE; + int Connect(const IPEndPoint& address) override; // AsyncSocket implementation. - virtual void OnReadComplete(const MockRead& data) OVERRIDE; - virtual void OnConnectComplete(const MockConnect& data) OVERRIDE; + void OnReadComplete(const MockRead& data) override; + void OnConnectComplete(const MockConnect& data) override; void set_source_port(int port) { source_port_ = port; } @@ -887,35 +901,35 @@ class DeterministicMockTCPClientSocket public: DeterministicMockTCPClientSocket(net::NetLog* net_log, DeterministicSocketData* data); - virtual ~DeterministicMockTCPClientSocket(); + ~DeterministicMockTCPClientSocket() override; // DeterministicSocketData::Delegate: - virtual bool WritePending() const OVERRIDE; - virtual bool ReadPending() const OVERRIDE; - virtual void CompleteWrite() OVERRIDE; - virtual int CompleteRead() OVERRIDE; + bool WritePending() const override; + bool ReadPending() const override; + void CompleteWrite() override; + int CompleteRead() override; // Socket: - virtual int Write(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) OVERRIDE; - virtual int Read(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) OVERRIDE; + int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override; + int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override; // StreamSocket: - virtual int Connect(const CompletionCallback& callback) OVERRIDE; - virtual void Disconnect() OVERRIDE; - virtual bool IsConnected() const OVERRIDE; - virtual bool IsConnectedAndIdle() const OVERRIDE; - virtual bool WasEverUsed() const OVERRIDE; - virtual bool UsingTCPFastOpen() const OVERRIDE; - virtual bool WasNpnNegotiated() const OVERRIDE; - virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; + int Connect(const CompletionCallback& callback) override; + void Disconnect() override; + bool IsConnected() const override; + bool IsConnectedAndIdle() const override; + bool WasEverUsed() const override; + bool UsingTCPFastOpen() const override; + bool WasNpnNegotiated() const override; + bool GetSSLInfo(SSLInfo* ssl_info) override; // AsyncSocket: - virtual void OnReadComplete(const MockRead& data) OVERRIDE; - virtual void OnConnectComplete(const MockConnect& data) OVERRIDE; + void OnReadComplete(const MockRead& data) override; + void OnConnectComplete(const MockConnect& data) override; private: DeterministicSocketHelper helper_; @@ -929,85 +943,113 @@ class MockSSLClientSocket : public MockClientSocket, public AsyncSocket { const HostPortPair& host_and_port, const SSLConfig& ssl_config, SSLSocketDataProvider* socket); - virtual ~MockSSLClientSocket(); + ~MockSSLClientSocket() override; // Socket implementation. - virtual int Read(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) OVERRIDE; - virtual int Write(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) OVERRIDE; + int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override; + int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override; // StreamSocket implementation. - virtual int Connect(const CompletionCallback& callback) OVERRIDE; - virtual void Disconnect() OVERRIDE; - virtual bool IsConnected() const OVERRIDE; - virtual bool WasEverUsed() const OVERRIDE; - virtual bool UsingTCPFastOpen() const OVERRIDE; - virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE; - virtual bool WasNpnNegotiated() const OVERRIDE; - virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; + int Connect(const CompletionCallback& callback) override; + void Disconnect() override; + bool IsConnected() const override; + bool WasEverUsed() const override; + bool UsingTCPFastOpen() const override; + int GetPeerAddress(IPEndPoint* address) const override; + bool WasNpnNegotiated() const override; + bool GetSSLInfo(SSLInfo* ssl_info) override; // SSLClientSocket implementation. - virtual void GetSSLCertRequestInfo(SSLCertRequestInfo* cert_request_info) - OVERRIDE; - virtual NextProtoStatus GetNextProto(std::string* proto, - std::string* server_protos) OVERRIDE; - virtual bool set_was_npn_negotiated(bool negotiated) OVERRIDE; - virtual void set_protocol_negotiated(NextProto protocol_negotiated) OVERRIDE; - virtual NextProto GetNegotiatedProtocol() const OVERRIDE; + std::string GetSessionCacheKey() const override; + bool InSessionCache() const override; + void SetHandshakeCompletionCallback(const base::Closure& cb) override; + void GetSSLCertRequestInfo(SSLCertRequestInfo* cert_request_info) override; + NextProtoStatus GetNextProto(std::string* proto) override; + bool set_was_npn_negotiated(bool negotiated) override; + void set_protocol_negotiated(NextProto protocol_negotiated) override; + NextProto GetNegotiatedProtocol() const override; // This MockSocket does not implement the manual async IO feature. - virtual void OnReadComplete(const MockRead& data) OVERRIDE; - virtual void OnConnectComplete(const MockConnect& data) OVERRIDE; + void OnReadComplete(const MockRead& data) override; + void OnConnectComplete(const MockConnect& data) override; + + bool WasChannelIDSent() const override; + void set_channel_id_sent(bool channel_id_sent) override; + ChannelIDService* GetChannelIDService() const override; - virtual bool WasChannelIDSent() const OVERRIDE; - virtual void set_channel_id_sent(bool channel_id_sent) OVERRIDE; - virtual ServerBoundCertService* GetServerBoundCertService() const OVERRIDE; + bool reached_connect() const { return reached_connect_; } + + // Resumes the connection of a socket that was paused for testing. + // |connect_callback_| should be set before invoking this method. + void RestartPausedConnect(); private: - static void ConnectCallback(MockSSLClientSocket* ssl_client_socket, - const CompletionCallback& callback, - int rv); + enum ConnectState { + STATE_NONE, + STATE_SSL_CONNECT, + STATE_SSL_CONNECT_COMPLETE, + }; + + void OnIOComplete(int result); + + // Runs the state transistion loop. + int DoConnectLoop(int result); + + int DoSSLConnect(); + int DoSSLConnectComplete(int result); scoped_ptr<ClientSocketHandle> transport_; + HostPortPair host_port_pair_; SSLSocketDataProvider* data_; bool is_npn_state_set_; bool new_npn_value_; bool is_protocol_negotiated_set_; NextProto protocol_negotiated_; + CompletionCallback connect_callback_; + // Indicates what state of Connect the socket should enter. + ConnectState next_connect_state_; + // True if the Connect method has been called on the socket. + bool reached_connect_; + + base::Closure handshake_completion_callback_; + + base::WeakPtrFactory<MockSSLClientSocket> weak_factory_; + DISALLOW_COPY_AND_ASSIGN(MockSSLClientSocket); }; class MockUDPClientSocket : public DatagramClientSocket, public AsyncSocket { public: MockUDPClientSocket(SocketDataProvider* data, net::NetLog* net_log); - virtual ~MockUDPClientSocket(); + ~MockUDPClientSocket() override; // Socket implementation. - virtual int Read(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) OVERRIDE; - virtual int Write(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) OVERRIDE; - virtual int SetReceiveBufferSize(int32 size) OVERRIDE; - virtual int SetSendBufferSize(int32 size) OVERRIDE; + int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override; + int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override; + int SetReceiveBufferSize(int32 size) override; + int SetSendBufferSize(int32 size) override; // DatagramSocket implementation. - virtual void Close() OVERRIDE; - virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE; - virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE; - virtual const BoundNetLog& NetLog() const OVERRIDE; + void Close() override; + int GetPeerAddress(IPEndPoint* address) const override; + int GetLocalAddress(IPEndPoint* address) const override; + const BoundNetLog& NetLog() const override; // DatagramClientSocket implementation. - virtual int Connect(const IPEndPoint& address) OVERRIDE; + int Connect(const IPEndPoint& address) override; // AsyncSocket implementation. - virtual void OnReadComplete(const MockRead& data) OVERRIDE; - virtual void OnConnectComplete(const MockConnect& data) OVERRIDE; + void OnReadComplete(const MockRead& data) override; + void OnConnectComplete(const MockConnect& data) override; void set_source_port(int port) { source_port_ = port;} @@ -1043,7 +1085,7 @@ class TestSocketRequest : public TestCompletionCallbackBase { public: TestSocketRequest(std::vector<TestSocketRequest*>* request_order, size_t* completion_count); - virtual ~TestSocketRequest(); + ~TestSocketRequest() override; ClientSocketHandle* handle() { return &handle_; } @@ -1163,7 +1205,7 @@ class MockTransportClientSocketPool : public TransportClientSocketPool { ClientSocketPoolHistograms* histograms, ClientSocketFactory* socket_factory); - virtual ~MockTransportClientSocketPool(); + ~MockTransportClientSocketPool() override; RequestPriority last_request_priority() const { return last_request_priority_; @@ -1172,18 +1214,18 @@ class MockTransportClientSocketPool : public TransportClientSocketPool { int cancel_count() const { return cancel_count_; } // TransportClientSocketPool implementation. - virtual int RequestSocket(const std::string& group_name, - const void* socket_params, - RequestPriority priority, - ClientSocketHandle* handle, - const CompletionCallback& callback, - const BoundNetLog& net_log) OVERRIDE; - - virtual void CancelRequest(const std::string& group_name, - ClientSocketHandle* handle) OVERRIDE; - virtual void ReleaseSocket(const std::string& group_name, - scoped_ptr<StreamSocket> socket, - int id) OVERRIDE; + int RequestSocket(const std::string& group_name, + const void* socket_params, + RequestPriority priority, + ClientSocketHandle* handle, + const CompletionCallback& callback, + const BoundNetLog& net_log) override; + + void CancelRequest(const std::string& group_name, + ClientSocketHandle* handle) override; + void ReleaseSocket(const std::string& group_name, + scoped_ptr<StreamSocket> socket, + int id) override; private: ClientSocketFactory* client_socket_factory_; @@ -1198,7 +1240,7 @@ class MockTransportClientSocketPool : public TransportClientSocketPool { class DeterministicMockClientSocketFactory : public ClientSocketFactory { public: DeterministicMockClientSocketFactory(); - virtual ~DeterministicMockClientSocketFactory(); + ~DeterministicMockClientSocketFactory() override; void AddSocketDataProvider(DeterministicSocketData* socket); void AddSSLSocketDataProvider(SSLSocketDataProvider* socket); @@ -1219,21 +1261,21 @@ class DeterministicMockClientSocketFactory : public ClientSocketFactory { } // ClientSocketFactory - virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( + scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( DatagramSocket::BindType bind_type, const RandIntCallback& rand_int_cb, NetLog* net_log, - const NetLog::Source& source) OVERRIDE; - virtual scoped_ptr<StreamSocket> CreateTransportClientSocket( + const NetLog::Source& source) override; + scoped_ptr<StreamSocket> CreateTransportClientSocket( const AddressList& addresses, NetLog* net_log, - const NetLog::Source& source) OVERRIDE; - virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + const NetLog::Source& source) override; + scoped_ptr<SSLClientSocket> CreateSSLClientSocket( scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, - const SSLClientSocketContext& context) OVERRIDE; - virtual void ClearSSLSessionCache() OVERRIDE; + const SSLClientSocketContext& context) override; + void ClearSSLSessionCache() override; private: SocketDataProviderArray<DeterministicSocketData> mock_data_; @@ -1254,21 +1296,21 @@ class MockSOCKSClientSocketPool : public SOCKSClientSocketPool { ClientSocketPoolHistograms* histograms, TransportClientSocketPool* transport_pool); - virtual ~MockSOCKSClientSocketPool(); + ~MockSOCKSClientSocketPool() override; // SOCKSClientSocketPool implementation. - virtual int RequestSocket(const std::string& group_name, - const void* socket_params, - RequestPriority priority, - ClientSocketHandle* handle, - const CompletionCallback& callback, - const BoundNetLog& net_log) OVERRIDE; - - virtual void CancelRequest(const std::string& group_name, - ClientSocketHandle* handle) OVERRIDE; - virtual void ReleaseSocket(const std::string& group_name, - scoped_ptr<StreamSocket> socket, - int id) OVERRIDE; + int RequestSocket(const std::string& group_name, + const void* socket_params, + RequestPriority priority, + ClientSocketHandle* handle, + const CompletionCallback& callback, + const BoundNetLog& net_log) override; + + void CancelRequest(const std::string& group_name, + ClientSocketHandle* handle) override; + void ReleaseSocket(const std::string& group_name, + scoped_ptr<StreamSocket> socket, + int id) override; private: TransportClientSocketPool* const transport_pool_; diff --git a/chromium/net/socket/socks5_client_socket.h b/chromium/net/socket/socks5_client_socket.h index 8da0b4da5ce..a405212b56b 100644 --- a/chromium/net/socket/socks5_client_socket.h +++ b/chromium/net/socket/socks5_client_socket.h @@ -38,37 +38,37 @@ class NET_EXPORT_PRIVATE SOCKS5ClientSocket : public StreamSocket { const HostResolver::RequestInfo& req_info); // On destruction Disconnect() is called. - virtual ~SOCKS5ClientSocket(); + ~SOCKS5ClientSocket() override; // StreamSocket implementation. // Does the SOCKS handshake and completes the protocol. - virtual int Connect(const CompletionCallback& callback) OVERRIDE; - virtual void Disconnect() OVERRIDE; - virtual bool IsConnected() const OVERRIDE; - virtual bool IsConnectedAndIdle() const OVERRIDE; - virtual const BoundNetLog& NetLog() const OVERRIDE; - virtual void SetSubresourceSpeculation() OVERRIDE; - virtual void SetOmniboxSpeculation() OVERRIDE; - virtual bool WasEverUsed() const OVERRIDE; - virtual bool UsingTCPFastOpen() const OVERRIDE; - virtual bool WasNpnNegotiated() const OVERRIDE; - virtual NextProto GetNegotiatedProtocol() const OVERRIDE; - virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; + int Connect(const CompletionCallback& callback) override; + void Disconnect() override; + bool IsConnected() const override; + bool IsConnectedAndIdle() const override; + const BoundNetLog& NetLog() const override; + void SetSubresourceSpeculation() override; + void SetOmniboxSpeculation() override; + bool WasEverUsed() const override; + bool UsingTCPFastOpen() const override; + bool WasNpnNegotiated() const override; + NextProto GetNegotiatedProtocol() const override; + bool GetSSLInfo(SSLInfo* ssl_info) override; // Socket implementation. - virtual int Read(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) OVERRIDE; - virtual int Write(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) OVERRIDE; - - virtual int SetReceiveBufferSize(int32 size) OVERRIDE; - virtual int SetSendBufferSize(int32 size) OVERRIDE; - - virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE; - virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE; + int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override; + int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override; + + int SetReceiveBufferSize(int32 size) override; + int SetSendBufferSize(int32 size) override; + + int GetPeerAddress(IPEndPoint* address) const override; + int GetLocalAddress(IPEndPoint* address) const override; private: enum State { diff --git a/chromium/net/socket/socks5_client_socket_unittest.cc b/chromium/net/socket/socks5_client_socket_unittest.cc index 78f2ac433c3..c474a0b4198 100644 --- a/chromium/net/socket/socks5_client_socket_unittest.cc +++ b/chromium/net/socket/socks5_client_socket_unittest.cc @@ -40,7 +40,7 @@ class SOCKS5ClientSocketTest : public PlatformTest { int port, NetLog* net_log); - virtual void SetUp(); + void SetUp() override; protected: const uint16 kNwPort; diff --git a/chromium/net/socket/socks_client_socket.h b/chromium/net/socket/socks_client_socket.h index 26da332b3ea..e792881cc7f 100644 --- a/chromium/net/socket/socks_client_socket.h +++ b/chromium/net/socket/socks_client_socket.h @@ -35,37 +35,37 @@ class NET_EXPORT_PRIVATE SOCKSClientSocket : public StreamSocket { HostResolver* host_resolver); // On destruction Disconnect() is called. - virtual ~SOCKSClientSocket(); + ~SOCKSClientSocket() override; // StreamSocket implementation. // Does the SOCKS handshake and completes the protocol. - virtual int Connect(const CompletionCallback& callback) OVERRIDE; - virtual void Disconnect() OVERRIDE; - virtual bool IsConnected() const OVERRIDE; - virtual bool IsConnectedAndIdle() const OVERRIDE; - virtual const BoundNetLog& NetLog() const OVERRIDE; - virtual void SetSubresourceSpeculation() OVERRIDE; - virtual void SetOmniboxSpeculation() OVERRIDE; - virtual bool WasEverUsed() const OVERRIDE; - virtual bool UsingTCPFastOpen() const OVERRIDE; - virtual bool WasNpnNegotiated() const OVERRIDE; - virtual NextProto GetNegotiatedProtocol() const OVERRIDE; - virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; + int Connect(const CompletionCallback& callback) override; + void Disconnect() override; + bool IsConnected() const override; + bool IsConnectedAndIdle() const override; + const BoundNetLog& NetLog() const override; + void SetSubresourceSpeculation() override; + void SetOmniboxSpeculation() override; + bool WasEverUsed() const override; + bool UsingTCPFastOpen() const override; + bool WasNpnNegotiated() const override; + NextProto GetNegotiatedProtocol() const override; + bool GetSSLInfo(SSLInfo* ssl_info) override; // Socket implementation. - virtual int Read(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) OVERRIDE; - virtual int Write(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) OVERRIDE; - - virtual int SetReceiveBufferSize(int32 size) OVERRIDE; - virtual int SetSendBufferSize(int32 size) OVERRIDE; - - virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE; - virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE; + int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override; + int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override; + + int SetReceiveBufferSize(int32 size) override; + int SetSendBufferSize(int32 size) override; + + int GetPeerAddress(IPEndPoint* address) const override; + int GetLocalAddress(IPEndPoint* address) const override; private: FRIEND_TEST_ALL_PREFIXES(SOCKSClientSocketTest, CompleteHandshake); diff --git a/chromium/net/socket/socks_client_socket_pool.h b/chromium/net/socket/socks_client_socket_pool.h index c6d5c8d0883..35f7146f967 100644 --- a/chromium/net/socket/socks_client_socket_pool.h +++ b/chromium/net/socket/socks_client_socket_pool.h @@ -63,10 +63,10 @@ class SOCKSConnectJob : public ConnectJob { HostResolver* host_resolver, Delegate* delegate, NetLog* net_log); - virtual ~SOCKSConnectJob(); + ~SOCKSConnectJob() override; // ConnectJob methods. - virtual LoadState GetLoadState() const OVERRIDE; + LoadState GetLoadState() const override; private: enum State { @@ -90,7 +90,7 @@ class SOCKSConnectJob : public ConnectJob { // Begins the transport connection and the SOCKS handshake. Returns OK on // success and ERR_IO_PENDING if it cannot immediately service the request. // Otherwise, it returns a net error code. - virtual int ConnectInternal() OVERRIDE; + int ConnectInternal() override; scoped_refptr<SOCKSSocketParams> socks_params_; TransportClientSocketPool* const transport_pool_; @@ -117,59 +117,57 @@ class NET_EXPORT_PRIVATE SOCKSClientSocketPool TransportClientSocketPool* transport_pool, NetLog* net_log); - virtual ~SOCKSClientSocketPool(); + ~SOCKSClientSocketPool() override; // ClientSocketPool implementation. - virtual int RequestSocket(const std::string& group_name, - const void* connect_params, - RequestPriority priority, - ClientSocketHandle* handle, - const CompletionCallback& callback, - const BoundNetLog& net_log) OVERRIDE; + int RequestSocket(const std::string& group_name, + const void* connect_params, + RequestPriority priority, + ClientSocketHandle* handle, + const CompletionCallback& callback, + const BoundNetLog& net_log) override; - virtual void RequestSockets(const std::string& group_name, - const void* params, - int num_sockets, - const BoundNetLog& net_log) OVERRIDE; + void RequestSockets(const std::string& group_name, + const void* params, + int num_sockets, + const BoundNetLog& net_log) override; - virtual void CancelRequest(const std::string& group_name, - ClientSocketHandle* handle) OVERRIDE; + void CancelRequest(const std::string& group_name, + ClientSocketHandle* handle) override; - virtual void ReleaseSocket(const std::string& group_name, - scoped_ptr<StreamSocket> socket, - int id) OVERRIDE; + void ReleaseSocket(const std::string& group_name, + scoped_ptr<StreamSocket> socket, + int id) override; - virtual void FlushWithError(int error) OVERRIDE; + void FlushWithError(int error) override; - virtual void CloseIdleSockets() OVERRIDE; + void CloseIdleSockets() override; - virtual int IdleSocketCount() const OVERRIDE; + int IdleSocketCount() const override; - virtual int IdleSocketCountInGroup( - const std::string& group_name) const OVERRIDE; + int IdleSocketCountInGroup(const std::string& group_name) const override; - virtual LoadState GetLoadState( - const std::string& group_name, - const ClientSocketHandle* handle) const OVERRIDE; + LoadState GetLoadState(const std::string& group_name, + const ClientSocketHandle* handle) const override; - virtual base::DictionaryValue* GetInfoAsValue( + base::DictionaryValue* GetInfoAsValue( const std::string& name, const std::string& type, - bool include_nested_pools) const OVERRIDE; + bool include_nested_pools) const override; - virtual base::TimeDelta ConnectionTimeout() const OVERRIDE; + base::TimeDelta ConnectionTimeout() const override; - virtual ClientSocketPoolHistograms* histograms() const OVERRIDE; + ClientSocketPoolHistograms* histograms() const override; // LowerLayeredPool implementation. - virtual bool IsStalled() const OVERRIDE; + bool IsStalled() const override; - virtual void AddHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE; + void AddHigherLayeredPool(HigherLayeredPool* higher_pool) override; - virtual void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE; + void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool) override; // HigherLayeredPool implementation. - virtual bool CloseOneIdleConnection() OVERRIDE; + bool CloseOneIdleConnection() override; private: typedef ClientSocketPoolBase<SOCKSSocketParams> PoolBase; @@ -183,15 +181,15 @@ class NET_EXPORT_PRIVATE SOCKSClientSocketPool host_resolver_(host_resolver), net_log_(net_log) {} - virtual ~SOCKSConnectJobFactory() {} + ~SOCKSConnectJobFactory() override {} // ClientSocketPoolBase::ConnectJobFactory methods. - virtual scoped_ptr<ConnectJob> NewConnectJob( + scoped_ptr<ConnectJob> NewConnectJob( const std::string& group_name, const PoolBase::Request& request, - ConnectJob::Delegate* delegate) const OVERRIDE; + ConnectJob::Delegate* delegate) const override; - virtual base::TimeDelta ConnectionTimeout() const OVERRIDE; + base::TimeDelta ConnectionTimeout() const override; private: TransportClientSocketPool* const transport_pool_; diff --git a/chromium/net/socket/socks_client_socket_pool_unittest.cc b/chromium/net/socket/socks_client_socket_pool_unittest.cc index b2b8655ee22..391d31beddb 100644 --- a/chromium/net/socket/socks_client_socket_pool_unittest.cc +++ b/chromium/net/socket/socks_client_socket_pool_unittest.cc @@ -44,8 +44,8 @@ void TestLoadTimingInfo(const ClientSocketHandle& handle) { scoped_refptr<TransportSocketParams> CreateProxyHostParams() { return new TransportSocketParams( - HostPortPair("proxy", 80), false, false, - OnHostResolutionCallback()); + HostPortPair("proxy", 80), false, false, OnHostResolutionCallback(), + TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT); } scoped_refptr<SOCKSSocketParams> CreateSOCKSv4Params() { @@ -103,7 +103,7 @@ class SOCKSClientSocketPoolTest : public testing::Test { NULL) { } - virtual ~SOCKSClientSocketPoolTest() {} + ~SOCKSClientSocketPoolTest() override {} int StartRequestV5(const std::string& group_name, RequestPriority priority) { return test_base_.StartRequestUsingPool( diff --git a/chromium/net/socket/socks_client_socket_unittest.cc b/chromium/net/socket/socks_client_socket_unittest.cc index f361244feff..fbb84f8f50a 100644 --- a/chromium/net/socket/socks_client_socket_unittest.cc +++ b/chromium/net/socket/socks_client_socket_unittest.cc @@ -10,6 +10,7 @@ #include "net/base/net_log_unittest.h" #include "net/base/test_completion_callback.h" #include "net/base/winsock_init.h" +#include "net/dns/host_resolver.h" #include "net/dns/mock_host_resolver.h" #include "net/socket/client_socket_factory.h" #include "net/socket/socket_test_util.h" @@ -34,7 +35,7 @@ class SOCKSClientSocketTest : public PlatformTest { HostResolver* host_resolver, const std::string& hostname, int port, NetLog* net_log); - virtual void SetUp(); + void SetUp() override; protected: scoped_ptr<SOCKSClientSocket> user_sock_; @@ -95,12 +96,12 @@ class HangingHostResolverWithCancel : public HostResolver { public: HangingHostResolverWithCancel() : outstanding_request_(NULL) {} - virtual int Resolve(const RequestInfo& info, - RequestPriority priority, - AddressList* addresses, - const CompletionCallback& callback, - RequestHandle* out_req, - const BoundNetLog& net_log) OVERRIDE { + int Resolve(const RequestInfo& info, + RequestPriority priority, + AddressList* addresses, + const CompletionCallback& callback, + RequestHandle* out_req, + const BoundNetLog& net_log) override { DCHECK(addresses); DCHECK_EQ(false, callback.is_null()); EXPECT_FALSE(HasOutstandingRequest()); @@ -109,14 +110,14 @@ class HangingHostResolverWithCancel : public HostResolver { return ERR_IO_PENDING; } - virtual int ResolveFromCache(const RequestInfo& info, - AddressList* addresses, - const BoundNetLog& net_log) OVERRIDE { + int ResolveFromCache(const RequestInfo& info, + AddressList* addresses, + const BoundNetLog& net_log) override { NOTIMPLEMENTED(); return ERR_UNEXPECTED; } - virtual void CancelRequest(RequestHandle req) OVERRIDE { + void CancelRequest(RequestHandle req) override { EXPECT_TRUE(HasOutstandingRequest()); EXPECT_EQ(outstanding_request_, req); outstanding_request_ = NULL; @@ -213,7 +214,7 @@ TEST_F(SOCKSClientSocketTest, HandshakeFailures) { //--------------------------------------- - for (size_t i = 0; i < ARRAYSIZE_UNSAFE(tests); ++i) { + for (size_t i = 0; i < arraysize(tests); ++i) { MockWrite data_writes[] = { MockWrite(SYNCHRONOUS, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) }; MockRead data_reads[] = { @@ -414,4 +415,36 @@ TEST_F(SOCKSClientSocketTest, DisconnectWhileHostResolveInProgress) { EXPECT_FALSE(user_sock_->IsConnectedAndIdle()); } +// Tries to connect to an IPv6 IP. Should fail, as SOCKS4 does not support +// IPv6. +TEST_F(SOCKSClientSocketTest, NoIPv6) { + const char kHostName[] = "::1"; + + user_sock_ = BuildMockSocket(NULL, 0, + NULL, 0, + host_resolver_.get(), + kHostName, 80, + NULL); + + EXPECT_EQ(ERR_NAME_NOT_RESOLVED, + callback_.GetResult(user_sock_->Connect(callback_.callback()))); +} + +// Same as above, but with a real resolver, to protect against regressions. +TEST_F(SOCKSClientSocketTest, NoIPv6RealResolver) { + const char kHostName[] = "::1"; + + scoped_ptr<HostResolver> host_resolver( + HostResolver::CreateSystemResolver(HostResolver::Options(), NULL)); + + user_sock_ = BuildMockSocket(NULL, 0, + NULL, 0, + host_resolver.get(), + kHostName, 80, + NULL); + + EXPECT_EQ(ERR_NAME_NOT_RESOLVED, + callback_.GetResult(user_sock_->Connect(callback_.callback()))); +} + } // namespace net diff --git a/chromium/net/socket/ssl_client_socket.cc b/chromium/net/socket/ssl_client_socket.cc index 1b2fe144304..3184e04e3f7 100644 --- a/chromium/net/socket/ssl_client_socket.cc +++ b/chromium/net/socket/ssl_client_socket.cc @@ -5,10 +5,14 @@ #include "net/socket/ssl_client_socket.h" #include "base/metrics/histogram.h" +#include "base/metrics/sparse_histogram.h" #include "base/strings/string_util.h" #include "crypto/ec_private_key.h" -#include "net/ssl/server_bound_cert_service.h" +#include "net/base/connection_type_histograms.h" +#include "net/base/host_port_pair.h" +#include "net/ssl/channel_id_service.h" #include "net/ssl/ssl_config_service.h" +#include "net/ssl/ssl_connection_status_flags.h" namespace net { @@ -18,7 +22,8 @@ SSLClientSocket::SSLClientSocket() protocol_negotiated_(kProtoUnknown), channel_id_sent_(false), signed_cert_timestamps_received_(false), - stapled_ocsp_response_received_(false) { + stapled_ocsp_response_received_(false), + negotiation_extension_(kExtensionUnknown) { } // static @@ -32,8 +37,8 @@ NextProto SSLClientSocket::NextProtoFromString( return kProtoSPDY3; } else if (proto_string == "spdy/3.1") { return kProtoSPDY31; - } else if (proto_string == "h2-12") { - // This is the HTTP/2 draft 12 identifier. For internal + } else if (proto_string == "h2-14") { + // This is the HTTP/2 draft 14 identifier. For internal // consistency, HTTP/2 is named SPDY4 within Chromium. return kProtoSPDY4; } else if (proto_string == "quic/1+spdy/3") { @@ -55,9 +60,9 @@ const char* SSLClientSocket::NextProtoToString(NextProto next_proto) { case kProtoSPDY31: return "spdy/3.1"; case kProtoSPDY4: - // This is the HTTP/2 draft 12 identifier. For internal + // This is the HTTP/2 draft 14 identifier. For internal // consistency, HTTP/2 is named SPDY4 within Chromium. - return "h2-12"; + return "h2-14"; case kProtoQUIC1SPDY3: return "quic/1+spdy/3"; case kProtoUnknown: @@ -80,21 +85,6 @@ const char* SSLClientSocket::NextProtoStatusToString( return NULL; } -// static -std::string SSLClientSocket::ServerProtosToString( - const std::string& server_protos) { - const char* protos = server_protos.c_str(); - size_t protos_len = server_protos.length(); - std::vector<std::string> server_protos_with_commas; - for (size_t i = 0; i < protos_len; ) { - const size_t len = protos[i]; - std::string proto_str(&protos[i + 1], len); - server_protos_with_commas.push_back(proto_str); - i += len + 1; - } - return JoinString(server_protos_with_commas, ','); -} - bool SSLClientSocket::WasNpnNegotiated() const { return was_npn_negotiated_; } @@ -138,6 +128,11 @@ void SSLClientSocket::set_protocol_negotiated(NextProto protocol_negotiated) { protocol_negotiated_ = protocol_negotiated; } +void SSLClientSocket::set_negotiation_extension( + SSLNegotiationExtension negotiation_extension) { + negotiation_extension_ = negotiation_extension; +} + bool SSLClientSocket::WasChannelIDSent() const { return channel_id_sent_; } @@ -158,7 +153,7 @@ void SSLClientSocket::set_stapled_ocsp_response_received( // static void SSLClientSocket::RecordChannelIDSupport( - ServerBoundCertService* server_bound_cert_service, + ChannelIDService* channel_id_service, bool negotiated_channel_id, bool channel_id_enabled, bool supports_ecc) { @@ -169,40 +164,62 @@ void SSLClientSocket::RecordChannelIDSupport( CLIENT_AND_SERVER = 2, CLIENT_NO_ECC = 3, CLIENT_BAD_SYSTEM_TIME = 4, - CLIENT_NO_SERVER_BOUND_CERT_SERVICE = 5, - DOMAIN_BOUND_CERT_USAGE_MAX + CLIENT_NO_CHANNEL_ID_SERVICE = 5, + CHANNEL_ID_USAGE_MAX } supported = DISABLED; if (negotiated_channel_id) { supported = CLIENT_AND_SERVER; } else if (channel_id_enabled) { - if (!server_bound_cert_service) - supported = CLIENT_NO_SERVER_BOUND_CERT_SERVICE; + if (!channel_id_service) + supported = CLIENT_NO_CHANNEL_ID_SERVICE; else if (!supports_ecc) supported = CLIENT_NO_ECC; - else if (!server_bound_cert_service->IsSystemTimeValid()) + else if (!channel_id_service->IsSystemTimeValid()) supported = CLIENT_BAD_SYSTEM_TIME; else supported = CLIENT_ONLY; } UMA_HISTOGRAM_ENUMERATION("DomainBoundCerts.Support", supported, - DOMAIN_BOUND_CERT_USAGE_MAX); + CHANNEL_ID_USAGE_MAX); +} + +// static +void SSLClientSocket::RecordConnectionTypeMetrics(int ssl_version) { + UpdateConnectionTypeHistograms(CONNECTION_SSL); + switch (ssl_version) { + case SSL_CONNECTION_VERSION_SSL2: + UpdateConnectionTypeHistograms(CONNECTION_SSL_SSL2); + break; + case SSL_CONNECTION_VERSION_SSL3: + UpdateConnectionTypeHistograms(CONNECTION_SSL_SSL3); + break; + case SSL_CONNECTION_VERSION_TLS1: + UpdateConnectionTypeHistograms(CONNECTION_SSL_TLS1); + break; + case SSL_CONNECTION_VERSION_TLS1_1: + UpdateConnectionTypeHistograms(CONNECTION_SSL_TLS1_1); + break; + case SSL_CONNECTION_VERSION_TLS1_2: + UpdateConnectionTypeHistograms(CONNECTION_SSL_TLS1_2); + break; + } } // static bool SSLClientSocket::IsChannelIDEnabled( const SSLConfig& ssl_config, - ServerBoundCertService* server_bound_cert_service) { + ChannelIDService* channel_id_service) { if (!ssl_config.channel_id_enabled) return false; - if (!server_bound_cert_service) { - DVLOG(1) << "NULL server_bound_cert_service_, not enabling channel ID."; + if (!channel_id_service) { + DVLOG(1) << "NULL channel_id_service_, not enabling channel ID."; return false; } if (!crypto::ECPrivateKey::IsSupported()) { DVLOG(1) << "Elliptic Curve not supported, not enabling channel ID."; return false; } - if (!server_bound_cert_service->IsSystemTimeValid()) { + if (!channel_id_service->IsSystemTimeValid()) { DVLOG(1) << "System time is not within the supported range for certificate " "generation, not enabling channel ID."; return false; @@ -210,4 +227,66 @@ bool SSLClientSocket::IsChannelIDEnabled( return true; } +// static +std::vector<uint8_t> SSLClientSocket::SerializeNextProtos( + const std::vector<std::string>& next_protos) { + // Do a first pass to determine the total length. + size_t wire_length = 0; + for (std::vector<std::string>::const_iterator i = next_protos.begin(); + i != next_protos.end(); ++i) { + if (i->size() > 255) { + LOG(WARNING) << "Ignoring overlong NPN/ALPN protocol: " << *i; + continue; + } + if (i->size() == 0) { + LOG(WARNING) << "Ignoring empty NPN/ALPN protocol"; + continue; + } + wire_length += i->size(); + wire_length++; + } + + // Allocate memory for the result and fill it in. + std::vector<uint8_t> wire_protos; + wire_protos.reserve(wire_length); + for (std::vector<std::string>::const_iterator i = next_protos.begin(); + i != next_protos.end(); i++) { + if (i->size() == 0 || i->size() > 255) + continue; + wire_protos.push_back(i->size()); + wire_protos.resize(wire_protos.size() + i->size()); + memcpy(&wire_protos[wire_protos.size() - i->size()], + i->data(), i->size()); + } + DCHECK_EQ(wire_protos.size(), wire_length); + + return wire_protos; +} + +void SSLClientSocket::RecordNegotiationExtension() { + if (negotiation_extension_ == kExtensionUnknown) + return; + std::string proto; + SSLClientSocket::NextProtoStatus status = GetNextProto(&proto); + if (status == kNextProtoUnsupported) + return; + // Convert protocol into numerical value for histogram. + NextProto protocol_negotiated = SSLClientSocket::NextProtoFromString(proto); + base::HistogramBase::Sample sample = + static_cast<base::HistogramBase::Sample>(protocol_negotiated); + // In addition to the protocol negotiated, we want to record which TLS + // extension was used, and in case of NPN, whether there was overlap between + // server and client list of supported protocols. + if (negotiation_extension_ == kExtensionNPN) { + if (status == kNextProtoNoOverlap) { + sample += 1000; + } else { + sample += 500; + } + } else { + DCHECK_EQ(kExtensionALPN, negotiation_extension_); + } + UMA_HISTOGRAM_SPARSE_SLOWLY("Net.SSLProtocolNegotiation", sample); +} + } // namespace net diff --git a/chromium/net/socket/ssl_client_socket.h b/chromium/net/socket/ssl_client_socket.h index a43e58cc26b..7adfa8c626a 100644 --- a/chromium/net/socket/ssl_client_socket.h +++ b/chromium/net/socket/ssl_client_socket.h @@ -17,7 +17,9 @@ namespace net { class CertVerifier; +class ChannelIDService; class CTVerifier; +class HostPortPair; class ServerBoundCertService; class SSLCertRequestInfo; struct SSLConfig; @@ -30,23 +32,23 @@ class X509Certificate; struct SSLClientSocketContext { SSLClientSocketContext() : cert_verifier(NULL), - server_bound_cert_service(NULL), + channel_id_service(NULL), transport_security_state(NULL), cert_transparency_verifier(NULL) {} SSLClientSocketContext(CertVerifier* cert_verifier_arg, - ServerBoundCertService* server_bound_cert_service_arg, + ChannelIDService* channel_id_service_arg, TransportSecurityState* transport_security_state_arg, CTVerifier* cert_transparency_verifier_arg, const std::string& ssl_session_cache_shard_arg) : cert_verifier(cert_verifier_arg), - server_bound_cert_service(server_bound_cert_service_arg), + channel_id_service(channel_id_service_arg), transport_security_state(transport_security_state_arg), cert_transparency_verifier(cert_transparency_verifier_arg), ssl_session_cache_shard(ssl_session_cache_shard_arg) {} CertVerifier* cert_verifier; - ServerBoundCertService* server_bound_cert_service; + ChannelIDService* channel_id_service; TransportSecurityState* transport_security_state; CTVerifier* cert_transparency_verifier; // ssl_session_cache_shard is an opaque string that identifies a shard of the @@ -77,9 +79,49 @@ class NET_EXPORT SSLClientSocket : public SSLSocket { // the first protocol in our list. }; + // TLS extension used to negotiate protocol. + enum SSLNegotiationExtension { + kExtensionUnknown, + kExtensionALPN, + kExtensionNPN, + }; + // StreamSocket: - virtual bool WasNpnNegotiated() const OVERRIDE; - virtual NextProto GetNegotiatedProtocol() const OVERRIDE; + bool WasNpnNegotiated() const override; + NextProto GetNegotiatedProtocol() const override; + + // Computes a unique key string for the SSL session cache. + virtual std::string GetSessionCacheKey() const = 0; + + // Returns true if there is a cache entry in the SSL session cache + // for the cache key of the SSL socket. + // + // The cache key consists of a host and port concatenated with a session + // cache shard. These two strings are passed to the constructor of most + // subclasses of SSLClientSocket. + virtual bool InSessionCache() const = 0; + + // Sets |callback| to be run when the handshake has fully completed. + // For example, in the case of False Start, Connect() will return + // early, before the peer's TLS Finished message has been verified, + // in order to allow the caller to call Write() and send application + // data with the client's Finished message. + // In such situations, |callback| will be invoked sometime after + // Connect() - either during a Write() or Read() call, and before + // invoking the Read() or Write() callback. + // Otherwise, during a traditional TLS connection (i.e. no False + // Start), this will be called right before the Connect() callback + // is called. + // + // Note that it's not valid to mutate this socket during such + // callbacks, including deleting the socket. + // + // TODO(mshelley): Provide additional details about whether or not + // the handshake actually succeeded or not. This can be inferred + // from the result to Connect()/Read()/Write(), but may be useful + // to inform here as well. + virtual void SetHandshakeCompletionCallback( + const base::Closure& callback) = 0; // Gets the SSL CertificateRequest info of the socket after Connect failed // with ERR_SSL_CLIENT_AUTH_CERT_NEEDED. @@ -93,9 +135,7 @@ class NET_EXPORT SSLClientSocket : public SSLSocket { // kNextProtoNegotiated: *proto is set to the negotiated protocol. // kNextProtoNoOverlap: *proto is set to the first protocol in the // supported list. - // *server_protos is set to the server advertised protocols. - virtual NextProtoStatus GetNextProto(std::string* proto, - std::string* server_protos) = 0; + virtual NextProtoStatus GetNextProto(std::string* proto) = 0; static NextProto NextProtoFromString(const std::string& proto_string); @@ -103,10 +143,6 @@ class NET_EXPORT SSLClientSocket : public SSLSocket { static const char* NextProtoStatusToString(const NextProtoStatus status); - // Can be used with the second argument(|server_protos|) of |GetNextProto| to - // construct a comma separated string of server advertised protocols. - static std::string ServerProtosToString(const std::string& server_protos); - static bool IgnoreCertError(int error, int load_flags); // ClearSessionCache clears the SSL session cache, used to resume SSL @@ -121,9 +157,11 @@ class NET_EXPORT SSLClientSocket : public SSLSocket { virtual void set_protocol_negotiated(NextProto protocol_negotiated); - // Returns the ServerBoundCertService used by this socket, or NULL if - // server bound certificates are not supported. - virtual ServerBoundCertService* GetServerBoundCertService() const = 0; + void set_negotiation_extension(SSLNegotiationExtension negotiation_extension); + + // Returns the ChannelIDService used by this socket, or NULL if + // channel ids are not supported. + virtual ChannelIDService* GetChannelIDService() const = 0; // Returns true if a channel ID was sent on this connection. // This may be useful for protocols, like SPDY, which allow the same @@ -133,6 +171,10 @@ class NET_EXPORT SSLClientSocket : public SSLSocket { // Public for ssl_client_socket_openssl_unittest.cc. virtual bool WasChannelIDSent() const; + // Record which TLS extension was used to negotiate protocol and protocol + // chosen in a UMA histogram. + void RecordNegotiationExtension(); + protected: virtual void set_channel_id_sent(bool channel_id_sent); @@ -145,15 +187,23 @@ class NET_EXPORT SSLClientSocket : public SSLSocket { // Records histograms for channel id support during full handshakes - resumed // handshakes are ignored. static void RecordChannelIDSupport( - ServerBoundCertService* server_bound_cert_service, + ChannelIDService* channel_id_service, bool negotiated_channel_id, bool channel_id_enabled, bool supports_ecc); + // Records ConnectionType histograms for a successful SSL connection. + static void RecordConnectionTypeMetrics(int ssl_version); + // Returns whether TLS channel ID is enabled. static bool IsChannelIDEnabled( const SSLConfig& ssl_config, - ServerBoundCertService* server_bound_cert_service); + ChannelIDService* channel_id_service); + + // Serializes |next_protos| in the wire format for ALPN: protocols are listed + // in order, each prefixed by a one-byte length. + static std::vector<uint8_t> SerializeNextProtos( + const std::vector<std::string>& next_protos); // For unit testing only. // Returns the unverified certificate chain as presented by server. @@ -185,6 +235,8 @@ class NET_EXPORT SSLClientSocket : public SSLSocket { bool signed_cert_timestamps_received_; // True if a stapled OCSP response was received. bool stapled_ocsp_response_received_; + // Protocol negotiation extension used. + SSLNegotiationExtension negotiation_extension_; }; } // namespace net diff --git a/chromium/net/socket/ssl_client_socket_nss.cc b/chromium/net/socket/ssl_client_socket_nss.cc index 9f40a78371e..08cf2c55f51 100644 --- a/chromium/net/socket/ssl_client_socket_nss.cc +++ b/chromium/net/socket/ssl_client_socket_nss.cc @@ -71,6 +71,7 @@ #include "base/logging.h" #include "base/memory/singleton.h" #include "base/metrics/histogram.h" +#include "base/profiler/scoped_tracker.h" #include "base/single_thread_task_runner.h" #include "base/stl_util.h" #include "base/strings/string_number_conversions.h" @@ -85,7 +86,6 @@ #include "crypto/rsa_private_key.h" #include "crypto/scoped_nss_types.h" #include "net/base/address_list.h" -#include "net/base/connection_type_histograms.h" #include "net/base/dns_util.h" #include "net/base/io_buffer.h" #include "net/base/net_errors.h" @@ -93,7 +93,7 @@ #include "net/cert/asn1_util.h" #include "net/cert/cert_status_flags.h" #include "net/cert/cert_verifier.h" -#include "net/cert/ct_objects_extractor.h" +#include "net/cert/ct_ev_whitelist.h" #include "net/cert/ct_verifier.h" #include "net/cert/ct_verify_result.h" #include "net/cert/scoped_nss_types.h" @@ -105,7 +105,6 @@ #include "net/ocsp/nss_ocsp.h" #include "net/socket/client_socket_handle.h" #include "net/socket/nss_ssl_util.h" -#include "net/socket/ssl_error_params.h" #include "net/ssl/ssl_cert_request_info.h" #include "net/ssl/ssl_connection_status_flags.h" #include "net/ssl/ssl_info.h" @@ -408,7 +407,7 @@ struct HandshakeState { void Reset() { next_proto_status = SSLClientSocket::kNextProtoUnsupported; next_proto.clear(); - server_protos.clear(); + negotiation_extension_ = SSLClientSocket::kExtensionUnknown; channel_id_sent = false; server_cert_chain.Reset(NULL); server_cert = NULL; @@ -422,8 +421,9 @@ struct HandshakeState { // negotiated protocol stored in |next_proto|. SSLClientSocket::NextProtoStatus next_proto_status; std::string next_proto; - // If the server supports NPN, the protocols supported by the server. - std::string server_protos; + + // TLS extension used for protocol negotiation. + SSLClientSocket::SSLNegotiationExtension negotiation_extension_; // True if a channel ID was sent. bool channel_id_sent; @@ -474,18 +474,6 @@ int MapNSSClientError(PRErrorCode err) { } } -// Map NSS error code from the first SSL handshake to network error code. -int MapNSSClientHandshakeError(PRErrorCode err) { - switch (err) { - // If the server closed on us, it is a protocol error. - // Some TLS-intolerant servers do this when we request TLS. - case PR_END_OF_FILE_ERROR: - return ERR_SSL_PROTOCOL_ERROR; - default: - return MapNSSClientError(err); - } -} - } // namespace // SSLClientSocketNSS::Core provides a thread-safe, ref-counted core that is @@ -524,7 +512,7 @@ int MapNSSClientHandshakeError(PRErrorCode err) { // 2) NSS Task Runner: Prepare data to go from NSS to an IO function: // (BufferRecv, BufferSend) // 3) Network Task Runner: Perform IO on that data (DoBufferRecv, -// DoBufferSend, DoGetDomainBoundCert, OnGetDomainBoundCertComplete) +// DoBufferSend, DoGetChannelID, OnGetChannelIDComplete) // 4) Both Task Runners: Callback for asynchronous completion or to marshal // data from the network task runner back to NSS (BufferRecvComplete, // BufferSendComplete, OnHandshakeIOComplete) @@ -592,7 +580,7 @@ class SSLClientSocketNSS::Core : public base::RefCountedThreadSafe<Core> { // that their lifetimes match that of the owning SSLClientSocketNSS. // // The caller retains ownership of |transport|, |net_log|, and - // |server_bound_cert_service|, and they will not be accessed once Detach() + // |channel_id_service|, and they will not be accessed once Detach() // has been called. Core(base::SequencedTaskRunner* network_task_runner, base::SequencedTaskRunner* nss_task_runner, @@ -600,7 +588,7 @@ class SSLClientSocketNSS::Core : public base::RefCountedThreadSafe<Core> { const HostPortPair& host_and_port, const SSLConfig& ssl_config, BoundNetLog* net_log, - ServerBoundCertService* server_bound_cert_service); + ChannelIDService* channel_id_service); // Called on the network task runner. // Transfers ownership of |socket|, an NSS SSL socket, and |buffers|, the @@ -720,7 +708,7 @@ class SSLClientSocketNSS::Core : public base::RefCountedThreadSafe<Core> { // Handles an NSS error generated while handshaking or performing IO. // Returns a network error code mapped from the original NSS error. - int HandleNSSError(PRErrorCode error, bool handshake_error); + int HandleNSSError(PRErrorCode error); int DoHandshakeLoop(int last_io_result); int DoReadLoop(int result); @@ -754,7 +742,7 @@ class SSLClientSocketNSS::Core : public base::RefCountedThreadSafe<Core> { // key into a SECKEYPublicKey and SECKEYPrivateKey. Returns OK upon success // and an error code otherwise. // Requires |domain_bound_private_key_| and |domain_bound_cert_| to have been - // set by a call to ServerBoundCertService->GetDomainBoundCert. The caller + // set by a call to ChannelIDService->GetChannelID. The caller // takes ownership of the |*cert| and |*key|. int ImportChannelIDKeys(SECKEYPublicKey** public_key, SECKEYPrivateKey** key); @@ -775,15 +763,17 @@ class SSLClientSocketNSS::Core : public base::RefCountedThreadSafe<Core> { // UpdateNextProto gets any application-layer protocol that may have been // negotiated by the TLS connection. void UpdateNextProto(); + // Record TLS extension used for protocol negotiation (NPN or ALPN). + void UpdateExtensionUsed(); //////////////////////////////////////////////////////////////////////////// // Methods that are ONLY called on the network task runner: //////////////////////////////////////////////////////////////////////////// int DoBufferRecv(IOBuffer* buffer, int len); int DoBufferSend(IOBuffer* buffer, int len); - int DoGetDomainBoundCert(const std::string& host); + int DoGetChannelID(const std::string& host); - void OnGetDomainBoundCertComplete(int result); + void OnGetChannelIDComplete(int result); void OnHandshakeStateUpdated(const HandshakeState& state); void OnNSSBufferUpdated(int amount_in_read_buffer); void DidNSSRead(int result); @@ -832,8 +822,8 @@ class SSLClientSocketNSS::Core : public base::RefCountedThreadSafe<Core> { HandshakeState network_handshake_state_; // The service for retrieving Channel ID keys. May be NULL. - ServerBoundCertService* server_bound_cert_service_; - ServerBoundCertService::RequestHandle domain_bound_cert_request_handle_; + ChannelIDService* channel_id_service_; + ChannelIDService::RequestHandle domain_bound_cert_request_handle_; // The information about NSS task runner. int unhandled_buffer_size_; @@ -914,7 +904,7 @@ class SSLClientSocketNSS::Core : public base::RefCountedThreadSafe<Core> { // for the network task runner from the NSS task runner. base::WeakPtr<BoundNetLog> weak_net_log_; - // Written on the network task runner by the |server_bound_cert_service_|, + // Written on the network task runner by the |channel_id_service_|, // prior to invoking OnHandshakeIOComplete. // Read on the NSS task runner when once OnHandshakeIOComplete is invoked // on the NSS task runner. @@ -931,11 +921,11 @@ SSLClientSocketNSS::Core::Core( const HostPortPair& host_and_port, const SSLConfig& ssl_config, BoundNetLog* net_log, - ServerBoundCertService* server_bound_cert_service) + ChannelIDService* channel_id_service) : detached_(false), transport_(transport), weak_net_log_factory_(net_log), - server_bound_cert_service_(server_bound_cert_service), + channel_id_service_(channel_id_service), unhandled_buffer_size_(0), nss_waiting_read_(false), nss_waiting_write_(false), @@ -969,6 +959,7 @@ SSLClientSocketNSS::Core::~Core() { PR_Close(nss_fd_); nss_fd_ = NULL; } + nss_bufs_ = NULL; } bool SSLClientSocketNSS::Core::Init(PRFileDesc* socket, @@ -983,30 +974,11 @@ bool SSLClientSocketNSS::Core::Init(PRFileDesc* socket, SECStatus rv = SECSuccess; if (!ssl_config_.next_protos.empty()) { - size_t wire_length = 0; - for (std::vector<std::string>::const_iterator - i = ssl_config_.next_protos.begin(); - i != ssl_config_.next_protos.end(); ++i) { - if (i->size() > 255) { - LOG(WARNING) << "Ignoring overlong NPN/ALPN protocol: " << *i; - continue; - } - wire_length += i->size(); - wire_length++; - } - scoped_ptr<uint8[]> wire_protos(new uint8[wire_length]); - uint8* dst = wire_protos.get(); - for (std::vector<std::string>::const_iterator - i = ssl_config_.next_protos.begin(); - i != ssl_config_.next_protos.end(); i++) { - if (i->size() > 255) - continue; - *dst++ = i->size(); - memcpy(dst, i->data(), i->size()); - dst += i->size(); - } - DCHECK_EQ(dst, wire_protos.get() + wire_length); - rv = SSL_SetNextProtoNego(nss_fd_, wire_protos.get(), wire_length); + std::vector<uint8_t> wire_protos = + SerializeNextProtos(ssl_config_.next_protos); + rv = SSL_SetNextProtoNego( + nss_fd_, wire_protos.empty() ? NULL : &wire_protos[0], + wire_protos.size()); if (rv != SECSuccess) LogFailedNSSFunction(*weak_net_log_, "SSL_SetNextProtoNego", ""); rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_ALPN, PR_TRUE); @@ -1037,7 +1009,7 @@ bool SSLClientSocketNSS::Core::Init(PRFileDesc* socket, return false; } - if (IsChannelIDEnabled(ssl_config_, server_bound_cert_service_)) { + if (IsChannelIDEnabled(ssl_config_, channel_id_service_)) { rv = SSL_SetClientChannelIDCallback( nss_fd_, SSLClientSocketNSS::Core::ClientChannelIDHandler, this); if (rv != SECSuccess) { @@ -1658,6 +1630,11 @@ void SSLClientSocketNSS::Core::HandshakeCallback( } void SSLClientSocketNSS::Core::HandshakeSucceeded() { + // TODO(vadimt): Remove ScopedTracker below once crbug.com/424386 is fixed. + tracked_objects::ScopedTracker tracking_profile( + FROM_HERE_WITH_EXPLICIT_FUNCTION( + "424386 SSLClientSocketNSS::Core::HandshakeSucceeded")); + DCHECK(OnNSSTaskRunner()); PRBool last_handshake_resumed; @@ -1674,6 +1651,7 @@ void SSLClientSocketNSS::Core::HandshakeSucceeded() { UpdateStapledOCSPResponse(); UpdateConnectionStatus(); UpdateNextProto(); + UpdateExtensionUsed(); // Update the network task runners view of the handshake state whenever // a handshake has completed. @@ -1682,12 +1660,15 @@ void SSLClientSocketNSS::Core::HandshakeSucceeded() { nss_handshake_state_)); } -int SSLClientSocketNSS::Core::HandleNSSError(PRErrorCode nss_error, - bool handshake_error) { +int SSLClientSocketNSS::Core::HandleNSSError(PRErrorCode nss_error) { + // TODO(vadimt): Remove ScopedTracker below once crbug.com/424386 is fixed. + tracked_objects::ScopedTracker tracking_profile( + FROM_HERE_WITH_EXPLICIT_FUNCTION( + "424386 SSLClientSocketNSS::Core::HandleNSSError")); + DCHECK(OnNSSTaskRunner()); - int net_error = handshake_error ? MapNSSClientHandshakeError(nss_error) : - MapNSSClientError(nss_error); + int net_error = MapNSSClientError(nss_error); #if defined(OS_WIN) // On Windows, a handle to the HCRYPTPROV is cached in the X509Certificate @@ -1717,6 +1698,11 @@ int SSLClientSocketNSS::Core::HandleNSSError(PRErrorCode nss_error, } int SSLClientSocketNSS::Core::DoHandshakeLoop(int last_io_result) { + // TODO(vadimt): Remove ScopedTracker below once crbug.com/424386 is fixed. + tracked_objects::ScopedTracker tracking_profile( + FROM_HERE_WITH_EXPLICIT_FUNCTION( + "424386 SSLClientSocketNSS::Core::DoHandshakeLoop")); + DCHECK(OnNSSTaskRunner()); int rv = last_io_result; @@ -1753,6 +1739,11 @@ int SSLClientSocketNSS::Core::DoHandshakeLoop(int last_io_result) { } int SSLClientSocketNSS::Core::DoReadLoop(int result) { + // TODO(vadimt): Remove ScopedTracker below once crbug.com/424386 is fixed. + tracked_objects::ScopedTracker tracking_profile( + FROM_HERE_WITH_EXPLICIT_FUNCTION( + "424386 SSLClientSocketNSS::Core::DoReadLoop")); + DCHECK(OnNSSTaskRunner()); DCHECK(false_started_ || handshake_callback_called_); DCHECK_EQ(STATE_NONE, next_handshake_state_); @@ -1812,11 +1803,21 @@ int SSLClientSocketNSS::Core::DoWriteLoop(int result) { } int SSLClientSocketNSS::Core::DoHandshake() { + // TODO(vadimt): Remove ScopedTracker below once crbug.com/424386 is fixed. + tracked_objects::ScopedTracker tracking_profile( + FROM_HERE_WITH_EXPLICIT_FUNCTION( + "424386 SSLClientSocketNSS::Core::DoHandshake")); + DCHECK(OnNSSTaskRunner()); - int net_error = net::OK; + int net_error = OK; SECStatus rv = SSL_ForceHandshake(nss_fd_); + // TODO(vadimt): Remove ScopedTracker below once crbug.com/424386 is fixed. + tracked_objects::ScopedTracker tracking_profile1( + FROM_HERE_WITH_EXPLICIT_FUNCTION( + "424386 SSLClientSocketNSS::Core::DoHandshake 1")); + // Note: this function may be called multiple times during the handshake, so // even though channel id and client auth are separate else cases, they can // both be used during a single SSL handshake. @@ -1845,24 +1846,7 @@ int SSLClientSocketNSS::Core::DoHandshake() { } } else { PRErrorCode prerr = PR_GetError(); - net_error = HandleNSSError(prerr, true); - - // Some network devices that inspect application-layer packets seem to - // inject TCP reset packets to break the connections when they see - // TLS 1.1 in ClientHello or ServerHello. See http://crbug.com/130293. - // - // Only allow ERR_CONNECTION_RESET to trigger a fallback from TLS 1.1 or - // 1.2. We don't lose much in this fallback because the explicit IV for CBC - // mode in TLS 1.1 is approximated by record splitting in TLS 1.0. The - // fallback will be more painful for TLS 1.2 when we have GCM support. - // - // ERR_CONNECTION_RESET is a common network error, so we don't want it - // to trigger a version fallback in general, especially the TLS 1.0 -> - // SSL 3.0 fallback, which would drop TLS extensions. - if (prerr == PR_CONNECT_RESET_ERROR && - ssl_config_.version_max >= SSL_PROTOCOL_VERSION_TLS1_1) { - net_error = ERR_SSL_PROTOCOL_ERROR; - } + net_error = HandleNSSError(prerr); // If not done, stay in this state if (net_error == ERR_IO_PENDING) { @@ -1880,6 +1864,11 @@ int SSLClientSocketNSS::Core::DoHandshake() { } int SSLClientSocketNSS::Core::DoGetDBCertComplete(int result) { + // TODO(vadimt): Remove ScopedTracker below once crbug.com/424386 is fixed. + tracked_objects::ScopedTracker tracking_profile( + FROM_HERE_WITH_EXPLICIT_FUNCTION( + "424386 SSLClientSocketNSS::Core::DoGetDBCertComplete")); + SECStatus rv; PostOrRunCallback( FROM_HERE, @@ -1989,7 +1978,7 @@ int SSLClientSocketNSS::Core::DoPayloadRead() { // If *next_result == 0, then that indicates EOF, and no special error // handling is needed. pending_read_nss_error_ = PR_GetError(); - *next_result = HandleNSSError(pending_read_nss_error_, false); + *next_result = HandleNSSError(pending_read_nss_error_); if (rv > 0 && *next_result == ERR_IO_PENDING) { // If at least some data was read from PR_Read(), do not treat // insufficient data as an error to return in the next call to @@ -2051,7 +2040,7 @@ int SSLClientSocketNSS::Core::DoPayloadWrite() { if (prerr == PR_WOULD_BLOCK_ERROR) return ERR_IO_PENDING; - rv = HandleNSSError(prerr, false); + rv = HandleNSSError(prerr); PostOrRunCallback( FROM_HERE, base::Bind(&AddLogEventWithCallback, weak_net_log_, @@ -2064,6 +2053,11 @@ int SSLClientSocketNSS::Core::DoPayloadWrite() { // transport socket. Return true if some I/O performed, false // otherwise (error or ERR_IO_PENDING). bool SSLClientSocketNSS::Core::DoTransportIO() { + // TODO(vadimt): Remove ScopedTracker below once crbug.com/424386 is fixed. + tracked_objects::ScopedTracker tracking_profile( + FROM_HERE_WITH_EXPLICIT_FUNCTION( + "424386 SSLClientSocketNSS::Core::DoTransportIO")); + DCHECK(OnNSSTaskRunner()); bool network_moved = false; @@ -2240,6 +2234,11 @@ void SSLClientSocketNSS::Core::OnSendComplete(int result) { // callback. For Read() and Write(), that's what we want. But for Connect(), // the caller expects OK (i.e. 0) for success. void SSLClientSocketNSS::Core::DoConnectCallback(int rv) { + // TODO(vadimt): Remove ScopedTracker below once crbug.com/424386 is fixed. + tracked_objects::ScopedTracker tracking_profile( + FROM_HERE_WITH_EXPLICIT_FUNCTION( + "424386 SSLClientSocketNSS::Core::DoConnectCallback")); + DCHECK(OnNSSTaskRunner()); DCHECK_NE(rv, ERR_IO_PENDING); DCHECK(!user_connect_callback_.is_null()); @@ -2251,6 +2250,11 @@ void SSLClientSocketNSS::Core::DoConnectCallback(int rv) { } void SSLClientSocketNSS::Core::DoReadCallback(int rv) { + // TODO(vadimt): Remove ScopedTracker below once crbug.com/424386 is fixed. + tracked_objects::ScopedTracker tracking_profile( + FROM_HERE_WITH_EXPLICIT_FUNCTION( + "424386 SSLClientSocketNSS::Core::DoReadCallback")); + DCHECK(OnNSSTaskRunner()); DCHECK_NE(ERR_IO_PENDING, rv); DCHECK(!user_read_callback_.is_null()); @@ -2266,6 +2270,10 @@ void SSLClientSocketNSS::Core::DoReadCallback(int rv) { PostOrRunCallback( FROM_HERE, base::Bind(&Core::DidNSSRead, this, rv)); + // TODO(vadimt): Remove ScopedTracker below once crbug.com/418183 is fixed. + tracked_objects::ScopedTracker tracking_profile1( + FROM_HERE_WITH_EXPLICIT_FUNCTION( + "SSLClientSocketNSS::Core::DoReadCallback")); PostOrRunCallback( FROM_HERE, base::Bind(base::ResetAndReturn(&user_read_callback_), rv)); @@ -2314,12 +2322,12 @@ SECStatus SSLClientSocketNSS::Core::ClientChannelIDHandler( std::string host = core->host_and_port_.host(); int error = ERR_UNEXPECTED; if (core->OnNetworkTaskRunner()) { - error = core->DoGetDomainBoundCert(host); + error = core->DoGetChannelID(host); } else { bool posted = core->network_task_runner_->PostTask( FROM_HERE, base::Bind( - IgnoreResult(&Core::DoGetDomainBoundCert), + IgnoreResult(&Core::DoGetChannelID), core, host)); error = posted ? ERR_IO_PENDING : ERR_ABORTED; } @@ -2367,7 +2375,7 @@ int SSLClientSocketNSS::Core::ImportChannelIDKeys(SECKEYPublicKey** public_key, // Set the private key. if (!crypto::ECPrivateKey::ImportFromEncryptedPrivateKeyInfo( slot.get(), - ServerBoundCertService::kEPKIPassword, + ChannelIDService::kEPKIPassword, reinterpret_cast<const unsigned char*>( domain_bound_private_key_.data()), domain_bound_private_key_.size(), @@ -2459,6 +2467,10 @@ void SSLClientSocketNSS::Core::UpdateStapledOCSPResponse() { } void SSLClientSocketNSS::Core::UpdateConnectionStatus() { + // Note: This function may be called multiple times for a single connection + // if renegotiations occur. + nss_handshake_state_.ssl_connection_status = 0; + SSLChannelInfo channel_info; SECStatus ok = SSL_GetChannelInfo(nss_fd_, &channel_info, sizeof(channel_info)); @@ -2475,8 +2487,6 @@ void SSLClientSocketNSS::Core::UpdateConnectionStatus() { SSL_CONNECTION_COMPRESSION_MASK) << SSL_CONNECTION_COMPRESSION_SHIFT; - // NSS 3.14.x doesn't have a version macro for TLS 1.2 (because NSS didn't - // support it yet), so use 0x0303 directly. int version = SSL_CONNECTION_VERSION_UNKNOWN; if (channel_info.protocolVersion < SSL_LIBRARY_VERSION_3_0) { // All versions less than SSL_LIBRARY_VERSION_3_0 are treated as SSL @@ -2484,11 +2494,11 @@ void SSLClientSocketNSS::Core::UpdateConnectionStatus() { version = SSL_CONNECTION_VERSION_SSL2; } else if (channel_info.protocolVersion == SSL_LIBRARY_VERSION_3_0) { version = SSL_CONNECTION_VERSION_SSL3; - } else if (channel_info.protocolVersion == SSL_LIBRARY_VERSION_3_1_TLS) { + } else if (channel_info.protocolVersion == SSL_LIBRARY_VERSION_TLS_1_0) { version = SSL_CONNECTION_VERSION_TLS1; } else if (channel_info.protocolVersion == SSL_LIBRARY_VERSION_TLS_1_1) { version = SSL_CONNECTION_VERSION_TLS1_1; - } else if (channel_info.protocolVersion == 0x0303) { + } else if (channel_info.protocolVersion == SSL_LIBRARY_VERSION_TLS_1_2) { version = SSL_CONNECTION_VERSION_TLS1_2; } nss_handshake_state_.ssl_connection_status |= @@ -2508,26 +2518,6 @@ void SSLClientSocketNSS::Core::UpdateConnectionStatus() { VLOG(1) << "The server " << host_and_port_.ToString() << " does not support the TLS renegotiation_info extension."; } - UMA_HISTOGRAM_ENUMERATION("Net.RenegotiationExtensionSupported", - peer_supports_renego_ext, 2); - - // We would like to eliminate fallback to SSLv3 for non-buggy servers - // because of security concerns. For example, Google offers forward - // secrecy with ECDHE but that requires TLS 1.0. An attacker can block - // TLSv1 connections and force us to downgrade to SSLv3 and remove forward - // secrecy. - // - // Yngve from Opera has suggested using the renegotiation extension as an - // indicator that SSLv3 fallback was mistaken: - // tools.ietf.org/html/draft-pettersen-tls-version-rollback-removal-00 . - // - // As a first step, measure how often clients perform version fallback - // while the server advertises support secure renegotiation. - if (ssl_config_.version_fallback && - channel_info.protocolVersion == SSL_LIBRARY_VERSION_3_0) { - UMA_HISTOGRAM_BOOLEAN("Net.SSLv3FallbackToRenegoPatchedServer", - peer_supports_renego_ext == PR_TRUE); - } } if (ssl_config_.version_fallback) { @@ -2564,6 +2554,23 @@ void SSLClientSocketNSS::Core::UpdateNextProto() { } } +void SSLClientSocketNSS::Core::UpdateExtensionUsed() { + PRBool negotiated_extension; + SECStatus rv = SSL_HandshakeNegotiatedExtension(nss_fd_, + ssl_app_layer_protocol_xtn, + &negotiated_extension); + if (rv == SECSuccess && negotiated_extension) { + nss_handshake_state_.negotiation_extension_ = kExtensionALPN; + } else { + rv = SSL_HandshakeNegotiatedExtension(nss_fd_, + ssl_next_proto_nego_xtn, + &negotiated_extension); + if (rv == SECSuccess && negotiated_extension) { + nss_handshake_state_.negotiation_extension_ = kExtensionNPN; + } + } +} + void SSLClientSocketNSS::Core::RecordChannelIDSupportOnNSSTaskRunner() { DCHECK(OnNSSTaskRunner()); if (nss_handshake_state_.resumed_handshake) @@ -2587,7 +2594,7 @@ void SSLClientSocketNSS::Core::RecordChannelIDSupportOnNetworkTaskRunner( bool supports_ecc) const { DCHECK(OnNetworkTaskRunner()); - RecordChannelIDSupport(server_bound_cert_service_, + RecordChannelIDSupport(channel_id_service_, negotiated_channel_id, channel_id_enabled, supports_ecc); @@ -2637,19 +2644,19 @@ int SSLClientSocketNSS::Core::DoBufferSend(IOBuffer* send_buffer, int len) { return rv; } -int SSLClientSocketNSS::Core::DoGetDomainBoundCert(const std::string& host) { +int SSLClientSocketNSS::Core::DoGetChannelID(const std::string& host) { DCHECK(OnNetworkTaskRunner()); if (detached_) - return ERR_FAILED; + return ERR_ABORTED; weak_net_log_->BeginEvent(NetLog::TYPE_SSL_GET_DOMAIN_BOUND_CERT); - int rv = server_bound_cert_service_->GetOrCreateDomainBoundCert( + int rv = channel_id_service_->GetOrCreateChannelID( host, &domain_bound_private_key_, &domain_bound_cert_, - base::Bind(&Core::OnGetDomainBoundCertComplete, base::Unretained(this)), + base::Bind(&Core::OnGetChannelIDComplete, base::Unretained(this)), &domain_bound_cert_request_handle_); if (rv != ERR_IO_PENDING && !OnNSSTaskRunner()) { @@ -2696,6 +2703,11 @@ void SSLClientSocketNSS::Core::DidNSSWrite(int result) { } void SSLClientSocketNSS::Core::BufferSendComplete(int result) { + // TODO(vadimt): Remove ScopedTracker below once crbug.com/418183 is fixed. + tracked_objects::ScopedTracker tracking_profile( + FROM_HERE_WITH_EXPLICIT_FUNCTION( + "418183 DidCompleteReadWrite => Core::BufferSendComplete")); + if (!OnNSSTaskRunner()) { if (detached_) return; @@ -2729,7 +2741,7 @@ void SSLClientSocketNSS::Core::OnHandshakeIOComplete(int result) { DoConnectCallback(rv); } -void SSLClientSocketNSS::Core::OnGetDomainBoundCertComplete(int result) { +void SSLClientSocketNSS::Core::OnGetChannelIDComplete(int result) { DVLOG(1) << __FUNCTION__ << " " << result; DCHECK(OnNetworkTaskRunner()); @@ -2739,6 +2751,11 @@ void SSLClientSocketNSS::Core::OnGetDomainBoundCertComplete(int result) { void SSLClientSocketNSS::Core::BufferRecvComplete( IOBuffer* read_buffer, int result) { + // TODO(vadimt): Remove ScopedTracker below once crbug.com/418183 is fixed. + tracked_objects::ScopedTracker tracking_profile( + FROM_HERE_WITH_EXPLICIT_FUNCTION( + "418183 DidCompleteReadWrite => SSLClientSocketNSS::Core::...")); + DCHECK(read_buffer); if (!OnNSSTaskRunner()) { @@ -2814,7 +2831,7 @@ SSLClientSocketNSS::SSLClientSocketNSS( ssl_config_(ssl_config), cert_verifier_(context.cert_verifier), cert_transparency_verifier_(context.cert_transparency_verifier), - server_bound_cert_service_(context.server_bound_cert_service), + channel_id_service_(context.channel_id_service), ssl_session_cache_shard_(context.ssl_session_cache_shard), completed_handshake_(false), next_handshake_state_(STATE_NONE), @@ -2886,6 +2903,21 @@ bool SSLClientSocketNSS::GetSSLInfo(SSLInfo* ssl_info) { return true; } +std::string SSLClientSocketNSS::GetSessionCacheKey() const { + NOTIMPLEMENTED(); + return std::string(); +} + +bool SSLClientSocketNSS::InSessionCache() const { + // For now, always return true so that SSLConnectJobs are never held back. + return true; +} + +void SSLClientSocketNSS::SetHandshakeCompletionCallback( + const base::Closure& callback) { + NOTIMPLEMENTED(); +} + void SSLClientSocketNSS::GetSSLCertRequestInfo( SSLCertRequestInfo* cert_request_info) { EnterFunction(""); @@ -2932,10 +2964,8 @@ int SSLClientSocketNSS::GetTLSUniqueChannelBinding(std::string* out) { } SSLClientSocket::NextProtoStatus -SSLClientSocketNSS::GetNextProto(std::string* proto, - std::string* server_protos) { +SSLClientSocketNSS::GetNextProto(std::string* proto) { *proto = core_->state().next_proto; - *server_protos = core_->state().server_protos; return core_->state().next_proto_status; } @@ -3125,7 +3155,7 @@ void SSLClientSocketNSS::InitCore() { host_and_port_, ssl_config_, &net_log_, - server_bound_cert_service_); + channel_id_service_); } int SSLClientSocketNSS::InitializeSSLOptions() { @@ -3374,6 +3404,11 @@ int SSLClientSocketNSS::DoHandshakeComplete(int result) { EnterFunction(result); if (result == OK) { + if (ssl_config_.version_fallback && + ssl_config_.version_max < ssl_config_.version_fallback_min) { + return ERR_SSL_FALLBACK_BEYOND_MINIMUM_VERSION; + } + // SSL handshake is completed. Let's verify the certificate. GotoState(STATE_VERIFY_CERT); // Done! @@ -3383,6 +3418,7 @@ int SSLClientSocketNSS::DoHandshakeComplete(int result) { !core_->state().sct_list_from_tls_extension.empty()); set_stapled_ocsp_response_received( !core_->state().stapled_ocsp_response.empty()); + set_negotiation_extension(core_->state().negotiation_extension_); LeaveFunction(result); return result; @@ -3468,56 +3504,38 @@ int SSLClientSocketNSS::DoVerifyCertComplete(int result) { // TODO(hclam): Skip logging if server cert was expected to be bad because // |server_cert_verify_result_| doesn't contain all the information about // the cert. - if (result == OK) - LogConnectionTypeMetrics(); - -#if defined(OFFICIAL_BUILD) && !defined(OS_ANDROID) && !defined(OS_IOS) - // Take care of any mandates for public key pinning. - // - // Pinning is only enabled for official builds to make sure that others don't - // end up with pins that cannot be easily updated. - // - // TODO(agl): We might have an issue here where a request for foo.example.com - // merges into a SPDY connection to www.example.com, and gets a different - // certificate. + if (result == OK) { + int ssl_version = + SSLConnectionStatusToVersion(core_->state().ssl_connection_status); + RecordConnectionTypeMetrics(ssl_version); + } - // Perform pin validation if, and only if, all these conditions obtain: - // - // * a TransportSecurityState object is available; - // * the server's certificate chain is valid (or suffers from only a minor - // error); - // * the server's certificate chain chains up to a known root (i.e. not a - // user-installed trust anchor); and - // * the build is recent (very old builds should fail open so that users - // have some chance to recover). - // const CertStatus cert_status = server_cert_verify_result_.cert_status; if (transport_security_state_ && (result == OK || (IsCertificateError(result) && IsCertStatusMinorError(cert_status))) && - server_cert_verify_result_.is_issued_by_known_root && - TransportSecurityState::IsBuildTimely()) { - bool sni_available = - ssl_config_.version_max >= SSL_PROTOCOL_VERSION_TLS1 || - ssl_config_.version_fallback; - const std::string& host = host_and_port_.host(); - - if (transport_security_state_->HasPublicKeyPins(host, sni_available)) { - if (!transport_security_state_->CheckPublicKeyPins( - host, - sni_available, - server_cert_verify_result_.public_key_hashes, - &pinning_failure_log_)) { - LOG(ERROR) << pinning_failure_log_; - result = ERR_SSL_PINNED_KEY_NOT_IN_CERT_CHAIN; - UMA_HISTOGRAM_BOOLEAN("Net.PublicKeyPinSuccess", false); - TransportSecurityState::ReportUMAOnPinFailure(host); - } else { - UMA_HISTOGRAM_BOOLEAN("Net.PublicKeyPinSuccess", true); - } + !transport_security_state_->CheckPublicKeyPins( + host_and_port_.host(), + server_cert_verify_result_.is_issued_by_known_root, + server_cert_verify_result_.public_key_hashes, + &pinning_failure_log_)) { + result = ERR_SSL_PINNED_KEY_NOT_IN_CERT_CHAIN; + } + + scoped_refptr<ct::EVCertsWhitelist> ev_whitelist = + SSLConfigService::GetEVCertsWhitelist(); + if (server_cert_verify_result_.cert_status & CERT_STATUS_IS_EV) { + if (ev_whitelist.get() && ev_whitelist->IsValid()) { + const SHA256HashValue fingerprint( + X509Certificate::CalculateFingerprint256( + server_cert_verify_result_.verified_cert->os_cert_handle())); + + UMA_HISTOGRAM_BOOLEAN( + "Net.SSL_EVCertificateInWhitelist", + ev_whitelist->ContainsCertificateHash( + std::string(reinterpret_cast<const char*>(fingerprint.data), 8))); } } -#endif if (result == OK) { // Only check Certificate Transparency if there were no other errors with @@ -3543,7 +3561,7 @@ void SSLClientSocketNSS::VerifyCT() { // gets all the data it needs for SCT verification and does not do any // external communication. int result = cert_transparency_verifier_->Verify( - server_cert_verify_result_.verified_cert, + server_cert_verify_result_.verified_cert.get(), core_->state().stapled_ocsp_response, core_->state().sct_list_from_tls_extension, &ct_verify_result_, @@ -3558,29 +3576,6 @@ void SSLClientSocketNSS::VerifyCT() { << ct_verify_result_.unknown_logs_scts.size(); } -void SSLClientSocketNSS::LogConnectionTypeMetrics() const { - UpdateConnectionTypeHistograms(CONNECTION_SSL); - int ssl_version = SSLConnectionStatusToVersion( - core_->state().ssl_connection_status); - switch (ssl_version) { - case SSL_CONNECTION_VERSION_SSL2: - UpdateConnectionTypeHistograms(CONNECTION_SSL_SSL2); - break; - case SSL_CONNECTION_VERSION_SSL3: - UpdateConnectionTypeHistograms(CONNECTION_SSL_SSL3); - break; - case SSL_CONNECTION_VERSION_TLS1: - UpdateConnectionTypeHistograms(CONNECTION_SSL_TLS1); - break; - case SSL_CONNECTION_VERSION_TLS1_1: - UpdateConnectionTypeHistograms(CONNECTION_SSL_TLS1_1); - break; - case SSL_CONNECTION_VERSION_TLS1_2: - UpdateConnectionTypeHistograms(CONNECTION_SSL_TLS1_2); - break; - }; -} - void SSLClientSocketNSS::EnsureThreadIdAssigned() const { base::AutoLock auto_lock(lock_); if (valid_thread_id_ != base::kInvalidThreadId) @@ -3621,8 +3616,8 @@ SSLClientSocketNSS::GetUnverifiedServerCertificateChain() const { return core_->state().server_cert.get(); } -ServerBoundCertService* SSLClientSocketNSS::GetServerBoundCertService() const { - return server_bound_cert_service_; +ChannelIDService* SSLClientSocketNSS::GetChannelIDService() const { + return channel_id_service_; } } // namespace net diff --git a/chromium/net/socket/ssl_client_socket_nss.h b/chromium/net/socket/ssl_client_socket_nss.h index e8cce574b64..71f09c0b82b 100644 --- a/chromium/net/socket/ssl_client_socket_nss.h +++ b/chromium/net/socket/ssl_client_socket_nss.h @@ -17,7 +17,6 @@ #include "base/synchronization/lock.h" #include "base/threading/platform_thread.h" #include "base/time/time.h" -#include "base/timer/timer.h" #include "net/base/completion_callback.h" #include "net/base/host_port_pair.h" #include "net/base/net_export.h" @@ -27,7 +26,7 @@ #include "net/cert/ct_verify_result.h" #include "net/cert/x509_certificate.h" #include "net/socket/ssl_client_socket.h" -#include "net/ssl/server_bound_cert_service.h" +#include "net/ssl/channel_id_service.h" #include "net/ssl/ssl_config_service.h" namespace base { @@ -38,9 +37,9 @@ namespace net { class BoundNetLog; class CertVerifier; +class ChannelIDService; class CTVerifier; class ClientSocketHandle; -class ServerBoundCertService; class SingleRequestCertVerifier; class TransportSecurityState; class X509Certificate; @@ -65,51 +64,52 @@ class SSLClientSocketNSS : public SSLClientSocket { const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context); - virtual ~SSLClientSocketNSS(); + ~SSLClientSocketNSS() override; // SSLClientSocket implementation. - virtual void GetSSLCertRequestInfo( - SSLCertRequestInfo* cert_request_info) OVERRIDE; - virtual NextProtoStatus GetNextProto(std::string* proto, - std::string* server_protos) OVERRIDE; + std::string GetSessionCacheKey() const override; + bool InSessionCache() const override; + void SetHandshakeCompletionCallback(const base::Closure& callback) override; + void GetSSLCertRequestInfo(SSLCertRequestInfo* cert_request_info) override; + NextProtoStatus GetNextProto(std::string* proto) override; // SSLSocket implementation. - virtual int ExportKeyingMaterial(const base::StringPiece& label, - bool has_context, - const base::StringPiece& context, - unsigned char* out, - unsigned int outlen) OVERRIDE; - virtual int GetTLSUniqueChannelBinding(std::string* out) OVERRIDE; + int ExportKeyingMaterial(const base::StringPiece& label, + bool has_context, + const base::StringPiece& context, + unsigned char* out, + unsigned int outlen) override; + int GetTLSUniqueChannelBinding(std::string* out) override; // StreamSocket implementation. - virtual int Connect(const CompletionCallback& callback) OVERRIDE; - virtual void Disconnect() OVERRIDE; - virtual bool IsConnected() const OVERRIDE; - virtual bool IsConnectedAndIdle() const OVERRIDE; - virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE; - virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE; - virtual const BoundNetLog& NetLog() const OVERRIDE; - virtual void SetSubresourceSpeculation() OVERRIDE; - virtual void SetOmniboxSpeculation() OVERRIDE; - virtual bool WasEverUsed() const OVERRIDE; - virtual bool UsingTCPFastOpen() const OVERRIDE; - virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; + int Connect(const CompletionCallback& callback) override; + void Disconnect() override; + bool IsConnected() const override; + bool IsConnectedAndIdle() const override; + int GetPeerAddress(IPEndPoint* address) const override; + int GetLocalAddress(IPEndPoint* address) const override; + const BoundNetLog& NetLog() const override; + void SetSubresourceSpeculation() override; + void SetOmniboxSpeculation() override; + bool WasEverUsed() const override; + bool UsingTCPFastOpen() const override; + bool GetSSLInfo(SSLInfo* ssl_info) override; // Socket implementation. - virtual int Read(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) OVERRIDE; - virtual int Write(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) OVERRIDE; - virtual int SetReceiveBufferSize(int32 size) OVERRIDE; - virtual int SetSendBufferSize(int32 size) OVERRIDE; - virtual ServerBoundCertService* GetServerBoundCertService() const OVERRIDE; + int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override; + int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override; + int SetReceiveBufferSize(int32 size) override; + int SetSendBufferSize(int32 size) override; + ChannelIDService* GetChannelIDService() const override; protected: // SSLClientSocket implementation. - virtual scoped_refptr<X509Certificate> GetUnverifiedServerCertificateChain() - const OVERRIDE; + scoped_refptr<X509Certificate> GetUnverifiedServerCertificateChain() + const override; private: // Helper class to handle marshalling any NSS interaction to and from the @@ -144,8 +144,6 @@ class SSLClientSocketNSS : public SSLClientSocket { void VerifyCT(); - void LogConnectionTypeMetrics() const; - // The following methods are for debugging bug 65948. Will remove this code // after fixing bug 65948. void EnsureThreadIdAssigned() const; @@ -178,7 +176,7 @@ class SSLClientSocketNSS : public SSLClientSocket { CTVerifier* cert_transparency_verifier_; // The service for retrieving Channel ID keys. May be NULL. - ServerBoundCertService* server_bound_cert_service_; + ChannelIDService* channel_id_service_; // ssl_session_cache_shard_ is an opaque string that partitions the SSL // session cache. i.e. sessions created with one value will not attempt to @@ -202,7 +200,7 @@ class SSLClientSocketNSS : public SSLClientSocket { TransportSecurityState* transport_security_state_; // pinning_failure_log contains a message produced by - // TransportSecurityState::DomainState::CheckPublicKeyPins in the event of a + // TransportSecurityState::CheckPublicKeyPins in the event of a // pinning failure. It is a (somewhat) human-readable string. std::string pinning_failure_log_; diff --git a/chromium/net/socket/ssl_client_socket_openssl.cc b/chromium/net/socket/ssl_client_socket_openssl.cc index 4ff8d438e96..9fdfe38ccdd 100644 --- a/chromium/net/socket/ssl_client_socket_openssl.cc +++ b/chromium/net/socket/ssl_client_socket_openssl.cc @@ -7,29 +7,44 @@ #include "net/socket/ssl_client_socket_openssl.h" +#include <errno.h> +#include <openssl/bio.h> #include <openssl/err.h> -#include <openssl/opensslv.h> #include <openssl/ssl.h> #include "base/bind.h" #include "base/callback_helpers.h" +#include "base/environment.h" #include "base/memory/singleton.h" #include "base/metrics/histogram.h" +#include "base/strings/string_piece.h" #include "base/synchronization/lock.h" #include "crypto/ec_private_key.h" #include "crypto/openssl_util.h" +#include "crypto/scoped_openssl_types.h" #include "net/base/net_errors.h" #include "net/cert/cert_verifier.h" +#include "net/cert/ct_ev_whitelist.h" +#include "net/cert/ct_verifier.h" #include "net/cert/single_request_cert_verifier.h" #include "net/cert/x509_certificate_net_log_param.h" -#include "net/socket/openssl_ssl_util.h" -#include "net/socket/ssl_error_params.h" +#include "net/cert/x509_util_openssl.h" +#include "net/http/transport_security_state.h" #include "net/socket/ssl_session_cache_openssl.h" -#include "net/ssl/openssl_client_key_store.h" #include "net/ssl/ssl_cert_request_info.h" #include "net/ssl/ssl_connection_status_flags.h" #include "net/ssl/ssl_info.h" +#if defined(OS_WIN) +#include "base/win/windows_version.h" +#endif + +#if defined(USE_OPENSSL_CERTS) +#include "net/ssl/openssl_client_key_store.h" +#else +#include "net/ssl/openssl_platform_key.h" +#endif + namespace net { namespace { @@ -51,6 +66,14 @@ const int kNoPendingReadResult = 1; // the server supports NPN, choosing "http/1.1" is the best answer. const char kDefaultSupportedNPNProtocol[] = "http/1.1"; +void FreeX509Stack(STACK_OF(X509)* ptr) { + sk_X509_pop_free(ptr, X509_free); +} + +typedef crypto::ScopedOpenSSL<X509, X509_free>::Type ScopedX509; +typedef crypto::ScopedOpenSSL<STACK_OF(X509), FreeX509Stack>::Type + ScopedX509Stack; + #if OPENSSL_VERSION_NUMBER < 0x1000103fL // This method doesn't seem to have made it into the OpenSSL headers. unsigned long SSL_CIPHER_get_id(const SSL_CIPHER* cipher) { return cipher->id; } @@ -78,22 +101,43 @@ int GetNetSSLVersion(SSL* ssl) { return SSL_CONNECTION_VERSION_SSL3; case TLS1_VERSION: return SSL_CONNECTION_VERSION_TLS1; - case 0x0302: + case TLS1_1_VERSION: return SSL_CONNECTION_VERSION_TLS1_1; - case 0x0303: + case TLS1_2_VERSION: return SSL_CONNECTION_VERSION_TLS1_2; default: return SSL_CONNECTION_VERSION_UNKNOWN; } } -// Compute a unique key string for the SSL session cache. |socket| is an -// input socket object. Return a string. -std::string GetSocketSessionCacheKey(const SSLClientSocketOpenSSL& socket) { - std::string result = socket.host_and_port().ToString(); - result.append("/"); - result.append(socket.ssl_session_cache_shard()); - return result; +ScopedX509 OSCertHandleToOpenSSL( + X509Certificate::OSCertHandle os_handle) { +#if defined(USE_OPENSSL_CERTS) + return ScopedX509(X509Certificate::DupOSCertHandle(os_handle)); +#else // !defined(USE_OPENSSL_CERTS) + std::string der_encoded; + if (!X509Certificate::GetDEREncoded(os_handle, &der_encoded)) + return ScopedX509(); + const uint8_t* bytes = reinterpret_cast<const uint8_t*>(der_encoded.data()); + return ScopedX509(d2i_X509(NULL, &bytes, der_encoded.size())); +#endif // defined(USE_OPENSSL_CERTS) +} + +ScopedX509Stack OSCertHandlesToOpenSSL( + const X509Certificate::OSCertHandles& os_handles) { + ScopedX509Stack stack(sk_X509_new_null()); + for (size_t i = 0; i < os_handles.size(); i++) { + ScopedX509 x509 = OSCertHandleToOpenSSL(os_handles[i]); + if (!x509) + return ScopedX509Stack(); + sk_X509_push(stack.get(), x509.release()); + } + return stack.Pass(); +} + +int LogErrorCallback(const char* str, size_t len, void* context) { + LOG(ERROR) << base::StringPiece(str, len); + return 1; } } // namespace @@ -126,34 +170,42 @@ class SSLClientSocketOpenSSL::SSLContext { ssl_ctx_.reset(SSL_CTX_new(SSLv23_client_method())); session_cache_.Reset(ssl_ctx_.get(), kDefaultSessionCacheConfig); SSL_CTX_set_cert_verify_callback(ssl_ctx_.get(), CertVerifyCallback, NULL); - SSL_CTX_set_client_cert_cb(ssl_ctx_.get(), ClientCertCallback); - SSL_CTX_set_channel_id_cb(ssl_ctx_.get(), ChannelIDCallback); + SSL_CTX_set_cert_cb(ssl_ctx_.get(), ClientCertRequestCallback, NULL); SSL_CTX_set_verify(ssl_ctx_.get(), SSL_VERIFY_PEER, NULL); // TODO(kristianm): Only select this if ssl_config_.next_proto is not empty. // It would be better if the callback were not a global setting, // but that is an OpenSSL issue. SSL_CTX_set_next_proto_select_cb(ssl_ctx_.get(), SelectNextProtoCallback, NULL); + ssl_ctx_->tlsext_channel_id_enabled_new = 1; + + scoped_ptr<base::Environment> env(base::Environment::Create()); + std::string ssl_keylog_file; + if (env->GetVar("SSLKEYLOGFILE", &ssl_keylog_file) && + !ssl_keylog_file.empty()) { + crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE); + BIO* bio = BIO_new_file(ssl_keylog_file.c_str(), "a"); + if (!bio) { + LOG(ERROR) << "Failed to open " << ssl_keylog_file; + ERR_print_errors_cb(&LogErrorCallback, NULL); + } else { + SSL_CTX_set_keylog_bio(ssl_ctx_.get(), bio); + } + } } static std::string GetSessionCacheKey(const SSL* ssl) { SSLClientSocketOpenSSL* socket = GetInstance()->GetClientSocketFromSSL(ssl); DCHECK(socket); - return GetSocketSessionCacheKey(*socket); + return socket->GetSessionCacheKey(); } static SSLSessionCacheOpenSSL::Config kDefaultSessionCacheConfig; - static int ClientCertCallback(SSL* ssl, X509** x509, EVP_PKEY** pkey) { - SSLClientSocketOpenSSL* socket = GetInstance()->GetClientSocketFromSSL(ssl); - CHECK(socket); - return socket->ClientCertRequestCallback(ssl, x509, pkey); - } - - static void ChannelIDCallback(SSL* ssl, EVP_PKEY** pkey) { + static int ClientCertRequestCallback(SSL* ssl, void* arg) { SSLClientSocketOpenSSL* socket = GetInstance()->GetClientSocketFromSSL(ssl); - CHECK(socket); - socket->ChannelIDRequestCallback(ssl, pkey); + DCHECK(socket); + return socket->ClientCertRequestCallback(ssl); } static int CertVerifyCallback(X509_STORE_CTX *store_ctx, void *arg) { @@ -177,7 +229,7 @@ class SSLClientSocketOpenSSL::SSLContext { // SSLClientSocketOpenSSL object from an SSL instance. int ssl_socket_data_index_; - crypto::ScopedOpenSSL<SSL_CTX, SSL_CTX_free> ssl_ctx_; + crypto::ScopedOpenSSL<SSL_CTX, SSL_CTX_free>::Type ssl_ctx_; // |session_cache_| must be destroyed before |ssl_ctx_|. SSLSessionCacheOpenSSL session_cache_; }; @@ -200,7 +252,7 @@ class SSLClientSocketOpenSSL::PeerCertificateChain { void Reset(STACK_OF(X509)* chain); // Note that when USE_OPENSSL is defined, OSCertHandle is X509* - const scoped_refptr<X509Certificate>& AsOSChain() const { return os_chain_; } + scoped_refptr<X509Certificate> AsOSChain() const; size_t size() const { if (!openssl_chain_.get()) @@ -208,23 +260,17 @@ class SSLClientSocketOpenSSL::PeerCertificateChain { return sk_X509_num(openssl_chain_.get()); } - X509* operator[](size_t index) const { + bool empty() const { + return size() == 0; + } + + X509* Get(size_t index) const { DCHECK_LT(index, size()); return sk_X509_value(openssl_chain_.get(), index); } - bool IsValid() { return os_chain_.get() && openssl_chain_.get(); } - private: - static void FreeX509Stack(STACK_OF(X509)* cert_chain) { - sk_X509_pop_free(cert_chain, X509_free); - } - - friend class crypto::ScopedOpenSSL<STACK_OF(X509), FreeX509Stack>; - - crypto::ScopedOpenSSL<STACK_OF(X509), FreeX509Stack> openssl_chain_; - - scoped_refptr<X509Certificate> os_chain_; + ScopedX509Stack openssl_chain_; }; SSLClientSocketOpenSSL::PeerCertificateChain& @@ -233,85 +279,41 @@ SSLClientSocketOpenSSL::PeerCertificateChain::operator=( if (this == &other) return *this; - // os_chain_ is reference counted by scoped_refptr; - os_chain_ = other.os_chain_; - - // Must increase the reference count manually for sk_X509_dup - openssl_chain_.reset(sk_X509_dup(other.openssl_chain_.get())); - for (int i = 0; i < sk_X509_num(openssl_chain_.get()); ++i) { - X509* x = sk_X509_value(openssl_chain_.get(), i); - CRYPTO_add(&x->references, 1, CRYPTO_LOCK_X509); - } + openssl_chain_.reset(X509_chain_up_ref(other.openssl_chain_.get())); return *this; } -#if defined(USE_OPENSSL_CERTS) -// When OSCertHandle is typedef'ed to X509, this implementation does a short cut -// to avoid converting back and forth between der and X509 struct. void SSLClientSocketOpenSSL::PeerCertificateChain::Reset( STACK_OF(X509)* chain) { - openssl_chain_.reset(NULL); - os_chain_ = NULL; - - if (!chain) - return; + openssl_chain_.reset(chain ? X509_chain_up_ref(chain) : NULL); +} +scoped_refptr<X509Certificate> +SSLClientSocketOpenSSL::PeerCertificateChain::AsOSChain() const { +#if defined(USE_OPENSSL_CERTS) + // When OSCertHandle is typedef'ed to X509, this implementation does a short + // cut to avoid converting back and forth between DER and the X509 struct. X509Certificate::OSCertHandles intermediates; - for (int i = 1; i < sk_X509_num(chain); ++i) - intermediates.push_back(sk_X509_value(chain, i)); - - os_chain_ = - X509Certificate::CreateFromHandle(sk_X509_value(chain, 0), intermediates); - - // sk_X509_dup does not increase reference count on the certs in the stack. - openssl_chain_.reset(sk_X509_dup(chain)); - - std::vector<base::StringPiece> der_chain; - for (int i = 0; i < sk_X509_num(openssl_chain_.get()); ++i) { - X509* x = sk_X509_value(openssl_chain_.get(), i); - // Increase the reference count for the certs in openssl_chain_. - CRYPTO_add(&x->references, 1, CRYPTO_LOCK_X509); + for (size_t i = 1; i < sk_X509_num(openssl_chain_.get()); ++i) { + intermediates.push_back(sk_X509_value(openssl_chain_.get(), i)); } -} -#else // !defined(USE_OPENSSL_CERTS) -void SSLClientSocketOpenSSL::PeerCertificateChain::Reset( - STACK_OF(X509)* chain) { - openssl_chain_.reset(NULL); - os_chain_ = NULL; - - if (!chain) - return; - - // sk_X509_dup does not increase reference count on the certs in the stack. - openssl_chain_.reset(sk_X509_dup(chain)); + return make_scoped_refptr(X509Certificate::CreateFromHandle( + sk_X509_value(openssl_chain_.get(), 0), intermediates)); +#else + // DER-encode the chain and convert to a platform certificate handle. std::vector<base::StringPiece> der_chain; - for (int i = 0; i < sk_X509_num(openssl_chain_.get()); ++i) { + for (size_t i = 0; i < sk_X509_num(openssl_chain_.get()); ++i) { X509* x = sk_X509_value(openssl_chain_.get(), i); - - // Increase the reference count for the certs in openssl_chain_. - CRYPTO_add(&x->references, 1, CRYPTO_LOCK_X509); - - unsigned char* cert_data = NULL; - int cert_data_length = i2d_X509(x, &cert_data); - if (cert_data_length && cert_data) - der_chain.push_back(base::StringPiece(reinterpret_cast<char*>(cert_data), - cert_data_length)); + base::StringPiece der; + if (!x509_util::GetDER(x, &der)) + return NULL; + der_chain.push_back(der); } - os_chain_ = X509Certificate::CreateFromDERCertChain(der_chain); - - for (size_t i = 0; i < der_chain.size(); ++i) { - OPENSSL_free(const_cast<char*>(der_chain[i].data())); - } - - if (der_chain.size() != - static_cast<size_t>(sk_X509_num(openssl_chain_.get()))) { - openssl_chain_.reset(NULL); - os_chain_ = NULL; - } + return make_scoped_refptr(X509Certificate::CreateFromDERCertChain(der_chain)); +#endif } -#endif // defined(USE_OPENSSL_CERTS) // static SSLSessionCacheOpenSSL::Config @@ -327,9 +329,6 @@ void SSLClientSocket::ClearSessionCache() { SSLClientSocketOpenSSL::SSLContext* context = SSLClientSocketOpenSSL::SSLContext::GetInstance(); context->session_cache()->Flush(); -#if defined(USE_OPENSSL_CERTS) - OpenSSLClientKeyStore::GetInstance()->Flush(); -#endif } SSLClientSocketOpenSSL::SSLClientSocketOpenSSL( @@ -339,16 +338,17 @@ SSLClientSocketOpenSSL::SSLClientSocketOpenSSL( const SSLClientSocketContext& context) : transport_send_busy_(false), transport_recv_busy_(false), - transport_recv_eof_(false), - weak_factory_(this), pending_read_error_(kNoPendingReadResult), + pending_read_ssl_error_(SSL_ERROR_NONE), + transport_read_error_(OK), transport_write_error_(OK), server_cert_chain_(new PeerCertificateChain(NULL)), - completed_handshake_(false), + completed_connect_(false), was_ever_used_(false), client_auth_cert_needed_(false), cert_verifier_(context.cert_verifier), - server_bound_cert_service_(context.server_bound_cert_service), + cert_transparency_verifier_(context.cert_transparency_verifier), + channel_id_service_(context.channel_id_service), ssl_(NULL), transport_bio_(NULL), transport_(transport_socket.Pass()), @@ -358,14 +358,36 @@ SSLClientSocketOpenSSL::SSLClientSocketOpenSSL( trying_cached_session_(false), next_handshake_state_(STATE_NONE), npn_status_(kNextProtoUnsupported), - channel_id_request_return_value_(ERR_UNEXPECTED), channel_id_xtn_negotiated_(false), - net_log_(transport_->socket()->NetLog()) {} + handshake_succeeded_(false), + marked_session_as_good_(false), + transport_security_state_(context.transport_security_state), + net_log_(transport_->socket()->NetLog()), + weak_factory_(this) { +} SSLClientSocketOpenSSL::~SSLClientSocketOpenSSL() { Disconnect(); } +std::string SSLClientSocketOpenSSL::GetSessionCacheKey() const { + std::string result = host_and_port_.ToString(); + result.append("/"); + result.append(ssl_session_cache_shard_); + return result; +} + +bool SSLClientSocketOpenSSL::InSessionCache() const { + SSLContext* context = SSLContext::GetInstance(); + std::string cache_key = GetSessionCacheKey(); + return context->session_cache()->SSLSessionIsInCache(cache_key); +} + +void SSLClientSocketOpenSSL::SetHandshakeCompletionCallback( + const base::Closure& callback) { + handshake_completion_callback_ = callback; +} + void SSLClientSocketOpenSSL::GetSSLCertRequestInfo( SSLCertRequestInfo* cert_request_info) { cert_request_info->host_and_port = host_and_port_; @@ -374,15 +396,14 @@ void SSLClientSocketOpenSSL::GetSSLCertRequestInfo( } SSLClientSocket::NextProtoStatus SSLClientSocketOpenSSL::GetNextProto( - std::string* proto, std::string* server_protos) { + std::string* proto) { *proto = npn_proto_; - *server_protos = server_protos_; return npn_status_; } -ServerBoundCertService* -SSLClientSocketOpenSSL::GetServerBoundCertService() const { - return server_bound_cert_service_; +ChannelIDService* +SSLClientSocketOpenSSL::GetChannelIDService() const { + return channel_id_service_; } int SSLClientSocketOpenSSL::ExportKeyingMaterial( @@ -412,6 +433,10 @@ int SSLClientSocketOpenSSL::GetTLSUniqueChannelBinding(std::string* out) { } int SSLClientSocketOpenSSL::Connect(const CompletionCallback& callback) { + // It is an error to create an SSLClientSocket whose context has no + // TransportSecurityState. + DCHECK(transport_security_state_); + net_log_.BeginEvent(NetLog::TYPE_SSL_CONNECT); // Set up new ssl object. @@ -430,12 +455,18 @@ int SSLClientSocketOpenSSL::Connect(const CompletionCallback& callback) { user_connect_callback_ = callback; } else { net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, rv); + if (rv < OK) + OnHandshakeCompletion(); } return rv > OK ? OK : rv; } void SSLClientSocketOpenSSL::Disconnect() { + // If a handshake was pending (Connect() had been called), notify interested + // parties that it's been aborted now. If the handshake had already + // completed, this is a no-op. + OnHandshakeCompletion(); if (ssl_) { // Calling SSL_shutdown prevents the session from being marked as // unresumable. @@ -456,7 +487,6 @@ void SSLClientSocketOpenSSL::Disconnect() { transport_send_busy_ = false; send_buffer_ = NULL; transport_recv_busy_ = false; - transport_recv_eof_ = false; recv_buffer_ = NULL; user_connect_callback_.Reset(); @@ -468,19 +498,31 @@ void SSLClientSocketOpenSSL::Disconnect() { user_write_buf_len_ = 0; pending_read_error_ = kNoPendingReadResult; + pending_read_ssl_error_ = SSL_ERROR_NONE; + pending_read_error_info_ = OpenSSLErrorInfo(); + + transport_read_error_ = OK; transport_write_error_ = OK; server_cert_verify_result_.Reset(); - completed_handshake_ = false; + completed_connect_ = false; cert_authorities_.clear(); cert_key_types_.clear(); client_auth_cert_needed_ = false; + + start_cert_verification_time_ = base::TimeTicks(); + + npn_status_ = kNextProtoUnsupported; + npn_proto_.clear(); + + channel_id_xtn_negotiated_ = false; + channel_id_request_handle_.Cancel(); } bool SSLClientSocketOpenSSL::IsConnected() const { // If the handshake has not yet completed. - if (!completed_handshake_) + if (!completed_connect_) return false; // If an asynchronous operation is still pending. if (user_read_buf_.get() || user_write_buf_.get()) @@ -491,15 +533,15 @@ bool SSLClientSocketOpenSSL::IsConnected() const { bool SSLClientSocketOpenSSL::IsConnectedAndIdle() const { // If the handshake has not yet completed. - if (!completed_handshake_) + if (!completed_connect_) return false; // If an asynchronous operation is still pending. if (user_read_buf_.get() || user_write_buf_.get()) return false; // If there is data waiting to be sent, or data read from the network that // has not yet been consumed. - if (BIO_ctrl_pending(transport_bio_) > 0 || - BIO_ctrl_wpending(transport_bio_) > 0) { + if (BIO_pending(transport_bio_) > 0 || + BIO_wpending(transport_bio_) > 0) { return false; } @@ -548,7 +590,7 @@ bool SSLClientSocketOpenSSL::UsingTCPFastOpen() const { bool SSLClientSocketOpenSSL::GetSSLInfo(SSLInfo* ssl_info) { ssl_info->Reset(); - if (!server_cert_.get()) + if (server_cert_chain_->empty()) return false; ssl_info->cert = server_cert_verify_result_.verified_cert; @@ -560,27 +602,20 @@ bool SSLClientSocketOpenSSL::GetSSLInfo(SSLInfo* ssl_info) { ssl_info->client_cert_sent = ssl_config_.send_client_cert && ssl_config_.client_cert.get(); ssl_info->channel_id_sent = WasChannelIDSent(); + ssl_info->pinning_failure_log = pinning_failure_log_; - RecordChannelIDSupport(server_bound_cert_service_, - channel_id_xtn_negotiated_, - ssl_config_.channel_id_enabled, - crypto::ECPrivateKey::IsSupported()); + AddSCTInfoToSSLInfo(ssl_info); const SSL_CIPHER* cipher = SSL_get_current_cipher(ssl_); CHECK(cipher); ssl_info->security_bits = SSL_CIPHER_get_bits(cipher, NULL); - const COMP_METHOD* compression = SSL_get_current_compression(ssl_); ssl_info->connection_status = EncodeSSLConnectionStatus( - SSL_CIPHER_get_id(cipher), - compression ? compression->type : 0, + SSL_CIPHER_get_id(cipher), 0 /* no compression */, GetNetSSLVersion(ssl_)); - bool peer_supports_renego_ext = !!SSL_get_secure_renegotiation_support(ssl_); - if (!peer_supports_renego_ext) + if (!SSL_get_secure_renegotiation_support(ssl_)) ssl_info->connection_status |= SSL_CONNECTION_NO_RENEGOTIATION_EXTENSION; - UMA_HISTOGRAM_ENUMERATION("Net.RenegotiationExtensionSupported", - implicit_cast<int>(peer_supports_renego_ext), 2); if (ssl_config_.version_fallback) ssl_info->connection_status |= SSL_CONNECTION_VERSION_FALLBACK; @@ -601,7 +636,7 @@ int SSLClientSocketOpenSSL::Read(IOBuffer* buf, user_read_buf_ = buf; user_read_buf_len_ = buf_len; - int rv = DoReadLoop(OK); + int rv = DoReadLoop(); if (rv == ERR_IO_PENDING) { user_read_callback_ = callback; @@ -610,6 +645,11 @@ int SSLClientSocketOpenSSL::Read(IOBuffer* buf, was_ever_used_ = true; user_read_buf_ = NULL; user_read_buf_len_ = 0; + if (rv <= 0) { + // Failure of a read attempt may indicate a failed false start + // connection. + OnHandshakeCompletion(); + } } return rv; @@ -621,7 +661,7 @@ int SSLClientSocketOpenSSL::Write(IOBuffer* buf, user_write_buf_ = buf; user_write_buf_len_ = buf_len; - int rv = DoWriteLoop(OK); + int rv = DoWriteLoop(); if (rv == ERR_IO_PENDING) { user_write_callback_ = callback; @@ -630,6 +670,11 @@ int SSLClientSocketOpenSSL::Write(IOBuffer* buf, was_ever_used_ = true; user_write_buf_ = NULL; user_write_buf_len_ = 0; + if (rv < 0) { + // Failure of a write attempt may indicate a failed false start + // connection. + OnHandshakeCompletion(); + } } return rv; @@ -657,8 +702,11 @@ int SSLClientSocketOpenSSL::Init() { if (!SSL_set_tlsext_host_name(ssl_, host_and_port_.host().c_str())) return ERR_UNEXPECTED; + // Set an OpenSSL callback to monitor this SSL*'s connection. + SSL_set_info_callback(ssl_, &InfoCallback); + trying_cached_session_ = context->session_cache()->SetSSLSessionWithKey( - ssl_, GetSocketSessionCacheKey(*this)); + ssl_, GetSessionCacheKey()); BIO* ssl_bio = NULL; // 0 => use default buffer sizes. @@ -667,6 +715,10 @@ int SSLClientSocketOpenSSL::Init() { DCHECK(ssl_bio); DCHECK(transport_bio_); + // Install a callback on OpenSSL's end to plumb transport errors through. + BIO_set_callback(ssl_bio, BIOCallback); + BIO_set_callback_arg(ssl_bio, reinterpret_cast<char*>(this)); + SSL_set_bio(ssl_, ssl_bio, ssl_bio); // OpenSSL defaults some options to on, others to off. To avoid ambiguity, @@ -699,6 +751,7 @@ int SSLClientSocketOpenSSL::Init() { SslSetClearMask mode; mode.ConfigureFlag(SSL_MODE_RELEASE_BUFFERS, true); + mode.ConfigureFlag(SSL_MODE_CBC_RECORD_SPLITTING, true); mode.ConfigureFlag(SSL_MODE_HANDSHAKE_CUTTHROUGH, ssl_config_.false_start_enabled); @@ -719,7 +772,7 @@ int SSLClientSocketOpenSSL::Init() { "!aECDH:!AESGCM+AES256"); // Walk through all the installed ciphers, seeing if any need to be // appended to the cipher removal |command|. - for (int i = 0; i < sk_SSL_CIPHER_num(ciphers); ++i) { + for (size_t i = 0; i < sk_SSL_CIPHER_num(ciphers); ++i) { const SSL_CIPHER* cipher = sk_SSL_CIPHER_value(ciphers, i); const uint16 id = SSL_CIPHER_get_id(cipher); // Remove any ciphers with a strength of less than 80 bits. Note the NSS @@ -740,6 +793,15 @@ int SSLClientSocketOpenSSL::Init() { command.append(name); } } + + // Disable ECDSA cipher suites on platforms that do not support ECDSA + // signed certificates, as servers may use the presence of such + // ciphersuites as a hint to send an ECDSA certificate. +#if defined(OS_WIN) + if (base::win::GetVersion() < base::win::VERSION_VISTA) + command.append(":!ECDSA"); +#endif + int rv = SSL_set_cipher_list(ssl_, command.c_str()); // If this fails (rv = 0) it means there are no ciphers enabled on this SSL. // This will almost certainly result in the socket failing to complete the @@ -747,11 +809,29 @@ int SSLClientSocketOpenSSL::Init() { LOG_IF(WARNING, rv != 1) << "SSL_set_cipher_list('" << command << "') " "returned " << rv; + if (ssl_config_.version_fallback) + SSL_enable_fallback_scsv(ssl_); + // TLS channel ids. - if (IsChannelIDEnabled(ssl_config_, server_bound_cert_service_)) { + if (IsChannelIDEnabled(ssl_config_, channel_id_service_)) { SSL_enable_tls_channel_id(ssl_); } + if (!ssl_config_.next_protos.empty()) { + std::vector<uint8_t> wire_protos = + SerializeNextProtos(ssl_config_.next_protos); + SSL_set_alpn_protos(ssl_, wire_protos.empty() ? NULL : &wire_protos[0], + wire_protos.size()); + } + + if (ssl_config_.signed_cert_timestamps_enabled) { + SSL_enable_signed_cert_timestamps(ssl_); + SSL_enable_ocsp_stapling(ssl_); + } + + // TODO(davidben): Enable OCSP stapling on platforms which support it and pass + // into the certificate verifier. https://crbug.com/398677 + return OK; } @@ -762,6 +842,11 @@ void SSLClientSocketOpenSSL::DoReadCallback(int rv) { was_ever_used_ = true; user_read_buf_ = NULL; user_read_buf_len_ = 0; + if (rv <= 0) { + // Failure of a read attempt may indicate a failed false start + // connection. + OnHandshakeCompletion(); + } base::ResetAndReturn(&user_read_callback_).Run(rv); } @@ -772,9 +857,19 @@ void SSLClientSocketOpenSSL::DoWriteCallback(int rv) { was_ever_used_ = true; user_write_buf_ = NULL; user_write_buf_len_ = 0; + if (rv < 0) { + // Failure of a write attempt may indicate a failed false start + // connection. + OnHandshakeCompletion(); + } base::ResetAndReturn(&user_write_callback_).Run(rv); } +void SSLClientSocketOpenSSL::OnHandshakeCompletion() { + if (!handshake_completion_callback_.is_null()) + base::ResetAndReturn(&handshake_completion_callback_).Run(); +} + bool SSLClientSocketOpenSSL::DoTransportIO() { bool network_moved = false; int rv; @@ -785,7 +880,7 @@ bool SSLClientSocketOpenSSL::DoTransportIO() { if (rv != ERR_IO_PENDING && rv != 0) network_moved = true; } while (rv > 0); - if (!transport_recv_eof_ && BufferRecv() != ERR_IO_PENDING) + if (transport_read_error_ == OK && BufferRecv() != ERR_IO_PENDING) network_moved = true; return network_moved; } @@ -815,25 +910,56 @@ int SSLClientSocketOpenSSL::DoHandshake() { DVLOG(2) << "Result of session reuse for " << host_and_port_.ToString() << " is: " << (SSL_session_reused(ssl_) ? "Success" : "Fail"); } - // SSL handshake is completed. Let's verify the certificate. - const bool got_cert = !!UpdateServerCert(); - DCHECK(got_cert); - net_log_.AddEvent( - NetLog::TYPE_SSL_CERTIFICATES_RECEIVED, - base::Bind(&NetLogX509CertificateCallback, - base::Unretained(server_cert_.get()))); + + if (ssl_config_.version_fallback && + ssl_config_.version_max < ssl_config_.version_fallback_min) { + return ERR_SSL_FALLBACK_BEYOND_MINIMUM_VERSION; + } + + // SSL handshake is completed. If NPN wasn't negotiated, see if ALPN was. + if (npn_status_ == kNextProtoUnsupported) { + const uint8_t* alpn_proto = NULL; + unsigned alpn_len = 0; + SSL_get0_alpn_selected(ssl_, &alpn_proto, &alpn_len); + if (alpn_len > 0) { + npn_proto_.assign(reinterpret_cast<const char*>(alpn_proto), alpn_len); + npn_status_ = kNextProtoNegotiated; + set_negotiation_extension(kExtensionALPN); + } + } + + RecordChannelIDSupport(channel_id_service_, + channel_id_xtn_negotiated_, + ssl_config_.channel_id_enabled, + crypto::ECPrivateKey::IsSupported()); + + uint8_t* ocsp_response; + size_t ocsp_response_len; + SSL_get0_ocsp_response(ssl_, &ocsp_response, &ocsp_response_len); + set_stapled_ocsp_response_received(ocsp_response_len != 0); + + uint8_t* sct_list; + size_t sct_list_len; + SSL_get0_signed_cert_timestamp_list(ssl_, &sct_list, &sct_list_len); + set_signed_cert_timestamps_received(sct_list_len != 0); + + // Verify the certificate. + UpdateServerCert(); GotoState(STATE_VERIFY_CERT); } else { int ssl_error = SSL_get_error(ssl_, rv); if (ssl_error == SSL_ERROR_WANT_CHANNEL_ID_LOOKUP) { - // The server supports TLS channel id and the lookup is asynchronous. - // Retrieve the error from the call to |server_bound_cert_service_|. - net_error = channel_id_request_return_value_; - } else { - net_error = MapOpenSSLError(ssl_error, err_tracer); + // The server supports channel ID. Stop to look one up before returning to + // the handshake. + channel_id_xtn_negotiated_ = true; + GotoState(STATE_CHANNEL_ID_LOOKUP); + return OK; } + OpenSSLErrorInfo error_info; + net_error = MapOpenSSLErrorWithDetails(ssl_error, err_tracer, &error_info); + // If not done, stay in this state if (net_error == ERR_IO_PENDING) { GotoState(STATE_HANDSHAKE); @@ -843,18 +969,78 @@ int SSLClientSocketOpenSSL::DoHandshake() { << ", net_error " << net_error; net_log_.AddEvent( NetLog::TYPE_SSL_HANDSHAKE_ERROR, - CreateNetLogSSLErrorCallback(net_error, ssl_error)); + CreateNetLogOpenSSLErrorCallback(net_error, ssl_error, error_info)); } } return net_error; } +int SSLClientSocketOpenSSL::DoChannelIDLookup() { + GotoState(STATE_CHANNEL_ID_LOOKUP_COMPLETE); + return channel_id_service_->GetOrCreateChannelID( + host_and_port_.host(), + &channel_id_private_key_, + &channel_id_cert_, + base::Bind(&SSLClientSocketOpenSSL::OnHandshakeIOComplete, + base::Unretained(this)), + &channel_id_request_handle_); +} + +int SSLClientSocketOpenSSL::DoChannelIDLookupComplete(int result) { + if (result < 0) + return result; + + DCHECK_LT(0u, channel_id_private_key_.size()); + // Decode key. + std::vector<uint8> encrypted_private_key_info; + std::vector<uint8> subject_public_key_info; + encrypted_private_key_info.assign( + channel_id_private_key_.data(), + channel_id_private_key_.data() + channel_id_private_key_.size()); + subject_public_key_info.assign( + channel_id_cert_.data(), + channel_id_cert_.data() + channel_id_cert_.size()); + scoped_ptr<crypto::ECPrivateKey> ec_private_key( + crypto::ECPrivateKey::CreateFromEncryptedPrivateKeyInfo( + ChannelIDService::kEPKIPassword, + encrypted_private_key_info, + subject_public_key_info)); + if (!ec_private_key) { + LOG(ERROR) << "Failed to import Channel ID."; + return ERR_CHANNEL_ID_IMPORT_FAILED; + } + + // Hand the key to OpenSSL. Check for error in case OpenSSL rejects the key + // type. + crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE); + int rv = SSL_set1_tls_channel_id(ssl_, ec_private_key->key()); + if (!rv) { + LOG(ERROR) << "Failed to set Channel ID."; + int err = SSL_get_error(ssl_, rv); + return MapOpenSSLError(err, err_tracer); + } + + // Return to the handshake. + set_channel_id_sent(true); + GotoState(STATE_HANDSHAKE); + return OK; +} + int SSLClientSocketOpenSSL::DoVerifyCert(int result) { - DCHECK(server_cert_.get()); + DCHECK(!server_cert_chain_->empty()); + DCHECK(start_cert_verification_time_.is_null()); + GotoState(STATE_VERIFY_CERT_COMPLETE); + // If the certificate is bad and has been previously accepted, use + // the previous status and bypass the error. + base::StringPiece der_cert; + if (!x509_util::GetDER(server_cert_chain_->Get(0), &der_cert)) { + NOTREACHED(); + return ERR_CERT_INVALID; + } CertStatus cert_status; - if (ssl_config_.IsAllowedBadCert(server_cert_.get(), &cert_status)) { + if (ssl_config_.IsAllowedBadCert(der_cert, &cert_status)) { VLOG(1) << "Received an expected bad cert with status: " << cert_status; server_cert_verify_result_.Reset(); server_cert_verify_result_.cert_status = cert_status; @@ -862,6 +1048,17 @@ int SSLClientSocketOpenSSL::DoVerifyCert(int result) { return OK; } + // When running in a sandbox, it may not be possible to create an + // X509Certificate*, as that may depend on OS functionality blocked + // in the sandbox. + if (!server_cert_.get()) { + server_cert_verify_result_.Reset(); + server_cert_verify_result_.cert_status = CERT_STATUS_INVALID; + return ERR_CERT_INVALID; + } + + start_cert_verification_time_ = base::TimeTicks::Now(); + int flags = 0; if (ssl_config_.rev_checking_enabled) flags |= CertVerifier::VERIFY_REV_CHECKING_ENABLED; @@ -876,7 +1073,9 @@ int SSLClientSocketOpenSSL::DoVerifyCert(int result) { server_cert_.get(), host_and_port_.host(), flags, - NULL /* no CRL set */, + // TODO(davidben): Route the CRLSet through SSLConfig so + // SSLClientSocket doesn't depend on SSLConfigService. + SSLConfigService::GetCRLSet().get(), &server_cert_verify_result_, base::Bind(&SSLClientSocketOpenSSL::OnHandshakeIOComplete, base::Unretained(this)), @@ -886,22 +1085,71 @@ int SSLClientSocketOpenSSL::DoVerifyCert(int result) { int SSLClientSocketOpenSSL::DoVerifyCertComplete(int result) { verifier_.reset(); + if (!start_cert_verification_time_.is_null()) { + base::TimeDelta verify_time = + base::TimeTicks::Now() - start_cert_verification_time_; + if (result == OK) { + UMA_HISTOGRAM_TIMES("Net.SSLCertVerificationTime", verify_time); + } else { + UMA_HISTOGRAM_TIMES("Net.SSLCertVerificationTimeError", verify_time); + } + } + + if (result == OK) + RecordConnectionTypeMetrics(GetNetSSLVersion(ssl_)); + + const CertStatus cert_status = server_cert_verify_result_.cert_status; + if (transport_security_state_ && + (result == OK || + (IsCertificateError(result) && IsCertStatusMinorError(cert_status))) && + !transport_security_state_->CheckPublicKeyPins( + host_and_port_.host(), + server_cert_verify_result_.is_issued_by_known_root, + server_cert_verify_result_.public_key_hashes, + &pinning_failure_log_)) { + result = ERR_SSL_PINNED_KEY_NOT_IN_CERT_CHAIN; + } + + scoped_refptr<ct::EVCertsWhitelist> ev_whitelist = + SSLConfigService::GetEVCertsWhitelist(); + if (server_cert_verify_result_.cert_status & CERT_STATUS_IS_EV) { + if (ev_whitelist.get() && ev_whitelist->IsValid()) { + const SHA256HashValue fingerprint( + X509Certificate::CalculateFingerprint256( + server_cert_verify_result_.verified_cert->os_cert_handle())); + + UMA_HISTOGRAM_BOOLEAN( + "Net.SSL_EVCertificateInWhitelist", + ev_whitelist->ContainsCertificateHash( + std::string(reinterpret_cast<const char*>(fingerprint.data), 8))); + } + } + if (result == OK) { + // Only check Certificate Transparency if there were no other errors with + // the connection. + VerifyCT(); + // TODO(joth): Work out if we need to remember the intermediate CA certs // when the server sends them to us, and do so here. SSLContext::GetInstance()->session_cache()->MarkSSLSessionAsGood(ssl_); + marked_session_as_good_ = true; + CheckIfHandshakeFinished(); } else { DVLOG(1) << "DoVerifyCertComplete error " << ErrorToString(result) << " (" << result << ")"; } - completed_handshake_ = true; + completed_connect_ = true; + // Exit DoHandshakeLoop and return the result to the caller to Connect. DCHECK_EQ(STATE_NONE, next_handshake_state_); return result; } void SSLClientSocketOpenSSL::DoConnectCallback(int rv) { + if (rv < OK) + OnHandshakeCompletion(); if (!user_connect_callback_.is_null()) { CompletionCallback c = user_connect_callback_; user_connect_callback_.Reset(); @@ -909,14 +1157,50 @@ void SSLClientSocketOpenSSL::DoConnectCallback(int rv) { } } -X509Certificate* SSLClientSocketOpenSSL::UpdateServerCert() { +void SSLClientSocketOpenSSL::UpdateServerCert() { server_cert_chain_->Reset(SSL_get_peer_cert_chain(ssl_)); server_cert_ = server_cert_chain_->AsOSChain(); - if (!server_cert_chain_->IsValid()) - DVLOG(1) << "UpdateServerCert received invalid certificate chain from peer"; + if (server_cert_.get()) { + net_log_.AddEvent( + NetLog::TYPE_SSL_CERTIFICATES_RECEIVED, + base::Bind(&NetLogX509CertificateCallback, + base::Unretained(server_cert_.get()))); + } +} + +void SSLClientSocketOpenSSL::VerifyCT() { + if (!cert_transparency_verifier_) + return; + + uint8_t* ocsp_response_raw; + size_t ocsp_response_len; + SSL_get0_ocsp_response(ssl_, &ocsp_response_raw, &ocsp_response_len); + std::string ocsp_response; + if (ocsp_response_len > 0) { + ocsp_response.assign(reinterpret_cast<const char*>(ocsp_response_raw), + ocsp_response_len); + } - return server_cert_.get(); + uint8_t* sct_list_raw; + size_t sct_list_len; + SSL_get0_signed_cert_timestamp_list(ssl_, &sct_list_raw, &sct_list_len); + std::string sct_list; + if (sct_list_len > 0) + sct_list.assign(reinterpret_cast<const char*>(sct_list_raw), sct_list_len); + + // Note that this is a completely synchronous operation: The CT Log Verifier + // gets all the data it needs for SCT verification and does not do any + // external communication. + int result = cert_transparency_verifier_->Verify( + server_cert_verify_result_.verified_cert.get(), + ocsp_response, sct_list, &ct_verify_result_, net_log_); + + VLOG(1) << "CT Verification complete: result " << result + << " Invalid scts: " << ct_verify_result_.invalid_scts.size() + << " Verified scts: " << ct_verify_result_.verified_scts.size() + << " scts from unknown logs: " + << ct_verify_result_.unknown_logs_scts.size(); } void SSLClientSocketOpenSSL::OnHandshakeIOComplete(int result) { @@ -974,7 +1258,7 @@ void SSLClientSocketOpenSSL::OnRecvComplete(int result) { if (!user_read_buf_.get()) return; - int rv = DoReadLoop(result); + int rv = DoReadLoop(); if (rv != ERR_IO_PENDING) DoReadCallback(rv); } @@ -993,8 +1277,15 @@ int SSLClientSocketOpenSSL::DoHandshakeLoop(int last_io_result) { case STATE_HANDSHAKE: rv = DoHandshake(); break; + case STATE_CHANNEL_ID_LOOKUP: + DCHECK_EQ(OK, rv); + rv = DoChannelIDLookup(); + break; + case STATE_CHANNEL_ID_LOOKUP_COMPLETE: + rv = DoChannelIDLookupComplete(rv); + break; case STATE_VERIFY_CERT: - DCHECK(rv == OK); + DCHECK_EQ(OK, rv); rv = DoVerifyCert(rv); break; case STATE_VERIFY_CERT_COMPLETE: @@ -1015,13 +1306,11 @@ int SSLClientSocketOpenSSL::DoHandshakeLoop(int last_io_result) { rv = OK; // This causes us to stay in the loop. } } while (rv != ERR_IO_PENDING && next_handshake_state_ != STATE_NONE); + return rv; } -int SSLClientSocketOpenSSL::DoReadLoop(int result) { - if (result < 0) - return result; - +int SSLClientSocketOpenSSL::DoReadLoop() { bool network_moved; int rv; do { @@ -1032,10 +1321,7 @@ int SSLClientSocketOpenSSL::DoReadLoop(int result) { return rv; } -int SSLClientSocketOpenSSL::DoWriteLoop(int result) { - if (result < 0) - return result; - +int SSLClientSocketOpenSSL::DoWriteLoop() { bool network_moved; int rv; do { @@ -1056,7 +1342,14 @@ int SSLClientSocketOpenSSL::DoPayloadRead() { if (rv == 0) { net_log_.AddByteTransferEvent(NetLog::TYPE_SSL_SOCKET_BYTES_RECEIVED, rv, user_read_buf_->data()); + } else { + net_log_.AddEvent( + NetLog::TYPE_SSL_READ_ERROR, + CreateNetLogOpenSSLErrorCallback(rv, pending_read_ssl_error_, + pending_read_error_info_)); } + pending_read_ssl_error_ = SSL_ERROR_NONE; + pending_read_error_info_ = OpenSSLErrorInfo(); return rv; } @@ -1091,8 +1384,19 @@ int SSLClientSocketOpenSSL::DoPayloadRead() { if (client_auth_cert_needed_) { *next_result = ERR_SSL_CLIENT_AUTH_CERT_NEEDED; } else if (*next_result < 0) { - int err = SSL_get_error(ssl_, *next_result); - *next_result = MapOpenSSLError(err, err_tracer); + pending_read_ssl_error_ = SSL_get_error(ssl_, *next_result); + *next_result = MapOpenSSLErrorWithDetails(pending_read_ssl_error_, + err_tracer, + &pending_read_error_info_); + + // Many servers do not reliably send a close_notify alert when shutting + // down a connection, and instead terminate the TCP connection. This is + // reported as ERR_CONNECTION_CLOSED. Because of this, map the unclean + // shutdown to a graceful EOF, instead of treating it as an error as it + // should be. + if (*next_result == ERR_CONNECTION_CLOSED) + *next_result = 0; + if (rv > 0 && *next_result == ERR_IO_PENDING) { // If at least some data was read from SSL_read(), do not treat // insufficient data as an error to return in the next call to @@ -1109,6 +1413,13 @@ int SSLClientSocketOpenSSL::DoPayloadRead() { if (rv >= 0) { net_log_.AddByteTransferEvent(NetLog::TYPE_SSL_SOCKET_BYTES_RECEIVED, rv, user_read_buf_->data()); + } else if (rv != ERR_IO_PENDING) { + net_log_.AddEvent( + NetLog::TYPE_SSL_READ_ERROR, + CreateNetLogOpenSSLErrorCallback(rv, pending_read_ssl_error_, + pending_read_error_info_)); + pending_read_ssl_error_ = SSL_ERROR_NONE; + pending_read_error_info_ = OpenSSLErrorInfo(); } return rv; } @@ -1116,15 +1427,23 @@ int SSLClientSocketOpenSSL::DoPayloadRead() { int SSLClientSocketOpenSSL::DoPayloadWrite() { crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE); int rv = SSL_write(ssl_, user_write_buf_->data(), user_write_buf_len_); - if (rv >= 0) { net_log_.AddByteTransferEvent(NetLog::TYPE_SSL_SOCKET_BYTES_SENT, rv, user_write_buf_->data()); return rv; } - int err = SSL_get_error(ssl_, rv); - return MapOpenSSLError(err, err_tracer); + int ssl_error = SSL_get_error(ssl_, rv); + OpenSSLErrorInfo error_info; + int net_error = MapOpenSSLErrorWithDetails(ssl_error, err_tracer, + &error_info); + + if (net_error != ERR_IO_PENDING) { + net_log_.AddEvent( + NetLog::TYPE_SSL_WRITE_ERROR, + CreateNetLogOpenSSLErrorCallback(net_error, ssl_error, error_info)); + } + return net_error; } int SSLClientSocketOpenSSL::BufferSend(void) { @@ -1133,7 +1452,7 @@ int SSLClientSocketOpenSSL::BufferSend(void) { if (!send_buffer_.get()) { // Get a fresh send buffer out of the send BIO. - size_t max_read = BIO_ctrl_pending(transport_bio_); + size_t max_read = BIO_pending(transport_bio_); if (!max_read) return 0; // Nothing pending in the OpenSSL write BIO. send_buffer_ = new DrainableIOBuffer(new IOBuffer(max_read), max_read); @@ -1209,20 +1528,9 @@ void SSLClientSocketOpenSSL::BufferRecvComplete(int result) { void SSLClientSocketOpenSSL::TransportWriteComplete(int result) { DCHECK(ERR_IO_PENDING != result); if (result < 0) { - // Got a socket write error; close the BIO to indicate this upward. - // - // TODO(davidben): The value of |result| gets lost. Feed the error back into - // the BIO so it gets (re-)detected in OnSendComplete. Perhaps with - // BIO_set_callback. - DVLOG(1) << "TransportWriteComplete error " << result; - (void)BIO_shutdown_wr(SSL_get_wbio(ssl_)); - - // Match the fix for http://crbug.com/249848 in NSS by erroring future reads - // from the socket after a write error. - // - // TODO(davidben): Avoid having read and write ends interact this way. + // Record the error. Save it to be reported in a future read or write on + // transport_bio_'s peer. transport_write_error_ = result; - (void)BIO_shutdown_wr(transport_bio_); send_buffer_ = NULL; } else { DCHECK(send_buffer_.get()); @@ -1235,19 +1543,15 @@ void SSLClientSocketOpenSSL::TransportWriteComplete(int result) { int SSLClientSocketOpenSSL::TransportReadComplete(int result) { DCHECK(ERR_IO_PENDING != result); - if (result <= 0) { + // If an EOF, canonicalize to ERR_CONNECTION_CLOSED here so MapOpenSSLError + // does not report success. + if (result == 0) + result = ERR_CONNECTION_CLOSED; + if (result < 0) { DVLOG(1) << "TransportReadComplete result " << result; - // Received 0 (end of file) or an error. Either way, bubble it up to the - // SSL layer via the BIO. TODO(joth): consider stashing the error code, to - // relay up to the SSL socket client (i.e. via DoReadCallback). - if (result == 0) - transport_recv_eof_ = true; - (void)BIO_shutdown_wr(transport_bio_); - } else if (transport_write_error_ < 0) { - // Mirror transport write errors as read failures; transport_bio_ has been - // shut down by TransportWriteComplete, so the BIO_write will fail, failing - // the CHECK. http://crbug.com/335557. - result = transport_write_error_; + // Received an error. Save it to be reported in a future read on + // transport_bio_'s peer. + transport_read_error_ = result; } else { DCHECK(recv_buffer_.get()); int ret = BIO_write(transport_bio_, recv_buffer_->data(), result); @@ -1259,19 +1563,23 @@ int SSLClientSocketOpenSSL::TransportReadComplete(int result) { return result; } -int SSLClientSocketOpenSSL::ClientCertRequestCallback(SSL* ssl, - X509** x509, - EVP_PKEY** pkey) { +int SSLClientSocketOpenSSL::ClientCertRequestCallback(SSL* ssl) { DVLOG(3) << "OpenSSL ClientCertRequestCallback called"; DCHECK(ssl == ssl_); - DCHECK(*x509 == NULL); - DCHECK(*pkey == NULL); + + // Clear any currently configured certificates. + SSL_certs_clear(ssl_); + +#if defined(OS_IOS) + // TODO(droger): Support client auth on iOS. See http://crbug.com/145954). + LOG(WARNING) << "Client auth is not supported"; +#else // !defined(OS_IOS) if (!ssl_config_.send_client_cert) { // First pass: we know that a client certificate is needed, but we do not // have one at hand. client_auth_cert_needed_ = true; STACK_OF(X509_NAME) *authorities = SSL_get_client_CA_list(ssl); - for (int i = 0; i < sk_X509_NAME_num(authorities); i++) { + for (size_t i = 0; i < sk_X509_NAME_num(authorities); i++) { X509_NAME *ca_name = (X509_NAME *)sk_X509_NAME_value(authorities, i); unsigned char* str = NULL; int length = i2d_X509_NAME(ca_name, &str); @@ -1282,9 +1590,8 @@ int SSLClientSocketOpenSSL::ClientCertRequestCallback(SSL* ssl, } const unsigned char* client_cert_types; - size_t num_client_cert_types; - SSL_get_client_certificate_types(ssl, &client_cert_types, - &num_client_cert_types); + size_t num_client_cert_types = + SSL_get0_certificate_types(ssl, &client_cert_types); for (size_t i = 0; i < num_client_cert_types; i++) { cert_key_types_.push_back( static_cast<SSLClientCertType>(client_cert_types[i])); @@ -1295,90 +1602,81 @@ int SSLClientSocketOpenSSL::ClientCertRequestCallback(SSL* ssl, // Second pass: a client certificate should have been selected. if (ssl_config_.client_cert.get()) { -#if defined(USE_OPENSSL_CERTS) - // A note about ownership: FetchClientCertPrivateKey() increments - // the reference count of the EVP_PKEY. Ownership of this reference - // is passed directly to OpenSSL, which will release the reference - // using EVP_PKEY_free() when the SSL object is destroyed. - OpenSSLClientKeyStore::ScopedEVP_PKEY privkey; - if (OpenSSLClientKeyStore::GetInstance()->FetchClientCertPrivateKey( - ssl_config_.client_cert.get(), &privkey)) { - // TODO(joth): (copied from NSS) We should wait for server certificate - // verification before sending our credentials. See http://crbug.com/13934 - *x509 = X509Certificate::DupOSCertHandle( - ssl_config_.client_cert->os_cert_handle()); - *pkey = privkey.release(); - return 1; + ScopedX509 leaf_x509 = + OSCertHandleToOpenSSL(ssl_config_.client_cert->os_cert_handle()); + if (!leaf_x509) { + LOG(WARNING) << "Failed to import certificate"; + OpenSSLPutNetError(FROM_HERE, ERR_SSL_CLIENT_AUTH_CERT_BAD_FORMAT); + return -1; } - LOG(WARNING) << "Client cert found without private key"; -#else // !defined(USE_OPENSSL_CERTS) - // OS handling of client certificates is not yet implemented. - NOTIMPLEMENTED(); -#endif // defined(USE_OPENSSL_CERTS) - } - // Send no client certificate. - return 0; -} + ScopedX509Stack chain = OSCertHandlesToOpenSSL( + ssl_config_.client_cert->GetIntermediateCertificates()); + if (!chain) { + LOG(WARNING) << "Failed to import intermediate certificates"; + OpenSSLPutNetError(FROM_HERE, ERR_SSL_CLIENT_AUTH_CERT_BAD_FORMAT); + return -1; + } -void SSLClientSocketOpenSSL::ChannelIDRequestCallback(SSL* ssl, - EVP_PKEY** pkey) { - DVLOG(3) << "OpenSSL ChannelIDRequestCallback called"; - DCHECK_EQ(ssl, ssl_); - DCHECK(!*pkey); + // TODO(davidben): With Linux client auth support, this should be + // conditioned on OS_ANDROID and then, with https://crbug.com/394131, + // removed altogether. OpenSSLClientKeyStore is mostly an artifact of the + // net/ client auth API lacking a private key handle. +#if defined(USE_OPENSSL_CERTS) + crypto::ScopedEVP_PKEY privkey = + OpenSSLClientKeyStore::GetInstance()->FetchClientCertPrivateKey( + ssl_config_.client_cert.get()); +#else // !defined(USE_OPENSSL_CERTS) + crypto::ScopedEVP_PKEY privkey = + FetchClientCertPrivateKey(ssl_config_.client_cert.get()); +#endif // defined(USE_OPENSSL_CERTS) + if (!privkey) { + // Could not find the private key. Fail the handshake and surface an + // appropriate error to the caller. + LOG(WARNING) << "Client cert found without private key"; + OpenSSLPutNetError(FROM_HERE, ERR_SSL_CLIENT_AUTH_CERT_NO_PRIVATE_KEY); + return -1; + } - channel_id_xtn_negotiated_ = true; - if (!channel_id_private_key_.size()) { - channel_id_request_return_value_ = - server_bound_cert_service_->GetOrCreateDomainBoundCert( - host_and_port_.host(), - &channel_id_private_key_, - &channel_id_cert_, - base::Bind(&SSLClientSocketOpenSSL::OnHandshakeIOComplete, - base::Unretained(this)), - &channel_id_request_handle_); - if (channel_id_request_return_value_ != OK) - return; + if (!SSL_use_certificate(ssl_, leaf_x509.get()) || + !SSL_use_PrivateKey(ssl_, privkey.get()) || + !SSL_set1_chain(ssl_, chain.get())) { + LOG(WARNING) << "Failed to set client certificate"; + return -1; + } + return 1; } +#endif // defined(OS_IOS) - // Decode key. - std::vector<uint8> encrypted_private_key_info; - std::vector<uint8> subject_public_key_info; - encrypted_private_key_info.assign( - channel_id_private_key_.data(), - channel_id_private_key_.data() + channel_id_private_key_.size()); - subject_public_key_info.assign( - channel_id_cert_.data(), - channel_id_cert_.data() + channel_id_cert_.size()); - scoped_ptr<crypto::ECPrivateKey> ec_private_key( - crypto::ECPrivateKey::CreateFromEncryptedPrivateKeyInfo( - ServerBoundCertService::kEPKIPassword, - encrypted_private_key_info, - subject_public_key_info)); - if (!ec_private_key) - return; - set_channel_id_sent(true); - *pkey = EVP_PKEY_dup(ec_private_key->key()); + // Send no client certificate. + return 1; } int SSLClientSocketOpenSSL::CertVerifyCallback(X509_STORE_CTX* store_ctx) { - if (!completed_handshake_) { + if (!completed_connect_) { // If the first handshake hasn't completed then we accept any certificates // because we verify after the handshake. return 1; } - CHECK(server_cert_.get()); - - PeerCertificateChain chain(store_ctx->untrusted); - if (chain.IsValid() && server_cert_->Equals(chain.AsOSChain())) - return 1; - - if (!chain.IsValid()) + // Disallow the server certificate to change in a renegotiation. + if (server_cert_chain_->empty()) { LOG(ERROR) << "Received invalid certificate chain between handshakes"; - else + return 0; + } + base::StringPiece old_der, new_der; + if (store_ctx->cert == NULL || + !x509_util::GetDER(server_cert_chain_->Get(0), &old_der) || + !x509_util::GetDER(store_ctx->cert, &new_der)) { + LOG(ERROR) << "Failed to encode certificates"; + return 0; + } + if (old_der != new_der) { LOG(ERROR) << "Server certificate changed between handshakes"; - return 0; + return 0; + } + + return 1; } // SelectNextProtoCallback is called by OpenSSL during the handshake. If the @@ -1426,11 +1724,104 @@ int SSLClientSocketOpenSSL::SelectNextProtoCallback(unsigned char** out, } npn_proto_.assign(reinterpret_cast<const char*>(*out), *outlen); - server_protos_.assign(reinterpret_cast<const char*>(in), inlen); DVLOG(2) << "next protocol: '" << npn_proto_ << "' status: " << npn_status_; + set_negotiation_extension(kExtensionNPN); return SSL_TLSEXT_ERR_OK; } +long SSLClientSocketOpenSSL::MaybeReplayTransportError( + BIO *bio, + int cmd, + const char *argp, int argi, long argl, + long retvalue) { + if (cmd == (BIO_CB_READ|BIO_CB_RETURN) && retvalue <= 0) { + // If there is no more data in the buffer, report any pending errors that + // were observed. Note that both the readbuf and the writebuf are checked + // for errors, since the application may have encountered a socket error + // while writing that would otherwise not be reported until the application + // attempted to write again - which it may never do. See + // https://crbug.com/249848. + if (transport_read_error_ != OK) { + OpenSSLPutNetError(FROM_HERE, transport_read_error_); + return -1; + } + if (transport_write_error_ != OK) { + OpenSSLPutNetError(FROM_HERE, transport_write_error_); + return -1; + } + } else if (cmd == BIO_CB_WRITE) { + // Because of the write buffer, this reports a failure from the previous + // write payload. If the current payload fails to write, the error will be + // reported in a future write or read to |bio|. + if (transport_write_error_ != OK) { + OpenSSLPutNetError(FROM_HERE, transport_write_error_); + return -1; + } + } + return retvalue; +} + +// static +long SSLClientSocketOpenSSL::BIOCallback( + BIO *bio, + int cmd, + const char *argp, int argi, long argl, + long retvalue) { + SSLClientSocketOpenSSL* socket = reinterpret_cast<SSLClientSocketOpenSSL*>( + BIO_get_callback_arg(bio)); + CHECK(socket); + return socket->MaybeReplayTransportError( + bio, cmd, argp, argi, argl, retvalue); +} + +// static +void SSLClientSocketOpenSSL::InfoCallback(const SSL* ssl, + int type, + int /*val*/) { + if (type == SSL_CB_HANDSHAKE_DONE) { + SSLClientSocketOpenSSL* ssl_socket = + SSLContext::GetInstance()->GetClientSocketFromSSL(ssl); + ssl_socket->handshake_succeeded_ = true; + ssl_socket->CheckIfHandshakeFinished(); + } +} + +// Determines if both the handshake and certificate verification have completed +// successfully, and calls the handshake completion callback if that is the +// case. +// +// CheckIfHandshakeFinished is called twice per connection: once after +// MarkSSLSessionAsGood, when the certificate has been verified, and +// once via an OpenSSL callback when the handshake has completed. On the +// second call, when the certificate has been verified and the handshake +// has completed, the connection's handshake completion callback is run. +void SSLClientSocketOpenSSL::CheckIfHandshakeFinished() { + if (handshake_succeeded_ && marked_session_as_good_) + OnHandshakeCompletion(); +} + +void SSLClientSocketOpenSSL::AddSCTInfoToSSLInfo(SSLInfo* ssl_info) const { + for (ct::SCTList::const_iterator iter = + ct_verify_result_.verified_scts.begin(); + iter != ct_verify_result_.verified_scts.end(); ++iter) { + ssl_info->signed_certificate_timestamps.push_back( + SignedCertificateTimestampAndStatus(*iter, ct::SCT_STATUS_OK)); + } + for (ct::SCTList::const_iterator iter = + ct_verify_result_.invalid_scts.begin(); + iter != ct_verify_result_.invalid_scts.end(); ++iter) { + ssl_info->signed_certificate_timestamps.push_back( + SignedCertificateTimestampAndStatus(*iter, ct::SCT_STATUS_INVALID)); + } + for (ct::SCTList::const_iterator iter = + ct_verify_result_.unknown_logs_scts.begin(); + iter != ct_verify_result_.unknown_logs_scts.end(); ++iter) { + ssl_info->signed_certificate_timestamps.push_back( + SignedCertificateTimestampAndStatus(*iter, + ct::SCT_STATUS_LOG_UNKNOWN)); + } +} + scoped_refptr<X509Certificate> SSLClientSocketOpenSSL::GetUnverifiedServerCertificateChain() const { return server_cert_; diff --git a/chromium/net/socket/ssl_client_socket_openssl.h b/chromium/net/socket/ssl_client_socket_openssl.h index 5d70c0523fa..53d33c4c8c7 100644 --- a/chromium/net/socket/ssl_client_socket_openssl.h +++ b/chromium/net/socket/ssl_client_socket_openssl.h @@ -13,9 +13,11 @@ #include "net/base/completion_callback.h" #include "net/base/io_buffer.h" #include "net/cert/cert_verify_result.h" +#include "net/cert/ct_verify_result.h" #include "net/socket/client_socket_handle.h" #include "net/socket/ssl_client_socket.h" -#include "net/ssl/server_bound_cert_service.h" +#include "net/ssl/channel_id_service.h" +#include "net/ssl/openssl_ssl_util.h" #include "net/ssl/ssl_client_cert_type.h" #include "net/ssl/ssl_config_service.h" @@ -34,6 +36,7 @@ typedef struct x509_store_ctx_st X509_STORE_CTX; namespace net { class CertVerifier; +class CTVerifier; class SingleRequestCertVerifier; class SSLCertRequestInfo; class SSLInfo; @@ -49,7 +52,7 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context); - virtual ~SSLClientSocketOpenSSL(); + ~SSLClientSocketOpenSSL() override; const HostPortPair& host_and_port() const { return host_and_port_; } const std::string& ssl_session_cache_shard() const { @@ -57,46 +60,49 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { } // SSLClientSocket implementation. - virtual void GetSSLCertRequestInfo( - SSLCertRequestInfo* cert_request_info) OVERRIDE; - virtual NextProtoStatus GetNextProto(std::string* proto, - std::string* server_protos) OVERRIDE; - virtual ServerBoundCertService* GetServerBoundCertService() const OVERRIDE; + std::string GetSessionCacheKey() const override; + bool InSessionCache() const override; + void SetHandshakeCompletionCallback(const base::Closure& callback) override; + void GetSSLCertRequestInfo(SSLCertRequestInfo* cert_request_info) override; + NextProtoStatus GetNextProto(std::string* proto) override; + ChannelIDService* GetChannelIDService() const override; // SSLSocket implementation. - virtual int ExportKeyingMaterial(const base::StringPiece& label, - bool has_context, - const base::StringPiece& context, - unsigned char* out, - unsigned int outlen) OVERRIDE; - virtual int GetTLSUniqueChannelBinding(std::string* out) OVERRIDE; + int ExportKeyingMaterial(const base::StringPiece& label, + bool has_context, + const base::StringPiece& context, + unsigned char* out, + unsigned int outlen) override; + int GetTLSUniqueChannelBinding(std::string* out) override; // StreamSocket implementation. - virtual int Connect(const CompletionCallback& callback) OVERRIDE; - virtual void Disconnect() OVERRIDE; - virtual bool IsConnected() const OVERRIDE; - virtual bool IsConnectedAndIdle() const OVERRIDE; - virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE; - virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE; - virtual const BoundNetLog& NetLog() const OVERRIDE; - virtual void SetSubresourceSpeculation() OVERRIDE; - virtual void SetOmniboxSpeculation() OVERRIDE; - virtual bool WasEverUsed() const OVERRIDE; - virtual bool UsingTCPFastOpen() const OVERRIDE; - virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; + int Connect(const CompletionCallback& callback) override; + void Disconnect() override; + bool IsConnected() const override; + bool IsConnectedAndIdle() const override; + int GetPeerAddress(IPEndPoint* address) const override; + int GetLocalAddress(IPEndPoint* address) const override; + const BoundNetLog& NetLog() const override; + void SetSubresourceSpeculation() override; + void SetOmniboxSpeculation() override; + bool WasEverUsed() const override; + bool UsingTCPFastOpen() const override; + bool GetSSLInfo(SSLInfo* ssl_info) override; // Socket implementation. - virtual int Read(IOBuffer* buf, int buf_len, - const CompletionCallback& callback) OVERRIDE; - virtual int Write(IOBuffer* buf, int buf_len, - const CompletionCallback& callback) OVERRIDE; - virtual int SetReceiveBufferSize(int32 size) OVERRIDE; - virtual int SetSendBufferSize(int32 size) OVERRIDE; + int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override; + int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override; + int SetReceiveBufferSize(int32 size) override; + int SetSendBufferSize(int32 size) override; protected: // SSLClientSocket implementation. - virtual scoped_refptr<X509Certificate> GetUnverifiedServerCertificateChain() - const OVERRIDE; + scoped_refptr<X509Certificate> GetUnverifiedServerCertificateChain() + const override; private: class PeerCertificateChain; @@ -108,20 +114,25 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { void DoReadCallback(int result); void DoWriteCallback(int result); + void OnHandshakeCompletion(); + bool DoTransportIO(); int DoHandshake(); + int DoChannelIDLookup(); + int DoChannelIDLookupComplete(int result); int DoVerifyCert(int result); int DoVerifyCertComplete(int result); void DoConnectCallback(int result); - X509Certificate* UpdateServerCert(); + void UpdateServerCert(); + void VerifyCT(); void OnHandshakeIOComplete(int result); void OnSendComplete(int result); void OnRecvComplete(int result); int DoHandshakeLoop(int last_io_result); - int DoReadLoop(int result); - int DoWriteLoop(int result); + int DoReadLoop(); + int DoWriteLoop(); int DoPayloadRead(); int DoPayloadWrite(); @@ -134,11 +145,7 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { // Callback from the SSL layer that indicates the remote server is requesting // a certificate for this client. - int ClientCertRequestCallback(SSL* ssl, X509** x509, EVP_PKEY** pkey); - - // Callback from the SSL layer that indicates the remote server supports TLS - // Channel IDs. - void ChannelIDRequestCallback(SSL* ssl, EVP_PKEY** pkey); + int ClientCertRequestCallback(SSL* ssl); // CertVerifyCallback is called to verify the server's certificates. We do // verification after the handshake so this function only enforces that the @@ -149,9 +156,36 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { int SelectNextProtoCallback(unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen); + // Called during an operation on |transport_bio_|'s peer. Checks saved + // transport error state and, if appropriate, returns an error through + // OpenSSL's error system. + long MaybeReplayTransportError(BIO *bio, + int cmd, + const char *argp, int argi, long argl, + long retvalue); + + // Callback from the SSL layer when an operation is performed on + // |transport_bio_|'s peer. + static long BIOCallback(BIO *bio, + int cmd, + const char *argp, int argi, long argl, + long retvalue); + + // Callback that is used to obtain information about the state of the SSL + // handshake. + static void InfoCallback(const SSL* ssl, int type, int val); + + void CheckIfHandshakeFinished(); + + // Adds the SignedCertificateTimestamps from ct_verify_result_ to |ssl_info|. + // SCTs are held in three separate vectors in ct_verify_result, each + // vetor representing a particular verification state, this method associates + // each of the SCTs with the corresponding SCTVerifyStatus as it adds it to + // the |ssl_info|.signed_certificate_timestamps list. + void AddSCTInfoToSSLInfo(SSLInfo* ssl_info) const; + bool transport_send_busy_; bool transport_recv_busy_; - bool transport_recv_eof_; scoped_refptr<DrainableIOBuffer> send_buffer_; scoped_refptr<IOBuffer> recv_buffer_; @@ -160,8 +194,6 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { CompletionCallback user_read_callback_; CompletionCallback user_write_callback_; - base::WeakPtrFactory<SSLClientSocketOpenSSL> weak_factory_; - // Used by Read function. scoped_refptr<IOBuffer> user_read_buf_; int user_read_buf_len_; @@ -178,15 +210,27 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { // indicates an error. int pending_read_error_; + // If there is a pending read result, the OpenSSL result code (output of + // SSL_get_error) associated with it. + int pending_read_ssl_error_; + + // If there is a pending read result, the OpenSSLErrorInfo associated with it. + OpenSSLErrorInfo pending_read_error_info_; + + // Used by TransportReadComplete() to signify an error reading from the + // transport socket. A value of OK indicates the socket is still + // readable. EOFs are mapped to ERR_CONNECTION_CLOSED. + int transport_read_error_; + // Used by TransportWriteComplete() and TransportReadComplete() to signify an // error writing to the transport socket. A value of OK indicates no error. int transport_write_error_; - // Set when handshake finishes. + // Set when Connect finishes. scoped_ptr<PeerCertificateChain> server_cert_chain_; scoped_refptr<X509Certificate> server_cert_; CertVerifyResult server_cert_verify_result_; - bool completed_handshake_; + bool completed_connect_; // Set when Read() or Write() successfully reads or writes data to or from the // network. @@ -204,9 +248,20 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { CertVerifier* const cert_verifier_; scoped_ptr<SingleRequestCertVerifier> verifier_; + base::TimeTicks start_cert_verification_time_; + + // Certificate Transparency: Verifier and result holder. + ct::CTVerifyResult ct_verify_result_; + CTVerifier* cert_transparency_verifier_; // The service for retrieving Channel ID keys. May be NULL. - ServerBoundCertService* server_bound_cert_service_; + ChannelIDService* channel_id_service_; + + // Callback that is invoked when the connection finishes. + // + // Note: this callback will be run in Disconnect(). It will not alter + // any member variables of the SSLClientSocketOpenSSL. + base::Closure handshake_completion_callback_; // OpenSSL stuff SSL* ssl_; @@ -226,23 +281,36 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { enum State { STATE_NONE, STATE_HANDSHAKE, + STATE_CHANNEL_ID_LOOKUP, + STATE_CHANNEL_ID_LOOKUP_COMPLETE, STATE_VERIFY_CERT, STATE_VERIFY_CERT_COMPLETE, }; State next_handshake_state_; NextProtoStatus npn_status_; std::string npn_proto_; - std::string server_protos_; - // Written by the |server_bound_cert_service_|. + // Written by the |channel_id_service_|. std::string channel_id_private_key_; std::string channel_id_cert_; - // The return value of the last call to |server_bound_cert_service_|. - int channel_id_request_return_value_; // True if channel ID extension was negotiated. bool channel_id_xtn_negotiated_; - // The request handle for |server_bound_cert_service_|. - ServerBoundCertService::RequestHandle channel_id_request_handle_; + // True if InfoCallback has been run with result = SSL_CB_HANDSHAKE_DONE. + bool handshake_succeeded_; + // True if MarkSSLSessionAsGood has been called for this socket's + // SSL session. + bool marked_session_as_good_; + // The request handle for |channel_id_service_|. + ChannelIDService::RequestHandle channel_id_request_handle_; + + TransportSecurityState* transport_security_state_; + + // pinning_failure_log contains a message produced by + // TransportSecurityState::CheckPublicKeyPins in the event of a + // pinning failure. It is a (somewhat) human-readable string. + std::string pinning_failure_log_; + BoundNetLog net_log_; + base::WeakPtrFactory<SSLClientSocketOpenSSL> weak_factory_; }; } // namespace net diff --git a/chromium/net/socket/ssl_client_socket_openssl_unittest.cc b/chromium/net/socket/ssl_client_socket_openssl_unittest.cc index d4e0685467e..8a6a8828810 100644 --- a/chromium/net/socket/ssl_client_socket_openssl_unittest.cc +++ b/chromium/net/socket/ssl_client_socket_openssl_unittest.cc @@ -13,12 +13,13 @@ #include <openssl/pem.h> #include <openssl/rsa.h> -#include "base/file_util.h" #include "base/files/file_path.h" +#include "base/files/file_util.h" #include "base/memory/ref_counted.h" #include "base/message_loop/message_loop_proxy.h" #include "base/values.h" #include "crypto/openssl_util.h" +#include "crypto/scoped_openssl_types.h" #include "net/base/address_list.h" #include "net/base/io_buffer.h" #include "net/base/net_errors.h" @@ -48,16 +49,6 @@ namespace { // These client auth tests are currently dependent on OpenSSL's struct X509. #if defined(USE_OPENSSL_CERTS) -typedef OpenSSLClientKeyStore::ScopedEVP_PKEY ScopedEVP_PKEY; - -// BIO_free is a macro, it can't be used as a template parameter. -void BIO_free_func(BIO* bio) { - BIO_free(bio); -} - -typedef crypto::ScopedOpenSSL<BIO, BIO_free_func> ScopedBIO; -typedef crypto::ScopedOpenSSL<RSA, RSA_free> ScopedRSA; -typedef crypto::ScopedOpenSSL<BIGNUM, BN_free> ScopedBIGNUM; const SSLConfig kDefaultSSLConfig; @@ -67,17 +58,16 @@ const SSLConfig kDefaultSSLConfig; // Returns true on success, false on failure. bool LoadPrivateKeyOpenSSL( const base::FilePath& filepath, - OpenSSLClientKeyStore::ScopedEVP_PKEY* pkey) { + crypto::ScopedEVP_PKEY* pkey) { std::string data; if (!base::ReadFileToString(filepath, &data)) { LOG(ERROR) << "Could not read private key file: " << filepath.value() << ": " << strerror(errno); return false; } - ScopedBIO bio( - BIO_new_mem_buf( - const_cast<char*>(reinterpret_cast<const char*>(data.data())), - static_cast<int>(data.size()))); + crypto::ScopedBIO bio(BIO_new_mem_buf( + const_cast<char*>(reinterpret_cast<const char*>(data.data())), + static_cast<int>(data.size()))); if (!bio.get()) { LOG(ERROR) << "Could not allocate BIO for buffer?"; return false; @@ -95,13 +85,13 @@ bool LoadPrivateKeyOpenSSL( class SSLClientSocketOpenSSLClientAuthTest : public PlatformTest { public: SSLClientSocketOpenSSLClientAuthTest() - : socket_factory_(net::ClientSocketFactory::GetDefaultFactory()), - cert_verifier_(new net::MockCertVerifier), - transport_security_state_(new net::TransportSecurityState) { - cert_verifier_->set_default_result(net::OK); + : socket_factory_(ClientSocketFactory::GetDefaultFactory()), + cert_verifier_(new MockCertVerifier), + transport_security_state_(new TransportSecurityState) { + cert_verifier_->set_default_result(OK); context_.cert_verifier = cert_verifier_.get(); context_.transport_security_state = transport_security_state_.get(); - key_store_ = net::OpenSSLClientKeyStore::GetInstance(); + key_store_ = OpenSSLClientKeyStore::GetInstance(); } virtual ~SSLClientSocketOpenSSLClientAuthTest() { @@ -260,7 +250,7 @@ TEST_F(SSLClientSocketOpenSSLClientAuthTest, SendGoodCert) { // This is required to ensure that signing works with the client // certificate's private key. - OpenSSLClientKeyStore::ScopedEVP_PKEY client_private_key; + crypto::ScopedEVP_PKEY client_private_key; ASSERT_TRUE(LoadPrivateKeyOpenSSL(certs_dir.AppendASCII("client_1.key"), &client_private_key)); EXPECT_TRUE(RecordPrivateKey(ssl_config, client_private_key.get())); diff --git a/chromium/net/socket/ssl_client_socket_pool.cc b/chromium/net/socket/ssl_client_socket_pool.cc index 2c704658820..56df1d85d09 100644 --- a/chromium/net/socket/ssl_client_socket_pool.cc +++ b/chromium/net/socket/ssl_client_socket_pool.cc @@ -9,6 +9,7 @@ #include "base/metrics/field_trial.h" #include "base/metrics/histogram.h" #include "base/metrics/sparse_histogram.h" +#include "base/stl_util.h" #include "base/values.h" #include "net/base/host_port_pair.h" #include "net/base/net_errors.h" @@ -45,15 +46,15 @@ SSLSocketParams::SSLSocketParams( force_spdy_over_ssl_(force_spdy_over_ssl), want_spdy_over_npn_(want_spdy_over_npn), ignore_limits_(false) { - if (direct_params_) { - DCHECK(!socks_proxy_params_); - DCHECK(!http_proxy_params_); + if (direct_params_.get()) { + DCHECK(!socks_proxy_params_.get()); + DCHECK(!http_proxy_params_.get()); ignore_limits_ = direct_params_->ignore_limits(); - } else if (socks_proxy_params_) { - DCHECK(!http_proxy_params_); + } else if (socks_proxy_params_.get()) { + DCHECK(!http_proxy_params_.get()); ignore_limits_ = socks_proxy_params_->ignore_limits(); } else { - DCHECK(http_proxy_params_); + DCHECK(http_proxy_params_.get()); ignore_limits_ = http_proxy_params_->ignore_limits(); } } @@ -61,18 +62,18 @@ SSLSocketParams::SSLSocketParams( SSLSocketParams::~SSLSocketParams() {} SSLSocketParams::ConnectionType SSLSocketParams::GetConnectionType() const { - if (direct_params_) { - DCHECK(!socks_proxy_params_); - DCHECK(!http_proxy_params_); + if (direct_params_.get()) { + DCHECK(!socks_proxy_params_.get()); + DCHECK(!http_proxy_params_.get()); return DIRECT; } - if (socks_proxy_params_) { - DCHECK(!http_proxy_params_); + if (socks_proxy_params_.get()) { + DCHECK(!http_proxy_params_.get()); return SOCKS_PROXY; } - DCHECK(http_proxy_params_); + DCHECK(http_proxy_params_.get()); return HTTP_PROXY; } @@ -94,6 +95,77 @@ SSLSocketParams::GetHttpProxyConnectionParams() const { return http_proxy_params_; } +SSLConnectJobMessenger::SocketAndCallback::SocketAndCallback( + SSLClientSocket* ssl_socket, + const base::Closure& job_resumption_callback) + : socket(ssl_socket), callback(job_resumption_callback) { +} + +SSLConnectJobMessenger::SocketAndCallback::~SocketAndCallback() { +} + +SSLConnectJobMessenger::SSLConnectJobMessenger( + const base::Closure& messenger_finished_callback) + : messenger_finished_callback_(messenger_finished_callback), + weak_factory_(this) { +} + +SSLConnectJobMessenger::~SSLConnectJobMessenger() { +} + +void SSLConnectJobMessenger::RemovePendingSocket(SSLClientSocket* ssl_socket) { + // Sockets do not need to be removed from connecting_sockets_ because + // OnSSLHandshakeCompleted will do this. + for (SSLPendingSocketsAndCallbacks::iterator it = + pending_sockets_and_callbacks_.begin(); + it != pending_sockets_and_callbacks_.end(); + ++it) { + if (it->socket == ssl_socket) { + pending_sockets_and_callbacks_.erase(it); + return; + } + } +} + +bool SSLConnectJobMessenger::CanProceed(SSLClientSocket* ssl_socket) { + // If there are no connecting sockets, allow the connection to proceed. + return connecting_sockets_.empty(); +} + +void SSLConnectJobMessenger::MonitorConnectionResult( + SSLClientSocket* ssl_socket) { + connecting_sockets_.push_back(ssl_socket); + ssl_socket->SetHandshakeCompletionCallback( + base::Bind(&SSLConnectJobMessenger::OnSSLHandshakeCompleted, + weak_factory_.GetWeakPtr())); +} + +void SSLConnectJobMessenger::AddPendingSocket(SSLClientSocket* ssl_socket, + const base::Closure& callback) { + DCHECK(!connecting_sockets_.empty()); + pending_sockets_and_callbacks_.push_back( + SocketAndCallback(ssl_socket, callback)); +} + +void SSLConnectJobMessenger::OnSSLHandshakeCompleted() { + connecting_sockets_.clear(); + SSLPendingSocketsAndCallbacks temp_list; + temp_list.swap(pending_sockets_and_callbacks_); + base::Closure messenger_finished_callback = messenger_finished_callback_; + messenger_finished_callback.Run(); + RunAllCallbacks(temp_list); +} + +void SSLConnectJobMessenger::RunAllCallbacks( + const SSLPendingSocketsAndCallbacks& pending_sockets_and_callbacks) { + for (std::vector<SocketAndCallback>::const_iterator it = + pending_sockets_and_callbacks.begin(); + it != pending_sockets_and_callbacks.end(); + ++it) { + it->callback.Run(); + } +} + // Timeout for the SSL handshake portion of the connect. static const int kSSLHandshakeTimeoutInSeconds = 30; @@ -107,6 +179,7 @@ SSLConnectJob::SSLConnectJob(const std::string& group_name, ClientSocketFactory* client_socket_factory, HostResolver* host_resolver, const SSLClientSocketContext& context, + const GetMessengerCallback& get_messenger_callback, Delegate* delegate, NetLog* net_log) : ConnectJob(group_name, @@ -121,16 +194,23 @@ SSLConnectJob::SSLConnectJob(const std::string& group_name, client_socket_factory_(client_socket_factory), host_resolver_(host_resolver), context_(context.cert_verifier, - context.server_bound_cert_service, + context.channel_id_service, context.transport_security_state, context.cert_transparency_verifier, (params->privacy_mode() == PRIVACY_MODE_ENABLED ? "pm/" + context.ssl_session_cache_shard : context.ssl_session_cache_shard)), - callback_(base::Bind(&SSLConnectJob::OnIOComplete, - base::Unretained(this))) {} + io_callback_( + base::Bind(&SSLConnectJob::OnIOComplete, base::Unretained(this))), + messenger_(NULL), + get_messenger_callback_(get_messenger_callback), + weak_factory_(this) { +} -SSLConnectJob::~SSLConnectJob() {} +SSLConnectJob::~SSLConnectJob() { + if (ssl_socket_.get() && messenger_) + messenger_->RemovePendingSocket(ssl_socket_.get()); +} LoadState SSLConnectJob::GetLoadState() const { switch (next_state_) { @@ -144,6 +224,8 @@ LoadState SSLConnectJob::GetLoadState() const { case STATE_SOCKS_CONNECT_COMPLETE: case STATE_TUNNEL_CONNECT: return transport_socket_handle_->GetLoadState(); + case STATE_CREATE_SSL_SOCKET: + case STATE_CHECK_FOR_RESUME: case STATE_SSL_CONNECT: case STATE_SSL_CONNECT_COMPLETE: return LOAD_STATE_SSL_HANDSHAKE; @@ -200,6 +282,12 @@ int SSLConnectJob::DoLoop(int result) { case STATE_TUNNEL_CONNECT_COMPLETE: rv = DoTunnelConnectComplete(rv); break; + case STATE_CREATE_SSL_SOCKET: + rv = DoCreateSSLSocket(); + break; + case STATE_CHECK_FOR_RESUME: + rv = DoCheckForResume(); + break; case STATE_SSL_CONNECT: DCHECK_EQ(OK, rv); rv = DoSSLConnect(); @@ -227,14 +315,14 @@ int SSLConnectJob::DoTransportConnect() { return transport_socket_handle_->Init(group_name(), direct_params, priority(), - callback_, + io_callback_, transport_pool_, net_log()); } int SSLConnectJob::DoTransportConnectComplete(int result) { if (result == OK) - next_state_ = STATE_SSL_CONNECT; + next_state_ = STATE_CREATE_SSL_SOCKET; return result; } @@ -248,14 +336,14 @@ int SSLConnectJob::DoSOCKSConnect() { return transport_socket_handle_->Init(group_name(), socks_proxy_params, priority(), - callback_, + io_callback_, socks_pool_, net_log()); } int SSLConnectJob::DoSOCKSConnectComplete(int result) { if (result == OK) - next_state_ = STATE_SSL_CONNECT; + next_state_ = STATE_CREATE_SSL_SOCKET; return result; } @@ -270,7 +358,7 @@ int SSLConnectJob::DoTunnelConnect() { return transport_socket_handle_->Init(group_name(), http_proxy_params, priority(), - callback_, + io_callback_, http_proxy_pool_, net_log()); } @@ -290,13 +378,13 @@ int SSLConnectJob::DoTunnelConnectComplete(int result) { } if (result < 0) return result; - - next_state_ = STATE_SSL_CONNECT; + next_state_ = STATE_CREATE_SSL_SOCKET; return result; } -int SSLConnectJob::DoSSLConnect() { - next_state_ = STATE_SSL_CONNECT_COMPLETE; +int SSLConnectJob::DoCreateSSLSocket() { + next_state_ = STATE_CHECK_FOR_RESUME; + // Reset the timeout to just the time allowed for the SSL handshake. ResetTimer(base::TimeDelta::FromSeconds(kSSLHandshakeTimeoutInSeconds)); @@ -314,14 +402,45 @@ int SSLConnectJob::DoSSLConnect() { connect_timing_.dns_end = socket_connect_timing.dns_end; } - connect_timing_.ssl_start = base::TimeTicks::Now(); - ssl_socket_ = client_socket_factory_->CreateSSLClientSocket( transport_socket_handle_.Pass(), params_->host_and_port(), params_->ssl_config(), context_); - return ssl_socket_->Connect(callback_); + + if (!ssl_socket_->InSessionCache()) + messenger_ = get_messenger_callback_.Run(ssl_socket_->GetSessionCacheKey()); + + return OK; +} + +int SSLConnectJob::DoCheckForResume() { + next_state_ = STATE_SSL_CONNECT; + + if (!messenger_) + return OK; + + if (messenger_->CanProceed(ssl_socket_.get())) { + messenger_->MonitorConnectionResult(ssl_socket_.get()); + // The SSLConnectJob no longer needs access to the messenger after this + // point. + messenger_ = NULL; + return OK; + } + + messenger_->AddPendingSocket(ssl_socket_.get(), + base::Bind(&SSLConnectJob::ResumeSSLConnection, + weak_factory_.GetWeakPtr())); + + return ERR_IO_PENDING; +} + +int SSLConnectJob::DoSSLConnect() { + next_state_ = STATE_SSL_CONNECT_COMPLETE; + + connect_timing_.ssl_start = base::TimeTicks::Now(); + + return ssl_socket_->Connect(io_callback_); } int SSLConnectJob::DoSSLConnectComplete(int result) { @@ -330,12 +449,13 @@ int SSLConnectJob::DoSSLConnectComplete(int result) { SSLClientSocket::NextProtoStatus status = SSLClientSocket::kNextProtoUnsupported; std::string proto; - std::string server_protos; // GetNextProto will fail and and trigger a NOTREACHED if we pass in a socket // that hasn't had SSL_ImportFD called on it. If we get a certificate error // here, then we know that we called SSL_ImportFD. - if (result == OK || IsCertificateError(result)) - status = ssl_socket_->GetNextProto(&proto, &server_protos); + if (result == OK || IsCertificateError(result)) { + status = ssl_socket_->GetNextProto(&proto); + ssl_socket_->RecordNegotiationExtension(); + } // If we want spdy over npn, make sure it succeeded. if (status == SSLClientSocket::kNextProtoNegotiated) { @@ -370,18 +490,6 @@ int SSLConnectJob::DoSSLConnectComplete(int result) { base::TimeDelta::FromMinutes(1), 100); } -#if defined(SPDY_PROXY_AUTH_ORIGIN) - bool using_data_reduction_proxy = params_->host_and_port().Equals( - HostPortPair::FromURL(GURL(SPDY_PROXY_AUTH_ORIGIN))); - if (using_data_reduction_proxy) { - UMA_HISTOGRAM_CUSTOM_TIMES( - "Net.SSL_Connection_Latency_DataReductionProxy", - connect_duration, - base::TimeDelta::FromMilliseconds(1), - base::TimeDelta::FromMinutes(1), - 100); - } -#endif UMA_HISTOGRAM_CUSTOM_TIMES("Net.SSL_Connection_Latency_2", connect_duration, @@ -396,6 +504,11 @@ int SSLConnectJob::DoSSLConnectComplete(int result) { SSLConnectionStatusToCipherSuite( ssl_info.connection_status)); + UMA_HISTOGRAM_BOOLEAN( + "Net.RenegotiationExtensionSupported", + (ssl_info.connection_status & + SSL_CONNECTION_NO_RENEGOTIATION_EXTENSION) == 0); + if (ssl_info.handshake_type == SSLInfo::HANDSHAKE_RESUME) { UMA_HISTOGRAM_CUSTOM_TIMES("Net.SSL_Connection_Latency_Resume_Handshake", connect_duration, @@ -411,9 +524,9 @@ int SSLConnectJob::DoSSLConnectComplete(int result) { } const std::string& host = params_->host_and_port().host(); - bool is_google = host == "google.com" || - (host.size() > 11 && - host.rfind(".google.com") == host.size() - 11); + bool is_google = + host == "google.com" || + (host.size() > 11 && host.rfind(".google.com") == host.size() - 11); if (is_google) { UMA_HISTOGRAM_CUSTOM_TIMES("Net.SSL_Connection_Latency_Google2", connect_duration, @@ -439,7 +552,7 @@ int SSLConnectJob::DoSSLConnectComplete(int result) { } if (result == OK || IsCertificateError(result)) { - SetSocket(ssl_socket_.PassAs<StreamSocket>()); + SetSocket(ssl_socket_.Pass()); } else if (result == ERR_SSL_CLIENT_AUTH_CERT_NEEDED) { error_response_info_.cert_request_info = new SSLCertRequestInfo; ssl_socket_->GetSSLCertRequestInfo( @@ -449,6 +562,12 @@ int SSLConnectJob::DoSSLConnectComplete(int result) { return result; } +void SSLConnectJob::ResumeSSLConnection() { + DCHECK_EQ(next_state_, STATE_SSL_CONNECT); + messenger_ = NULL; + OnIOComplete(OK); +} + SSLConnectJob::State SSLConnectJob::GetInitialState( SSLSocketParams::ConnectionType connection_type) { switch (connection_type) { @@ -475,6 +594,7 @@ SSLClientSocketPool::SSLConnectJobFactory::SSLConnectJobFactory( ClientSocketFactory* client_socket_factory, HostResolver* host_resolver, const SSLClientSocketContext& context, + const SSLConnectJob::GetMessengerCallback& get_messenger_callback, NetLog* net_log) : transport_pool_(transport_pool), socks_pool_(socks_pool), @@ -482,6 +602,7 @@ SSLClientSocketPool::SSLConnectJobFactory::SSLConnectJobFactory( client_socket_factory_(client_socket_factory), host_resolver_(host_resolver), context_(context), + get_messenger_callback_(get_messenger_callback), net_log_(net_log) { base::TimeDelta max_transport_timeout = base::TimeDelta(); base::TimeDelta pool_timeout; @@ -501,13 +622,16 @@ SSLClientSocketPool::SSLConnectJobFactory::SSLConnectJobFactory( base::TimeDelta::FromSeconds(kSSLHandshakeTimeoutInSeconds); } +SSLClientSocketPool::SSLConnectJobFactory::~SSLConnectJobFactory() { +} + SSLClientSocketPool::SSLClientSocketPool( int max_sockets, int max_sockets_per_group, ClientSocketPoolHistograms* histograms, HostResolver* host_resolver, CertVerifier* cert_verifier, - ServerBoundCertService* server_bound_cert_service, + ChannelIDService* channel_id_service, TransportSecurityState* transport_security_state, CTVerifier* cert_transparency_verifier, const std::string& ssl_session_cache_shard, @@ -516,26 +640,34 @@ SSLClientSocketPool::SSLClientSocketPool( SOCKSClientSocketPool* socks_pool, HttpProxyClientSocketPool* http_proxy_pool, SSLConfigService* ssl_config_service, + bool enable_ssl_connect_job_waiting, NetLog* net_log) : transport_pool_(transport_pool), socks_pool_(socks_pool), http_proxy_pool_(http_proxy_pool), - base_(this, max_sockets, max_sockets_per_group, histograms, + base_(this, + max_sockets, + max_sockets_per_group, + histograms, ClientSocketPool::unused_idle_socket_timeout(), ClientSocketPool::used_idle_socket_timeout(), - new SSLConnectJobFactory(transport_pool, - socks_pool, - http_proxy_pool, - client_socket_factory, - host_resolver, - SSLClientSocketContext( - cert_verifier, - server_bound_cert_service, - transport_security_state, - cert_transparency_verifier, - ssl_session_cache_shard), - net_log)), - ssl_config_service_(ssl_config_service) { + new SSLConnectJobFactory( + transport_pool, + socks_pool, + http_proxy_pool, + client_socket_factory, + host_resolver, + SSLClientSocketContext(cert_verifier, + channel_id_service, + transport_security_state, + cert_transparency_verifier, + ssl_session_cache_shard), + base::Bind( + &SSLClientSocketPool::GetOrCreateSSLConnectJobMessenger, + base::Unretained(this)), + net_log)), + ssl_config_service_(ssl_config_service), + enable_ssl_connect_job_waiting_(enable_ssl_connect_job_waiting) { if (ssl_config_service_.get()) ssl_config_service_->AddObserver(this); if (transport_pool_) @@ -547,24 +679,33 @@ SSLClientSocketPool::SSLClientSocketPool( } SSLClientSocketPool::~SSLClientSocketPool() { + STLDeleteContainerPairSecondPointers(messenger_map_.begin(), + messenger_map_.end()); if (ssl_config_service_.get()) ssl_config_service_->RemoveObserver(this); } -scoped_ptr<ConnectJob> -SSLClientSocketPool::SSLConnectJobFactory::NewConnectJob( +scoped_ptr<ConnectJob> SSLClientSocketPool::SSLConnectJobFactory::NewConnectJob( const std::string& group_name, const PoolBase::Request& request, ConnectJob::Delegate* delegate) const { - return scoped_ptr<ConnectJob>( - new SSLConnectJob(group_name, request.priority(), request.params(), - ConnectionTimeout(), transport_pool_, socks_pool_, - http_proxy_pool_, client_socket_factory_, - host_resolver_, context_, delegate, net_log_)); -} - -base::TimeDelta -SSLClientSocketPool::SSLConnectJobFactory::ConnectionTimeout() const { + return scoped_ptr<ConnectJob>(new SSLConnectJob(group_name, + request.priority(), + request.params(), + ConnectionTimeout(), + transport_pool_, + socks_pool_, + http_proxy_pool_, + client_socket_factory_, + host_resolver_, + context_, + get_messenger_callback_, + delegate, + net_log_)); +} + +base::TimeDelta SSLClientSocketPool::SSLConnectJobFactory::ConnectionTimeout() + const { return timeout_; } @@ -679,6 +820,32 @@ bool SSLClientSocketPool::CloseOneIdleConnection() { return base_.CloseOneIdleConnectionInHigherLayeredPool(); } +SSLConnectJobMessenger* SSLClientSocketPool::GetOrCreateSSLConnectJobMessenger( + const std::string& cache_key) { + if (!enable_ssl_connect_job_waiting_) + return NULL; + MessengerMap::const_iterator it = messenger_map_.find(cache_key); + if (it == messenger_map_.end()) { + std::pair<MessengerMap::iterator, bool> iter = + messenger_map_.insert(MessengerMap::value_type( + cache_key, + new SSLConnectJobMessenger( + base::Bind(&SSLClientSocketPool::DeleteSSLConnectJobMessenger, + base::Unretained(this), + cache_key)))); + it = iter.first; + } + return it->second; +} + +void SSLClientSocketPool::DeleteSSLConnectJobMessenger( + const std::string& cache_key) { + MessengerMap::iterator it = messenger_map_.find(cache_key); + CHECK(it != messenger_map_.end()); + delete it->second; + messenger_map_.erase(it); +} + void SSLClientSocketPool::OnSSLConfigChanged() { FlushWithError(ERR_NETWORK_CHANGED); } diff --git a/chromium/net/socket/ssl_client_socket_pool.h b/chromium/net/socket/ssl_client_socket_pool.h index e03b76ade6a..c7f613e5149 100644 --- a/chromium/net/socket/ssl_client_socket_pool.h +++ b/chromium/net/socket/ssl_client_socket_pool.h @@ -5,7 +5,9 @@ #ifndef NET_SOCKET_SSL_CLIENT_SOCKET_POOL_H_ #define NET_SOCKET_SSL_CLIENT_SOCKET_POOL_H_ +#include <map> #include <string> +#include <vector> #include "base/memory/ref_counted.h" #include "base/memory/scoped_ptr.h" @@ -94,29 +96,110 @@ class NET_EXPORT_PRIVATE SSLSocketParams DISALLOW_COPY_AND_ASSIGN(SSLSocketParams); }; +// SSLConnectJobMessenger handles communication between concurrent +// SSLConnectJobs that share the same SSL session cache key. +// +// SSLConnectJobMessengers tell the session cache when a certain +// connection should be monitored for success or failure, and +// tell SSLConnectJobs when to pause or resume their connections. +class SSLConnectJobMessenger { + public: + struct SocketAndCallback { + SocketAndCallback(SSLClientSocket* ssl_socket, + const base::Closure& job_resumption_callback); + ~SocketAndCallback(); + + SSLClientSocket* socket; + base::Closure callback; + }; + + typedef std::vector<SocketAndCallback> SSLPendingSocketsAndCallbacks; + + // |messenger_finished_callback| is run when a connection monitored by the + // SSLConnectJobMessenger has completed and we are finished with the + // SSLConnectJobMessenger. + explicit SSLConnectJobMessenger( + const base::Closure& messenger_finished_callback); + ~SSLConnectJobMessenger(); + + // Removes |socket| from the set of sockets being monitored. This + // guarantees that |job_resumption_callback| will not be called for + // the socket. + void RemovePendingSocket(SSLClientSocket* ssl_socket); + + // Returns true if |ssl_socket|'s Connect() method should be called. + bool CanProceed(SSLClientSocket* ssl_socket); + + // Configures the SSLConnectJobMessenger to begin monitoring |ssl_socket|'s + // connection status. After a successful connection, or an error, + // the messenger will determine which sockets that have been added + // via AddPendingSocket() to allow to proceed. + void MonitorConnectionResult(SSLClientSocket* ssl_socket); + + // Adds |socket| to the list of sockets waiting to Connect(). When + // the messenger has determined that it's an appropriate time for |socket| + // to connect, it will invoke |callback|. + // + // Note: It is an error to call AddPendingSocket() without having first + // called MonitorConnectionResult() and configuring a socket that WILL + // have Connect() called on it. + void AddPendingSocket(SSLClientSocket* ssl_socket, + const base::Closure& callback); + + private: + // Processes pending callbacks when a socket completes its SSL handshake -- + // either successfully or unsuccessfully. + void OnSSLHandshakeCompleted(); + + // Runs all callbacks stored in |pending_sockets_and_callbacks_|. + void RunAllCallbacks( + const SSLPendingSocketsAndCallbacks& pending_socket_and_callbacks); + + SSLPendingSocketsAndCallbacks pending_sockets_and_callbacks_; + // Note: this field is a vector to allow for future design changes. Currently, + // this vector should only ever have one entry. + std::vector<SSLClientSocket*> connecting_sockets_; + + base::Closure messenger_finished_callback_; + + base::WeakPtrFactory<SSLConnectJobMessenger> weak_factory_; +}; + // SSLConnectJob handles the SSL handshake after setting up the underlying // connection as specified in the params. class SSLConnectJob : public ConnectJob { public: - SSLConnectJob( - const std::string& group_name, - RequestPriority priority, - const scoped_refptr<SSLSocketParams>& params, - const base::TimeDelta& timeout_duration, - TransportClientSocketPool* transport_pool, - SOCKSClientSocketPool* socks_pool, - HttpProxyClientSocketPool* http_proxy_pool, - ClientSocketFactory* client_socket_factory, - HostResolver* host_resolver, - const SSLClientSocketContext& context, - Delegate* delegate, - NetLog* net_log); - virtual ~SSLConnectJob(); + // Callback to allow the SSLConnectJob to obtain an SSLConnectJobMessenger to + // coordinate connecting. The SSLConnectJob will supply a unique identifer + // (ex: the SSL session cache key), with the expectation that the same + // Messenger will be returned for all such ConnectJobs. + // + // Note: It will only be called for situations where the SSL session cache + // does not already have a candidate session to resume. + typedef base::Callback<SSLConnectJobMessenger*(const std::string&)> + GetMessengerCallback; + + // Note: the SSLConnectJob does not own |messenger| so it must outlive the + // job. + SSLConnectJob(const std::string& group_name, + RequestPriority priority, + const scoped_refptr<SSLSocketParams>& params, + const base::TimeDelta& timeout_duration, + TransportClientSocketPool* transport_pool, + SOCKSClientSocketPool* socks_pool, + HttpProxyClientSocketPool* http_proxy_pool, + ClientSocketFactory* client_socket_factory, + HostResolver* host_resolver, + const SSLClientSocketContext& context, + const GetMessengerCallback& get_messenger_callback, + Delegate* delegate, + NetLog* net_log); + ~SSLConnectJob() override; // ConnectJob methods. - virtual LoadState GetLoadState() const OVERRIDE; + LoadState GetLoadState() const override; - virtual void GetAdditionalErrorState(ClientSocketHandle * handle) OVERRIDE; + void GetAdditionalErrorState(ClientSocketHandle* handle) override; private: enum State { @@ -126,6 +209,8 @@ class SSLConnectJob : public ConnectJob { STATE_SOCKS_CONNECT_COMPLETE, STATE_TUNNEL_CONNECT, STATE_TUNNEL_CONNECT_COMPLETE, + STATE_CREATE_SSL_SOCKET, + STATE_CHECK_FOR_RESUME, STATE_SSL_CONNECT, STATE_SSL_CONNECT_COMPLETE, STATE_NONE, @@ -142,9 +227,14 @@ class SSLConnectJob : public ConnectJob { int DoSOCKSConnectComplete(int result); int DoTunnelConnect(); int DoTunnelConnectComplete(int result); + int DoCreateSSLSocket(); + int DoCheckForResume(); int DoSSLConnect(); int DoSSLConnectComplete(int result); + // Tells a waiting SSLConnectJob to resume its SSL connection. + void ResumeSSLConnection(); + // Returns the initial state for the state machine based on the // |connection_type|. static State GetInitialState(SSLSocketParams::ConnectionType connection_type); @@ -152,7 +242,7 @@ class SSLConnectJob : public ConnectJob { // Starts the SSL connection process. Returns OK on success and // ERR_IO_PENDING if it cannot immediately service the request. // Otherwise, it returns a net error code. - virtual int ConnectInternal() OVERRIDE; + int ConnectInternal() override; scoped_refptr<SSLSocketParams> params_; TransportClientSocketPool* const transport_pool_; @@ -164,12 +254,17 @@ class SSLConnectJob : public ConnectJob { const SSLClientSocketContext context_; State next_state_; - CompletionCallback callback_; + CompletionCallback io_callback_; scoped_ptr<ClientSocketHandle> transport_socket_handle_; scoped_ptr<SSLClientSocket> ssl_socket_; + SSLConnectJobMessenger* messenger_; HttpResponseInfo error_response_info_; + GetMessengerCallback get_messenger_callback_; + + base::WeakPtrFactory<SSLConnectJob> weak_factory_; + DISALLOW_COPY_AND_ASSIGN(SSLConnectJob); }; @@ -182,85 +277,91 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool // Only the pools that will be used are required. i.e. if you never // try to create an SSL over SOCKS socket, |socks_pool| may be NULL. - SSLClientSocketPool( - int max_sockets, - int max_sockets_per_group, - ClientSocketPoolHistograms* histograms, - HostResolver* host_resolver, - CertVerifier* cert_verifier, - ServerBoundCertService* server_bound_cert_service, - TransportSecurityState* transport_security_state, - CTVerifier* cert_transparency_verifier, - const std::string& ssl_session_cache_shard, - ClientSocketFactory* client_socket_factory, - TransportClientSocketPool* transport_pool, - SOCKSClientSocketPool* socks_pool, - HttpProxyClientSocketPool* http_proxy_pool, - SSLConfigService* ssl_config_service, - NetLog* net_log); - - virtual ~SSLClientSocketPool(); + SSLClientSocketPool(int max_sockets, + int max_sockets_per_group, + ClientSocketPoolHistograms* histograms, + HostResolver* host_resolver, + CertVerifier* cert_verifier, + ChannelIDService* channel_id_service, + TransportSecurityState* transport_security_state, + CTVerifier* cert_transparency_verifier, + const std::string& ssl_session_cache_shard, + ClientSocketFactory* client_socket_factory, + TransportClientSocketPool* transport_pool, + SOCKSClientSocketPool* socks_pool, + HttpProxyClientSocketPool* http_proxy_pool, + SSLConfigService* ssl_config_service, + bool enable_ssl_connect_job_waiting, + NetLog* net_log); + + ~SSLClientSocketPool() override; // ClientSocketPool implementation. - virtual int RequestSocket(const std::string& group_name, - const void* connect_params, - RequestPriority priority, - ClientSocketHandle* handle, - const CompletionCallback& callback, - const BoundNetLog& net_log) OVERRIDE; + int RequestSocket(const std::string& group_name, + const void* connect_params, + RequestPriority priority, + ClientSocketHandle* handle, + const CompletionCallback& callback, + const BoundNetLog& net_log) override; - virtual void RequestSockets(const std::string& group_name, - const void* params, - int num_sockets, - const BoundNetLog& net_log) OVERRIDE; + void RequestSockets(const std::string& group_name, + const void* params, + int num_sockets, + const BoundNetLog& net_log) override; - virtual void CancelRequest(const std::string& group_name, - ClientSocketHandle* handle) OVERRIDE; + void CancelRequest(const std::string& group_name, + ClientSocketHandle* handle) override; - virtual void ReleaseSocket(const std::string& group_name, - scoped_ptr<StreamSocket> socket, - int id) OVERRIDE; + void ReleaseSocket(const std::string& group_name, + scoped_ptr<StreamSocket> socket, + int id) override; - virtual void FlushWithError(int error) OVERRIDE; + void FlushWithError(int error) override; - virtual void CloseIdleSockets() OVERRIDE; + void CloseIdleSockets() override; - virtual int IdleSocketCount() const OVERRIDE; + int IdleSocketCount() const override; - virtual int IdleSocketCountInGroup( - const std::string& group_name) const OVERRIDE; + int IdleSocketCountInGroup(const std::string& group_name) const override; - virtual LoadState GetLoadState( - const std::string& group_name, - const ClientSocketHandle* handle) const OVERRIDE; + LoadState GetLoadState(const std::string& group_name, + const ClientSocketHandle* handle) const override; - virtual base::DictionaryValue* GetInfoAsValue( + base::DictionaryValue* GetInfoAsValue( const std::string& name, const std::string& type, - bool include_nested_pools) const OVERRIDE; + bool include_nested_pools) const override; - virtual base::TimeDelta ConnectionTimeout() const OVERRIDE; + base::TimeDelta ConnectionTimeout() const override; - virtual ClientSocketPoolHistograms* histograms() const OVERRIDE; + ClientSocketPoolHistograms* histograms() const override; // LowerLayeredPool implementation. - virtual bool IsStalled() const OVERRIDE; + bool IsStalled() const override; - virtual void AddHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE; + void AddHigherLayeredPool(HigherLayeredPool* higher_pool) override; - virtual void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE; + void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool) override; // HigherLayeredPool implementation. - virtual bool CloseOneIdleConnection() OVERRIDE; + bool CloseOneIdleConnection() override; + + // Gets the SSLConnectJobMessenger for the given ssl session |cache_key|. If + // none exits, it creates one and stores it in |messenger_map_|. + SSLConnectJobMessenger* GetOrCreateSSLConnectJobMessenger( + const std::string& cache_key); + void DeleteSSLConnectJobMessenger(const std::string& cache_key); private: typedef ClientSocketPoolBase<SSLSocketParams> PoolBase; + // Maps SSLConnectJob cache keys to SSLConnectJobMessenger objects. + typedef std::map<std::string, SSLConnectJobMessenger*> MessengerMap; // SSLConfigService::Observer implementation. // When the user changes the SSL config, we flush all idle sockets so they // won't get re-used. - virtual void OnSSLConfigChanged() OVERRIDE; + void OnSSLConfigChanged() override; class SSLConnectJobFactory : public PoolBase::ConnectJobFactory { public: @@ -271,17 +372,18 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool ClientSocketFactory* client_socket_factory, HostResolver* host_resolver, const SSLClientSocketContext& context, + const SSLConnectJob::GetMessengerCallback& get_messenger_callback, NetLog* net_log); - virtual ~SSLConnectJobFactory() {} + ~SSLConnectJobFactory() override; // ClientSocketPoolBase::ConnectJobFactory methods. - virtual scoped_ptr<ConnectJob> NewConnectJob( + scoped_ptr<ConnectJob> NewConnectJob( const std::string& group_name, const PoolBase::Request& request, - ConnectJob::Delegate* delegate) const OVERRIDE; + ConnectJob::Delegate* delegate) const override; - virtual base::TimeDelta ConnectionTimeout() const OVERRIDE; + base::TimeDelta ConnectionTimeout() const override; private: TransportClientSocketPool* const transport_pool_; @@ -291,6 +393,7 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool HostResolver* const host_resolver_; const SSLClientSocketContext context_; base::TimeDelta timeout_; + SSLConnectJob::GetMessengerCallback get_messenger_callback_; NetLog* net_log_; DISALLOW_COPY_AND_ASSIGN(SSLConnectJobFactory); @@ -301,6 +404,8 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool HttpProxyClientSocketPool* const http_proxy_pool_; PoolBase base_; const scoped_refptr<SSLConfigService> ssl_config_service_; + MessengerMap messenger_map_; + bool enable_ssl_connect_job_waiting_; DISALLOW_COPY_AND_ASSIGN(SSLClientSocketPool); }; diff --git a/chromium/net/socket/ssl_client_socket_pool_unittest.cc b/chromium/net/socket/ssl_client_socket_pool_unittest.cc index 6ae07ed0bda..202cd8809e5 100644 --- a/chromium/net/socket/ssl_client_socket_pool_unittest.cc +++ b/chromium/net/socket/ssl_client_socket_pool_unittest.cc @@ -6,6 +6,7 @@ #include "base/callback.h" #include "base/compiler_specific.h" +#include "base/run_loop.h" #include "base/strings/string_util.h" #include "base/strings/utf_string_conversions.h" #include "base/time/time.h" @@ -21,6 +22,7 @@ #include "net/http/http_request_headers.h" #include "net/http/http_response_headers.h" #include "net/http/http_server_properties_impl.h" +#include "net/http/transport_security_state.h" #include "net/proxy/proxy_service.h" #include "net/socket/client_socket_handle.h" #include "net/socket/client_socket_pool_histograms.h" @@ -78,26 +80,31 @@ class SSLClientSocketPoolTest public ::testing::WithParamInterface<NextProto> { protected: SSLClientSocketPoolTest() - : proxy_service_(ProxyService::CreateDirect()), + : transport_security_state_(new TransportSecurityState), + proxy_service_(ProxyService::CreateDirect()), ssl_config_service_(new SSLConfigServiceDefaults), http_auth_handler_factory_( HttpAuthHandlerFactory::CreateDefault(&host_resolver_)), session_(CreateNetworkSession()), direct_transport_socket_params_( - new TransportSocketParams(HostPortPair("host", 443), - false, - false, - OnHostResolutionCallback())), + new TransportSocketParams( + HostPortPair("host", 443), + false, + false, + OnHostResolutionCallback(), + TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)), transport_histograms_("MockTCP"), transport_socket_pool_(kMaxSockets, kMaxSocketsPerGroup, &transport_histograms_, &socket_factory_), proxy_transport_socket_params_( - new TransportSocketParams(HostPortPair("proxy", 443), - false, - false, - OnHostResolutionCallback())), + new TransportSocketParams( + HostPortPair("proxy", 443), + false, + false, + OnHostResolutionCallback(), + TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)), socks_socket_params_( new SOCKSSocketParams(proxy_transport_socket_params_, true, @@ -116,7 +123,8 @@ class SSLClientSocketPoolTest session_->http_auth_cache(), session_->http_auth_handler_factory(), session_->spdy_session_pool(), - true)), + true, + NULL)), http_proxy_histograms_("MockHttpProxy"), http_proxy_socket_pool_(kMaxSockets, kMaxSocketsPerGroup, @@ -124,7 +132,9 @@ class SSLClientSocketPoolTest &host_resolver_, &transport_socket_pool_, NULL, - NULL) { + NULL, + NULL), + enable_ssl_connect_job_waiting_(false) { scoped_refptr<SSLConfigService> ssl_config_service( new SSLConfigServiceDefaults); ssl_config_service->GetSSLConfig(&ssl_config_); @@ -138,7 +148,7 @@ class SSLClientSocketPoolTest ssl_histograms_.get(), NULL /* host_resolver */, NULL /* cert_verifier */, - NULL /* server_bound_cert_service */, + NULL /* channel_id_service */, NULL /* transport_security_state */, NULL /* cert_transparency_verifier */, std::string() /* ssl_session_cache_shard */, @@ -147,6 +157,7 @@ class SSLClientSocketPoolTest socks_pool ? &socks_socket_pool_ : NULL, http_proxy_pool ? &http_proxy_socket_pool_ : NULL, NULL, + enable_ssl_connect_job_waiting_, NULL)); } @@ -221,6 +232,8 @@ class SSLClientSocketPoolTest SSLConfig ssl_config_; scoped_ptr<ClientSocketPoolHistograms> ssl_histograms_; scoped_ptr<SSLClientSocketPool> pool_; + + bool enable_ssl_connect_job_waiting_; }; INSTANTIATE_TEST_CASE_P( @@ -229,6 +242,462 @@ INSTANTIATE_TEST_CASE_P( testing::Values(kProtoDeprecatedSPDY2, kProtoSPDY3, kProtoSPDY31, kProtoSPDY4)); +// Tests that the final socket will connect even if all sockets +// prior to it fail. +// +// All sockets should wait for the first socket to attempt to +// connect. Once it fails to connect, all other sockets should +// attempt to connect. All should fail, except the final socket. +TEST_P(SSLClientSocketPoolTest, AllSocketsFailButLast) { + // Although we request four sockets, the first three socket connect + // failures cause the socket pool to create three more sockets because + // there are pending requests. + StaticSocketDataProvider data1; + StaticSocketDataProvider data2; + StaticSocketDataProvider data3; + StaticSocketDataProvider data4; + StaticSocketDataProvider data5; + StaticSocketDataProvider data6; + StaticSocketDataProvider data7; + socket_factory_.AddSocketDataProvider(&data1); + socket_factory_.AddSocketDataProvider(&data2); + socket_factory_.AddSocketDataProvider(&data3); + socket_factory_.AddSocketDataProvider(&data4); + socket_factory_.AddSocketDataProvider(&data5); + socket_factory_.AddSocketDataProvider(&data6); + socket_factory_.AddSocketDataProvider(&data7); + SSLSocketDataProvider ssl(ASYNC, ERR_SSL_PROTOCOL_ERROR); + ssl.is_in_session_cache = false; + SSLSocketDataProvider ssl2(ASYNC, ERR_SSL_PROTOCOL_ERROR); + ssl2.is_in_session_cache = false; + SSLSocketDataProvider ssl3(ASYNC, ERR_SSL_PROTOCOL_ERROR); + ssl3.is_in_session_cache = false; + SSLSocketDataProvider ssl4(ASYNC, OK); + ssl4.is_in_session_cache = false; + SSLSocketDataProvider ssl5(ASYNC, OK); + ssl5.is_in_session_cache = false; + SSLSocketDataProvider ssl6(ASYNC, OK); + ssl6.is_in_session_cache = false; + SSLSocketDataProvider ssl7(ASYNC, OK); + ssl7.is_in_session_cache = false; + + socket_factory_.AddSSLSocketDataProvider(&ssl); + socket_factory_.AddSSLSocketDataProvider(&ssl2); + socket_factory_.AddSSLSocketDataProvider(&ssl3); + socket_factory_.AddSSLSocketDataProvider(&ssl4); + socket_factory_.AddSSLSocketDataProvider(&ssl5); + socket_factory_.AddSSLSocketDataProvider(&ssl6); + socket_factory_.AddSSLSocketDataProvider(&ssl7); + + enable_ssl_connect_job_waiting_ = true; + CreatePool(true, false, false); + + scoped_refptr<SSLSocketParams> params1 = + SSLParams(ProxyServer::SCHEME_DIRECT, false); + scoped_refptr<SSLSocketParams> params2 = + SSLParams(ProxyServer::SCHEME_DIRECT, false); + scoped_refptr<SSLSocketParams> params3 = + SSLParams(ProxyServer::SCHEME_DIRECT, false); + scoped_refptr<SSLSocketParams> params4 = + SSLParams(ProxyServer::SCHEME_DIRECT, false); + ClientSocketHandle handle1; + ClientSocketHandle handle2; + ClientSocketHandle handle3; + ClientSocketHandle handle4; + TestCompletionCallback callback1; + TestCompletionCallback callback2; + TestCompletionCallback callback3; + TestCompletionCallback callback4; + + handle1.Init( + "b", params1, MEDIUM, callback1.callback(), pool_.get(), BoundNetLog()); + handle2.Init( + "b", params2, MEDIUM, callback2.callback(), pool_.get(), BoundNetLog()); + handle3.Init( + "b", params3, MEDIUM, callback3.callback(), pool_.get(), BoundNetLog()); + handle4.Init( + "b", params4, MEDIUM, callback4.callback(), pool_.get(), BoundNetLog()); + + base::RunLoop().RunUntilIdle(); + + // Only the last socket should have connected. + EXPECT_FALSE(handle1.socket()); + EXPECT_FALSE(handle2.socket()); + EXPECT_FALSE(handle3.socket()); + EXPECT_TRUE(handle4.socket()->IsConnected()); +} + +// Tests that sockets will still connect in parallel if the +// EnableSSLConnectJobWaiting flag is not enabled. +TEST_P(SSLClientSocketPoolTest, SocketsConnectWithoutFlag) { + StaticSocketDataProvider data1; + StaticSocketDataProvider data2; + StaticSocketDataProvider data3; + socket_factory_.AddSocketDataProvider(&data1); + socket_factory_.AddSocketDataProvider(&data2); + socket_factory_.AddSocketDataProvider(&data3); + + SSLSocketDataProvider ssl(ASYNC, OK); + ssl.is_in_session_cache = false; + ssl.should_pause_on_connect = true; + SSLSocketDataProvider ssl2(ASYNC, OK); + ssl2.is_in_session_cache = false; + ssl2.should_pause_on_connect = true; + SSLSocketDataProvider ssl3(ASYNC, OK); + ssl3.is_in_session_cache = false; + ssl3.should_pause_on_connect = true; + socket_factory_.AddSSLSocketDataProvider(&ssl); + socket_factory_.AddSSLSocketDataProvider(&ssl2); + socket_factory_.AddSSLSocketDataProvider(&ssl3); + + CreatePool(true, false, false); + + scoped_refptr<SSLSocketParams> params1 = + SSLParams(ProxyServer::SCHEME_DIRECT, false); + scoped_refptr<SSLSocketParams> params2 = + SSLParams(ProxyServer::SCHEME_DIRECT, false); + scoped_refptr<SSLSocketParams> params3 = + SSLParams(ProxyServer::SCHEME_DIRECT, false); + ClientSocketHandle handle1; + ClientSocketHandle handle2; + ClientSocketHandle handle3; + TestCompletionCallback callback1; + TestCompletionCallback callback2; + TestCompletionCallback callback3; + + handle1.Init( + "b", params1, MEDIUM, callback1.callback(), pool_.get(), BoundNetLog()); + handle2.Init( + "b", params2, MEDIUM, callback2.callback(), pool_.get(), BoundNetLog()); + handle3.Init( + "b", params3, MEDIUM, callback3.callback(), pool_.get(), BoundNetLog()); + + base::RunLoop().RunUntilIdle(); + + std::vector<MockSSLClientSocket*> sockets = + socket_factory_.ssl_client_sockets(); + + // All sockets should have started their connections. + for (std::vector<MockSSLClientSocket*>::iterator it = sockets.begin(); + it != sockets.end(); + ++it) { + EXPECT_TRUE((*it)->reached_connect()); + } + + // Resume connecting all of the sockets. + for (std::vector<MockSSLClientSocket*>::iterator it = sockets.begin(); + it != sockets.end(); + ++it) { + (*it)->RestartPausedConnect(); + } + + callback1.WaitForResult(); + callback2.WaitForResult(); + callback3.WaitForResult(); + + EXPECT_TRUE(handle1.socket()->IsConnected()); + EXPECT_TRUE(handle2.socket()->IsConnected()); + EXPECT_TRUE(handle3.socket()->IsConnected()); +} + +// Tests that the pool deleting an SSLConnectJob will not cause a crash, +// or prevent pending sockets from connecting. +TEST_P(SSLClientSocketPoolTest, DeletedSSLConnectJob) { + StaticSocketDataProvider data1; + StaticSocketDataProvider data2; + StaticSocketDataProvider data3; + socket_factory_.AddSocketDataProvider(&data1); + socket_factory_.AddSocketDataProvider(&data2); + socket_factory_.AddSocketDataProvider(&data3); + + SSLSocketDataProvider ssl(ASYNC, OK); + ssl.is_in_session_cache = false; + ssl.should_pause_on_connect = true; + SSLSocketDataProvider ssl2(ASYNC, OK); + ssl2.is_in_session_cache = false; + SSLSocketDataProvider ssl3(ASYNC, OK); + ssl3.is_in_session_cache = false; + socket_factory_.AddSSLSocketDataProvider(&ssl); + socket_factory_.AddSSLSocketDataProvider(&ssl2); + socket_factory_.AddSSLSocketDataProvider(&ssl3); + + enable_ssl_connect_job_waiting_ = true; + CreatePool(true, false, false); + + scoped_refptr<SSLSocketParams> params1 = + SSLParams(ProxyServer::SCHEME_DIRECT, false); + scoped_refptr<SSLSocketParams> params2 = + SSLParams(ProxyServer::SCHEME_DIRECT, false); + scoped_refptr<SSLSocketParams> params3 = + SSLParams(ProxyServer::SCHEME_DIRECT, false); + ClientSocketHandle handle1; + ClientSocketHandle handle2; + ClientSocketHandle handle3; + TestCompletionCallback callback1; + TestCompletionCallback callback2; + TestCompletionCallback callback3; + + handle1.Init( + "b", params1, MEDIUM, callback1.callback(), pool_.get(), BoundNetLog()); + handle2.Init( + "b", params2, MEDIUM, callback2.callback(), pool_.get(), BoundNetLog()); + handle3.Init( + "b", params3, MEDIUM, callback3.callback(), pool_.get(), BoundNetLog()); + + // Allow the connections to proceed until the first socket has started + // connecting. + base::RunLoop().RunUntilIdle(); + + std::vector<MockSSLClientSocket*> sockets = + socket_factory_.ssl_client_sockets(); + + pool_->CancelRequest("b", &handle2); + + sockets[0]->RestartPausedConnect(); + + callback1.WaitForResult(); + callback3.WaitForResult(); + + EXPECT_TRUE(handle1.socket()->IsConnected()); + EXPECT_FALSE(handle2.socket()); + EXPECT_TRUE(handle3.socket()->IsConnected()); +} + +// Tests that all pending sockets still connect when the pool deletes a pending +// SSLConnectJob which immediately followed a failed leading connection. +TEST_P(SSLClientSocketPoolTest, DeletedSocketAfterFail) { + StaticSocketDataProvider data1; + StaticSocketDataProvider data2; + StaticSocketDataProvider data3; + StaticSocketDataProvider data4; + socket_factory_.AddSocketDataProvider(&data1); + socket_factory_.AddSocketDataProvider(&data2); + socket_factory_.AddSocketDataProvider(&data3); + socket_factory_.AddSocketDataProvider(&data4); + + SSLSocketDataProvider ssl(ASYNC, ERR_SSL_PROTOCOL_ERROR); + ssl.is_in_session_cache = false; + ssl.should_pause_on_connect = true; + SSLSocketDataProvider ssl2(ASYNC, OK); + ssl2.is_in_session_cache = false; + SSLSocketDataProvider ssl3(ASYNC, OK); + ssl3.is_in_session_cache = false; + SSLSocketDataProvider ssl4(ASYNC, OK); + ssl4.is_in_session_cache = false; + socket_factory_.AddSSLSocketDataProvider(&ssl); + socket_factory_.AddSSLSocketDataProvider(&ssl2); + socket_factory_.AddSSLSocketDataProvider(&ssl3); + socket_factory_.AddSSLSocketDataProvider(&ssl4); + + enable_ssl_connect_job_waiting_ = true; + CreatePool(true, false, false); + + scoped_refptr<SSLSocketParams> params1 = + SSLParams(ProxyServer::SCHEME_DIRECT, false); + scoped_refptr<SSLSocketParams> params2 = + SSLParams(ProxyServer::SCHEME_DIRECT, false); + scoped_refptr<SSLSocketParams> params3 = + SSLParams(ProxyServer::SCHEME_DIRECT, false); + ClientSocketHandle handle1; + ClientSocketHandle handle2; + ClientSocketHandle handle3; + TestCompletionCallback callback1; + TestCompletionCallback callback2; + TestCompletionCallback callback3; + + handle1.Init( + "b", params1, MEDIUM, callback1.callback(), pool_.get(), BoundNetLog()); + handle2.Init( + "b", params2, MEDIUM, callback2.callback(), pool_.get(), BoundNetLog()); + handle3.Init( + "b", params3, MEDIUM, callback3.callback(), pool_.get(), BoundNetLog()); + + // Allow the connections to proceed until the first socket has started + // connecting. + base::RunLoop().RunUntilIdle(); + + std::vector<MockSSLClientSocket*> sockets = + socket_factory_.ssl_client_sockets(); + + EXPECT_EQ(3u, sockets.size()); + EXPECT_TRUE(sockets[0]->reached_connect()); + EXPECT_FALSE(handle1.socket()); + + pool_->CancelRequest("b", &handle2); + + sockets[0]->RestartPausedConnect(); + + callback1.WaitForResult(); + callback3.WaitForResult(); + + EXPECT_FALSE(handle1.socket()); + EXPECT_FALSE(handle2.socket()); + EXPECT_TRUE(handle3.socket()->IsConnected()); +} + +// Make sure that sockets still connect after the leader socket's +// connection fails. +TEST_P(SSLClientSocketPoolTest, SimultaneousConnectJobsFail) { + StaticSocketDataProvider data1; + StaticSocketDataProvider data2; + StaticSocketDataProvider data3; + StaticSocketDataProvider data4; + StaticSocketDataProvider data5; + socket_factory_.AddSocketDataProvider(&data1); + socket_factory_.AddSocketDataProvider(&data2); + socket_factory_.AddSocketDataProvider(&data3); + socket_factory_.AddSocketDataProvider(&data4); + socket_factory_.AddSocketDataProvider(&data5); + SSLSocketDataProvider ssl(ASYNC, ERR_SSL_PROTOCOL_ERROR); + ssl.is_in_session_cache = false; + ssl.should_pause_on_connect = true; + SSLSocketDataProvider ssl2(ASYNC, OK); + ssl2.is_in_session_cache = false; + ssl2.should_pause_on_connect = true; + SSLSocketDataProvider ssl3(ASYNC, OK); + ssl3.is_in_session_cache = false; + SSLSocketDataProvider ssl4(ASYNC, OK); + ssl4.is_in_session_cache = false; + SSLSocketDataProvider ssl5(ASYNC, OK); + ssl5.is_in_session_cache = false; + + socket_factory_.AddSSLSocketDataProvider(&ssl); + socket_factory_.AddSSLSocketDataProvider(&ssl2); + socket_factory_.AddSSLSocketDataProvider(&ssl3); + socket_factory_.AddSSLSocketDataProvider(&ssl4); + socket_factory_.AddSSLSocketDataProvider(&ssl5); + + enable_ssl_connect_job_waiting_ = true; + CreatePool(true, false, false); + scoped_refptr<SSLSocketParams> params1 = + SSLParams(ProxyServer::SCHEME_DIRECT, false); + scoped_refptr<SSLSocketParams> params2 = + SSLParams(ProxyServer::SCHEME_DIRECT, false); + scoped_refptr<SSLSocketParams> params3 = + SSLParams(ProxyServer::SCHEME_DIRECT, false); + scoped_refptr<SSLSocketParams> params4 = + SSLParams(ProxyServer::SCHEME_DIRECT, false); + ClientSocketHandle handle1; + ClientSocketHandle handle2; + ClientSocketHandle handle3; + ClientSocketHandle handle4; + TestCompletionCallback callback1; + TestCompletionCallback callback2; + TestCompletionCallback callback3; + TestCompletionCallback callback4; + handle1.Init( + "b", params1, MEDIUM, callback1.callback(), pool_.get(), BoundNetLog()); + handle2.Init( + "b", params2, MEDIUM, callback2.callback(), pool_.get(), BoundNetLog()); + handle3.Init( + "b", params3, MEDIUM, callback3.callback(), pool_.get(), BoundNetLog()); + handle4.Init( + "b", params4, MEDIUM, callback4.callback(), pool_.get(), BoundNetLog()); + + base::RunLoop().RunUntilIdle(); + + std::vector<MockSSLClientSocket*> sockets = + socket_factory_.ssl_client_sockets(); + + std::vector<MockSSLClientSocket*>::const_iterator it = sockets.begin(); + + // The first socket should have had Connect called on it. + EXPECT_TRUE((*it)->reached_connect()); + ++it; + + // No other socket should have reached connect yet. + for (; it != sockets.end(); ++it) + EXPECT_FALSE((*it)->reached_connect()); + + // Allow the first socket to resume it's connection process. + sockets[0]->RestartPausedConnect(); + + base::RunLoop().RunUntilIdle(); + + // The second socket should have reached connect. + EXPECT_TRUE(sockets[1]->reached_connect()); + + // Allow the second socket to continue its connection. + sockets[1]->RestartPausedConnect(); + + base::RunLoop().RunUntilIdle(); + + EXPECT_FALSE(handle1.socket()); + EXPECT_TRUE(handle2.socket()->IsConnected()); + EXPECT_TRUE(handle3.socket()->IsConnected()); + EXPECT_TRUE(handle4.socket()->IsConnected()); +} + +// Make sure that no sockets connect before the "leader" socket, +// given that the leader has a successful connection. +TEST_P(SSLClientSocketPoolTest, SimultaneousConnectJobsSuccess) { + StaticSocketDataProvider data1; + StaticSocketDataProvider data2; + StaticSocketDataProvider data3; + socket_factory_.AddSocketDataProvider(&data1); + socket_factory_.AddSocketDataProvider(&data2); + socket_factory_.AddSocketDataProvider(&data3); + + SSLSocketDataProvider ssl(ASYNC, OK); + ssl.is_in_session_cache = false; + ssl.should_pause_on_connect = true; + SSLSocketDataProvider ssl2(ASYNC, OK); + ssl2.is_in_session_cache = false; + SSLSocketDataProvider ssl3(ASYNC, OK); + ssl3.is_in_session_cache = false; + socket_factory_.AddSSLSocketDataProvider(&ssl); + socket_factory_.AddSSLSocketDataProvider(&ssl2); + socket_factory_.AddSSLSocketDataProvider(&ssl3); + + enable_ssl_connect_job_waiting_ = true; + CreatePool(true, false, false); + + scoped_refptr<SSLSocketParams> params1 = + SSLParams(ProxyServer::SCHEME_DIRECT, false); + scoped_refptr<SSLSocketParams> params2 = + SSLParams(ProxyServer::SCHEME_DIRECT, false); + scoped_refptr<SSLSocketParams> params3 = + SSLParams(ProxyServer::SCHEME_DIRECT, false); + ClientSocketHandle handle1; + ClientSocketHandle handle2; + ClientSocketHandle handle3; + TestCompletionCallback callback1; + TestCompletionCallback callback2; + TestCompletionCallback callback3; + + handle1.Init( + "b", params1, MEDIUM, callback1.callback(), pool_.get(), BoundNetLog()); + handle2.Init( + "b", params2, MEDIUM, callback2.callback(), pool_.get(), BoundNetLog()); + handle3.Init( + "b", params3, MEDIUM, callback3.callback(), pool_.get(), BoundNetLog()); + + // Allow the connections to proceed until the first socket has finished + // connecting. + base::RunLoop().RunUntilIdle(); + + std::vector<MockSSLClientSocket*> sockets = + socket_factory_.ssl_client_sockets(); + + std::vector<MockSSLClientSocket*>::const_iterator it = sockets.begin(); + // The first socket should have reached connect. + EXPECT_TRUE((*it)->reached_connect()); + ++it; + // No other socket should have reached connect yet. + for (; it != sockets.end(); ++it) + EXPECT_FALSE((*it)->reached_connect()); + + sockets[0]->RestartPausedConnect(); + + callback1.WaitForResult(); + callback2.WaitForResult(); + callback3.WaitForResult(); + + EXPECT_TRUE(handle1.socket()->IsConnected()); + EXPECT_TRUE(handle2.socket()->IsConnected()); + EXPECT_TRUE(handle3.socket()->IsConnected()); +} + TEST_P(SSLClientSocketPoolTest, TCPFail) { StaticSocketDataProvider data; data.set_connect_data(MockConnect(SYNCHRONOUS, ERR_CONNECTION_FAILED)); @@ -466,8 +935,7 @@ TEST_P(SSLClientSocketPoolTest, DirectGotSPDY) { SSLClientSocket* ssl_socket = static_cast<SSLClientSocket*>(handle.socket()); EXPECT_TRUE(ssl_socket->WasNpnNegotiated()); std::string proto; - std::string server_protos; - ssl_socket->GetNextProto(&proto, &server_protos); + ssl_socket->GetNextProto(&proto); EXPECT_EQ(GetParam(), SSLClientSocket::NextProtoFromString(proto)); } @@ -498,8 +966,7 @@ TEST_P(SSLClientSocketPoolTest, DirectGotBonusSPDY) { SSLClientSocket* ssl_socket = static_cast<SSLClientSocket*>(handle.socket()); EXPECT_TRUE(ssl_socket->WasNpnNegotiated()); std::string proto; - std::string server_protos; - ssl_socket->GetNextProto(&proto, &server_protos); + ssl_socket->GetNextProto(&proto); EXPECT_EQ(GetParam(), SSLClientSocket::NextProtoFromString(proto)); } @@ -798,8 +1265,7 @@ TEST_P(SSLClientSocketPoolTest, NeedProxyAuth) { EXPECT_FALSE(tunnel_handle->socket()->IsConnected()); } -// TODO(rch): re-enable this. -TEST_P(SSLClientSocketPoolTest, DISABLED_IPPooling) { +TEST_P(SSLClientSocketPoolTest, IPPooling) { const int kTestPort = 80; struct TestHosts { std::string name; @@ -813,7 +1279,7 @@ TEST_P(SSLClientSocketPoolTest, DISABLED_IPPooling) { }; host_resolver_.set_synchronous_mode(true); - for (size_t i = 0; i < ARRAYSIZE_UNSAFE(test_hosts); i++) { + for (size_t i = 0; i < arraysize(test_hosts); i++) { host_resolver_.rules()->AddIPLiteralRule( test_hosts[i].name, test_hosts[i].iplist, std::string()); @@ -873,7 +1339,7 @@ void SSLClientSocketPoolTest::TestIPPoolingDisabled( TestCompletionCallback callback; int rv; - for (size_t i = 0; i < ARRAYSIZE_UNSAFE(test_hosts); i++) { + for (size_t i = 0; i < arraysize(test_hosts); i++) { host_resolver_.rules()->AddIPLiteralRule( test_hosts[i].name, test_hosts[i].iplist, std::string()); diff --git a/chromium/net/socket/ssl_client_socket_unittest.cc b/chromium/net/socket/ssl_client_socket_unittest.cc index 51c67565dad..16e03f7eb8b 100644 --- a/chromium/net/socket/ssl_client_socket_unittest.cc +++ b/chromium/net/socket/ssl_client_socket_unittest.cc @@ -15,6 +15,8 @@ #include "net/base/net_log_unittest.h" #include "net/base/test_completion_callback.h" #include "net/base/test_data_directory.h" +#include "net/cert/asn1_util.h" +#include "net/cert/ct_verifier.h" #include "net/cert/mock_cert_verifier.h" #include "net/cert/test_root_certs.h" #include "net/dns/host_resolver.h" @@ -23,16 +25,22 @@ #include "net/socket/client_socket_handle.h" #include "net/socket/socket_test_util.h" #include "net/socket/tcp_client_socket.h" -#include "net/ssl/default_server_bound_cert_store.h" +#include "net/ssl/channel_id_service.h" +#include "net/ssl/default_channel_id_store.h" #include "net/ssl/ssl_cert_request_info.h" #include "net/ssl/ssl_config_service.h" #include "net/test/cert_test_util.h" #include "net/test/spawned_test_server/spawned_test_server.h" +#include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h" #include "testing/platform_test.h" //----------------------------------------------------------------------------- +using testing::_; +using testing::Return; +using testing::Truly; + namespace net { namespace { @@ -49,65 +57,57 @@ class WrappedStreamSocket : public StreamSocket { public: explicit WrappedStreamSocket(scoped_ptr<StreamSocket> transport) : transport_(transport.Pass()) {} - virtual ~WrappedStreamSocket() {} + ~WrappedStreamSocket() override {} // StreamSocket implementation: - virtual int Connect(const CompletionCallback& callback) OVERRIDE { + int Connect(const CompletionCallback& callback) override { return transport_->Connect(callback); } - virtual void Disconnect() OVERRIDE { transport_->Disconnect(); } - virtual bool IsConnected() const OVERRIDE { - return transport_->IsConnected(); - } - virtual bool IsConnectedAndIdle() const OVERRIDE { + void Disconnect() override { transport_->Disconnect(); } + bool IsConnected() const override { return transport_->IsConnected(); } + bool IsConnectedAndIdle() const override { return transport_->IsConnectedAndIdle(); } - virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE { + int GetPeerAddress(IPEndPoint* address) const override { return transport_->GetPeerAddress(address); } - virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE { + int GetLocalAddress(IPEndPoint* address) const override { return transport_->GetLocalAddress(address); } - virtual const BoundNetLog& NetLog() const OVERRIDE { - return transport_->NetLog(); - } - virtual void SetSubresourceSpeculation() OVERRIDE { + const BoundNetLog& NetLog() const override { return transport_->NetLog(); } + void SetSubresourceSpeculation() override { transport_->SetSubresourceSpeculation(); } - virtual void SetOmniboxSpeculation() OVERRIDE { - transport_->SetOmniboxSpeculation(); - } - virtual bool WasEverUsed() const OVERRIDE { - return transport_->WasEverUsed(); - } - virtual bool UsingTCPFastOpen() const OVERRIDE { + void SetOmniboxSpeculation() override { transport_->SetOmniboxSpeculation(); } + bool WasEverUsed() const override { return transport_->WasEverUsed(); } + bool UsingTCPFastOpen() const override { return transport_->UsingTCPFastOpen(); } - virtual bool WasNpnNegotiated() const OVERRIDE { + bool WasNpnNegotiated() const override { return transport_->WasNpnNegotiated(); } - virtual NextProto GetNegotiatedProtocol() const OVERRIDE { + NextProto GetNegotiatedProtocol() const override { return transport_->GetNegotiatedProtocol(); } - virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { + bool GetSSLInfo(SSLInfo* ssl_info) override { return transport_->GetSSLInfo(ssl_info); } // Socket implementation: - virtual int Read(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) OVERRIDE { + int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override { return transport_->Read(buf, buf_len, callback); } - virtual int Write(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) OVERRIDE { + int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override { return transport_->Write(buf, buf_len, callback); } - virtual int SetReceiveBufferSize(int32 size) OVERRIDE { + int SetReceiveBufferSize(int32 size) override { return transport_->SetReceiveBufferSize(size); } - virtual int SetSendBufferSize(int32 size) OVERRIDE { + int SetSendBufferSize(int32 size) override { return transport_->SetSendBufferSize(size); } @@ -124,12 +124,12 @@ class WrappedStreamSocket : public StreamSocket { class ReadBufferingStreamSocket : public WrappedStreamSocket { public: explicit ReadBufferingStreamSocket(scoped_ptr<StreamSocket> transport); - virtual ~ReadBufferingStreamSocket() {} + ~ReadBufferingStreamSocket() override {} // Socket implementation: - virtual int Read(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) OVERRIDE; + int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override; // Sets the internal buffer to |size|. This must not be greater than // the largest value supplied to Read() - that is, it does not handle @@ -254,21 +254,21 @@ void ReadBufferingStreamSocket::OnReadCompleted(int result) { class SynchronousErrorStreamSocket : public WrappedStreamSocket { public: explicit SynchronousErrorStreamSocket(scoped_ptr<StreamSocket> transport); - virtual ~SynchronousErrorStreamSocket() {} + ~SynchronousErrorStreamSocket() override {} // Socket implementation: - virtual int Read(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) OVERRIDE; - virtual int Write(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) OVERRIDE; + int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override; + int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override; // Sets the next Read() call and all future calls to return |error|. // If there is already a pending asynchronous read, the configured error // will not be returned until that asynchronous read has completed and Read() // is called again. - void SetNextReadError(Error error) { + void SetNextReadError(int error) { DCHECK_GE(0, error); have_read_error_ = true; pending_read_error_ = error; @@ -278,7 +278,7 @@ class SynchronousErrorStreamSocket : public WrappedStreamSocket { // If there is already a pending asynchronous write, the configured error // will not be returned until that asynchronous write has completed and // Write() is called again. - void SetNextWriteError(Error error) { + void SetNextWriteError(int error) { DCHECK_GE(0, error); have_write_error_ = true; pending_write_error_ = error; @@ -325,15 +325,15 @@ int SynchronousErrorStreamSocket::Write(IOBuffer* buf, class FakeBlockingStreamSocket : public WrappedStreamSocket { public: explicit FakeBlockingStreamSocket(scoped_ptr<StreamSocket> transport); - virtual ~FakeBlockingStreamSocket() {} + ~FakeBlockingStreamSocket() override {} // Socket implementation: - virtual int Read(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) OVERRIDE; - virtual int Write(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) OVERRIDE; + int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override; + int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override; // Blocks read results on the socket. Reads will not complete until // UnblockReadResult() has been called and a result is ready from the @@ -357,6 +357,9 @@ class FakeBlockingStreamSocket : public WrappedStreamSocket { // Waits for the blocked Write() call to be scheduled. void WaitForWrite(); + // Returns the wrapped stream socket. + StreamSocket* transport() { return transport_.get(); } + private: // Handles completion from the underlying transport read. void OnReadCompleted(int result); @@ -429,7 +432,7 @@ int FakeBlockingStreamSocket::Write(IOBuffer* buf, return transport_->Write(buf, len, callback); // Schedule the write, but do nothing. - DCHECK(!pending_write_buf_); + DCHECK(!pending_write_buf_.get()); DCHECK_EQ(-1, pending_write_len_); DCHECK(pending_write_callback_.is_null()); DCHECK(!callback.is_null()); @@ -485,11 +488,11 @@ void FakeBlockingStreamSocket::UnblockWrite() { // Do nothing if UnblockWrite() was called after BlockWrite(), // without a Write() in between. - if (!pending_write_buf_) + if (!pending_write_buf_.get()) return; - int rv = transport_->Write(pending_write_buf_, pending_write_len_, - pending_write_callback_); + int rv = transport_->Write( + pending_write_buf_.get(), pending_write_len_, pending_write_callback_); pending_write_buf_ = NULL; pending_write_len_ = -1; if (rv == ERR_IO_PENDING) { @@ -503,12 +506,12 @@ void FakeBlockingStreamSocket::WaitForWrite() { DCHECK(should_block_write_); DCHECK(!write_loop_); - if (pending_write_buf_) + if (pending_write_buf_.get()) return; write_loop_.reset(new base::RunLoop); write_loop_->Run(); write_loop_.reset(); - DCHECK(pending_write_buf_); + DCHECK(pending_write_buf_.get()); } void FakeBlockingStreamSocket::OnReadCompleted(int result) { @@ -538,18 +541,18 @@ class CountingStreamSocket : public WrappedStreamSocket { : WrappedStreamSocket(transport.Pass()), read_count_(0), write_count_(0) {} - virtual ~CountingStreamSocket() {} + ~CountingStreamSocket() override {} // Socket implementation: - virtual int Read(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) OVERRIDE { + int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override { read_count_++; return transport_->Read(buf, buf_len, callback); } - virtual int Write(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) OVERRIDE { + int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override { write_count_++; return transport_->Write(buf, buf_len, callback); } @@ -570,7 +573,7 @@ class DeleteSocketCallback : public TestCompletionCallbackBase { : socket_(socket), callback_(base::Bind(&DeleteSocketCallback::OnComplete, base::Unretained(this))) {} - virtual ~DeleteSocketCallback() {} + ~DeleteSocketCallback() override {} const CompletionCallback& callback() const { return callback_; } @@ -591,33 +594,72 @@ class DeleteSocketCallback : public TestCompletionCallbackBase { DISALLOW_COPY_AND_ASSIGN(DeleteSocketCallback); }; -// A ServerBoundCertStore that always returns an error when asked for a -// certificate. -class FailingServerBoundCertStore : public ServerBoundCertStore { - virtual int GetServerBoundCert(const std::string& server_identifier, - base::Time* expiration_time, - std::string* private_key_result, - std::string* cert_result, - const GetCertCallback& callback) OVERRIDE { +// A ChannelIDStore that always returns an error when asked for a +// channel id. +class FailingChannelIDStore : public ChannelIDStore { + int GetChannelID(const std::string& server_identifier, + base::Time* expiration_time, + std::string* private_key_result, + std::string* cert_result, + const GetChannelIDCallback& callback) override { return ERR_UNEXPECTED; } - virtual void SetServerBoundCert(const std::string& server_identifier, - base::Time creation_time, - base::Time expiration_time, - const std::string& private_key, - const std::string& cert) OVERRIDE {} - virtual void DeleteServerBoundCert(const std::string& server_identifier, - const base::Closure& completion_callback) - OVERRIDE {} - virtual void DeleteAllCreatedBetween(base::Time delete_begin, - base::Time delete_end, - const base::Closure& completion_callback) - OVERRIDE {} - virtual void DeleteAll(const base::Closure& completion_callback) OVERRIDE {} - virtual void GetAllServerBoundCerts(const GetCertListCallback& callback) - OVERRIDE {} - virtual int GetCertCount() OVERRIDE { return 0; } - virtual void SetForceKeepSessionState() OVERRIDE {} + void SetChannelID(const std::string& server_identifier, + base::Time creation_time, + base::Time expiration_time, + const std::string& private_key, + const std::string& cert) override {} + void DeleteChannelID(const std::string& server_identifier, + const base::Closure& completion_callback) override {} + void DeleteAllCreatedBetween( + base::Time delete_begin, + base::Time delete_end, + const base::Closure& completion_callback) override {} + void DeleteAll(const base::Closure& completion_callback) override {} + void GetAllChannelIDs(const GetChannelIDListCallback& callback) override {} + int GetChannelIDCount() override { return 0; } + void SetForceKeepSessionState() override {} +}; + +// A ChannelIDStore that asynchronously returns an error when asked for a +// channel id. +class AsyncFailingChannelIDStore : public ChannelIDStore { + int GetChannelID(const std::string& server_identifier, + base::Time* expiration_time, + std::string* private_key_result, + std::string* cert_result, + const GetChannelIDCallback& callback) override { + base::MessageLoop::current()->PostTask( + FROM_HERE, base::Bind(callback, ERR_UNEXPECTED, + server_identifier, base::Time(), "", "")); + return ERR_IO_PENDING; + } + void SetChannelID(const std::string& server_identifier, + base::Time creation_time, + base::Time expiration_time, + const std::string& private_key, + const std::string& cert) override {} + void DeleteChannelID(const std::string& server_identifier, + const base::Closure& completion_callback) override {} + void DeleteAllCreatedBetween( + base::Time delete_begin, + base::Time delete_end, + const base::Closure& completion_callback) override {} + void DeleteAll(const base::Closure& completion_callback) override {} + void GetAllChannelIDs(const GetChannelIDListCallback& callback) override {} + int GetChannelIDCount() override { return 0; } + void SetForceKeepSessionState() override {} +}; + +// A mock CTVerifier that records every call to Verify but doesn't verify +// anything. +class MockCTVerifier : public CTVerifier { + public: + MOCK_METHOD5(Verify, int(X509Certificate*, + const std::string&, + const std::string&, + ct::CTVerifyResult*, + const BoundNetLog&)); }; class SSLClientSocketTest : public PlatformTest { @@ -625,12 +667,15 @@ class SSLClientSocketTest : public PlatformTest { SSLClientSocketTest() : socket_factory_(ClientSocketFactory::GetDefaultFactory()), cert_verifier_(new MockCertVerifier), - transport_security_state_(new TransportSecurityState) { + transport_security_state_(new TransportSecurityState), + ran_handshake_completion_callback_(false) { cert_verifier_->set_default_result(OK); context_.cert_verifier = cert_verifier_.get(); context_.transport_security_state = transport_security_state_.get(); } + void RecordCompletedHandshake() { ran_handshake_completion_callback_ = true; } + protected: // The address of the spawned test server, after calling StartTestServer(). const AddressList& addr() const { return addr_; } @@ -638,6 +683,10 @@ class SSLClientSocketTest : public PlatformTest { // The SpawnedTestServer object, after calling StartTestServer(). const SpawnedTestServer* test_server() const { return test_server_.get(); } + void SetCTVerifier(CTVerifier* ct_verifier) { + context_.cert_transparency_verifier = ct_verifier; + } + // Starts the test server with SSL configuration |ssl_options|. Returns true // on success. bool StartTestServer(const SpawnedTestServer::SSLOptions& ssl_options) { @@ -707,6 +756,7 @@ class SSLClientSocketTest : public PlatformTest { SSLClientSocketContext context_; scoped_ptr<SSLClientSocket> sock_; CapturingNetLog log_; + bool ran_handshake_completion_callback_; private: scoped_ptr<StreamSocket> transport_; @@ -759,6 +809,11 @@ class SSLClientSocketCertRequestInfoTest : public SSLClientSocketTest { }; class SSLClientSocketFalseStartTest : public SSLClientSocketTest { + public: + SSLClientSocketFalseStartTest() + : monitor_handshake_callback_(false), + fail_handshake_after_false_start_(false) {} + protected: // Creates an SSLClientSocket with |client_config| attached to a // FakeBlockingStreamSocket, returning both in |*out_raw_transport| and @@ -780,18 +835,25 @@ class SSLClientSocketFalseStartTest : public SSLClientSocketTest { scoped_ptr<SSLClientSocket>* out_sock) { CHECK(test_server()); - scoped_ptr<StreamSocket> real_transport( - new TCPClientSocket(addr(), NULL, NetLog::Source())); + scoped_ptr<StreamSocket> real_transport(scoped_ptr<StreamSocket>( + new TCPClientSocket(addr(), NULL, NetLog::Source()))); + real_transport.reset( + new SynchronousErrorStreamSocket(real_transport.Pass())); + scoped_ptr<FakeBlockingStreamSocket> transport( new FakeBlockingStreamSocket(real_transport.Pass())); int rv = callback->GetResult(transport->Connect(callback->callback())); EXPECT_EQ(OK, rv); FakeBlockingStreamSocket* raw_transport = transport.get(); - scoped_ptr<SSLClientSocket> sock = - CreateSSLClientSocket(transport.PassAs<StreamSocket>(), - test_server()->host_port_pair(), - client_config); + scoped_ptr<SSLClientSocket> sock = CreateSSLClientSocket( + transport.Pass(), test_server()->host_port_pair(), client_config); + + if (monitor_handshake_callback_) { + sock->SetHandshakeCompletionCallback( + base::Bind(&SSLClientSocketTest::RecordCompletedHandshake, + base::Unretained(this))); + } // Connect. Stop before the client processes the first server leg // (ServerHello, etc.) @@ -808,6 +870,12 @@ class SSLClientSocketFalseStartTest : public SSLClientSocketTest { raw_transport->UnblockReadResult(); raw_transport->WaitForWrite(); + if (fail_handshake_after_false_start_) { + SynchronousErrorStreamSocket* error_socket = + static_cast<SynchronousErrorStreamSocket*>( + raw_transport->transport()); + error_socket->SetNextReadError(ERR_CONNECTION_RESET); + } // And, finally, release that and block the next server leg // (ChangeCipherSpec, Finished). raw_transport->BlockReadResult(); @@ -825,6 +893,7 @@ class SSLClientSocketFalseStartTest : public SSLClientSocketTest { TestCompletionCallback callback; FakeBlockingStreamSocket* raw_transport = NULL; scoped_ptr<SSLClientSocket> sock; + ASSERT_NO_FATAL_FAILURE(CreateAndConnectUntilServerFinishedReceived( client_config, &callback, &raw_transport, &sock)); @@ -859,7 +928,10 @@ class SSLClientSocketFalseStartTest : public SSLClientSocketTest { // After releasing reads, the connection proceeds. raw_transport->UnblockReadResult(); rv = callback.GetResult(rv); - EXPECT_LT(0, rv); + if (fail_handshake_after_false_start_) + EXPECT_EQ(ERR_CONNECTION_RESET, rv); + else + EXPECT_LT(0, rv); } else { // False Start is not enabled, so the handshake will not complete because // the server second leg is blocked. @@ -867,25 +939,39 @@ class SSLClientSocketFalseStartTest : public SSLClientSocketTest { EXPECT_FALSE(callback.have_result()); } } + + // Indicates that the socket's handshake completion callback should + // be monitored. + bool monitor_handshake_callback_; + // Indicates that this test's handshake should fail after the client + // "finished" message is sent. + bool fail_handshake_after_false_start_; }; class SSLClientSocketChannelIDTest : public SSLClientSocketTest { protected: void EnableChannelID() { - cert_service_.reset( - new ServerBoundCertService(new DefaultServerBoundCertStore(NULL), - base::MessageLoopProxy::current())); - context_.server_bound_cert_service = cert_service_.get(); + channel_id_service_.reset( + new ChannelIDService(new DefaultChannelIDStore(NULL), + base::MessageLoopProxy::current())); + context_.channel_id_service = channel_id_service_.get(); } void EnableFailingChannelID() { - cert_service_.reset(new ServerBoundCertService( - new FailingServerBoundCertStore(), base::MessageLoopProxy::current())); - context_.server_bound_cert_service = cert_service_.get(); + channel_id_service_.reset(new ChannelIDService( + new FailingChannelIDStore(), base::MessageLoopProxy::current())); + context_.channel_id_service = channel_id_service_.get(); + } + + void EnableAsyncFailingChannelID() { + channel_id_service_.reset(new ChannelIDService( + new AsyncFailingChannelIDStore(), + base::MessageLoopProxy::current())); + context_.channel_id_service = channel_id_service_.get(); } private: - scoped_ptr<ServerBoundCertService> cert_service_; + scoped_ptr<ChannelIDService> channel_id_service_; }; //----------------------------------------------------------------------------- @@ -1207,6 +1293,41 @@ TEST_F(SSLClientSocketTest, Read) { } } +// Tests that SSLClientSocket properly handles when the underlying transport +// synchronously fails a transport read in during the handshake. The error code +// should be preserved so SSLv3 fallback logic can condition on it. +TEST_F(SSLClientSocketTest, Connect_WithSynchronousError) { + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); + ASSERT_TRUE(test_server.Start()); + + AddressList addr; + ASSERT_TRUE(test_server.GetAddressList(&addr)); + + TestCompletionCallback callback; + scoped_ptr<StreamSocket> real_transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); + scoped_ptr<SynchronousErrorStreamSocket> transport( + new SynchronousErrorStreamSocket(real_transport.Pass())); + int rv = callback.GetResult(transport->Connect(callback.callback())); + EXPECT_EQ(OK, rv); + + // Disable TLS False Start to avoid handshake non-determinism. + SSLConfig ssl_config; + ssl_config.false_start_enabled = false; + + SynchronousErrorStreamSocket* raw_transport = transport.get(); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), ssl_config)); + + raw_transport->SetNextWriteError(ERR_CONNECTION_RESET); + + rv = callback.GetResult(sock->Connect(callback.callback())); + EXPECT_EQ(ERR_CONNECTION_RESET, rv); + EXPECT_FALSE(sock->IsConnected()); +} + // Tests that the SSLClientSocket properly handles when the underlying transport // synchronously returns an error code - such as if an intermediary terminates // the socket connection uncleanly. @@ -1233,10 +1354,8 @@ TEST_F(SSLClientSocketTest, Read_WithSynchronousError) { ssl_config.false_start_enabled = false; SynchronousErrorStreamSocket* raw_transport = transport.get(); - scoped_ptr<SSLClientSocket> sock( - CreateSSLClientSocket(transport.PassAs<StreamSocket>(), - test_server.host_port_pair(), - ssl_config)); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), ssl_config)); rv = callback.GetResult(sock->Connect(callback.callback())); EXPECT_EQ(OK, rv); @@ -1261,14 +1380,7 @@ TEST_F(SSLClientSocketTest, Read_WithSynchronousError) { // rv != ERR_IO_PENDING is insufficient, as ERR_IO_PENDING is a legitimate // result when using a dedicated task runner for NSS. rv = callback.GetResult(sock->Read(buf.get(), 4096, callback.callback())); - -#if !defined(USE_OPENSSL) - // SSLClientSocketNSS records the error exactly EXPECT_EQ(ERR_CONNECTION_RESET, rv); -#else - // SSLClientSocketOpenSSL treats any errors as a simple EOF. - EXPECT_EQ(0, rv); -#endif } // Tests that the SSLClientSocket properly handles when the underlying transport @@ -1293,7 +1405,7 @@ TEST_F(SSLClientSocketTest, Write_WithSynchronousError) { new SynchronousErrorStreamSocket(real_transport.Pass())); SynchronousErrorStreamSocket* raw_error_socket = error_socket.get(); scoped_ptr<FakeBlockingStreamSocket> transport( - new FakeBlockingStreamSocket(error_socket.PassAs<StreamSocket>())); + new FakeBlockingStreamSocket(error_socket.Pass())); FakeBlockingStreamSocket* raw_transport = transport.get(); int rv = callback.GetResult(transport->Connect(callback.callback())); EXPECT_EQ(OK, rv); @@ -1302,10 +1414,8 @@ TEST_F(SSLClientSocketTest, Write_WithSynchronousError) { SSLConfig ssl_config; ssl_config.false_start_enabled = false; - scoped_ptr<SSLClientSocket> sock( - CreateSSLClientSocket(transport.PassAs<StreamSocket>(), - test_server.host_port_pair(), - ssl_config)); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), ssl_config)); rv = callback.GetResult(sock->Connect(callback.callback())); EXPECT_EQ(OK, rv); @@ -1342,14 +1452,7 @@ TEST_F(SSLClientSocketTest, Write_WithSynchronousError) { // checking that rv != ERR_IO_PENDING is insufficient, as ERR_IO_PENDING // is a legitimate result when using a dedicated task runner for NSS. rv = callback.GetResult(rv); - -#if !defined(USE_OPENSSL) - // SSLClientSocketNSS records the error exactly EXPECT_EQ(ERR_CONNECTION_RESET, rv); -#else - // SSLClientSocketOpenSSL treats any errors as a simple EOF. - EXPECT_EQ(0, rv); -#endif } // If there is a Write failure at the transport with no follow-up Read, although @@ -1374,7 +1477,7 @@ TEST_F(SSLClientSocketTest, Write_WithSynchronousErrorNoRead) { new SynchronousErrorStreamSocket(real_transport.Pass())); SynchronousErrorStreamSocket* raw_error_socket = error_socket.get(); scoped_ptr<CountingStreamSocket> counting_socket( - new CountingStreamSocket(error_socket.PassAs<StreamSocket>())); + new CountingStreamSocket(error_socket.Pass())); CountingStreamSocket* raw_counting_socket = counting_socket.get(); int rv = callback.GetResult(counting_socket->Connect(callback.callback())); ASSERT_EQ(OK, rv); @@ -1383,10 +1486,8 @@ TEST_F(SSLClientSocketTest, Write_WithSynchronousErrorNoRead) { SSLConfig ssl_config; ssl_config.false_start_enabled = false; - scoped_ptr<SSLClientSocket> sock( - CreateSSLClientSocket(counting_socket.PassAs<StreamSocket>(), - test_server.host_port_pair(), - ssl_config)); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + counting_socket.Pass(), test_server.host_port_pair(), ssl_config)); rv = callback.GetResult(sock->Connect(callback.callback())); ASSERT_EQ(OK, rv); @@ -1504,7 +1605,7 @@ TEST_F(SSLClientSocketTest, Read_DeleteWhilePendingFullDuplex) { new SynchronousErrorStreamSocket(real_transport.Pass())); SynchronousErrorStreamSocket* raw_error_socket = error_socket.get(); scoped_ptr<FakeBlockingStreamSocket> transport( - new FakeBlockingStreamSocket(error_socket.PassAs<StreamSocket>())); + new FakeBlockingStreamSocket(error_socket.Pass())); FakeBlockingStreamSocket* raw_transport = transport.get(); int rv = callback.GetResult(transport->Connect(callback.callback())); @@ -1514,10 +1615,8 @@ TEST_F(SSLClientSocketTest, Read_DeleteWhilePendingFullDuplex) { SSLConfig ssl_config; ssl_config.false_start_enabled = false; - scoped_ptr<SSLClientSocket> sock = - CreateSSLClientSocket(transport.PassAs<StreamSocket>(), - test_server.host_port_pair(), - ssl_config); + scoped_ptr<SSLClientSocket> sock = CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), ssl_config); rv = callback.GetResult(sock->Connect(callback.callback())); EXPECT_EQ(OK, rv); @@ -1592,14 +1691,7 @@ TEST_F(SSLClientSocketTest, Read_DeleteWhilePendingFullDuplex) { raw_transport->UnblockWrite(); rv = read_callback.WaitForResult(); - -#if !defined(USE_OPENSSL) - // NSS records the error exactly. EXPECT_EQ(ERR_CONNECTION_RESET, rv); -#else - // OpenSSL treats any errors as a simple EOF. - EXPECT_EQ(0, rv); -#endif // The Write callback should not have been called. EXPECT_FALSE(callback.have_result()); @@ -1627,7 +1719,7 @@ TEST_F(SSLClientSocketTest, Read_WithWriteError) { new SynchronousErrorStreamSocket(real_transport.Pass())); SynchronousErrorStreamSocket* raw_error_socket = error_socket.get(); scoped_ptr<FakeBlockingStreamSocket> transport( - new FakeBlockingStreamSocket(error_socket.PassAs<StreamSocket>())); + new FakeBlockingStreamSocket(error_socket.Pass())); FakeBlockingStreamSocket* raw_transport = transport.get(); int rv = callback.GetResult(transport->Connect(callback.callback())); @@ -1637,10 +1729,8 @@ TEST_F(SSLClientSocketTest, Read_WithWriteError) { SSLConfig ssl_config; ssl_config.false_start_enabled = false; - scoped_ptr<SSLClientSocket> sock( - CreateSSLClientSocket(transport.PassAs<StreamSocket>(), - test_server.host_port_pair(), - ssl_config)); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), ssl_config)); rv = callback.GetResult(sock->Connect(callback.callback())); EXPECT_EQ(OK, rv); @@ -1694,21 +1784,141 @@ TEST_F(SSLClientSocketTest, Read_WithWriteError) { } } while (rv > 0); -#if !defined(USE_OPENSSL) - // NSS records the error exactly. EXPECT_EQ(ERR_CONNECTION_RESET, rv); -#else - // OpenSSL treats the reset as a generic protocol error. - EXPECT_EQ(ERR_SSL_PROTOCOL_ERROR, rv); -#endif - // Release the read. Some bytes should go through. + // Release the read. raw_transport->UnblockReadResult(); rv = read_callback.WaitForResult(); - // Per the fix for http://crbug.com/249848, write failures currently break - // reads. Change this assertion if they're changed to not collide. +#if defined(USE_OPENSSL) + // Should still read bytes despite the write error. + EXPECT_LT(0, rv); +#else + // NSS attempts to flush the write buffer in PR_Read on an SSL socket before + // pumping the read state machine, unless configured with SSL_ENABLE_FDX, so + // the write error stops future reads. EXPECT_EQ(ERR_CONNECTION_RESET, rv); +#endif +} + +// Tests that SSLClientSocket fails the handshake if the underlying +// transport is cleanly closed. +TEST_F(SSLClientSocketTest, Connect_WithZeroReturn) { + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); + ASSERT_TRUE(test_server.Start()); + + AddressList addr; + ASSERT_TRUE(test_server.GetAddressList(&addr)); + + TestCompletionCallback callback; + scoped_ptr<StreamSocket> real_transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); + scoped_ptr<SynchronousErrorStreamSocket> transport( + new SynchronousErrorStreamSocket(real_transport.Pass())); + int rv = callback.GetResult(transport->Connect(callback.callback())); + EXPECT_EQ(OK, rv); + + SynchronousErrorStreamSocket* raw_transport = transport.get(); + scoped_ptr<SSLClientSocket> sock( + CreateSSLClientSocket(transport.Pass(), + test_server.host_port_pair(), + kDefaultSSLConfig)); + + raw_transport->SetNextReadError(0); + + rv = callback.GetResult(sock->Connect(callback.callback())); + EXPECT_EQ(ERR_CONNECTION_CLOSED, rv); + EXPECT_FALSE(sock->IsConnected()); +} + +// Tests that SSLClientSocket cleanly returns a Read of size 0 if the +// underlying socket is cleanly closed. +// This is a regression test for https://crbug.com/422246 +TEST_F(SSLClientSocketTest, Read_WithZeroReturn) { + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); + ASSERT_TRUE(test_server.Start()); + + AddressList addr; + ASSERT_TRUE(test_server.GetAddressList(&addr)); + + TestCompletionCallback callback; + scoped_ptr<StreamSocket> real_transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); + scoped_ptr<SynchronousErrorStreamSocket> transport( + new SynchronousErrorStreamSocket(real_transport.Pass())); + int rv = callback.GetResult(transport->Connect(callback.callback())); + EXPECT_EQ(OK, rv); + + // Disable TLS False Start to ensure the handshake has completed. + SSLConfig ssl_config; + ssl_config.false_start_enabled = false; + + SynchronousErrorStreamSocket* raw_transport = transport.get(); + scoped_ptr<SSLClientSocket> sock( + CreateSSLClientSocket(transport.Pass(), + test_server.host_port_pair(), + ssl_config)); + + rv = callback.GetResult(sock->Connect(callback.callback())); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(sock->IsConnected()); + + raw_transport->SetNextReadError(0); + scoped_refptr<IOBuffer> buf(new IOBuffer(4096)); + rv = callback.GetResult(sock->Read(buf.get(), 4096, callback.callback())); + EXPECT_EQ(0, rv); +} + +// Tests that SSLClientSocket cleanly returns a Read of size 0 if the +// underlying socket is cleanly closed asynchronously. +// This is a regression test for https://crbug.com/422246 +TEST_F(SSLClientSocketTest, Read_WithAsyncZeroReturn) { + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); + ASSERT_TRUE(test_server.Start()); + + AddressList addr; + ASSERT_TRUE(test_server.GetAddressList(&addr)); + + TestCompletionCallback callback; + scoped_ptr<StreamSocket> real_transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); + scoped_ptr<SynchronousErrorStreamSocket> error_socket( + new SynchronousErrorStreamSocket(real_transport.Pass())); + SynchronousErrorStreamSocket* raw_error_socket = error_socket.get(); + scoped_ptr<FakeBlockingStreamSocket> transport( + new FakeBlockingStreamSocket(error_socket.Pass())); + FakeBlockingStreamSocket* raw_transport = transport.get(); + int rv = callback.GetResult(transport->Connect(callback.callback())); + EXPECT_EQ(OK, rv); + + // Disable TLS False Start to ensure the handshake has completed. + SSLConfig ssl_config; + ssl_config.false_start_enabled = false; + + scoped_ptr<SSLClientSocket> sock( + CreateSSLClientSocket(transport.Pass(), + test_server.host_port_pair(), + ssl_config)); + + rv = callback.GetResult(sock->Connect(callback.callback())); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(sock->IsConnected()); + + raw_error_socket->SetNextReadError(0); + raw_transport->BlockReadResult(); + scoped_refptr<IOBuffer> buf(new IOBuffer(4096)); + rv = sock->Read(buf.get(), 4096, callback.callback()); + EXPECT_EQ(ERR_IO_PENDING, rv); + + raw_transport->UnblockReadResult(); + rv = callback.GetResult(rv); + EXPECT_EQ(0, rv); } TEST_F(SSLClientSocketTest, Read_SmallChunks) { @@ -1782,10 +1992,8 @@ TEST_F(SSLClientSocketTest, Read_ManySmallRecords) { int rv = callback.GetResult(transport->Connect(callback.callback())); ASSERT_EQ(OK, rv); - scoped_ptr<SSLClientSocket> sock( - CreateSSLClientSocket(transport.PassAs<StreamSocket>(), - test_server.host_port_pair(), - kDefaultSSLConfig)); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); rv = callback.GetResult(sock->Connect(callback.callback())); ASSERT_EQ(OK, rv); @@ -2251,7 +2459,7 @@ TEST_F(SSLClientSocketTest, VerifyReturnChainProperlyOrdered) { // Load and install the root for the validated chain. scoped_refptr<X509Certificate> root_cert = ImportCertFromFile( GetTestCertsDirectory(), "redundant-validated-chain-root.pem"); - ASSERT_NE(static_cast<X509Certificate*>(NULL), root_cert); + ASSERT_NE(static_cast<X509Certificate*>(NULL), root_cert.get()); ScopedTestRoot scoped_root(root_cert.get()); // Set up a test server with CERT_CHAIN_WRONG_ROOT. @@ -2389,45 +2597,45 @@ TEST_F(SSLClientSocketTest, ConnectSignedCertTimestampsEnabledTLSExtension) { ASSERT_TRUE(test_server.GetAddressList(&addr)); TestCompletionCallback callback; - CapturingNetLog log; scoped_ptr<StreamSocket> transport( - new TCPClientSocket(addr, &log, NetLog::Source())); - int rv = transport->Connect(callback.callback()); - if (rv == ERR_IO_PENDING) - rv = callback.WaitForResult(); + new TCPClientSocket(addr, &log_, NetLog::Source())); + int rv = callback.GetResult(transport->Connect(callback.callback())); EXPECT_EQ(OK, rv); SSLConfig ssl_config; ssl_config.signed_cert_timestamps_enabled = true; - scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport.Pass(), test_server.host_port_pair(), ssl_config)); + MockCTVerifier ct_verifier; + SetCTVerifier(&ct_verifier); - EXPECT_FALSE(sock->IsConnected()); + // Check that the SCT list is extracted as expected. + EXPECT_CALL(ct_verifier, Verify(_, "", "test", _, _)).WillRepeatedly( + Return(ERR_CT_NO_SCTS_VERIFIED_OK)); - rv = sock->Connect(callback.callback()); - - CapturingNetLog::CapturedEntryList entries; - log.GetEntries(&entries); - EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT)); - if (rv == ERR_IO_PENDING) - rv = callback.WaitForResult(); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), ssl_config)); + rv = callback.GetResult(sock->Connect(callback.callback())); EXPECT_EQ(OK, rv); - EXPECT_TRUE(sock->IsConnected()); - log.GetEntries(&entries); - EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1)); -#if !defined(USE_OPENSSL) EXPECT_TRUE(sock->signed_cert_timestamps_received_); -#else - // Enabling CT for OpenSSL is currently a noop. - EXPECT_FALSE(sock->signed_cert_timestamps_received_); -#endif +} - sock->Disconnect(); - EXPECT_FALSE(sock->IsConnected()); +namespace { + +bool IsValidOCSPResponse(const base::StringPiece& input) { + base::StringPiece ocsp_response = input; + base::StringPiece sequence, response_status, response_bytes; + return asn1::GetElement(&ocsp_response, asn1::kSEQUENCE, &sequence) && + ocsp_response.empty() && + asn1::GetElement(&sequence, asn1::kENUMERATED, &response_status) && + asn1::GetElement(&sequence, + asn1::kContextSpecific | asn1::kConstructed | 0, + &response_status) && + sequence.empty(); } +} // namespace + // Test that enabling Signed Certificate Timestamps enables OCSP stapling. TEST_F(SSLClientSocketTest, ConnectSignedCertTimestampsEnabledOCSP) { SpawnedTestServer::SSLOptions ssl_options; @@ -2445,12 +2653,9 @@ TEST_F(SSLClientSocketTest, ConnectSignedCertTimestampsEnabledOCSP) { ASSERT_TRUE(test_server.GetAddressList(&addr)); TestCompletionCallback callback; - CapturingNetLog log; scoped_ptr<StreamSocket> transport( - new TCPClientSocket(addr, &log, NetLog::Source())); - int rv = transport->Connect(callback.callback()); - if (rv == ERR_IO_PENDING) - rv = callback.WaitForResult(); + new TCPClientSocket(addr, &log_, NetLog::Source())); + int rv = callback.GetResult(transport->Connect(callback.callback())); EXPECT_EQ(OK, rv); SSLConfig ssl_config; @@ -2459,32 +2664,24 @@ TEST_F(SSLClientSocketTest, ConnectSignedCertTimestampsEnabledOCSP) { // is able to process the OCSP status itself. ssl_config.signed_cert_timestamps_enabled = true; - scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport.Pass(), test_server.host_port_pair(), ssl_config)); - - EXPECT_FALSE(sock->IsConnected()); + MockCTVerifier ct_verifier; + SetCTVerifier(&ct_verifier); - rv = sock->Connect(callback.callback()); + // Check that the OCSP response is extracted and well-formed. It should be the + // DER encoding of an OCSPResponse (RFC 2560), so check that it consists of a + // SEQUENCE of an ENUMERATED type and an element tagged with [0] EXPLICIT. In + // particular, it should not include the overall two-byte length prefix from + // TLS. + EXPECT_CALL(ct_verifier, + Verify(_, Truly(IsValidOCSPResponse), "", _, _)).WillRepeatedly( + Return(ERR_CT_NO_SCTS_VERIFIED_OK)); - CapturingNetLog::CapturedEntryList entries; - log.GetEntries(&entries); - EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT)); - if (rv == ERR_IO_PENDING) - rv = callback.WaitForResult(); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), ssl_config)); + rv = callback.GetResult(sock->Connect(callback.callback())); EXPECT_EQ(OK, rv); - EXPECT_TRUE(sock->IsConnected()); - log.GetEntries(&entries); - EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1)); -#if !defined(USE_OPENSSL) EXPECT_TRUE(sock->stapled_ocsp_response_received_); -#else - // OCSP stapling isn't currently supported in the OpenSSL socket. - EXPECT_FALSE(sock->stapled_ocsp_response_received_); -#endif - - sock->Disconnect(); - EXPECT_FALSE(sock->IsConnected()); } TEST_F(SSLClientSocketTest, ConnectSignedCertTimestampsDisabled) { @@ -2500,12 +2697,9 @@ TEST_F(SSLClientSocketTest, ConnectSignedCertTimestampsDisabled) { ASSERT_TRUE(test_server.GetAddressList(&addr)); TestCompletionCallback callback; - CapturingNetLog log; scoped_ptr<StreamSocket> transport( - new TCPClientSocket(addr, &log, NetLog::Source())); - int rv = transport->Connect(callback.callback()); - if (rv == ERR_IO_PENDING) - rv = callback.WaitForResult(); + new TCPClientSocket(addr, &log_, NetLog::Source())); + int rv = callback.GetResult(transport->Connect(callback.callback())); EXPECT_EQ(OK, rv); SSLConfig ssl_config; @@ -2513,25 +2707,10 @@ TEST_F(SSLClientSocketTest, ConnectSignedCertTimestampsDisabled) { scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( transport.Pass(), test_server.host_port_pair(), ssl_config)); - - EXPECT_FALSE(sock->IsConnected()); - - rv = sock->Connect(callback.callback()); - - CapturingNetLog::CapturedEntryList entries; - log.GetEntries(&entries); - EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT)); - if (rv == ERR_IO_PENDING) - rv = callback.WaitForResult(); + rv = callback.GetResult(sock->Connect(callback.callback())); EXPECT_EQ(OK, rv); - EXPECT_TRUE(sock->IsConnected()); - log.GetEntries(&entries); - EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1)); EXPECT_FALSE(sock->signed_cert_timestamps_received_); - - sock->Disconnect(); - EXPECT_FALSE(sock->IsConnected()); } // Tests that IsConnectedAndIdle and WasEverUsed behave as expected. @@ -2590,6 +2769,148 @@ TEST_F(SSLClientSocketTest, ReuseStates) { // attempt to read one byte extra. } +#if defined(USE_OPENSSL) + +TEST_F(SSLClientSocketTest, HandshakeCallbackIsRun_WithFailure) { + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); + ASSERT_TRUE(test_server.Start()); + + AddressList addr; + ASSERT_TRUE(test_server.GetAddressList(&addr)); + + TestCompletionCallback callback; + scoped_ptr<StreamSocket> real_transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); + scoped_ptr<SynchronousErrorStreamSocket> transport( + new SynchronousErrorStreamSocket(real_transport.Pass())); + int rv = callback.GetResult(transport->Connect(callback.callback())); + EXPECT_EQ(OK, rv); + + // Disable TLS False Start to avoid handshake non-determinism. + SSLConfig ssl_config; + ssl_config.false_start_enabled = false; + + SynchronousErrorStreamSocket* raw_transport = transport.get(); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), ssl_config)); + + sock->SetHandshakeCompletionCallback(base::Bind( + &SSLClientSocketTest::RecordCompletedHandshake, base::Unretained(this))); + + raw_transport->SetNextWriteError(ERR_CONNECTION_RESET); + + rv = callback.GetResult(sock->Connect(callback.callback())); + EXPECT_EQ(ERR_CONNECTION_RESET, rv); + EXPECT_FALSE(sock->IsConnected()); + + EXPECT_TRUE(ran_handshake_completion_callback_); +} + +// Tests that the completion callback is run when an SSL connection +// completes successfully. +TEST_F(SSLClientSocketTest, HandshakeCallbackIsRun_WithSuccess) { + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); + ASSERT_TRUE(test_server.Start()); + + AddressList addr; + ASSERT_TRUE(test_server.GetAddressList(&addr)); + + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); + + TestCompletionCallback callback; + int rv = transport->Connect(callback.callback()); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + + SSLConfig ssl_config = kDefaultSSLConfig; + ssl_config.false_start_enabled = false; + + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), ssl_config)); + + sock->SetHandshakeCompletionCallback(base::Bind( + &SSLClientSocketTest::RecordCompletedHandshake, base::Unretained(this))); + + rv = callback.GetResult(sock->Connect(callback.callback())); + + EXPECT_EQ(OK, rv); + EXPECT_TRUE(sock->IsConnected()); + EXPECT_TRUE(ran_handshake_completion_callback_); +} + +// Tests that the completion callback is run with a server that doesn't cache +// sessions. +TEST_F(SSLClientSocketTest, HandshakeCallbackIsRun_WithDisabledSessionCache) { + SpawnedTestServer::SSLOptions ssl_options; + ssl_options.disable_session_cache = true; + SpawnedTestServer test_server( + SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath()); + ASSERT_TRUE(test_server.Start()); + + AddressList addr; + ASSERT_TRUE(test_server.GetAddressList(&addr)); + + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); + + TestCompletionCallback callback; + int rv = transport->Connect(callback.callback()); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + EXPECT_EQ(OK, rv); + + SSLConfig ssl_config = kDefaultSSLConfig; + ssl_config.false_start_enabled = false; + + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), ssl_config)); + + sock->SetHandshakeCompletionCallback(base::Bind( + &SSLClientSocketTest::RecordCompletedHandshake, base::Unretained(this))); + + rv = callback.GetResult(sock->Connect(callback.callback())); + + EXPECT_EQ(OK, rv); + EXPECT_TRUE(sock->IsConnected()); + EXPECT_TRUE(ran_handshake_completion_callback_); +} + +TEST_F(SSLClientSocketFalseStartTest, + HandshakeCallbackIsRun_WithFalseStartFailure) { + // False Start requires NPN and a forward-secret cipher suite. + SpawnedTestServer::SSLOptions server_options; + server_options.key_exchanges = + SpawnedTestServer::SSLOptions::KEY_EXCHANGE_DHE_RSA; + server_options.enable_npn = true; + SSLConfig client_config; + client_config.next_protos.push_back("http/1.1"); + monitor_handshake_callback_ = true; + fail_handshake_after_false_start_ = true; + ASSERT_NO_FATAL_FAILURE(TestFalseStart(server_options, client_config, true)); + ASSERT_TRUE(ran_handshake_completion_callback_); +} + +TEST_F(SSLClientSocketFalseStartTest, + HandshakeCallbackIsRun_WithFalseStartSuccess) { + // False Start requires NPN and a forward-secret cipher suite. + SpawnedTestServer::SSLOptions server_options; + server_options.key_exchanges = + SpawnedTestServer::SSLOptions::KEY_EXCHANGE_DHE_RSA; + server_options.enable_npn = true; + SSLConfig client_config; + client_config.next_protos.push_back("http/1.1"); + monitor_handshake_callback_ = true; + ASSERT_NO_FATAL_FAILURE(TestFalseStart(server_options, client_config, true)); + ASSERT_TRUE(ran_handshake_completion_callback_); +} +#endif // defined(USE_OPENSSL) + TEST_F(SSLClientSocketFalseStartTest, FalseStartEnabled) { // False Start requires NPN and a forward-secret cipher suite. SpawnedTestServer::SSLOptions server_options; @@ -2718,8 +3039,8 @@ TEST_F(SSLClientSocketChannelIDTest, SendChannelID) { EXPECT_FALSE(sock_->IsConnected()); } -// Connect to a server using channel id but without sending a key. It should -// fail. +// Connect to a server using Channel ID but failing to look up the Channel +// ID. It should fail. TEST_F(SSLClientSocketChannelIDTest, FailingChannelID) { SpawnedTestServer::SSLOptions ssl_options; @@ -2740,4 +3061,22 @@ TEST_F(SSLClientSocketChannelIDTest, FailingChannelID) { EXPECT_FALSE(sock_->IsConnected()); } +// Connect to a server using Channel ID but asynchronously failing to look up +// the Channel ID. It should fail. +TEST_F(SSLClientSocketChannelIDTest, FailingChannelIDAsync) { + SpawnedTestServer::SSLOptions ssl_options; + + ASSERT_TRUE(ConnectToTestServer(ssl_options)); + + EnableAsyncFailingChannelID(); + SSLConfig ssl_config = kDefaultSSLConfig; + ssl_config.channel_id_enabled = true; + + int rv; + ASSERT_TRUE(CreateAndConnectSSLClientSocket(ssl_config, &rv)); + + EXPECT_EQ(ERR_UNEXPECTED, rv); + EXPECT_FALSE(sock_->IsConnected()); +} + } // namespace net diff --git a/chromium/net/socket/ssl_error_params.cc b/chromium/net/socket/ssl_error_params.cc deleted file mode 100644 index 37561f0de48..00000000000 --- a/chromium/net/socket/ssl_error_params.cc +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "net/socket/ssl_error_params.h" - -#include "base/bind.h" -#include "base/values.h" - -namespace net { - -namespace { - -base::Value* NetLogSSLErrorCallback(int net_error, - int ssl_lib_error, - NetLog::LogLevel /* log_level */) { - base::DictionaryValue* dict = new base::DictionaryValue(); - dict->SetInteger("net_error", net_error); - if (ssl_lib_error) - dict->SetInteger("ssl_lib_error", ssl_lib_error); - return dict; -} - -} // namespace - -NetLog::ParametersCallback CreateNetLogSSLErrorCallback(int net_error, - int ssl_lib_error) { - return base::Bind(&NetLogSSLErrorCallback, net_error, ssl_lib_error); -} - -} // namespace net diff --git a/chromium/net/socket/ssl_error_params.h b/chromium/net/socket/ssl_error_params.h deleted file mode 100644 index 07a1c4d99d9..00000000000 --- a/chromium/net/socket/ssl_error_params.h +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef NET_SOCKET_SSL_ERROR_PARAMS_H_ -#define NET_SOCKET_SSL_ERROR_PARAMS_H_ - -#include "net/base/net_log.h" - -namespace net { - -// Creates NetLog callback for when we receive an SSL error. -NetLog::ParametersCallback CreateNetLogSSLErrorCallback(int net_error, - int ssl_lib_error); - -} // namespace net - -#endif // NET_SOCKET_SSL_ERROR_PARAMS_H_ diff --git a/chromium/net/socket/ssl_server_socket.h b/chromium/net/socket/ssl_server_socket.h index 8b607bf80cf..88f7f941439 100644 --- a/chromium/net/socket/ssl_server_socket.h +++ b/chromium/net/socket/ssl_server_socket.h @@ -23,7 +23,7 @@ class X509Certificate; class SSLServerSocket : public SSLSocket { public: - virtual ~SSLServerSocket() {} + ~SSLServerSocket() override {} // Perform the SSL server handshake, and notify the supplied callback // if the process completes asynchronously. If Disconnect is called before diff --git a/chromium/net/socket/ssl_server_socket_nss.cc b/chromium/net/socket/ssl_server_socket_nss.cc index cfdc5c545ef..7fa5835b430 100644 --- a/chromium/net/socket/ssl_server_socket_nss.cc +++ b/chromium/net/socket/ssl_server_socket_nss.cc @@ -38,7 +38,6 @@ #include "net/base/net_errors.h" #include "net/base/net_log.h" #include "net/socket/nss_ssl_util.h" -#include "net/socket/ssl_error_params.h" // SSL plaintext fragments are shorter than 16KB. Although the record layer // overhead is allowed to be 2K + 5 bytes, in practice the overhead is much @@ -60,7 +59,7 @@ class NSSSSLServerInitSingleton { NSSSSLServerInitSingleton() { EnsureNSSSSLInit(); - SSL_ConfigServerSessionIDCache(1024, 5, 5, NULL); + SSL_ConfigServerSessionIDCache(64, 28800, 28800, NULL); g_nss_server_sockets_init = true; } @@ -85,7 +84,7 @@ scoped_ptr<SSLServerSocket> CreateSSLServerSocket( crypto::RSAPrivateKey* key, const SSLConfig& ssl_config) { DCHECK(g_nss_server_sockets_init) << "EnableSSLServerSockets() has not been" - << "called yet!"; + << " called yet!"; return scoped_ptr<SSLServerSocket>( new SSLServerSocketNSS(socket.Pass(), cert, key, ssl_config)); @@ -446,7 +445,7 @@ int SSLServerSocketNSS::InitializeSSLOptions() { } SECKEYPrivateKeyStr* private_key = NULL; - PK11SlotInfo* slot = crypto::GetPrivateNSSKeySlot(); + PK11SlotInfo* slot = PK11_GetInternalSlot(); if (!slot) { CERT_DestroyCertificate(cert); return ERR_UNEXPECTED; diff --git a/chromium/net/socket/ssl_server_socket_nss.h b/chromium/net/socket/ssl_server_socket_nss.h index bc5b65d5368..d40b096577c 100644 --- a/chromium/net/socket/ssl_server_socket_nss.h +++ b/chromium/net/socket/ssl_server_socket_nss.h @@ -28,42 +28,44 @@ class SSLServerSocketNSS : public SSLServerSocket { scoped_refptr<X509Certificate> certificate, crypto::RSAPrivateKey* key, const SSLConfig& ssl_config); - virtual ~SSLServerSocketNSS(); + ~SSLServerSocketNSS() override; // SSLServerSocket interface. - virtual int Handshake(const CompletionCallback& callback) OVERRIDE; + int Handshake(const CompletionCallback& callback) override; // SSLSocket interface. - virtual int ExportKeyingMaterial(const base::StringPiece& label, - bool has_context, - const base::StringPiece& context, - unsigned char* out, - unsigned int outlen) OVERRIDE; - virtual int GetTLSUniqueChannelBinding(std::string* out) OVERRIDE; + int ExportKeyingMaterial(const base::StringPiece& label, + bool has_context, + const base::StringPiece& context, + unsigned char* out, + unsigned int outlen) override; + int GetTLSUniqueChannelBinding(std::string* out) override; // Socket interface (via StreamSocket). - virtual int Read(IOBuffer* buf, int buf_len, - const CompletionCallback& callback) OVERRIDE; - virtual int Write(IOBuffer* buf, int buf_len, - const CompletionCallback& callback) OVERRIDE; - virtual int SetReceiveBufferSize(int32 size) OVERRIDE; - virtual int SetSendBufferSize(int32 size) OVERRIDE; + int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override; + int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override; + int SetReceiveBufferSize(int32 size) override; + int SetSendBufferSize(int32 size) override; // StreamSocket implementation. - virtual int Connect(const CompletionCallback& callback) OVERRIDE; - virtual void Disconnect() OVERRIDE; - virtual bool IsConnected() const OVERRIDE; - virtual bool IsConnectedAndIdle() const OVERRIDE; - virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE; - virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE; - virtual const BoundNetLog& NetLog() const OVERRIDE; - virtual void SetSubresourceSpeculation() OVERRIDE; - virtual void SetOmniboxSpeculation() OVERRIDE; - virtual bool WasEverUsed() const OVERRIDE; - virtual bool UsingTCPFastOpen() const OVERRIDE; - virtual bool WasNpnNegotiated() const OVERRIDE; - virtual NextProto GetNegotiatedProtocol() const OVERRIDE; - virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; + int Connect(const CompletionCallback& callback) override; + void Disconnect() override; + bool IsConnected() const override; + bool IsConnectedAndIdle() const override; + int GetPeerAddress(IPEndPoint* address) const override; + int GetLocalAddress(IPEndPoint* address) const override; + const BoundNetLog& NetLog() const override; + void SetSubresourceSpeculation() override; + void SetOmniboxSpeculation() override; + bool WasEverUsed() const override; + bool UsingTCPFastOpen() const override; + bool WasNpnNegotiated() const override; + NextProto GetNegotiatedProtocol() const override; + bool GetSSLInfo(SSLInfo* ssl_info) override; private: enum State { diff --git a/chromium/net/socket/ssl_server_socket_openssl.cc b/chromium/net/socket/ssl_server_socket_openssl.cc index f6bd0cd8f81..29d1ffa0508 100644 --- a/chromium/net/socket/ssl_server_socket_openssl.cc +++ b/chromium/net/socket/ssl_server_socket_openssl.cc @@ -11,9 +11,9 @@ #include "base/logging.h" #include "crypto/openssl_util.h" #include "crypto/rsa_private_key.h" +#include "crypto/scoped_openssl_types.h" #include "net/base/net_errors.h" -#include "net/socket/openssl_ssl_util.h" -#include "net/socket/ssl_error_params.h" +#include "net/ssl/openssl_ssl_util.h" #define GotoState(s) next_handshake_state_ = s @@ -304,7 +304,7 @@ int SSLServerSocketOpenSSL::BufferSend() { if (!send_buffer_.get()) { // Get a fresh send buffer out of the send BIO. - size_t max_read = BIO_ctrl_pending(transport_bio_); + size_t max_read = BIO_pending(transport_bio_); if (!max_read) return 0; // Nothing pending in the OpenSSL write BIO. send_buffer_ = new DrainableIOBuffer(new IOBuffer(max_read), max_read); @@ -456,10 +456,13 @@ int SSLServerSocketOpenSSL::DoPayloadRead() { if (rv >= 0) return rv; int ssl_error = SSL_get_error(ssl_, rv); - int net_error = MapOpenSSLError(ssl_error, err_tracer); + OpenSSLErrorInfo error_info; + int net_error = MapOpenSSLErrorWithDetails(ssl_error, err_tracer, + &error_info); if (net_error != ERR_IO_PENDING) { - net_log_.AddEvent(NetLog::TYPE_SSL_READ_ERROR, - CreateNetLogSSLErrorCallback(net_error, ssl_error)); + net_log_.AddEvent( + NetLog::TYPE_SSL_READ_ERROR, + CreateNetLogOpenSSLErrorCallback(net_error, ssl_error, error_info)); } return net_error; } @@ -471,10 +474,13 @@ int SSLServerSocketOpenSSL::DoPayloadWrite() { if (rv >= 0) return rv; int ssl_error = SSL_get_error(ssl_, rv); - int net_error = MapOpenSSLError(ssl_error, err_tracer); + OpenSSLErrorInfo error_info; + int net_error = MapOpenSSLErrorWithDetails(ssl_error, err_tracer, + &error_info); if (net_error != ERR_IO_PENDING) { - net_log_.AddEvent(NetLog::TYPE_SSL_WRITE_ERROR, - CreateNetLogSSLErrorCallback(net_error, ssl_error)); + net_log_.AddEvent( + NetLog::TYPE_SSL_WRITE_ERROR, + CreateNetLogOpenSSLErrorCallback(net_error, ssl_error, error_info)); } return net_error; } @@ -553,7 +559,8 @@ int SSLServerSocketOpenSSL::DoHandshake() { completed_handshake_ = true; } else { int ssl_error = SSL_get_error(ssl_, rv); - net_error = MapOpenSSLError(ssl_error, err_tracer); + OpenSSLErrorInfo error_info; + net_error = MapOpenSSLErrorWithDetails(ssl_error, err_tracer, &error_info); // If not done, stay in this state if (net_error == ERR_IO_PENDING) { @@ -562,8 +569,9 @@ int SSLServerSocketOpenSSL::DoHandshake() { LOG(ERROR) << "handshake failed; returned " << rv << ", SSL error code " << ssl_error << ", net_error " << net_error; - net_log_.AddEvent(NetLog::TYPE_SSL_HANDSHAKE_ERROR, - CreateNetLogSSLErrorCallback(net_error, ssl_error)); + net_log_.AddEvent( + NetLog::TYPE_SSL_HANDSHAKE_ERROR, + CreateNetLogOpenSSLErrorCallback(net_error, ssl_error, error_info)); } } return net_error; @@ -598,7 +606,7 @@ int SSLServerSocketOpenSSL::Init() { crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE); - crypto::ScopedOpenSSL<SSL_CTX, SSL_CTX_free> ssl_ctx( + crypto::ScopedOpenSSL<SSL_CTX, SSL_CTX_free>::Type ssl_ctx( // It support SSLv2, SSLv3, and TLSv1. SSL_CTX_new(SSLv23_server_method())); ssl_ = SSL_new(ssl_ctx.get()); @@ -630,8 +638,8 @@ int SSLServerSocketOpenSSL::Init() { const unsigned char* der_string_array = reinterpret_cast<const unsigned char*>(der_string.data()); - crypto::ScopedOpenSSL<X509, X509_free> - x509(d2i_X509(NULL, &der_string_array, der_string.length())); + crypto::ScopedOpenSSL<X509, X509_free>::Type x509( + d2i_X509(NULL, &der_string_array, der_string.length())); if (!x509.get()) return ERR_UNEXPECTED; diff --git a/chromium/net/socket/ssl_server_socket_openssl.h b/chromium/net/socket/ssl_server_socket_openssl.h index e1c8aad7aea..c58bd569352 100644 --- a/chromium/net/socket/ssl_server_socket_openssl.h +++ b/chromium/net/socket/ssl_server_socket_openssl.h @@ -30,42 +30,44 @@ class SSLServerSocketOpenSSL : public SSLServerSocket { scoped_refptr<X509Certificate> certificate, crypto::RSAPrivateKey* key, const SSLConfig& ssl_config); - virtual ~SSLServerSocketOpenSSL(); + ~SSLServerSocketOpenSSL() override; // SSLServerSocket interface. - virtual int Handshake(const CompletionCallback& callback) OVERRIDE; + int Handshake(const CompletionCallback& callback) override; // SSLSocket interface. - virtual int ExportKeyingMaterial(const base::StringPiece& label, - bool has_context, - const base::StringPiece& context, - unsigned char* out, - unsigned int outlen) OVERRIDE; - virtual int GetTLSUniqueChannelBinding(std::string* out) OVERRIDE; + int ExportKeyingMaterial(const base::StringPiece& label, + bool has_context, + const base::StringPiece& context, + unsigned char* out, + unsigned int outlen) override; + int GetTLSUniqueChannelBinding(std::string* out) override; // Socket interface (via StreamSocket). - virtual int Read(IOBuffer* buf, int buf_len, - const CompletionCallback& callback) OVERRIDE; - virtual int Write(IOBuffer* buf, int buf_len, - const CompletionCallback& callback) OVERRIDE; - virtual int SetReceiveBufferSize(int32 size) OVERRIDE; - virtual int SetSendBufferSize(int32 size) OVERRIDE; + int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override; + int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override; + int SetReceiveBufferSize(int32 size) override; + int SetSendBufferSize(int32 size) override; // StreamSocket implementation. - virtual int Connect(const CompletionCallback& callback) OVERRIDE; - virtual void Disconnect() OVERRIDE; - virtual bool IsConnected() const OVERRIDE; - virtual bool IsConnectedAndIdle() const OVERRIDE; - virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE; - virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE; - virtual const BoundNetLog& NetLog() const OVERRIDE; - virtual void SetSubresourceSpeculation() OVERRIDE; - virtual void SetOmniboxSpeculation() OVERRIDE; - virtual bool WasEverUsed() const OVERRIDE; - virtual bool UsingTCPFastOpen() const OVERRIDE; - virtual bool WasNpnNegotiated() const OVERRIDE; - virtual NextProto GetNegotiatedProtocol() const OVERRIDE; - virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; + int Connect(const CompletionCallback& callback) override; + void Disconnect() override; + bool IsConnected() const override; + bool IsConnectedAndIdle() const override; + int GetPeerAddress(IPEndPoint* address) const override; + int GetLocalAddress(IPEndPoint* address) const override; + const BoundNetLog& NetLog() const override; + void SetSubresourceSpeculation() override; + void SetOmniboxSpeculation() override; + bool WasEverUsed() const override; + bool UsingTCPFastOpen() const override; + bool WasNpnNegotiated() const override; + NextProto GetNegotiatedProtocol() const override; + bool GetSSLInfo(SSLInfo* ssl_info) override; private: enum State { diff --git a/chromium/net/socket/ssl_server_socket_unittest.cc b/chromium/net/socket/ssl_server_socket_unittest.cc index 9da10f6bbd0..5fabdd2e90e 100644 --- a/chromium/net/socket/ssl_server_socket_unittest.cc +++ b/chromium/net/socket/ssl_server_socket_unittest.cc @@ -20,10 +20,9 @@ #include <queue> #include "base/compiler_specific.h" -#include "base/file_util.h" #include "base/files/file_path.h" +#include "base/files/file_util.h" #include "base/message_loop/message_loop.h" -#include "base/path_service.h" #include "crypto/nss_util.h" #include "crypto/rsa_private_key.h" #include "net/base/address_list.h" @@ -62,29 +61,35 @@ class FakeDataChannel { } int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback) { + DCHECK(read_callback_.is_null()); + DCHECK(!read_buf_.get()); if (closed_) return 0; if (data_.empty()) { read_callback_ = callback; read_buf_ = buf; read_buf_len_ = buf_len; - return net::ERR_IO_PENDING; + return ERR_IO_PENDING; } return PropogateData(buf, buf_len); } int Write(IOBuffer* buf, int buf_len, const CompletionCallback& callback) { + DCHECK(write_callback_.is_null()); if (closed_) { if (write_called_after_close_) - return net::ERR_CONNECTION_RESET; + return ERR_CONNECTION_RESET; write_called_after_close_ = true; write_callback_ = callback; base::MessageLoop::current()->PostTask( FROM_HERE, base::Bind(&FakeDataChannel::DoWriteCallback, weak_factory_.GetWeakPtr())); - return net::ERR_IO_PENDING; + return ERR_IO_PENDING; } - data_.push(new net::DrainableIOBuffer(buf, buf_len)); + // This function returns synchronously, so make a copy of the buffer. + data_.push(new DrainableIOBuffer( + new StringIOBuffer(std::string(buf->data(), buf_len)), + buf_len)); base::MessageLoop::current()->PostTask( FROM_HERE, base::Bind(&FakeDataChannel::DoReadCallback, weak_factory_.GetWeakPtr())); @@ -118,11 +123,11 @@ class FakeDataChannel { CompletionCallback callback = write_callback_; write_callback_.Reset(); - callback.Run(net::ERR_CONNECTION_RESET); + callback.Run(ERR_CONNECTION_RESET); } - int PropogateData(scoped_refptr<net::IOBuffer> read_buf, int read_buf_len) { - scoped_refptr<net::DrainableIOBuffer> buf = data_.front(); + int PropogateData(scoped_refptr<IOBuffer> read_buf, int read_buf_len) { + scoped_refptr<DrainableIOBuffer> buf = data_.front(); int copied = std::min(buf->BytesRemaining(), read_buf_len); memcpy(read_buf->data(), buf->data(), copied); buf->DidConsume(copied); @@ -133,12 +138,12 @@ class FakeDataChannel { } CompletionCallback read_callback_; - scoped_refptr<net::IOBuffer> read_buf_; + scoped_refptr<IOBuffer> read_buf_; int read_buf_len_; CompletionCallback write_callback_; - std::queue<scoped_refptr<net::DrainableIOBuffer> > data_; + std::queue<scoped_refptr<DrainableIOBuffer> > data_; // True if Close() has been called. bool closed_; @@ -161,90 +166,68 @@ class FakeSocket : public StreamSocket { outgoing_(outgoing_channel) { } - virtual ~FakeSocket() { - } + ~FakeSocket() override {} - virtual int Read(IOBuffer* buf, int buf_len, - const CompletionCallback& callback) OVERRIDE { + int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override { // Read random number of bytes. buf_len = rand() % buf_len + 1; return incoming_->Read(buf, buf_len, callback); } - virtual int Write(IOBuffer* buf, int buf_len, - const CompletionCallback& callback) OVERRIDE { + int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override { // Write random number of bytes. buf_len = rand() % buf_len + 1; return outgoing_->Write(buf, buf_len, callback); } - virtual int SetReceiveBufferSize(int32 size) OVERRIDE { - return net::OK; - } + int SetReceiveBufferSize(int32 size) override { return OK; } - virtual int SetSendBufferSize(int32 size) OVERRIDE { - return net::OK; - } + int SetSendBufferSize(int32 size) override { return OK; } - virtual int Connect(const CompletionCallback& callback) OVERRIDE { - return net::OK; - } + int Connect(const CompletionCallback& callback) override { return OK; } - virtual void Disconnect() OVERRIDE { + void Disconnect() override { incoming_->Close(); outgoing_->Close(); } - virtual bool IsConnected() const OVERRIDE { - return true; - } + bool IsConnected() const override { return true; } - virtual bool IsConnectedAndIdle() const OVERRIDE { - return true; - } + bool IsConnectedAndIdle() const override { return true; } - virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE { - net::IPAddressNumber ip_address(net::kIPv4AddressSize); - *address = net::IPEndPoint(ip_address, 0 /*port*/); - return net::OK; + int GetPeerAddress(IPEndPoint* address) const override { + IPAddressNumber ip_address(kIPv4AddressSize); + *address = IPEndPoint(ip_address, 0 /*port*/); + return OK; } - virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE { - net::IPAddressNumber ip_address(4); - *address = net::IPEndPoint(ip_address, 0); - return net::OK; + int GetLocalAddress(IPEndPoint* address) const override { + IPAddressNumber ip_address(4); + *address = IPEndPoint(ip_address, 0); + return OK; } - virtual const BoundNetLog& NetLog() const OVERRIDE { - return net_log_; - } + const BoundNetLog& NetLog() const override { return net_log_; } - virtual void SetSubresourceSpeculation() OVERRIDE {} - virtual void SetOmniboxSpeculation() OVERRIDE {} + void SetSubresourceSpeculation() override {} + void SetOmniboxSpeculation() override {} - virtual bool WasEverUsed() const OVERRIDE { - return true; - } - - virtual bool UsingTCPFastOpen() const OVERRIDE { - return false; - } + bool WasEverUsed() const override { return true; } + bool UsingTCPFastOpen() const override { return false; } - virtual bool WasNpnNegotiated() const OVERRIDE { - return false; - } + bool WasNpnNegotiated() const override { return false; } - virtual NextProto GetNegotiatedProtocol() const OVERRIDE { - return kProtoUnknown; - } + NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; } - virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { - return false; - } + bool GetSSLInfo(SSLInfo* ssl_info) override { return false; } private: - net::BoundNetLog net_log_; + BoundNetLog net_log_; FakeDataChannel* incoming_; FakeDataChannel* outgoing_; @@ -264,8 +247,8 @@ TEST(FakeSocketTest, DataTransfer) { const char kTestData[] = "testing123"; const int kTestDataSize = strlen(kTestData); const int kReadBufSize = 1024; - scoped_refptr<net::IOBuffer> write_buf = new net::StringIOBuffer(kTestData); - scoped_refptr<net::IOBuffer> read_buf = new net::IOBuffer(kReadBufSize); + scoped_refptr<IOBuffer> write_buf = new StringIOBuffer(kTestData); + scoped_refptr<IOBuffer> read_buf = new IOBuffer(kReadBufSize); // Write then read. int written = @@ -280,7 +263,7 @@ TEST(FakeSocketTest, DataTransfer) { // Read then write. TestCompletionCallback callback; - EXPECT_EQ(net::ERR_IO_PENDING, + EXPECT_EQ(ERR_IO_PENDING, server.Read(read_buf.get(), kReadBufSize, callback.callback())); written = client.Write(write_buf.get(), kTestDataSize, CompletionCallback()); @@ -296,10 +279,10 @@ TEST(FakeSocketTest, DataTransfer) { class SSLServerSocketTest : public PlatformTest { public: SSLServerSocketTest() - : socket_factory_(net::ClientSocketFactory::GetDefaultFactory()), + : socket_factory_(ClientSocketFactory::GetDefaultFactory()), cert_verifier_(new MockCertVerifier()), transport_security_state_(new TransportSecurityState) { - cert_verifier_->set_default_result(net::CERT_STATUS_AUTHORITY_INVALID); + cert_verifier_->set_default_result(CERT_STATUS_AUTHORITY_INVALID); } protected: @@ -316,7 +299,7 @@ class SSLServerSocketTest : public PlatformTest { std::string cert_der; ASSERT_TRUE(base::ReadFileToString(cert_path, &cert_der)); - scoped_refptr<net::X509Certificate> cert = + scoped_refptr<X509Certificate> cert = X509Certificate::CreateFromBytes(cert_der.data(), cert_der.size()); base::FilePath key_path = certs_dir.AppendASCII("unittest.key.bin"); @@ -330,35 +313,35 @@ class SSLServerSocketTest : public PlatformTest { scoped_ptr<crypto::RSAPrivateKey> private_key( crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector)); - net::SSLConfig ssl_config; + SSLConfig ssl_config; ssl_config.false_start_enabled = false; ssl_config.channel_id_enabled = false; // Certificate provided by the host doesn't need authority. - net::SSLConfig::CertAndStatus cert_and_status; + SSLConfig::CertAndStatus cert_and_status; cert_and_status.cert_status = CERT_STATUS_AUTHORITY_INVALID; cert_and_status.der_cert = cert_der; ssl_config.allowed_bad_certs.push_back(cert_and_status); - net::HostPortPair host_and_pair("unittest", 0); - net::SSLClientSocketContext context; + HostPortPair host_and_pair("unittest", 0); + SSLClientSocketContext context; context.cert_verifier = cert_verifier_.get(); context.transport_security_state = transport_security_state_.get(); client_socket_ = socket_factory_->CreateSSLClientSocket( client_connection.Pass(), host_and_pair, ssl_config, context); - server_socket_ = net::CreateSSLServerSocket( + server_socket_ = CreateSSLServerSocket( server_socket.Pass(), - cert.get(), private_key.get(), net::SSLConfig()); + cert.get(), private_key.get(), SSLConfig()); } FakeDataChannel channel_1_; FakeDataChannel channel_2_; - scoped_ptr<net::SSLClientSocket> client_socket_; - scoped_ptr<net::SSLServerSocket> server_socket_; - net::ClientSocketFactory* socket_factory_; - scoped_ptr<net::MockCertVerifier> cert_verifier_; - scoped_ptr<net::TransportSecurityState> transport_security_state_; + scoped_ptr<SSLClientSocket> client_socket_; + scoped_ptr<SSLServerSocket> server_socket_; + ClientSocketFactory* socket_factory_; + scoped_ptr<MockCertVerifier> cert_verifier_; + scoped_ptr<TransportSecurityState> transport_security_state_; }; // This test only executes creation of client and server sockets. This is to @@ -378,16 +361,16 @@ TEST_F(SSLServerSocketTest, Handshake) { TestCompletionCallback handshake_callback; int server_ret = server_socket_->Handshake(handshake_callback.callback()); - EXPECT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING); + EXPECT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING); int client_ret = client_socket_->Connect(connect_callback.callback()); - EXPECT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING); + EXPECT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING); - if (client_ret == net::ERR_IO_PENDING) { - EXPECT_EQ(net::OK, connect_callback.WaitForResult()); + if (client_ret == ERR_IO_PENDING) { + EXPECT_EQ(OK, connect_callback.WaitForResult()); } - if (server_ret == net::ERR_IO_PENDING) { - EXPECT_EQ(net::OK, handshake_callback.WaitForResult()); + if (server_ret == ERR_IO_PENDING) { + EXPECT_EQ(OK, handshake_callback.WaitForResult()); } // Make sure the cert status is expected. @@ -404,32 +387,31 @@ TEST_F(SSLServerSocketTest, DataTransfer) { // Establish connection. int client_ret = client_socket_->Connect(connect_callback.callback()); - ASSERT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING); + ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING); int server_ret = server_socket_->Handshake(handshake_callback.callback()); - ASSERT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING); + ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING); client_ret = connect_callback.GetResult(client_ret); - ASSERT_EQ(net::OK, client_ret); + ASSERT_EQ(OK, client_ret); server_ret = handshake_callback.GetResult(server_ret); - ASSERT_EQ(net::OK, server_ret); + ASSERT_EQ(OK, server_ret); const int kReadBufSize = 1024; - scoped_refptr<net::StringIOBuffer> write_buf = - new net::StringIOBuffer("testing123"); - scoped_refptr<net::DrainableIOBuffer> read_buf = - new net::DrainableIOBuffer(new net::IOBuffer(kReadBufSize), - kReadBufSize); + scoped_refptr<StringIOBuffer> write_buf = + new StringIOBuffer("testing123"); + scoped_refptr<DrainableIOBuffer> read_buf = + new DrainableIOBuffer(new IOBuffer(kReadBufSize), kReadBufSize); // Write then read. TestCompletionCallback write_callback; TestCompletionCallback read_callback; server_ret = server_socket_->Write( write_buf.get(), write_buf->size(), write_callback.callback()); - EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING); + EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING); client_ret = client_socket_->Read( read_buf.get(), read_buf->BytesRemaining(), read_callback.callback()); - EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING); + EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING); server_ret = write_callback.GetResult(server_ret); EXPECT_GT(server_ret, 0); @@ -440,7 +422,7 @@ TEST_F(SSLServerSocketTest, DataTransfer) { while (read_buf->BytesConsumed() < write_buf->size()) { client_ret = client_socket_->Read( read_buf.get(), read_buf->BytesRemaining(), read_callback.callback()); - EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING); + EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING); client_ret = read_callback.GetResult(client_ret); ASSERT_GT(client_ret, 0); read_buf->DidConsume(client_ret); @@ -450,13 +432,13 @@ TEST_F(SSLServerSocketTest, DataTransfer) { EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size())); // Read then write. - write_buf = new net::StringIOBuffer("hello123"); + write_buf = new StringIOBuffer("hello123"); server_ret = server_socket_->Read( read_buf.get(), read_buf->BytesRemaining(), read_callback.callback()); - EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING); + EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING); client_ret = client_socket_->Write( write_buf.get(), write_buf->size(), write_callback.callback()); - EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING); + EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING); server_ret = read_callback.GetResult(server_ret); ASSERT_GT(server_ret, 0); @@ -467,7 +449,7 @@ TEST_F(SSLServerSocketTest, DataTransfer) { while (read_buf->BytesConsumed() < write_buf->size()) { server_ret = server_socket_->Read( read_buf.get(), read_buf->BytesRemaining(), read_callback.callback()); - EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING); + EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING); server_ret = read_callback.GetResult(server_ret); ASSERT_GT(server_ret, 0); read_buf->DidConsume(server_ret); @@ -477,17 +459,11 @@ TEST_F(SSLServerSocketTest, DataTransfer) { EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size())); } -// Flaky on Android: http://crbug.com/381147 -#if defined(OS_ANDROID) -#define MAYBE_ClientWriteAfterServerClose DISABLED_ClientWriteAfterServerClose -#else -#define MAYBE_ClientWriteAfterServerClose ClientWriteAfterServerClose -#endif // A regression test for bug 127822 (http://crbug.com/127822). // If the server closes the connection after the handshake is finished, // the client's Write() call should not cause an infinite loop. // NOTE: this is a test for SSLClientSocket rather than SSLServerSocket. -TEST_F(SSLServerSocketTest, MAYBE_ClientWriteAfterServerClose) { +TEST_F(SSLServerSocketTest, ClientWriteAfterServerClose) { Initialize(); TestCompletionCallback connect_callback; @@ -495,18 +471,17 @@ TEST_F(SSLServerSocketTest, MAYBE_ClientWriteAfterServerClose) { // Establish connection. int client_ret = client_socket_->Connect(connect_callback.callback()); - ASSERT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING); + ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING); int server_ret = server_socket_->Handshake(handshake_callback.callback()); - ASSERT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING); + ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING); client_ret = connect_callback.GetResult(client_ret); - ASSERT_EQ(net::OK, client_ret); + ASSERT_EQ(OK, client_ret); server_ret = handshake_callback.GetResult(server_ret); - ASSERT_EQ(net::OK, server_ret); + ASSERT_EQ(OK, server_ret); - scoped_refptr<net::StringIOBuffer> write_buf = - new net::StringIOBuffer("testing123"); + scoped_refptr<StringIOBuffer> write_buf = new StringIOBuffer("testing123"); // The server closes the connection. The server needs to write some // data first so that the client's Read() calls from the transport @@ -516,7 +491,7 @@ TEST_F(SSLServerSocketTest, MAYBE_ClientWriteAfterServerClose) { server_ret = server_socket_->Write( write_buf.get(), write_buf->size(), write_callback.callback()); - EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING); + EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING); server_ret = write_callback.GetResult(server_ret); EXPECT_GT(server_ret, 0); @@ -526,7 +501,7 @@ TEST_F(SSLServerSocketTest, MAYBE_ClientWriteAfterServerClose) { // The client writes some data. This should not cause an infinite loop. client_ret = client_socket_->Write( write_buf.get(), write_buf->size(), write_callback.callback()); - EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING); + EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING); client_ret = write_callback.GetResult(client_ret); EXPECT_GT(client_ret, 0); @@ -547,16 +522,16 @@ TEST_F(SSLServerSocketTest, ExportKeyingMaterial) { TestCompletionCallback handshake_callback; int client_ret = client_socket_->Connect(connect_callback.callback()); - ASSERT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING); + ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING); int server_ret = server_socket_->Handshake(handshake_callback.callback()); - ASSERT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING); + ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING); - if (client_ret == net::ERR_IO_PENDING) { - ASSERT_EQ(net::OK, connect_callback.WaitForResult()); + if (client_ret == ERR_IO_PENDING) { + ASSERT_EQ(OK, connect_callback.WaitForResult()); } - if (server_ret == net::ERR_IO_PENDING) { - ASSERT_EQ(net::OK, handshake_callback.WaitForResult()); + if (server_ret == ERR_IO_PENDING) { + ASSERT_EQ(OK, handshake_callback.WaitForResult()); } const int kKeyingMaterialSize = 32; @@ -566,13 +541,13 @@ TEST_F(SSLServerSocketTest, ExportKeyingMaterial) { int rv = server_socket_->ExportKeyingMaterial(kKeyingLabel, false, kKeyingContext, server_out, sizeof(server_out)); - ASSERT_EQ(net::OK, rv); + ASSERT_EQ(OK, rv); unsigned char client_out[kKeyingMaterialSize]; rv = client_socket_->ExportKeyingMaterial(kKeyingLabel, false, kKeyingContext, client_out, sizeof(client_out)); - ASSERT_EQ(net::OK, rv); + ASSERT_EQ(OK, rv); EXPECT_EQ(0, memcmp(server_out, client_out, sizeof(server_out))); const char* kKeyingLabelBad = "EXPERIMENTAL-server-socket-test-bad"; @@ -580,7 +555,7 @@ TEST_F(SSLServerSocketTest, ExportKeyingMaterial) { rv = client_socket_->ExportKeyingMaterial(kKeyingLabelBad, false, kKeyingContext, client_bad, sizeof(client_bad)); - ASSERT_EQ(rv, net::OK); + ASSERT_EQ(rv, OK); EXPECT_NE(0, memcmp(server_out, client_bad, sizeof(server_out))); } diff --git a/chromium/net/socket/ssl_session_cache_openssl.cc b/chromium/net/socket/ssl_session_cache_openssl.cc index d16bb8d6325..92ae44b9ac5 100644 --- a/chromium/net/socket/ssl_session_cache_openssl.cc +++ b/chromium/net/socket/ssl_session_cache_openssl.cc @@ -87,8 +87,9 @@ struct SessionId { // this one is just simple enough to do the job. size_t ComputeHash(const unsigned char* id, unsigned id_len) { size_t result = 0; - for (unsigned n = 0; n < id_len; ++n) - result += 131 * id[n]; + for (unsigned n = 0; n < id_len; ++n) { + result = (result * 131) + id[n]; + } return result; } }; @@ -236,10 +237,25 @@ class SSLSessionCacheOpenSSLImpl { return SSL_set_session(ssl, session) == 1; } + // Return true iff a cached session was associated with the given |cache_key|. + bool SSLSessionIsInCache(const std::string& cache_key) const { + base::AutoLock locked(lock_); + KeyIndex::const_iterator it = key_index_.find(cache_key); + if (it == key_index_.end()) + return false; + + SSL_SESSION* session = *it->second; + DCHECK(session); + + void* session_is_good = + SSL_SESSION_get_ex_data(session, GetSSLSessionExIndex()); + + return session_is_good != NULL; + } + void MarkSSLSessionAsGood(SSL* ssl) { SSL_SESSION* session = SSL_get_session(ssl); - if (!session) - return; + CHECK(session); // Mark the session as good, allowing it to be used for future connections. SSL_SESSION_set_ex_data( @@ -342,7 +358,8 @@ class SSLSessionCacheOpenSSLImpl { // to indicate that it took ownership of the session, i.e. that the caller // should not decrement its reference count after completion. static int NewSessionCallbackStatic(SSL* ssl, SSL_SESSION* session) { - GetCache(ssl->ctx)->OnSessionAdded(ssl, session); + SSLSessionCacheOpenSSLImpl* cache = GetCache(ssl->ctx); + cache->OnSessionAdded(ssl, session); return 1; } @@ -469,7 +486,7 @@ class SSLSessionCacheOpenSSLImpl { // method to get the index which can later be used with SSL_CTX_get_ex_data() // or SSL_CTX_set_ex_data(). - base::Lock lock_; // Protects access to containers below. + mutable base::Lock lock_; // Protects access to containers below. MRUSessionList ordering_; KeyIndex key_index_; @@ -499,6 +516,11 @@ bool SSLSessionCacheOpenSSL::SetSSLSessionWithKey( return impl_->SetSSLSessionWithKey(ssl, cache_key); } +bool SSLSessionCacheOpenSSL::SSLSessionIsInCache( + const std::string& cache_key) const { + return impl_->SSLSessionIsInCache(cache_key); +} + void SSLSessionCacheOpenSSL::MarkSSLSessionAsGood(SSL* ssl) { return impl_->MarkSSLSessionAsGood(ssl); } diff --git a/chromium/net/socket/ssl_session_cache_openssl.h b/chromium/net/socket/ssl_session_cache_openssl.h index bbd9659641d..abf5eab78cb 100644 --- a/chromium/net/socket/ssl_session_cache_openssl.h +++ b/chromium/net/socket/ssl_session_cache_openssl.h @@ -113,6 +113,9 @@ class NET_EXPORT SSLSessionCacheOpenSSL { // Return true iff a cached session was associated with the |ssl| connection. bool SetSSLSessionWithKey(SSL* ssl, const std::string& cache_key); + // Return true iff a cached session was associated with the given |cache_key|. + bool SSLSessionIsInCache(const std::string& cache_key) const; + // Indicates that the SSL session associated with |ssl| is "good" - that is, // that all associated cryptographic parameters that were negotiated, // including the peer's certificate, were successfully validated. Because diff --git a/chromium/net/socket/ssl_session_cache_openssl_unittest.cc b/chromium/net/socket/ssl_session_cache_openssl_unittest.cc index 22c4fbaeb9c..78bac63ccb2 100644 --- a/chromium/net/socket/ssl_session_cache_openssl_unittest.cc +++ b/chromium/net/socket/ssl_session_cache_openssl_unittest.cc @@ -10,6 +10,7 @@ #include "base/logging.h" #include "base/strings/stringprintf.h" #include "crypto/openssl_util.h" +#include "crypto/scoped_openssl_types.h" #include "testing/gtest/include/gtest/gtest.h" @@ -19,18 +20,19 @@ // |s| is the target SSL connection handle. // |session| is non-0 to ask for the creation of a new session. If 0, // this will set an empty session with no ID instead. -extern "C" int ssl_get_new_session(SSL* s, int session); +extern "C" OPENSSL_EXPORT int ssl_get_new_session(SSL* s, int session); // This is an internal OpenSSL function which is used internally to add // a new session to the cache. It is normally triggered by a succesful // connection. However, this unit test does not use the network at all. -extern "C" void ssl_update_cache(SSL* s, int mode); +extern "C" OPENSSL_EXPORT void ssl_update_cache(SSL* s, int mode); namespace net { namespace { -typedef crypto::ScopedOpenSSL<SSL, SSL_free> ScopedSSL; +typedef crypto::ScopedOpenSSL<SSL, SSL_free>::Type ScopedSSL; +typedef crypto::ScopedOpenSSL<SSL_CTX, SSL_CTX_free>::Type ScopedSSL_CTX; // Helper class used to associate arbitrary std::string keys with SSL objects. class SSLKeyHelper { @@ -73,8 +75,8 @@ class SSLKeyHelper { // Called when an SSL object is copied through SSL_dup(). This needs to copy // the value as well. static int KeyDup(CRYPTO_EX_DATA* to, - CRYPTO_EX_DATA* from, - void* from_fd, + const CRYPTO_EX_DATA* from, + void** from_fd, int idx, long argl, void* argp) { @@ -142,7 +144,7 @@ class SSLSessionCacheOpenSSLTest : public testing::Test { static const SSLSessionCacheOpenSSL::Config kDefaultConfig; protected: - crypto::ScopedOpenSSL<SSL_CTX, SSL_CTX_free> ctx_; + ScopedSSL_CTX ctx_; // |cache_| must be destroyed before |ctx_| and thus appears after it. SSLSessionCacheOpenSSL cache_; }; diff --git a/chromium/net/socket/ssl_socket.h b/chromium/net/socket/ssl_socket.h index 68d1e4a2bfe..0dc817becdd 100644 --- a/chromium/net/socket/ssl_socket.h +++ b/chromium/net/socket/ssl_socket.h @@ -15,7 +15,7 @@ namespace net { // and server SSL sockets. class NET_EXPORT SSLSocket : public StreamSocket { public: - virtual ~SSLSocket() {} + ~SSLSocket() override {} // Exports data derived from the SSL master-secret (see RFC 5705). // If |has_context| is false, uses the no-context construction from the diff --git a/chromium/net/socket/stream_listen_socket.cc b/chromium/net/socket/stream_listen_socket.cc index fd164a556d5..abb5fbc6b52 100644 --- a/chromium/net/socket/stream_listen_socket.cc +++ b/chromium/net/socket/stream_listen_socket.cc @@ -21,6 +21,7 @@ #include "base/memory/ref_counted.h" #include "base/memory/scoped_ptr.h" #include "base/posix/eintr_wrapper.h" +#include "base/profiler/scoped_tracker.h" #include "base/sys_byteorder.h" #include "base/threading/platform_thread.h" #include "build/build_config.h" @@ -246,6 +247,10 @@ void StreamListenSocket::UnwatchSocket() { #if defined(OS_WIN) // MessageLoop watcher callback. void StreamListenSocket::OnObjectSignaled(HANDLE object) { + // TODO(vadimt): Remove ScopedTracker below once crbug.com/418183 is fixed. + tracked_objects::ScopedTracker tracking_profile( + FROM_HERE_WITH_EXPLICIT_FUNCTION("StreamListenSocket_OnObjectSignaled")); + WSANETWORKEVENTS ev; if (kSocketError == WSAEnumNetworkEvents(socket_, socket_event_, &ev)) { // TODO diff --git a/chromium/net/socket/stream_listen_socket.h b/chromium/net/socket/stream_listen_socket.h index 813d96a22be..f8f9419484d 100644 --- a/chromium/net/socket/stream_listen_socket.h +++ b/chromium/net/socket/stream_listen_socket.h @@ -47,7 +47,7 @@ class NET_EXPORT StreamListenSocket #endif public: - virtual ~StreamListenSocket(); + ~StreamListenSocket() override; // TODO(erikkay): this delegate should really be split into two parts // to split up the listener from the connected socket. Perhaps this class @@ -116,8 +116,8 @@ class NET_EXPORT StreamListenSocket HANDLE socket_event_; #elif defined(OS_POSIX) // Called by MessagePumpLibevent when the socket is ready to do I/O. - virtual void OnFileCanReadWithoutBlocking(int fd) OVERRIDE; - virtual void OnFileCanWriteWithoutBlocking(int fd) OVERRIDE; + void OnFileCanReadWithoutBlocking(int fd) override; + void OnFileCanWriteWithoutBlocking(int fd) override; WaitState wait_state_; // The socket's libevent wrapper. base::MessageLoopForIO::FileDescriptorWatcher watcher_; diff --git a/chromium/net/socket/stream_socket.h b/chromium/net/socket/stream_socket.h index 72088103a3d..b41fed8b51b 100644 --- a/chromium/net/socket/stream_socket.h +++ b/chromium/net/socket/stream_socket.h @@ -17,7 +17,7 @@ class SSLInfo; class NET_EXPORT_PRIVATE StreamSocket : public Socket { public: - virtual ~StreamSocket() {} + ~StreamSocket() override {} // Called to establish a connection. Returns OK if the connection could be // established synchronously. Otherwise, ERR_IO_PENDING is returned and the @@ -75,10 +75,15 @@ class NET_EXPORT_PRIVATE StreamSocket : public Socket { // Write() methods had been called, not the underlying transport's. virtual bool WasEverUsed() const = 0; + // TODO(jri): Clean up -- remove this method. // Returns true if the underlying transport socket is using TCP FastOpen. // TCP FastOpen is an experiment with sending data in the TCP SYN packet. virtual bool UsingTCPFastOpen() const = 0; + // TODO(jri): Clean up -- rename to a more general EnableAutoConnectOnWrite. + // Enables use of TCP FastOpen for the underlying transport socket. + virtual void EnableTCPFastOpenIfSupported() {} + // Returns true if NPN was negotiated during the connection of this socket. virtual bool WasNpnNegotiated() const = 0; diff --git a/chromium/net/socket/tcp_client_socket.cc b/chromium/net/socket/tcp_client_socket.cc index 53d8d1fe259..dcf124bfd6b 100644 --- a/chromium/net/socket/tcp_client_socket.cc +++ b/chromium/net/socket/tcp_client_socket.cc @@ -6,6 +6,7 @@ #include "base/callback_helpers.h" #include "base/logging.h" +#include "base/profiler/scoped_tracker.h" #include "net/base/io_buffer.h" #include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" @@ -225,6 +226,10 @@ bool TCPClientSocket::UsingTCPFastOpen() const { return socket_->UsingTCPFastOpen(); } +void TCPClientSocket::EnableTCPFastOpenIfSupported() { + socket_->EnableTCPFastOpenIfSupported(); +} + bool TCPClientSocket::WasNpnNegotiated() const { return false; } @@ -302,6 +307,10 @@ void TCPClientSocket::DidCompleteReadWrite(const CompletionCallback& callback, if (result > 0) use_history_.set_was_used_to_convey_data(); + // TODO(vadimt): Remove ScopedTracker below once crbug.com/418183 is fixed. + tracked_objects::ScopedTracker tracking_profile( + FROM_HERE_WITH_EXPLICIT_FUNCTION( + "TCPClientSocket::DidCompleteReadWrite")); callback.Run(result); } diff --git a/chromium/net/socket/tcp_client_socket.h b/chromium/net/socket/tcp_client_socket.h index 970da2a026f..0deec2a0c9f 100644 --- a/chromium/net/socket/tcp_client_socket.h +++ b/chromium/net/socket/tcp_client_socket.h @@ -32,36 +32,39 @@ class NET_EXPORT TCPClientSocket : public StreamSocket { TCPClientSocket(scoped_ptr<TCPSocket> connected_socket, const IPEndPoint& peer_address); - virtual ~TCPClientSocket(); + ~TCPClientSocket() override; // Binds the socket to a local IP address and port. int Bind(const IPEndPoint& address); // StreamSocket implementation. - virtual int Connect(const CompletionCallback& callback) OVERRIDE; - virtual void Disconnect() OVERRIDE; - virtual bool IsConnected() const OVERRIDE; - virtual bool IsConnectedAndIdle() const OVERRIDE; - virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE; - virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE; - virtual const BoundNetLog& NetLog() const OVERRIDE; - virtual void SetSubresourceSpeculation() OVERRIDE; - virtual void SetOmniboxSpeculation() OVERRIDE; - virtual bool WasEverUsed() const OVERRIDE; - virtual bool UsingTCPFastOpen() const OVERRIDE; - virtual bool WasNpnNegotiated() const OVERRIDE; - virtual NextProto GetNegotiatedProtocol() const OVERRIDE; - virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; + int Connect(const CompletionCallback& callback) override; + void Disconnect() override; + bool IsConnected() const override; + bool IsConnectedAndIdle() const override; + int GetPeerAddress(IPEndPoint* address) const override; + int GetLocalAddress(IPEndPoint* address) const override; + const BoundNetLog& NetLog() const override; + void SetSubresourceSpeculation() override; + void SetOmniboxSpeculation() override; + bool WasEverUsed() const override; + bool UsingTCPFastOpen() const override; + void EnableTCPFastOpenIfSupported() override; + bool WasNpnNegotiated() const override; + NextProto GetNegotiatedProtocol() const override; + bool GetSSLInfo(SSLInfo* ssl_info) override; // Socket implementation. // Multiple outstanding requests are not supported. // Full duplex mode (reading and writing at the same time) is supported. - virtual int Read(IOBuffer* buf, int buf_len, - const CompletionCallback& callback) OVERRIDE; - virtual int Write(IOBuffer* buf, int buf_len, - const CompletionCallback& callback) OVERRIDE; - virtual int SetReceiveBufferSize(int32 size) OVERRIDE; - virtual int SetSendBufferSize(int32 size) OVERRIDE; + int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override; + int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override; + int SetReceiveBufferSize(int32 size) override; + int SetSendBufferSize(int32 size) override; virtual bool SetKeepAlive(bool enable, int delay); virtual bool SetNoDelay(bool no_delay); diff --git a/chromium/net/socket/tcp_listen_socket.cc b/chromium/net/socket/tcp_listen_socket.cc index 223abee2cba..585c41292de 100644 --- a/chromium/net/socket/tcp_listen_socket.cc +++ b/chromium/net/socket/tcp_listen_socket.cc @@ -107,7 +107,7 @@ void TCPListenSocket::Accept() { #if defined(OS_POSIX) sock->WatchSocket(WAITING_READ); #endif - socket_delegate_->DidAccept(this, sock.PassAs<StreamListenSocket>()); + socket_delegate_->DidAccept(this, sock.Pass()); } TCPListenSocketFactory::TCPListenSocketFactory(const string& ip, int port) @@ -119,8 +119,7 @@ TCPListenSocketFactory::~TCPListenSocketFactory() {} scoped_ptr<StreamListenSocket> TCPListenSocketFactory::CreateAndListen( StreamListenSocket::Delegate* delegate) const { - return TCPListenSocket::CreateAndListen(ip_, port_, delegate) - .PassAs<StreamListenSocket>(); + return TCPListenSocket::CreateAndListen(ip_, port_, delegate); } } // namespace net diff --git a/chromium/net/socket/tcp_listen_socket.h b/chromium/net/socket/tcp_listen_socket.h index 54a91de59bb..1702e50e8ed 100644 --- a/chromium/net/socket/tcp_listen_socket.h +++ b/chromium/net/socket/tcp_listen_socket.h @@ -17,7 +17,7 @@ namespace net { // Implements a TCP socket. class NET_EXPORT TCPListenSocket : public StreamListenSocket { public: - virtual ~TCPListenSocket(); + ~TCPListenSocket() override; // Listen on port for the specified IP address. Use 127.0.0.1 to only // accept local connections. static scoped_ptr<TCPListenSocket> CreateAndListen( @@ -34,7 +34,7 @@ class NET_EXPORT TCPListenSocket : public StreamListenSocket { TCPListenSocket(SocketDescriptor s, StreamListenSocket::Delegate* del); // Implements StreamListenSocket::Accept. - virtual void Accept() OVERRIDE; + void Accept() override; private: DISALLOW_COPY_AND_ASSIGN(TCPListenSocket); @@ -44,11 +44,11 @@ class NET_EXPORT TCPListenSocket : public StreamListenSocket { class NET_EXPORT TCPListenSocketFactory : public StreamListenSocketFactory { public: TCPListenSocketFactory(const std::string& ip, int port); - virtual ~TCPListenSocketFactory(); + ~TCPListenSocketFactory() override; // StreamListenSocketFactory overrides. - virtual scoped_ptr<StreamListenSocket> CreateAndListen( - StreamListenSocket::Delegate* delegate) const OVERRIDE; + scoped_ptr<StreamListenSocket> CreateAndListen( + StreamListenSocket::Delegate* delegate) const override; private: const std::string ip_; diff --git a/chromium/net/socket/tcp_listen_socket_unittest.cc b/chromium/net/socket/tcp_listen_socket_unittest.cc index 41c41f81fe7..b58642643ca 100644 --- a/chromium/net/socket/tcp_listen_socket_unittest.cc +++ b/chromium/net/socket/tcp_listen_socket_unittest.cc @@ -275,13 +275,13 @@ class TCPListenSocketTest : public PlatformTest { tester_ = NULL; } - virtual void SetUp() { + void SetUp() override { PlatformTest::SetUp(); tester_ = new TCPListenSocketTester(); tester_->SetUp(); } - virtual void TearDown() { + void TearDown() override { PlatformTest::TearDown(); tester_->TearDown(); tester_ = NULL; diff --git a/chromium/net/socket/tcp_listen_socket_unittest.h b/chromium/net/socket/tcp_listen_socket_unittest.h index 1bc31a8d1ce..984442afdc0 100644 --- a/chromium/net/socket/tcp_listen_socket_unittest.h +++ b/chromium/net/socket/tcp_listen_socket_unittest.h @@ -90,11 +90,12 @@ class TCPListenSocketTester : virtual bool Send(SocketDescriptor sock, const std::string& str); // StreamListenSocket::Delegate: - virtual void DidAccept(StreamListenSocket* server, - scoped_ptr<StreamListenSocket> connection) OVERRIDE; - virtual void DidRead(StreamListenSocket* connection, const char* data, - int len) OVERRIDE; - virtual void DidClose(StreamListenSocket* sock) OVERRIDE; + void DidAccept(StreamListenSocket* server, + scoped_ptr<StreamListenSocket> connection) override; + void DidRead(StreamListenSocket* connection, + const char* data, + int len) override; + void DidClose(StreamListenSocket* sock) override; scoped_ptr<base::Thread> thread_; base::MessageLoopForIO* loop_; @@ -111,7 +112,7 @@ class TCPListenSocketTester : private: friend class base::RefCountedThreadSafe<TCPListenSocketTester>; - virtual ~TCPListenSocketTester(); + ~TCPListenSocketTester() override; virtual scoped_ptr<TCPListenSocket> DoListen(); diff --git a/chromium/net/socket/tcp_server_socket.h b/chromium/net/socket/tcp_server_socket.h index faff9ad826a..a3919e6845a 100644 --- a/chromium/net/socket/tcp_server_socket.h +++ b/chromium/net/socket/tcp_server_socket.h @@ -19,13 +19,13 @@ namespace net { class NET_EXPORT_PRIVATE TCPServerSocket : public ServerSocket { public: TCPServerSocket(NetLog* net_log, const NetLog::Source& source); - virtual ~TCPServerSocket(); + ~TCPServerSocket() override; // net::ServerSocket implementation. - virtual int Listen(const IPEndPoint& address, int backlog) OVERRIDE; - virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE; - virtual int Accept(scoped_ptr<StreamSocket>* socket, - const CompletionCallback& callback) OVERRIDE; + int Listen(const IPEndPoint& address, int backlog) override; + int GetLocalAddress(IPEndPoint* address) const override; + int Accept(scoped_ptr<StreamSocket>* socket, + const CompletionCallback& callback) override; private: // Converts |accepted_socket_| and stores the result in diff --git a/chromium/net/socket/tcp_server_socket_unittest.cc b/chromium/net/socket/tcp_server_socket_unittest.cc index fd81e550d08..01bae9ff188 100644 --- a/chromium/net/socket/tcp_server_socket_unittest.cc +++ b/chromium/net/socket/tcp_server_socket_unittest.cc @@ -215,8 +215,8 @@ TEST_F(TCPServerSocketTest, AcceptIO) { size_t bytes_written = 0; while (bytes_written < message.size()) { - scoped_refptr<net::IOBufferWithSize> write_buffer( - new net::IOBufferWithSize(message.size() - bytes_written)); + scoped_refptr<IOBufferWithSize> write_buffer( + new IOBufferWithSize(message.size() - bytes_written)); memmove(write_buffer->data(), message.data(), message.size()); TestCompletionCallback write_callback; @@ -230,8 +230,8 @@ TEST_F(TCPServerSocketTest, AcceptIO) { size_t bytes_read = 0; while (bytes_read < message.size()) { - scoped_refptr<net::IOBufferWithSize> read_buffer( - new net::IOBufferWithSize(message.size() - bytes_read)); + scoped_refptr<IOBufferWithSize> read_buffer( + new IOBufferWithSize(message.size() - bytes_read)); TestCompletionCallback read_callback; int read_result = connecting_socket.Read( read_buffer.get(), read_buffer->size(), read_callback.callback()); diff --git a/chromium/net/socket/tcp_socket.cc b/chromium/net/socket/tcp_socket.cc deleted file mode 100644 index 63703627134..00000000000 --- a/chromium/net/socket/tcp_socket.cc +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2013 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "net/socket/tcp_socket.h" - -#include "base/file_util.h" -#include "base/files/file_path.h" -#include "base/memory/ref_counted.h" -#include "base/threading/worker_pool.h" - -namespace net { - -namespace { - -bool g_tcp_fastopen_enabled = false; - -#if defined(OS_LINUX) || defined(OS_ANDROID) - -typedef base::RefCountedData<bool> SharedBoolean; - -// Checks to see if the system supports TCP FastOpen. Notably, it requires -// kernel support. Additionally, this checks system configuration to ensure that -// it's enabled. -void SystemSupportsTCPFastOpen(scoped_refptr<SharedBoolean> supported) { - supported->data = false; - static const base::FilePath::CharType kTCPFastOpenProcFilePath[] = - "/proc/sys/net/ipv4/tcp_fastopen"; - std::string system_enabled_tcp_fastopen; - if (!base::ReadFileToString(base::FilePath(kTCPFastOpenProcFilePath), - &system_enabled_tcp_fastopen)) { - return; - } - - // As per http://lxr.linux.no/linux+v3.7.7/include/net/tcp.h#L225 - // TFO_CLIENT_ENABLE is the LSB - if (system_enabled_tcp_fastopen.empty() || - (system_enabled_tcp_fastopen[0] & 0x1) == 0) { - return; - } - - supported->data = true; -} - -void EnableCallback(scoped_refptr<SharedBoolean> supported) { - g_tcp_fastopen_enabled = supported->data; -} - -// This is asynchronous because it needs to do file IO, and it isn't allowed to -// do that on the IO thread. -void EnableFastOpenIfSupported() { - scoped_refptr<SharedBoolean> supported = new SharedBoolean; - base::WorkerPool::PostTaskAndReply( - FROM_HERE, - base::Bind(SystemSupportsTCPFastOpen, supported), - base::Bind(EnableCallback, supported), - false); -} - -#else - -void EnableFastOpenIfSupported() { - g_tcp_fastopen_enabled = false; -} - -#endif - -} // namespace - -void SetTCPFastOpenEnabled(bool value) { - if (value) { - EnableFastOpenIfSupported(); - } else { - g_tcp_fastopen_enabled = false; - } -} - -bool IsTCPFastOpenEnabled() { - return g_tcp_fastopen_enabled; -} - -} // namespace net diff --git a/chromium/net/socket/tcp_socket.h b/chromium/net/socket/tcp_socket.h index 8b36fade758..04fd7d2ba6f 100644 --- a/chromium/net/socket/tcp_socket.h +++ b/chromium/net/socket/tcp_socket.h @@ -16,13 +16,6 @@ namespace net { -// Enable/disable experimental TCP FastOpen option. -// Not thread safe. Must be called during initialization/startup only. -NET_EXPORT void SetTCPFastOpenEnabled(bool value); - -// Check if the TCP FastOpen option is enabled. -bool IsTCPFastOpenEnabled(); - // TCPSocket provides a platform-independent interface for TCP sockets. // // It is recommended to use TCPClientSocket/TCPServerSocket instead of this @@ -35,6 +28,17 @@ typedef TCPSocketWin TCPSocket; typedef TCPSocketLibevent TCPSocket; #endif +// Check if TCP FastOpen is supported by the OS. +bool IsTCPFastOpenSupported(); + +// Check if TCP FastOpen is enabled by the user. +bool IsTCPFastOpenUserEnabled(); + +// Checks if TCP FastOpen is supported by the kernel. Also enables TFO for all +// connections if indicated by user. +// Not thread safe. Must be called during initialization/startup only. +NET_EXPORT void CheckSupportAndMaybeEnableTCPFastOpen(bool user_enabled); + } // namespace net #endif // NET_SOCKET_TCP_SOCKET_H_ diff --git a/chromium/net/socket/tcp_socket_libevent.cc b/chromium/net/socket/tcp_socket_libevent.cc index 72ae5809e10..cc2376590f5 100644 --- a/chromium/net/socket/tcp_socket_libevent.cc +++ b/chromium/net/socket/tcp_socket_libevent.cc @@ -5,18 +5,16 @@ #include "net/socket/tcp_socket.h" #include <errno.h> -#include <fcntl.h> -#include <netdb.h> -#include <netinet/in.h> #include <netinet/tcp.h> #include <sys/socket.h> -#include "base/callback_helpers.h" +#include "base/bind.h" #include "base/logging.h" #include "base/metrics/histogram.h" #include "base/metrics/stats_counters.h" #include "base/posix/eintr_wrapper.h" -#include "build/build_config.h" +#include "base/task_runner_util.h" +#include "base/threading/worker_pool.h" #include "net/base/address_list.h" #include "net/base/connection_type_histograms.h" #include "net/base/io_buffer.h" @@ -24,6 +22,7 @@ #include "net/base/net_errors.h" #include "net/base/net_util.h" #include "net/base/network_change_notifier.h" +#include "net/socket/socket_libevent.h" #include "net/socket/socket_net_log_params.h" // If we don't have a definition for TCPI_OPT_SYN_DATA, create one. @@ -35,6 +34,14 @@ namespace net { namespace { +// True if OS supports TCP FastOpen. +bool g_tcp_fastopen_supported = false; +// True if TCP FastOpen is user-enabled for all connections. +// TODO(jri): Change global variable to param in HttpNetworkSession::Params. +bool g_tcp_fastopen_user_enabled = false; +// True if TCP FastOpen connect-with-write has failed at least once. +bool g_tcp_fastopen_has_failed = false; + // SetTCPNoDelay turns on/off buffering in the kernel. By default, TCP sockets // will wait up to 200ms for more data to complete a packet before transmitting. // After calling this function, the kernel will not wait. See TCP_NODELAY in @@ -72,90 +79,61 @@ bool SetTCPKeepAlive(int fd, bool enable, int delay) { return true; } -int MapAcceptError(int os_error) { - switch (os_error) { - // If the client aborts the connection before the server calls accept, - // POSIX specifies accept should fail with ECONNABORTED. The server can - // ignore the error and just call accept again, so we map the error to - // ERR_IO_PENDING. See UNIX Network Programming, Vol. 1, 3rd Ed., Sec. - // 5.11, "Connection Abort before accept Returns". - case ECONNABORTED: - return ERR_IO_PENDING; - default: - return MapSystemError(os_error); +#if defined(OS_LINUX) || defined(OS_ANDROID) +// Checks if the kernel supports TCP FastOpen. +bool SystemSupportsTCPFastOpen() { + const base::FilePath::CharType kTCPFastOpenProcFilePath[] = + "/proc/sys/net/ipv4/tcp_fastopen"; + std::string system_supports_tcp_fastopen; + if (!base::ReadFileToString(base::FilePath(kTCPFastOpenProcFilePath), + &system_supports_tcp_fastopen)) { + return false; } + // The read from /proc should return '1' if TCP FastOpen is enabled in the OS. + if (system_supports_tcp_fastopen.empty() || + (system_supports_tcp_fastopen[0] != '1')) { + return false; + } + return true; } -int MapConnectError(int os_error) { - switch (os_error) { - case EACCES: - return ERR_NETWORK_ACCESS_DENIED; - case ETIMEDOUT: - return ERR_CONNECTION_TIMED_OUT; - default: { - int net_error = MapSystemError(os_error); - if (net_error == ERR_FAILED) - return ERR_CONNECTION_FAILED; // More specific than ERR_FAILED. - - // Give a more specific error when the user is offline. - if (net_error == ERR_ADDRESS_UNREACHABLE && - NetworkChangeNotifier::IsOffline()) { - return ERR_INTERNET_DISCONNECTED; - } - return net_error; - } - } +void RegisterTCPFastOpenIntentAndSupport(bool user_enabled, + bool system_supported) { + g_tcp_fastopen_supported = system_supported; + g_tcp_fastopen_user_enabled = user_enabled; } +#endif } // namespace //----------------------------------------------------------------------------- -TCPSocketLibevent::Watcher::Watcher( - const base::Closure& read_ready_callback, - const base::Closure& write_ready_callback) - : read_ready_callback_(read_ready_callback), - write_ready_callback_(write_ready_callback) { +bool IsTCPFastOpenSupported() { + return g_tcp_fastopen_supported; } -TCPSocketLibevent::Watcher::~Watcher() { +bool IsTCPFastOpenUserEnabled() { + return g_tcp_fastopen_user_enabled; } -void TCPSocketLibevent::Watcher::OnFileCanReadWithoutBlocking(int /* fd */) { - if (!read_ready_callback_.is_null()) - read_ready_callback_.Run(); - else - NOTREACHED(); -} - -void TCPSocketLibevent::Watcher::OnFileCanWriteWithoutBlocking(int /* fd */) { - if (!write_ready_callback_.is_null()) - write_ready_callback_.Run(); - else - NOTREACHED(); +// This is asynchronous because it needs to do file IO, and it isn't allowed to +// do that on the IO thread. +void CheckSupportAndMaybeEnableTCPFastOpen(bool user_enabled) { +#if defined(OS_LINUX) || defined(OS_ANDROID) + base::PostTaskAndReplyWithResult( + base::WorkerPool::GetTaskRunner(/*task_is_slow=*/false).get(), + FROM_HERE, + base::Bind(SystemSupportsTCPFastOpen), + base::Bind(RegisterTCPFastOpenIntentAndSupport, user_enabled)); +#endif } TCPSocketLibevent::TCPSocketLibevent(NetLog* net_log, const NetLog::Source& source) - : socket_(kInvalidSocket), - accept_watcher_(base::Bind(&TCPSocketLibevent::DidCompleteAccept, - base::Unretained(this)), - base::Closure()), - accept_socket_(NULL), - accept_address_(NULL), - read_watcher_(base::Bind(&TCPSocketLibevent::DidCompleteRead, - base::Unretained(this)), - base::Closure()), - write_watcher_(base::Closure(), - base::Bind(&TCPSocketLibevent::DidCompleteConnectOrWrite, - base::Unretained(this))), - read_buf_len_(0), - write_buf_len_(0), - use_tcp_fastopen_(IsTCPFastOpenEnabled()), + : use_tcp_fastopen_(false), + tcp_fastopen_write_attempted_(false), tcp_fastopen_connected_(false), - fast_open_status_(FAST_OPEN_STATUS_UNKNOWN), - waiting_connect_(false), - connect_os_error_(0), + tcp_fastopen_status_(TCP_FASTOPEN_STATUS_UNKNOWN), logging_multiple_connect_attempts_(false), net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) { net_log_.BeginEvent(NetLog::TYPE_SOCKET_ALIVE, @@ -164,274 +142,173 @@ TCPSocketLibevent::TCPSocketLibevent(NetLog* net_log, TCPSocketLibevent::~TCPSocketLibevent() { net_log_.EndEvent(NetLog::TYPE_SOCKET_ALIVE); - if (tcp_fastopen_connected_) { - UMA_HISTOGRAM_ENUMERATION("Net.TcpFastOpenSocketConnection", - fast_open_status_, FAST_OPEN_MAX_VALUE); - } Close(); } int TCPSocketLibevent::Open(AddressFamily family) { - DCHECK(CalledOnValidThread()); - DCHECK_EQ(socket_, kInvalidSocket); - - socket_ = CreatePlatformSocket(ConvertAddressFamily(family), SOCK_STREAM, - IPPROTO_TCP); - if (socket_ < 0) { - PLOG(ERROR) << "CreatePlatformSocket() returned an error"; - return MapSystemError(errno); - } - - if (SetNonBlocking(socket_)) { - int result = MapSystemError(errno); - Close(); - return result; - } - - return OK; + DCHECK(!socket_); + socket_.reset(new SocketLibevent); + int rv = socket_->Open(ConvertAddressFamily(family)); + if (rv != OK) + socket_.reset(); + return rv; } -int TCPSocketLibevent::AdoptConnectedSocket(int socket, +int TCPSocketLibevent::AdoptConnectedSocket(int socket_fd, const IPEndPoint& peer_address) { - DCHECK(CalledOnValidThread()); - DCHECK_EQ(socket_, kInvalidSocket); + DCHECK(!socket_); - socket_ = socket; - - if (SetNonBlocking(socket_)) { - int result = MapSystemError(errno); - Close(); - return result; + SockaddrStorage storage; + if (!peer_address.ToSockAddr(storage.addr, &storage.addr_len) && + // For backward compatibility, allows the empty address. + !(peer_address == IPEndPoint())) { + return ERR_ADDRESS_INVALID; } - peer_address_.reset(new IPEndPoint(peer_address)); - - return OK; + socket_.reset(new SocketLibevent); + int rv = socket_->AdoptConnectedSocket(socket_fd, storage); + if (rv != OK) + socket_.reset(); + return rv; } int TCPSocketLibevent::Bind(const IPEndPoint& address) { - DCHECK(CalledOnValidThread()); - DCHECK_NE(socket_, kInvalidSocket); + DCHECK(socket_); SockaddrStorage storage; if (!address.ToSockAddr(storage.addr, &storage.addr_len)) return ERR_ADDRESS_INVALID; - int result = bind(socket_, storage.addr, storage.addr_len); - if (result < 0) { - PLOG(ERROR) << "bind() returned an error"; - return MapSystemError(errno); - } - - return OK; + return socket_->Bind(storage); } int TCPSocketLibevent::Listen(int backlog) { - DCHECK(CalledOnValidThread()); - DCHECK_GT(backlog, 0); - DCHECK_NE(socket_, kInvalidSocket); - - int result = listen(socket_, backlog); - if (result < 0) { - PLOG(ERROR) << "listen() returned an error"; - return MapSystemError(errno); - } - - return OK; + DCHECK(socket_); + return socket_->Listen(backlog); } -int TCPSocketLibevent::Accept(scoped_ptr<TCPSocketLibevent>* socket, +int TCPSocketLibevent::Accept(scoped_ptr<TCPSocketLibevent>* tcp_socket, IPEndPoint* address, const CompletionCallback& callback) { - DCHECK(CalledOnValidThread()); - DCHECK(socket); - DCHECK(address); + DCHECK(tcp_socket); DCHECK(!callback.is_null()); - DCHECK(accept_callback_.is_null()); + DCHECK(socket_); + DCHECK(!accept_socket_); net_log_.BeginEvent(NetLog::TYPE_TCP_ACCEPT); - int result = AcceptInternal(socket, address); - - if (result == ERR_IO_PENDING) { - if (!base::MessageLoopForIO::current()->WatchFileDescriptor( - socket_, true, base::MessageLoopForIO::WATCH_READ, - &accept_socket_watcher_, &accept_watcher_)) { - PLOG(ERROR) << "WatchFileDescriptor failed on read"; - return MapSystemError(errno); - } - - accept_socket_ = socket; - accept_address_ = address; - accept_callback_ = callback; - } - - return result; + int rv = socket_->Accept( + &accept_socket_, + base::Bind(&TCPSocketLibevent::AcceptCompleted, + base::Unretained(this), tcp_socket, address, callback)); + if (rv != ERR_IO_PENDING) + rv = HandleAcceptCompleted(tcp_socket, address, rv); + return rv; } int TCPSocketLibevent::Connect(const IPEndPoint& address, const CompletionCallback& callback) { - DCHECK(CalledOnValidThread()); - DCHECK_NE(socket_, kInvalidSocket); - DCHECK(!waiting_connect_); - - // |peer_address_| will be non-NULL if Connect() has been called. Unless - // Close() is called to reset the internal state, a second call to Connect() - // is not allowed. - // Please note that we don't allow a second Connect() even if the previous - // Connect() has failed. Connecting the same |socket_| again after a - // connection attempt failed results in unspecified behavior according to - // POSIX. - DCHECK(!peer_address_); + DCHECK(socket_); if (!logging_multiple_connect_attempts_) LogConnectBegin(AddressList(address)); - peer_address_.reset(new IPEndPoint(address)); + net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT, + CreateNetLogIPEndPointCallback(&address)); - int rv = DoConnect(); - if (rv == ERR_IO_PENDING) { - // Synchronous operation not supported. - DCHECK(!callback.is_null()); - write_callback_ = callback; - waiting_connect_ = true; - } else { - DoConnectComplete(rv); + SockaddrStorage storage; + if (!address.ToSockAddr(storage.addr, &storage.addr_len)) + return ERR_ADDRESS_INVALID; + + if (use_tcp_fastopen_) { + // With TCP FastOpen, we pretend that the socket is connected. + DCHECK(!tcp_fastopen_write_attempted_); + socket_->SetPeerAddress(storage); + return OK; } + int rv = socket_->Connect(storage, + base::Bind(&TCPSocketLibevent::ConnectCompleted, + base::Unretained(this), callback)); + if (rv != ERR_IO_PENDING) + rv = HandleConnectCompleted(rv); return rv; } bool TCPSocketLibevent::IsConnected() const { - DCHECK(CalledOnValidThread()); - - if (socket_ == kInvalidSocket || waiting_connect_) + if (!socket_) return false; - if (use_tcp_fastopen_ && !tcp_fastopen_connected_ && peer_address_) { + if (use_tcp_fastopen_ && !tcp_fastopen_write_attempted_ && + socket_->HasPeerAddress()) { // With TCP FastOpen, we pretend that the socket is connected. // This allows GetPeerAddress() to return peer_address_. return true; } - // Check if connection is alive. - char c; - int rv = HANDLE_EINTR(recv(socket_, &c, 1, MSG_PEEK)); - if (rv == 0) - return false; - if (rv == -1 && errno != EAGAIN && errno != EWOULDBLOCK) - return false; - - return true; + return socket_->IsConnected(); } bool TCPSocketLibevent::IsConnectedAndIdle() const { - DCHECK(CalledOnValidThread()); - - if (socket_ == kInvalidSocket || waiting_connect_) - return false; - // TODO(wtc): should we also handle the TCP FastOpen case here, // as we do in IsConnected()? - - // Check if connection is alive and we haven't received any data - // unexpectedly. - char c; - int rv = HANDLE_EINTR(recv(socket_, &c, 1, MSG_PEEK)); - if (rv >= 0) - return false; - if (errno != EAGAIN && errno != EWOULDBLOCK) - return false; - - return true; + return socket_ && socket_->IsConnectedAndIdle(); } int TCPSocketLibevent::Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback) { - DCHECK(CalledOnValidThread()); - DCHECK_NE(kInvalidSocket, socket_); - DCHECK(!waiting_connect_); - DCHECK(read_callback_.is_null()); - // Synchronous operation not supported + DCHECK(socket_); DCHECK(!callback.is_null()); - DCHECK_GT(buf_len, 0); - - int nread = HANDLE_EINTR(read(socket_, buf->data(), buf_len)); - if (nread >= 0) { - base::StatsCounter read_bytes("tcp.read_bytes"); - read_bytes.Add(nread); - net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED, nread, - buf->data()); - RecordFastOpenStatus(); - return nread; - } - if (errno != EAGAIN && errno != EWOULDBLOCK) { - int net_error = MapSystemError(errno); - net_log_.AddEvent(NetLog::TYPE_SOCKET_READ_ERROR, - CreateNetLogSocketErrorCallback(net_error, errno)); - return net_error; - } - if (!base::MessageLoopForIO::current()->WatchFileDescriptor( - socket_, true, base::MessageLoopForIO::WATCH_READ, - &read_socket_watcher_, &read_watcher_)) { - DVLOG(1) << "WatchFileDescriptor failed on read, errno " << errno; - return MapSystemError(errno); - } - - read_buf_ = buf; - read_buf_len_ = buf_len; - read_callback_ = callback; - return ERR_IO_PENDING; + int rv = socket_->Read( + buf, buf_len, + base::Bind(&TCPSocketLibevent::ReadCompleted, + // Grab a reference to |buf| so that ReadCompleted() can still + // use it when Read() completes, as otherwise, this transfers + // ownership of buf to socket. + base::Unretained(this), make_scoped_refptr(buf), callback)); + if (rv != ERR_IO_PENDING) + rv = HandleReadCompleted(buf, rv); + return rv; } int TCPSocketLibevent::Write(IOBuffer* buf, int buf_len, const CompletionCallback& callback) { - DCHECK(CalledOnValidThread()); - DCHECK_NE(kInvalidSocket, socket_); - DCHECK(!waiting_connect_); - DCHECK(write_callback_.is_null()); - // Synchronous operation not supported + DCHECK(socket_); DCHECK(!callback.is_null()); - DCHECK_GT(buf_len, 0); - - int nwrite = InternalWrite(buf, buf_len); - if (nwrite >= 0) { - base::StatsCounter write_bytes("tcp.write_bytes"); - write_bytes.Add(nwrite); - net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_SENT, nwrite, - buf->data()); - return nwrite; - } - if (errno != EAGAIN && errno != EWOULDBLOCK) { - int net_error = MapSystemError(errno); - net_log_.AddEvent(NetLog::TYPE_SOCKET_WRITE_ERROR, - CreateNetLogSocketErrorCallback(net_error, errno)); - return net_error; - } - if (!base::MessageLoopForIO::current()->WatchFileDescriptor( - socket_, true, base::MessageLoopForIO::WATCH_WRITE, - &write_socket_watcher_, &write_watcher_)) { - DVLOG(1) << "WatchFileDescriptor failed on write, errno " << errno; - return MapSystemError(errno); + CompletionCallback write_callback = + base::Bind(&TCPSocketLibevent::WriteCompleted, + // Grab a reference to |buf| so that WriteCompleted() can still + // use it when Write() completes, as otherwise, this transfers + // ownership of buf to socket. + base::Unretained(this), make_scoped_refptr(buf), callback); + int rv; + + if (use_tcp_fastopen_ && !tcp_fastopen_write_attempted_) { + rv = TcpFastOpenWrite(buf, buf_len, write_callback); + } else { + rv = socket_->Write(buf, buf_len, write_callback); } - write_buf_ = buf; - write_buf_len_ = buf_len; - write_callback_ = callback; - return ERR_IO_PENDING; + if (rv != ERR_IO_PENDING) + rv = HandleWriteCompleted(buf, rv); + return rv; } int TCPSocketLibevent::GetLocalAddress(IPEndPoint* address) const { - DCHECK(CalledOnValidThread()); DCHECK(address); + if (!socket_) + return ERR_SOCKET_NOT_CONNECTED; + SockaddrStorage storage; - if (getsockname(socket_, storage.addr, &storage.addr_len) < 0) - return MapSystemError(errno); + int rv = socket_->GetLocalAddress(&storage); + if (rv != OK) + return rv; + if (!address->FromSockAddr(storage.addr, storage.addr_len)) return ERR_ADDRESS_INVALID; @@ -439,25 +316,34 @@ int TCPSocketLibevent::GetLocalAddress(IPEndPoint* address) const { } int TCPSocketLibevent::GetPeerAddress(IPEndPoint* address) const { - DCHECK(CalledOnValidThread()); DCHECK(address); + if (!IsConnected()) return ERR_SOCKET_NOT_CONNECTED; - *address = *peer_address_; + + SockaddrStorage storage; + int rv = socket_->GetPeerAddress(&storage); + if (rv != OK) + return rv; + + if (!address->FromSockAddr(storage.addr, storage.addr_len)) + return ERR_ADDRESS_INVALID; + return OK; } int TCPSocketLibevent::SetDefaultOptionsForServer() { - DCHECK(CalledOnValidThread()); + DCHECK(socket_); return SetAddressReuse(true); } void TCPSocketLibevent::SetDefaultOptionsForClient() { - DCHECK(CalledOnValidThread()); + DCHECK(socket_); // This mirrors the behaviour on Windows. See the comment in // tcp_socket_win.cc after searching for "NODELAY". - SetTCPNoDelay(socket_, true); // If SetTCPNoDelay fails, we don't care. + // If SetTCPNoDelay fails, we don't care. + SetTCPNoDelay(socket_->socket_fd(), true); // TCP keep alive wakes up the radio, which is expensive on mobile. Do not // enable it there. It's useful to prevent TCP middleboxes from timing out @@ -473,12 +359,12 @@ void TCPSocketLibevent::SetDefaultOptionsForClient() { #if !defined(OS_ANDROID) && !defined(OS_IOS) const int kTCPKeepAliveSeconds = 45; - SetTCPKeepAlive(socket_, true, kTCPKeepAliveSeconds); + SetTCPKeepAlive(socket_->socket_fd(), true, kTCPKeepAliveSeconds); #endif } int TCPSocketLibevent::SetAddressReuse(bool allow) { - DCHECK(CalledOnValidThread()); + DCHECK(socket_); // SO_REUSEADDR is useful for server sockets to bind to a recently unbound // port. When a socket is closed, the end point changes its state to TIME_WAIT @@ -494,82 +380,74 @@ int TCPSocketLibevent::SetAddressReuse(bool allow) { // // SO_REUSEPORT is provided in MacOS X and iOS. int boolean_value = allow ? 1 : 0; - int rv = setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR, &boolean_value, - sizeof(boolean_value)); + int rv = setsockopt(socket_->socket_fd(), SOL_SOCKET, SO_REUSEADDR, + &boolean_value, sizeof(boolean_value)); if (rv < 0) return MapSystemError(errno); return OK; } int TCPSocketLibevent::SetReceiveBufferSize(int32 size) { - DCHECK(CalledOnValidThread()); - int rv = setsockopt(socket_, SOL_SOCKET, SO_RCVBUF, + DCHECK(socket_); + int rv = setsockopt(socket_->socket_fd(), SOL_SOCKET, SO_RCVBUF, reinterpret_cast<const char*>(&size), sizeof(size)); return (rv == 0) ? OK : MapSystemError(errno); } int TCPSocketLibevent::SetSendBufferSize(int32 size) { - DCHECK(CalledOnValidThread()); - int rv = setsockopt(socket_, SOL_SOCKET, SO_SNDBUF, + DCHECK(socket_); + int rv = setsockopt(socket_->socket_fd(), SOL_SOCKET, SO_SNDBUF, reinterpret_cast<const char*>(&size), sizeof(size)); return (rv == 0) ? OK : MapSystemError(errno); } bool TCPSocketLibevent::SetKeepAlive(bool enable, int delay) { - DCHECK(CalledOnValidThread()); - return SetTCPKeepAlive(socket_, enable, delay); + DCHECK(socket_); + return SetTCPKeepAlive(socket_->socket_fd(), enable, delay); } bool TCPSocketLibevent::SetNoDelay(bool no_delay) { - DCHECK(CalledOnValidThread()); - return SetTCPNoDelay(socket_, no_delay); + DCHECK(socket_); + return SetTCPNoDelay(socket_->socket_fd(), no_delay); } void TCPSocketLibevent::Close() { - DCHECK(CalledOnValidThread()); - - bool ok = accept_socket_watcher_.StopWatchingFileDescriptor(); - DCHECK(ok); - ok = read_socket_watcher_.StopWatchingFileDescriptor(); - DCHECK(ok); - ok = write_socket_watcher_.StopWatchingFileDescriptor(); - DCHECK(ok); - - if (socket_ != kInvalidSocket) { - if (IGNORE_EINTR(close(socket_)) < 0) - PLOG(ERROR) << "close"; - socket_ = kInvalidSocket; - } - - if (!accept_callback_.is_null()) { - accept_socket_ = NULL; - accept_address_ = NULL; - accept_callback_.Reset(); - } - - if (!read_callback_.is_null()) { - read_buf_ = NULL; - read_buf_len_ = 0; - read_callback_.Reset(); - } + socket_.reset(); - if (!write_callback_.is_null()) { - write_buf_ = NULL; - write_buf_len_ = 0; - write_callback_.Reset(); + // Record and reset TCP FastOpen state. + if (tcp_fastopen_write_attempted_ || + tcp_fastopen_status_ == TCP_FASTOPEN_PREVIOUSLY_FAILED) { + UMA_HISTOGRAM_ENUMERATION("Net.TcpFastOpenSocketConnection", + tcp_fastopen_status_, TCP_FASTOPEN_MAX_VALUE); } - + use_tcp_fastopen_ = false; tcp_fastopen_connected_ = false; - fast_open_status_ = FAST_OPEN_STATUS_UNKNOWN; - waiting_connect_ = false; - peer_address_.reset(); - connect_os_error_ = 0; + tcp_fastopen_write_attempted_ = false; + tcp_fastopen_status_ = TCP_FASTOPEN_STATUS_UNKNOWN; } bool TCPSocketLibevent::UsingTCPFastOpen() const { return use_tcp_fastopen_; } +void TCPSocketLibevent::EnableTCPFastOpenIfSupported() { + if (!IsTCPFastOpenSupported()) + return; + + // Do not enable TCP FastOpen if it had previously failed. + // This check conservatively avoids middleboxes that may blackhole + // TCP FastOpen SYN+Data packets; on such a failure, subsequent sockets + // should not use TCP FastOpen. + if(!g_tcp_fastopen_has_failed) + use_tcp_fastopen_ = true; + else + tcp_fastopen_status_ = TCP_FASTOPEN_PREVIOUSLY_FAILED; +} + +bool TCPSocketLibevent::IsValid() const { + return socket_ != NULL && socket_->socket_fd() != kInvalidSocket; +} + void TCPSocketLibevent::StartLoggingMultipleConnectAttempts( const AddressList& addresses) { if (!logging_multiple_connect_attempts_) { @@ -589,98 +467,76 @@ void TCPSocketLibevent::EndLoggingMultipleConnectAttempts(int net_error) { } } -int TCPSocketLibevent::AcceptInternal(scoped_ptr<TCPSocketLibevent>* socket, - IPEndPoint* address) { - SockaddrStorage storage; - int new_socket = HANDLE_EINTR(accept(socket_, - storage.addr, - &storage.addr_len)); - if (new_socket < 0) { - int net_error = MapAcceptError(errno); - if (net_error != ERR_IO_PENDING) - net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, net_error); - return net_error; - } - - IPEndPoint ip_end_point; - if (!ip_end_point.FromSockAddr(storage.addr, storage.addr_len)) { - NOTREACHED(); - if (IGNORE_EINTR(close(new_socket)) < 0) - PLOG(ERROR) << "close"; - int net_error = ERR_ADDRESS_INVALID; - net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, net_error); - return net_error; - } - scoped_ptr<TCPSocketLibevent> tcp_socket(new TCPSocketLibevent( - net_log_.net_log(), net_log_.source())); - int adopt_result = tcp_socket->AdoptConnectedSocket(new_socket, ip_end_point); - if (adopt_result != OK) { - net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, adopt_result); - return adopt_result; - } - *socket = tcp_socket.Pass(); - *address = ip_end_point; - net_log_.EndEvent(NetLog::TYPE_TCP_ACCEPT, - CreateNetLogIPEndPointCallback(&ip_end_point)); - return OK; +void TCPSocketLibevent::AcceptCompleted( + scoped_ptr<TCPSocketLibevent>* tcp_socket, + IPEndPoint* address, + const CompletionCallback& callback, + int rv) { + DCHECK_NE(ERR_IO_PENDING, rv); + callback.Run(HandleAcceptCompleted(tcp_socket, address, rv)); } -int TCPSocketLibevent::DoConnect() { - DCHECK_EQ(0, connect_os_error_); +int TCPSocketLibevent::HandleAcceptCompleted( + scoped_ptr<TCPSocketLibevent>* tcp_socket, + IPEndPoint* address, + int rv) { + if (rv == OK) + rv = BuildTcpSocketLibevent(tcp_socket, address); - net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT, - CreateNetLogIPEndPointCallback(peer_address_.get())); - - // Connect the socket. - if (!use_tcp_fastopen_) { - SockaddrStorage storage; - if (!peer_address_->ToSockAddr(storage.addr, &storage.addr_len)) - return ERR_ADDRESS_INVALID; - - if (!HANDLE_EINTR(connect(socket_, storage.addr, storage.addr_len))) { - // Connected without waiting! - return OK; - } + if (rv == OK) { + net_log_.EndEvent(NetLog::TYPE_TCP_ACCEPT, + CreateNetLogIPEndPointCallback(address)); } else { - // With TCP FastOpen, we pretend that the socket is connected. - DCHECK(!tcp_fastopen_connected_); - return OK; + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, rv); } - // Check if the connect() failed synchronously. - connect_os_error_ = errno; - if (connect_os_error_ != EINPROGRESS) - return MapConnectError(connect_os_error_); - - // Otherwise the connect() is going to complete asynchronously, so watch - // for its completion. - if (!base::MessageLoopForIO::current()->WatchFileDescriptor( - socket_, true, base::MessageLoopForIO::WATCH_WRITE, - &write_socket_watcher_, &write_watcher_)) { - connect_os_error_ = errno; - DVLOG(1) << "WatchFileDescriptor failed: " << connect_os_error_; - return MapSystemError(connect_os_error_); + return rv; +} + +int TCPSocketLibevent::BuildTcpSocketLibevent( + scoped_ptr<TCPSocketLibevent>* tcp_socket, + IPEndPoint* address) { + DCHECK(accept_socket_); + + SockaddrStorage storage; + if (accept_socket_->GetPeerAddress(&storage) != OK || + !address->FromSockAddr(storage.addr, storage.addr_len)) { + accept_socket_.reset(); + return ERR_ADDRESS_INVALID; } - return ERR_IO_PENDING; + tcp_socket->reset(new TCPSocketLibevent(net_log_.net_log(), + net_log_.source())); + (*tcp_socket)->socket_.reset(accept_socket_.release()); + return OK; +} + +void TCPSocketLibevent::ConnectCompleted(const CompletionCallback& callback, + int rv) const { + DCHECK_NE(ERR_IO_PENDING, rv); + callback.Run(HandleConnectCompleted(rv)); } -void TCPSocketLibevent::DoConnectComplete(int result) { +int TCPSocketLibevent::HandleConnectCompleted(int rv) const { // Log the end of this attempt (and any OS error it threw). - int os_error = connect_os_error_; - connect_os_error_ = 0; - if (result != OK) { + if (rv != OK) { net_log_.EndEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT, - NetLog::IntegerCallback("os_error", os_error)); + NetLog::IntegerCallback("os_error", errno)); } else { net_log_.EndEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT); } + // Give a more specific error when the user is offline. + if (rv == ERR_ADDRESS_UNREACHABLE && NetworkChangeNotifier::IsOffline()) + rv = ERR_INTERNET_DISCONNECTED; + if (!logging_multiple_connect_attempts_) - LogConnectEnd(result); + LogConnectEnd(rv); + + return rv; } -void TCPSocketLibevent::LogConnectBegin(const AddressList& addresses) { +void TCPSocketLibevent::LogConnectBegin(const AddressList& addresses) const { base::StatsCounter connects("tcp.connect"); connects.Increment(); @@ -688,19 +544,18 @@ void TCPSocketLibevent::LogConnectBegin(const AddressList& addresses) { addresses.CreateNetLogCallback()); } -void TCPSocketLibevent::LogConnectEnd(int net_error) { - if (net_error == OK) - UpdateConnectionTypeHistograms(CONNECTION_ANY); - +void TCPSocketLibevent::LogConnectEnd(int net_error) const { if (net_error != OK) { net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_CONNECT, net_error); return; } + UpdateConnectionTypeHistograms(CONNECTION_ANY); + SockaddrStorage storage; - int rv = getsockname(socket_, storage.addr, &storage.addr_len); - if (rv != 0) { - PLOG(ERROR) << "getsockname() [rv: " << rv << "] error: "; + int rv = socket_->GetLocalAddress(&storage); + if (rv != OK) { + PLOG(ERROR) << "GetLocalAddress() [rv: " << rv << "] error: "; NOTREACHED(); net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_CONNECT, rv); return; @@ -711,191 +566,179 @@ void TCPSocketLibevent::LogConnectEnd(int net_error) { storage.addr_len)); } -void TCPSocketLibevent::DidCompleteRead() { - RecordFastOpenStatus(); - if (read_callback_.is_null()) - return; - - int bytes_transferred; - bytes_transferred = HANDLE_EINTR(read(socket_, read_buf_->data(), - read_buf_len_)); - - int result; - if (bytes_transferred >= 0) { - result = bytes_transferred; - base::StatsCounter read_bytes("tcp.read_bytes"); - read_bytes.Add(bytes_transferred); - net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED, result, - read_buf_->data()); - } else { - result = MapSystemError(errno); - if (result != ERR_IO_PENDING) { - net_log_.AddEvent(NetLog::TYPE_SOCKET_READ_ERROR, - CreateNetLogSocketErrorCallback(result, errno)); - } +void TCPSocketLibevent::ReadCompleted(const scoped_refptr<IOBuffer>& buf, + const CompletionCallback& callback, + int rv) { + DCHECK_NE(ERR_IO_PENDING, rv); + callback.Run(HandleReadCompleted(buf.get(), rv)); +} + +int TCPSocketLibevent::HandleReadCompleted(IOBuffer* buf, int rv) { + if (tcp_fastopen_write_attempted_ && !tcp_fastopen_connected_) { + // A TCP FastOpen connect-with-write was attempted. This read was a + // subsequent read, which either succeeded or failed. If the read + // succeeded, the socket is considered connected via TCP FastOpen. + // If the read failed, TCP FastOpen is (conservatively) turned off for all + // subsequent connections. TCP FastOpen status is recorded in both cases. + // TODO (jri): This currently results in conservative behavior, where TCP + // FastOpen is turned off on _any_ error. Implement optimizations, + // such as turning off TCP FastOpen on more specific errors, and + // re-attempting TCP FastOpen after a certain amount of time has passed. + if (rv >= 0) + tcp_fastopen_connected_ = true; + else + g_tcp_fastopen_has_failed = true; + UpdateTCPFastOpenStatusAfterRead(); + } + + if (rv < 0) { + net_log_.AddEvent(NetLog::TYPE_SOCKET_READ_ERROR, + CreateNetLogSocketErrorCallback(rv, errno)); + return rv; } - if (result != ERR_IO_PENDING) { - read_buf_ = NULL; - read_buf_len_ = 0; - bool ok = read_socket_watcher_.StopWatchingFileDescriptor(); - DCHECK(ok); - base::ResetAndReturn(&read_callback_).Run(result); - } + base::StatsCounter read_bytes("tcp.read_bytes"); + read_bytes.Add(rv); + net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED, rv, + buf->data()); + return rv; } -void TCPSocketLibevent::DidCompleteWrite() { - if (write_callback_.is_null()) - return; - - int bytes_transferred; - bytes_transferred = HANDLE_EINTR(write(socket_, write_buf_->data(), - write_buf_len_)); - - int result; - if (bytes_transferred >= 0) { - result = bytes_transferred; - base::StatsCounter write_bytes("tcp.write_bytes"); - write_bytes.Add(bytes_transferred); - net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_SENT, result, - write_buf_->data()); - } else { - result = MapSystemError(errno); - if (result != ERR_IO_PENDING) { - net_log_.AddEvent(NetLog::TYPE_SOCKET_WRITE_ERROR, - CreateNetLogSocketErrorCallback(result, errno)); +void TCPSocketLibevent::WriteCompleted(const scoped_refptr<IOBuffer>& buf, + const CompletionCallback& callback, + int rv) { + DCHECK_NE(ERR_IO_PENDING, rv); + callback.Run(HandleWriteCompleted(buf.get(), rv)); +} + +int TCPSocketLibevent::HandleWriteCompleted(IOBuffer* buf, int rv) { + if (rv < 0) { + if (tcp_fastopen_write_attempted_ && !tcp_fastopen_connected_) { + // TCP FastOpen connect-with-write was attempted, and the write failed + // for unknown reasons. Record status and (conservatively) turn off + // TCP FastOpen for all subsequent connections. + // TODO (jri): This currently results in conservative behavior, where TCP + // FastOpen is turned off on _any_ error. Implement optimizations, + // such as turning off TCP FastOpen on more specific errors, and + // re-attempting TCP FastOpen after a certain amount of time has passed. + tcp_fastopen_status_ = TCP_FASTOPEN_ERROR; + g_tcp_fastopen_has_failed = true; } + net_log_.AddEvent(NetLog::TYPE_SOCKET_WRITE_ERROR, + CreateNetLogSocketErrorCallback(rv, errno)); + return rv; } - if (result != ERR_IO_PENDING) { - write_buf_ = NULL; - write_buf_len_ = 0; - write_socket_watcher_.StopWatchingFileDescriptor(); - base::ResetAndReturn(&write_callback_).Run(result); - } + base::StatsCounter write_bytes("tcp.write_bytes"); + write_bytes.Add(rv); + net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_SENT, rv, + buf->data()); + return rv; } -void TCPSocketLibevent::DidCompleteConnect() { - DCHECK(waiting_connect_); - - // Get the error that connect() completed with. - int os_error = 0; - socklen_t len = sizeof(os_error); - if (getsockopt(socket_, SOL_SOCKET, SO_ERROR, &os_error, &len) < 0) - os_error = errno; +int TCPSocketLibevent::TcpFastOpenWrite( + IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + SockaddrStorage storage; + int rv = socket_->GetPeerAddress(&storage); + if (rv != OK) + return rv; - int result = MapConnectError(os_error); - connect_os_error_ = os_error; - if (result != ERR_IO_PENDING) { - DoConnectComplete(result); - waiting_connect_ = false; - write_socket_watcher_.StopWatchingFileDescriptor(); - base::ResetAndReturn(&write_callback_).Run(result); + int flags = 0x20000000; // Magic flag to enable TCP_FASTOPEN. +#if defined(OS_LINUX) || defined(OS_ANDROID) + // sendto() will fail with EPIPE when the system doesn't implement TCP + // FastOpen, and with EOPNOTSUPP when the system implements TCP FastOpen + // but it is disabled. Theoretically these shouldn't happen + // since the caller should check for system support on startup, but + // users may dynamically disable TCP FastOpen via sysctl. + flags |= MSG_NOSIGNAL; +#endif // defined(OS_LINUX) || defined(OS_ANDROID) + rv = HANDLE_EINTR(sendto(socket_->socket_fd(), + buf->data(), + buf_len, + flags, + storage.addr, + storage.addr_len)); + tcp_fastopen_write_attempted_ = true; + + if (rv >= 0) { + tcp_fastopen_status_ = TCP_FASTOPEN_FAST_CONNECT_RETURN; + return rv; + } + + DCHECK_NE(EPIPE, errno); + + // If errno == EINPROGRESS, that means the kernel didn't have a cookie + // and would block. The kernel is internally doing a connect() though. + // Remap EINPROGRESS to EAGAIN so we treat this the same as our other + // asynchronous cases. Note that the user buffer has not been copied to + // kernel space. + if (errno == EINPROGRESS) { + rv = ERR_IO_PENDING; + } else { + rv = MapSystemError(errno); } -} - -void TCPSocketLibevent::DidCompleteConnectOrWrite() { - if (waiting_connect_) - DidCompleteConnect(); - else - DidCompleteWrite(); -} -void TCPSocketLibevent::DidCompleteAccept() { - DCHECK(CalledOnValidThread()); - - int result = AcceptInternal(accept_socket_, accept_address_); - if (result != ERR_IO_PENDING) { - accept_socket_ = NULL; - accept_address_ = NULL; - bool ok = accept_socket_watcher_.StopWatchingFileDescriptor(); - DCHECK(ok); - CompletionCallback callback = accept_callback_; - accept_callback_.Reset(); - callback.Run(result); + if (rv != ERR_IO_PENDING) { + // TCP FastOpen connect-with-write was attempted, and the write failed + // since TCP FastOpen was not implemented or disabled in the OS. + // Record status and turn off TCP FastOpen for all subsequent connections. + // TODO (jri): This is almost certainly too conservative, since it blanket + // turns off TCP FastOpen on any write error. Two things need to be done + // here: (i) record a histogram of write errors; in particular, record + // occurrences of EOPNOTSUPP and EPIPE, and (ii) afterwards, consider + // turning off TCP FastOpen on more specific errors. + tcp_fastopen_status_ = TCP_FASTOPEN_ERROR; + g_tcp_fastopen_has_failed = true; + return rv; } + + tcp_fastopen_status_ = TCP_FASTOPEN_SLOW_CONNECT_RETURN; + return socket_->WaitForWrite(buf, buf_len, callback); } -int TCPSocketLibevent::InternalWrite(IOBuffer* buf, int buf_len) { - int nwrite; - if (use_tcp_fastopen_ && !tcp_fastopen_connected_) { - SockaddrStorage storage; - if (!peer_address_->ToSockAddr(storage.addr, &storage.addr_len)) { - // Set errno to EADDRNOTAVAIL so that MapSystemError will map it to - // ERR_ADDRESS_INVALID later. - errno = EADDRNOTAVAIL; - return -1; - } +void TCPSocketLibevent::UpdateTCPFastOpenStatusAfterRead() { + DCHECK(tcp_fastopen_status_ == TCP_FASTOPEN_FAST_CONNECT_RETURN || + tcp_fastopen_status_ == TCP_FASTOPEN_SLOW_CONNECT_RETURN); - int flags = 0x20000000; // Magic flag to enable TCP_FASTOPEN. -#if defined(OS_LINUX) - // sendto() will fail with EPIPE when the system doesn't support TCP Fast - // Open. Theoretically that shouldn't happen since the caller should check - // for system support on startup, but users may dynamically disable TCP Fast - // Open via sysctl. - flags |= MSG_NOSIGNAL; -#endif // defined(OS_LINUX) - nwrite = HANDLE_EINTR(sendto(socket_, - buf->data(), - buf_len, - flags, - storage.addr, - storage.addr_len)); - tcp_fastopen_connected_ = true; - - if (nwrite < 0) { - DCHECK_NE(EPIPE, errno); - - // If errno == EINPROGRESS, that means the kernel didn't have a cookie - // and would block. The kernel is internally doing a connect() though. - // Remap EINPROGRESS to EAGAIN so we treat this the same as our other - // asynchronous cases. Note that the user buffer has not been copied to - // kernel space. - if (errno == EINPROGRESS) { - errno = EAGAIN; - fast_open_status_ = FAST_OPEN_SLOW_CONNECT_RETURN; - } else { - fast_open_status_ = FAST_OPEN_ERROR; - } - } else { - fast_open_status_ = FAST_OPEN_FAST_CONNECT_RETURN; - } - } else { - nwrite = HANDLE_EINTR(write(socket_, buf->data(), buf_len)); + if (tcp_fastopen_write_attempted_ && !tcp_fastopen_connected_) { + // TCP FastOpen connect-with-write was attempted, and failed. + tcp_fastopen_status_ = + (tcp_fastopen_status_ == TCP_FASTOPEN_FAST_CONNECT_RETURN ? + TCP_FASTOPEN_FAST_CONNECT_READ_FAILED : + TCP_FASTOPEN_SLOW_CONNECT_READ_FAILED); + return; } - return nwrite; -} -void TCPSocketLibevent::RecordFastOpenStatus() { - if (use_tcp_fastopen_ && - (fast_open_status_ == FAST_OPEN_FAST_CONNECT_RETURN || - fast_open_status_ == FAST_OPEN_SLOW_CONNECT_RETURN)) { - DCHECK_NE(FAST_OPEN_STATUS_UNKNOWN, fast_open_status_); - bool getsockopt_success(false); - bool server_acked_data(false); + bool getsockopt_success = false; + bool server_acked_data = false; #if defined(TCP_INFO) - // Probe to see the if the socket used TCP Fast Open. - tcp_info info; - socklen_t info_len = sizeof(tcp_info); - getsockopt_success = - getsockopt(socket_, IPPROTO_TCP, TCP_INFO, &info, &info_len) == 0 && - info_len == sizeof(tcp_info); - server_acked_data = getsockopt_success && - (info.tcpi_options & TCPI_OPT_SYN_DATA); + // Probe to see the if the socket used TCP FastOpen. + tcp_info info; + socklen_t info_len = sizeof(tcp_info); + getsockopt_success = getsockopt(socket_->socket_fd(), IPPROTO_TCP, TCP_INFO, + &info, &info_len) == 0 && + info_len == sizeof(tcp_info); + server_acked_data = getsockopt_success && + (info.tcpi_options & TCPI_OPT_SYN_DATA); #endif - if (getsockopt_success) { - if (fast_open_status_ == FAST_OPEN_FAST_CONNECT_RETURN) { - fast_open_status_ = (server_acked_data ? FAST_OPEN_SYN_DATA_ACK : - FAST_OPEN_SYN_DATA_NACK); - } else { - fast_open_status_ = (server_acked_data ? FAST_OPEN_NO_SYN_DATA_ACK : - FAST_OPEN_NO_SYN_DATA_NACK); - } + + if (getsockopt_success) { + if (tcp_fastopen_status_ == TCP_FASTOPEN_FAST_CONNECT_RETURN) { + tcp_fastopen_status_ = (server_acked_data ? + TCP_FASTOPEN_SYN_DATA_ACK : + TCP_FASTOPEN_SYN_DATA_NACK); } else { - fast_open_status_ = (fast_open_status_ == FAST_OPEN_FAST_CONNECT_RETURN ? - FAST_OPEN_SYN_DATA_FAILED : - FAST_OPEN_NO_SYN_DATA_FAILED); + tcp_fastopen_status_ = (server_acked_data ? + TCP_FASTOPEN_NO_SYN_DATA_ACK : + TCP_FASTOPEN_NO_SYN_DATA_NACK); } + } else { + tcp_fastopen_status_ = + (tcp_fastopen_status_ == TCP_FASTOPEN_FAST_CONNECT_RETURN ? + TCP_FASTOPEN_SYN_DATA_GETSOCKOPT_FAILED : + TCP_FASTOPEN_NO_SYN_DATA_GETSOCKOPT_FAILED); } } diff --git a/chromium/net/socket/tcp_socket_libevent.h b/chromium/net/socket/tcp_socket_libevent.h index 9ef235be85d..0958b6d25d0 100644 --- a/chromium/net/socket/tcp_socket_libevent.h +++ b/chromium/net/socket/tcp_socket_libevent.h @@ -8,30 +8,27 @@ #include "base/basictypes.h" #include "base/callback.h" #include "base/compiler_specific.h" -#include "base/memory/ref_counted.h" #include "base/memory/scoped_ptr.h" -#include "base/message_loop/message_loop.h" -#include "base/threading/non_thread_safe.h" #include "net/base/address_family.h" #include "net/base/completion_callback.h" #include "net/base/net_export.h" #include "net/base/net_log.h" -#include "net/socket/socket_descriptor.h" namespace net { class AddressList; class IOBuffer; class IPEndPoint; +class SocketLibevent; -class NET_EXPORT TCPSocketLibevent : public base::NonThreadSafe { +class NET_EXPORT TCPSocketLibevent { public: TCPSocketLibevent(NetLog* net_log, const NetLog::Source& source); virtual ~TCPSocketLibevent(); int Open(AddressFamily family); - // Takes ownership of |socket|. - int AdoptConnectedSocket(int socket, const IPEndPoint& peer_address); + // Takes ownership of |socket_fd|. + int AdoptConnectedSocket(int socket_fd, const IPEndPoint& peer_address); int Bind(const IPEndPoint& address); @@ -68,8 +65,11 @@ class NET_EXPORT TCPSocketLibevent : public base::NonThreadSafe { void Close(); + // Setter/Getter methods for TCP FastOpen socket option. bool UsingTCPFastOpen() const; - bool IsValid() const { return socket_ != kInvalidSocket; } + void EnableTCPFastOpenIfSupported(); + + bool IsValid() const; // Marks the start/end of a series of connect attempts for logging purpose. // @@ -87,141 +87,120 @@ class NET_EXPORT TCPSocketLibevent : public base::NonThreadSafe { const BoundNetLog& net_log() const { return net_log_; } private: - // States that a fast open socket attempt can result in. - enum FastOpenStatus { - FAST_OPEN_STATUS_UNKNOWN, + // States that using a socket with TCP FastOpen can lead to. + enum TCPFastOpenStatus { + TCP_FASTOPEN_STATUS_UNKNOWN, - // The initial fast open connect attempted returned synchronously, + // The initial FastOpen connect attempted returned synchronously, // indicating that we had and sent a cookie along with the initial data. - FAST_OPEN_FAST_CONNECT_RETURN, + TCP_FASTOPEN_FAST_CONNECT_RETURN, - // The initial fast open connect attempted returned asynchronously, + // The initial FastOpen connect attempted returned asynchronously, // indicating that we did not have a cookie for the server. - FAST_OPEN_SLOW_CONNECT_RETURN, + TCP_FASTOPEN_SLOW_CONNECT_RETURN, // Some other error occurred on connection, so we couldn't tell if - // fast open would have worked. - FAST_OPEN_ERROR, + // FastOpen would have worked. + TCP_FASTOPEN_ERROR, - // An attempt to do a fast open succeeded immediately - // (FAST_OPEN_FAST_CONNECT_RETURN) and we later confirmed that the server + // An attempt to do a FastOpen succeeded immediately + // (TCP_FASTOPEN_FAST_CONNECT_RETURN) and we later confirmed that the server // had acked the data we sent. - FAST_OPEN_SYN_DATA_ACK, + TCP_FASTOPEN_SYN_DATA_ACK, - // An attempt to do a fast open succeeded immediately - // (FAST_OPEN_FAST_CONNECT_RETURN) and we later confirmed that the server + // An attempt to do a FastOpen succeeded immediately + // (TCP_FASTOPEN_FAST_CONNECT_RETURN) and we later confirmed that the server // had nacked the data we sent. - FAST_OPEN_SYN_DATA_NACK, + TCP_FASTOPEN_SYN_DATA_NACK, - // An attempt to do a fast open succeeded immediately - // (FAST_OPEN_FAST_CONNECT_RETURN) and our probe to determine if the - // socket was using fast open failed. - FAST_OPEN_SYN_DATA_FAILED, + // An attempt to do a FastOpen succeeded immediately + // (TCP_FASTOPEN_FAST_CONNECT_RETURN) and our probe to determine if the + // socket was using FastOpen failed. + TCP_FASTOPEN_SYN_DATA_GETSOCKOPT_FAILED, - // An attempt to do a fast open failed (FAST_OPEN_SLOW_CONNECT_RETURN) + // An attempt to do a FastOpen failed (TCP_FASTOPEN_SLOW_CONNECT_RETURN) // and we later confirmed that the server had acked initial data. This // should never happen (we didn't send data, so it shouldn't have // been acked). - FAST_OPEN_NO_SYN_DATA_ACK, + TCP_FASTOPEN_NO_SYN_DATA_ACK, - // An attempt to do a fast open failed (FAST_OPEN_SLOW_CONNECT_RETURN) + // An attempt to do a FastOpen failed (TCP_FASTOPEN_SLOW_CONNECT_RETURN) // and we later discovered that the server had nacked initial data. This - // is the expected case results for FAST_OPEN_SLOW_CONNECT_RETURN. - FAST_OPEN_NO_SYN_DATA_NACK, + // is the expected case results for TCP_FASTOPEN_SLOW_CONNECT_RETURN. + TCP_FASTOPEN_NO_SYN_DATA_NACK, - // An attempt to do a fast open failed (FAST_OPEN_SLOW_CONNECT_RETURN) + // An attempt to do a FastOpen failed (TCP_FASTOPEN_SLOW_CONNECT_RETURN) // and our later probe for ack/nack state failed. - FAST_OPEN_NO_SYN_DATA_FAILED, - - FAST_OPEN_MAX_VALUE - }; - - // Watcher simply forwards notifications to Closure objects set via the - // constructor. - class Watcher: public base::MessageLoopForIO::Watcher { - public: - Watcher(const base::Closure& read_ready_callback, - const base::Closure& write_ready_callback); - virtual ~Watcher(); - - // base::MessageLoopForIO::Watcher methods. - virtual void OnFileCanReadWithoutBlocking(int fd) OVERRIDE; - virtual void OnFileCanWriteWithoutBlocking(int fd) OVERRIDE; - - private: - base::Closure read_ready_callback_; - base::Closure write_ready_callback_; - - DISALLOW_COPY_AND_ASSIGN(Watcher); + TCP_FASTOPEN_NO_SYN_DATA_GETSOCKOPT_FAILED, + + // The initial FastOpen connect+write succeeded immediately + // (TCP_FASTOPEN_FAST_CONNECT_RETURN) and a subsequent attempt to read from + // the connection failed. + TCP_FASTOPEN_FAST_CONNECT_READ_FAILED, + + // The initial FastOpen connect+write failed + // (TCP_FASTOPEN_SLOW_CONNECT_RETURN) + // and a subsequent attempt to read from the connection failed. + TCP_FASTOPEN_SLOW_CONNECT_READ_FAILED, + + // We didn't try FastOpen because it had failed in the past + // (g_tcp_fastopen_has_failed was true.) + // NOTE: This status is currently registered before a connect/write call + // is attempted, and may capture some cases where the status is registered + // but no connect is subsequently attempted. + // TODO(jri): The expectation is that such cases are not the common case + // with TCP FastOpen for SSL sockets however. Change code to be more + // accurate when TCP FastOpen is used for more than just SSL sockets. + TCP_FASTOPEN_PREVIOUSLY_FAILED, + + TCP_FASTOPEN_MAX_VALUE }; - int AcceptInternal(scoped_ptr<TCPSocketLibevent>* socket, - IPEndPoint* address); - - int DoConnect(); - void DoConnectComplete(int result); - - void LogConnectBegin(const AddressList& addresses); - void LogConnectEnd(int net_error); - - void DidCompleteRead(); - void DidCompleteWrite(); - void DidCompleteConnect(); - void DidCompleteConnectOrWrite(); - void DidCompleteAccept(); - - // Internal function to write to a socket. Returns an OS error. - int InternalWrite(IOBuffer* buf, int buf_len); - - // Called when the socket is known to be in a connected state. - void RecordFastOpenStatus(); - - int socket_; - - base::MessageLoopForIO::FileDescriptorWatcher accept_socket_watcher_; - Watcher accept_watcher_; - - scoped_ptr<TCPSocketLibevent>* accept_socket_; - IPEndPoint* accept_address_; - CompletionCallback accept_callback_; - - // The socket's libevent wrappers for reads and writes. - base::MessageLoopForIO::FileDescriptorWatcher read_socket_watcher_; - base::MessageLoopForIO::FileDescriptorWatcher write_socket_watcher_; - - // The corresponding watchers for reads and writes. - Watcher read_watcher_; - Watcher write_watcher_; - - // The buffer used for reads. - scoped_refptr<IOBuffer> read_buf_; - int read_buf_len_; - - // The buffer used for writes. - scoped_refptr<IOBuffer> write_buf_; - int write_buf_len_; - - // External callback; called when read is complete. - CompletionCallback read_callback_; - - // External callback; called when write or connect is complete. - CompletionCallback write_callback_; + void AcceptCompleted(scoped_ptr<TCPSocketLibevent>* tcp_socket, + IPEndPoint* address, + const CompletionCallback& callback, + int rv); + int HandleAcceptCompleted(scoped_ptr<TCPSocketLibevent>* tcp_socket, + IPEndPoint* address, + int rv); + int BuildTcpSocketLibevent(scoped_ptr<TCPSocketLibevent>* tcp_socket, + IPEndPoint* address); + + void ConnectCompleted(const CompletionCallback& callback, int rv) const; + int HandleConnectCompleted(int rv) const; + void LogConnectBegin(const AddressList& addresses) const; + void LogConnectEnd(int net_error) const; + + void ReadCompleted(const scoped_refptr<IOBuffer>& buf, + const CompletionCallback& callback, + int rv); + int HandleReadCompleted(IOBuffer* buf, int rv); + + void WriteCompleted(const scoped_refptr<IOBuffer>& buf, + const CompletionCallback& callback, + int rv); + int HandleWriteCompleted(IOBuffer* buf, int rv); + int TcpFastOpenWrite(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback); + + // Called after the first read completes on a TCP FastOpen socket. + void UpdateTCPFastOpenStatusAfterRead(); + + scoped_ptr<SocketLibevent> socket_; + scoped_ptr<SocketLibevent> accept_socket_; // Enables experimental TCP FastOpen option. - const bool use_tcp_fastopen_; + bool use_tcp_fastopen_; + + // True when TCP FastOpen is in use and we have attempted the + // connect with write. + bool tcp_fastopen_write_attempted_; // True when TCP FastOpen is in use and we have done the connect. bool tcp_fastopen_connected_; - FastOpenStatus fast_open_status_; - - // A connect operation is pending. In this case, |write_callback_| needs to be - // called when connect is complete. - bool waiting_connect_; - - scoped_ptr<IPEndPoint> peer_address_; - // The OS error that a connect attempt last completed with. - int connect_os_error_; + TCPFastOpenStatus tcp_fastopen_status_; bool logging_multiple_connect_attempts_; diff --git a/chromium/net/socket/tcp_socket_win.cc b/chromium/net/socket/tcp_socket_win.cc index 88db36fd41c..d5565ad669c 100644 --- a/chromium/net/socket/tcp_socket_win.cc +++ b/chromium/net/socket/tcp_socket_win.cc @@ -2,6 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +#include "net/socket/tcp_socket.h" #include "net/socket/tcp_socket_win.h" #include <mstcpip.h> @@ -9,6 +10,7 @@ #include "base/callback_helpers.h" #include "base/logging.h" #include "base/metrics/stats_counters.h" +#include "base/profiler/scoped_tracker.h" #include "base/win/windows_version.h" #include "net/base/address_list.h" #include "net/base/connection_type_histograms.h" @@ -123,6 +125,12 @@ int MapConnectError(int os_error) { //----------------------------------------------------------------------------- +// Nothing to do for Windows since it doesn't support TCP FastOpen. +// TODO(jri): Remove these along with the corresponding global variables. +bool IsTCPFastOpenSupported() { return false; } +bool IsTCPFastOpenUserEnabled() { return false; } +void CheckSupportAndMaybeEnableTCPFastOpen(bool user_enabled) {} + // This class encapsulates all the state that has to be preserved as long as // there is a network IO operation in progress. If the owner TCPSocketWin is // destroyed while an operation is in progress, the Core is detached and it @@ -238,6 +246,11 @@ void TCPSocketWin::Core::WatchForWrite() { } void TCPSocketWin::Core::ReadDelegate::OnObjectSignaled(HANDLE object) { + // TODO(vadimt): Remove ScopedTracker below once crbug.com/418183 is fixed. + tracked_objects::ScopedTracker tracking_profile( + FROM_HERE_WITH_EXPLICIT_FUNCTION( + "TCPSocketWin_Core_ReadDelegate_OnObjectSignaled")); + DCHECK_EQ(object, core_->read_overlapped_.hEvent); if (core_->socket_) { if (core_->socket_->waiting_connect_) @@ -251,6 +264,11 @@ void TCPSocketWin::Core::ReadDelegate::OnObjectSignaled(HANDLE object) { void TCPSocketWin::Core::WriteDelegate::OnObjectSignaled( HANDLE object) { + // TODO(vadimt): Remove ScopedTracker below once crbug.com/418183 is fixed. + tracked_objects::ScopedTracker tracking_profile( + FROM_HERE_WITH_EXPLICIT_FUNCTION( + "TCPSocketWin_Core_WriteDelegate_OnObjectSignaled")); + DCHECK_EQ(object, core_->write_overlapped_.hEvent); if (core_->socket_) core_->socket_->DidCompleteWrite(); @@ -485,7 +503,7 @@ int TCPSocketWin::Read(IOBuffer* buf, DCHECK(CalledOnValidThread()); DCHECK_NE(socket_, INVALID_SOCKET); DCHECK(!waiting_read_); - DCHECK(read_callback_.is_null()); + CHECK(read_callback_.is_null()); DCHECK(!core_->read_iobuffer_); return DoRead(buf, buf_len, callback); @@ -497,7 +515,7 @@ int TCPSocketWin::Write(IOBuffer* buf, DCHECK(CalledOnValidThread()); DCHECK_NE(socket_, INVALID_SOCKET); DCHECK(!waiting_write_); - DCHECK(write_callback_.is_null()); + CHECK(write_callback_.is_null()); DCHECK_GT(buf_len, 0); DCHECK(!core_->write_iobuffer_); @@ -694,11 +712,6 @@ void TCPSocketWin::Close() { connect_os_error_ = 0; } -bool TCPSocketWin::UsingTCPFastOpen() const { - // Not supported on windows. - return false; -} - void TCPSocketWin::StartLoggingMultipleConnectAttempts( const AddressList& addresses) { if (!logging_multiple_connect_attempts_) { @@ -753,6 +766,10 @@ int TCPSocketWin::AcceptInternal(scoped_ptr<TCPSocketWin>* socket, } void TCPSocketWin::OnObjectSignaled(HANDLE object) { + // TODO(vadimt): Remove ScopedTracker below once crbug.com/418183 is fixed. + tracked_objects::ScopedTracker tracking_profile( + FROM_HERE_WITH_EXPLICIT_FUNCTION("TCPSocketWin_OnObjectSignaled")); + WSANETWORKEVENTS ev; if (WSAEnumNetworkEvents(socket_, accept_event_, &ev) == SOCKET_ERROR) { PLOG(ERROR) << "WSAEnumNetworkEvents()"; @@ -1017,8 +1034,10 @@ void TCPSocketWin::DidSignalRead() { core_->read_buffer_length_ = 0; DCHECK_NE(rv, ERR_IO_PENDING); + // TODO(vadimt): Remove ScopedTracker below once crbug.com/418183 is fixed. + tracked_objects::ScopedTracker tracking_profile( + FROM_HERE_WITH_EXPLICIT_FUNCTION("TCPSocketWin::DidSignalRead")); base::ResetAndReturn(&read_callback_).Run(rv); } } // namespace net - diff --git a/chromium/net/socket/tcp_socket_win.h b/chromium/net/socket/tcp_socket_win.h index 2ddd538f211..80174adea9c 100644 --- a/chromium/net/socket/tcp_socket_win.h +++ b/chromium/net/socket/tcp_socket_win.h @@ -75,7 +75,11 @@ class NET_EXPORT TCPSocketWin : NON_EXPORTED_BASE(public base::NonThreadSafe), void Close(); - bool UsingTCPFastOpen() const; + // Setter/Getter methods for TCP FastOpen socket option. + // NOOPs since TCP FastOpen is not implemented in Windows. + bool UsingTCPFastOpen() const { return false; } + void EnableTCPFastOpenIfSupported() {} + bool IsValid() const { return socket_ != INVALID_SOCKET; } // Marks the start/end of a series of connect attempts for logging purpose. @@ -97,7 +101,7 @@ class NET_EXPORT TCPSocketWin : NON_EXPORTED_BASE(public base::NonThreadSafe), class Core; // base::ObjectWatcher::Delegate implementation. - virtual void OnObjectSignaled(HANDLE object) OVERRIDE; + virtual void OnObjectSignaled(HANDLE object) override; int AcceptInternal(scoped_ptr<TCPSocketWin>* socket, IPEndPoint* address); @@ -152,4 +156,3 @@ class NET_EXPORT TCPSocketWin : NON_EXPORTED_BASE(public base::NonThreadSafe), } // namespace net #endif // NET_SOCKET_TCP_SOCKET_WIN_H_ - diff --git a/chromium/net/socket/transport_client_socket_pool.cc b/chromium/net/socket/transport_client_socket_pool.cc index dc481fa4b6d..06202f13d81 100644 --- a/chromium/net/socket/transport_client_socket_pool.cc +++ b/chromium/net/socket/transport_client_socket_pool.cc @@ -31,7 +31,7 @@ namespace net { // TODO(willchan): Base this off RTT instead of statically setting it. Note we // choose a timeout that is different from the backup connect job timer so they // don't synchronize. -const int TransportConnectJob::kIPv6FallbackTimerInMs = 300; +const int TransportConnectJobHelper::kIPv6FallbackTimerInMs = 300; namespace { @@ -60,12 +60,21 @@ TransportSocketParams::TransportSocketParams( const HostPortPair& host_port_pair, bool disable_resolver_cache, bool ignore_limits, - const OnHostResolutionCallback& host_resolution_callback) + const OnHostResolutionCallback& host_resolution_callback, + CombineConnectAndWritePolicy combine_connect_and_write_if_supported) : destination_(host_port_pair), ignore_limits_(ignore_limits), - host_resolution_callback_(host_resolution_callback) { + host_resolution_callback_(host_resolution_callback), + combine_connect_and_write_(combine_connect_and_write_if_supported) { if (disable_resolver_cache) destination_.set_allow_cached_response(false); + // combine_connect_and_write currently translates to TCP FastOpen. + // Enable TCP FastOpen if user wants it. + if (combine_connect_and_write_ == COMBINE_CONNECT_AND_WRITE_DEFAULT) { + IsTCPFastOpenUserEnabled() ? combine_connect_and_write_ = + COMBINE_CONNECT_AND_WRITE_DESIRED : + COMBINE_CONNECT_AND_WRITE_PROHIBITED; + } } TransportSocketParams::~TransportSocketParams() {} @@ -81,6 +90,107 @@ TransportSocketParams::~TransportSocketParams() {} // See comment #12 at http://crbug.com/23364 for specifics. static const int kTransportConnectJobTimeoutInSeconds = 240; // 4 minutes. +TransportConnectJobHelper::TransportConnectJobHelper( + const scoped_refptr<TransportSocketParams>& params, + ClientSocketFactory* client_socket_factory, + HostResolver* host_resolver, + LoadTimingInfo::ConnectTiming* connect_timing) + : params_(params), + client_socket_factory_(client_socket_factory), + resolver_(host_resolver), + next_state_(STATE_NONE), + connect_timing_(connect_timing) {} + +TransportConnectJobHelper::~TransportConnectJobHelper() {} + +int TransportConnectJobHelper::DoResolveHost(RequestPriority priority, + const BoundNetLog& net_log) { + next_state_ = STATE_RESOLVE_HOST_COMPLETE; + connect_timing_->dns_start = base::TimeTicks::Now(); + + return resolver_.Resolve( + params_->destination(), priority, &addresses_, on_io_complete_, net_log); +} + +int TransportConnectJobHelper::DoResolveHostComplete( + int result, + const BoundNetLog& net_log) { + connect_timing_->dns_end = base::TimeTicks::Now(); + // Overwrite connection start time, since for connections that do not go + // through proxies, |connect_start| should not include dns lookup time. + connect_timing_->connect_start = connect_timing_->dns_end; + + if (result == OK) { + // Invoke callback, and abort if it fails. + if (!params_->host_resolution_callback().is_null()) + result = params_->host_resolution_callback().Run(addresses_, net_log); + + if (result == OK) + next_state_ = STATE_TRANSPORT_CONNECT; + } + return result; +} + +base::TimeDelta TransportConnectJobHelper::HistogramDuration( + ConnectionLatencyHistogram race_result) { + DCHECK(!connect_timing_->connect_start.is_null()); + DCHECK(!connect_timing_->dns_start.is_null()); + base::TimeTicks now = base::TimeTicks::Now(); + base::TimeDelta total_duration = now - connect_timing_->dns_start; + UMA_HISTOGRAM_CUSTOM_TIMES("Net.DNS_Resolution_And_TCP_Connection_Latency2", + total_duration, + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(10), + 100); + + base::TimeDelta connect_duration = now - connect_timing_->connect_start; + UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency", + connect_duration, + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(10), + 100); + + switch (race_result) { + case CONNECTION_LATENCY_IPV4_WINS_RACE: + UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency_IPv4_Wins_Race", + connect_duration, + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(10), + 100); + break; + + case CONNECTION_LATENCY_IPV4_NO_RACE: + UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency_IPv4_No_Race", + connect_duration, + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(10), + 100); + break; + + case CONNECTION_LATENCY_IPV6_RACEABLE: + UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency_IPv6_Raceable", + connect_duration, + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(10), + 100); + break; + + case CONNECTION_LATENCY_IPV6_SOLO: + UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency_IPv6_Solo", + connect_duration, + base::TimeDelta::FromMilliseconds(1), + base::TimeDelta::FromMinutes(10), + 100); + break; + + default: + NOTREACHED(); + break; + } + + return connect_duration; +} + TransportConnectJob::TransportConnectJob( const std::string& group_name, RequestPriority priority, @@ -92,11 +202,9 @@ TransportConnectJob::TransportConnectJob( NetLog* net_log) : ConnectJob(group_name, timeout_duration, priority, delegate, BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)), - params_(params), - client_socket_factory_(client_socket_factory), - resolver_(host_resolver), - next_state_(STATE_NONE), + helper_(params, client_socket_factory, host_resolver, &connect_timing_), interval_between_connects_(CONNECT_INTERVAL_GT_20MS) { + helper_.SetOnIOComplete(this); } TransportConnectJob::~TransportConnectJob() { @@ -105,14 +213,14 @@ TransportConnectJob::~TransportConnectJob() { } LoadState TransportConnectJob::GetLoadState() const { - switch (next_state_) { - case STATE_RESOLVE_HOST: - case STATE_RESOLVE_HOST_COMPLETE: + switch (helper_.next_state()) { + case TransportConnectJobHelper::STATE_RESOLVE_HOST: + case TransportConnectJobHelper::STATE_RESOLVE_HOST_COMPLETE: return LOAD_STATE_RESOLVING_HOST; - case STATE_TRANSPORT_CONNECT: - case STATE_TRANSPORT_CONNECT_COMPLETE: + case TransportConnectJobHelper::STATE_TRANSPORT_CONNECT: + case TransportConnectJobHelper::STATE_TRANSPORT_CONNECT_COMPLETE: return LOAD_STATE_CONNECTING; - case STATE_NONE: + case TransportConnectJobHelper::STATE_NONE: return LOAD_STATE_IDLE; } NOTREACHED(); @@ -129,71 +237,12 @@ void TransportConnectJob::MakeAddressListStartWithIPv4(AddressList* list) { } } -void TransportConnectJob::OnIOComplete(int result) { - int rv = DoLoop(result); - if (rv != ERR_IO_PENDING) - NotifyDelegateOfCompletion(rv); // Deletes |this| -} - -int TransportConnectJob::DoLoop(int result) { - DCHECK_NE(next_state_, STATE_NONE); - - int rv = result; - do { - State state = next_state_; - next_state_ = STATE_NONE; - switch (state) { - case STATE_RESOLVE_HOST: - DCHECK_EQ(OK, rv); - rv = DoResolveHost(); - break; - case STATE_RESOLVE_HOST_COMPLETE: - rv = DoResolveHostComplete(rv); - break; - case STATE_TRANSPORT_CONNECT: - DCHECK_EQ(OK, rv); - rv = DoTransportConnect(); - break; - case STATE_TRANSPORT_CONNECT_COMPLETE: - rv = DoTransportConnectComplete(rv); - break; - default: - NOTREACHED(); - rv = ERR_FAILED; - break; - } - } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE); - - return rv; -} - int TransportConnectJob::DoResolveHost() { - next_state_ = STATE_RESOLVE_HOST_COMPLETE; - connect_timing_.dns_start = base::TimeTicks::Now(); - - return resolver_.Resolve( - params_->destination(), - priority(), - &addresses_, - base::Bind(&TransportConnectJob::OnIOComplete, base::Unretained(this)), - net_log()); + return helper_.DoResolveHost(priority(), net_log()); } int TransportConnectJob::DoResolveHostComplete(int result) { - connect_timing_.dns_end = base::TimeTicks::Now(); - // Overwrite connection start time, since for connections that do not go - // through proxies, |connect_start| should not include dns lookup time. - connect_timing_.connect_start = connect_timing_.dns_end; - - if (result == OK) { - // Invoke callback, and abort if it fails. - if (!params_->host_resolution_callback().is_null()) - result = params_->host_resolution_callback().Run(addresses_, net_log()); - - if (result == OK) - next_state_ = STATE_TRANSPORT_CONNECT; - } - return result; + return helper_.DoResolveHostComplete(result, net_log()); } int TransportConnectJob::DoTransportConnect() { @@ -216,42 +265,57 @@ int TransportConnectJob::DoTransportConnect() { interval_between_connects_ = CONNECT_INTERVAL_GT_20MS; } - next_state_ = STATE_TRANSPORT_CONNECT_COMPLETE; - transport_socket_ = client_socket_factory_->CreateTransportClientSocket( - addresses_, net_log().net_log(), net_log().source()); - int rv = transport_socket_->Connect( - base::Bind(&TransportConnectJob::OnIOComplete, base::Unretained(this))); - if (rv == ERR_IO_PENDING && - addresses_.front().GetFamily() == ADDRESS_FAMILY_IPV6 && - !AddressListOnlyContainsIPv6(addresses_)) { - fallback_timer_.Start(FROM_HERE, - base::TimeDelta::FromMilliseconds(kIPv6FallbackTimerInMs), - this, &TransportConnectJob::DoIPv6FallbackTransportConnect); + helper_.set_next_state( + TransportConnectJobHelper::STATE_TRANSPORT_CONNECT_COMPLETE); + transport_socket_ = + helper_.client_socket_factory()->CreateTransportClientSocket( + helper_.addresses(), net_log().net_log(), net_log().source()); + + // If the list contains IPv6 and IPv4 addresses, the first address will + // be IPv6, and the IPv4 addresses will be tried as fallback addresses, + // per "Happy Eyeballs" (RFC 6555). + bool try_ipv6_connect_with_ipv4_fallback = + helper_.addresses().front().GetFamily() == ADDRESS_FAMILY_IPV6 && + !AddressListOnlyContainsIPv6(helper_.addresses()); + + // Enable TCP FastOpen if indicated by transport socket params. + // Note: We currently do not turn on TCP FastOpen for destinations where + // we try a TCP connect over IPv6 with fallback to IPv4. + if (!try_ipv6_connect_with_ipv4_fallback && + helper_.params()->combine_connect_and_write() == + TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DESIRED) { + transport_socket_->EnableTCPFastOpenIfSupported(); + } + + int rv = transport_socket_->Connect(helper_.on_io_complete()); + if (rv == ERR_IO_PENDING && try_ipv6_connect_with_ipv4_fallback) { + fallback_timer_.Start( + FROM_HERE, + base::TimeDelta::FromMilliseconds( + TransportConnectJobHelper::kIPv6FallbackTimerInMs), + this, + &TransportConnectJob::DoIPv6FallbackTransportConnect); } return rv; } int TransportConnectJob::DoTransportConnectComplete(int result) { if (result == OK) { - bool is_ipv4 = addresses_.front().GetFamily() == ADDRESS_FAMILY_IPV4; - DCHECK(!connect_timing_.connect_start.is_null()); - DCHECK(!connect_timing_.dns_start.is_null()); - base::TimeTicks now = base::TimeTicks::Now(); - base::TimeDelta total_duration = now - connect_timing_.dns_start; - UMA_HISTOGRAM_CUSTOM_TIMES( - "Net.DNS_Resolution_And_TCP_Connection_Latency2", - total_duration, - base::TimeDelta::FromMilliseconds(1), - base::TimeDelta::FromMinutes(10), - 100); - - base::TimeDelta connect_duration = now - connect_timing_.connect_start; - UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency", - connect_duration, - base::TimeDelta::FromMilliseconds(1), - base::TimeDelta::FromMinutes(10), - 100); - + bool is_ipv4 = + helper_.addresses().front().GetFamily() == ADDRESS_FAMILY_IPV4; + TransportConnectJobHelper::ConnectionLatencyHistogram race_result = + TransportConnectJobHelper::CONNECTION_LATENCY_UNKNOWN; + if (is_ipv4) { + race_result = TransportConnectJobHelper::CONNECTION_LATENCY_IPV4_NO_RACE; + } else { + if (AddressListOnlyContainsIPv6(helper_.addresses())) { + race_result = TransportConnectJobHelper::CONNECTION_LATENCY_IPV6_SOLO; + } else { + race_result = + TransportConnectJobHelper::CONNECTION_LATENCY_IPV6_RACEABLE; + } + } + base::TimeDelta connect_duration = helper_.HistogramDuration(race_result); switch (interval_between_connects_) { case CONNECT_INTERVAL_LE_10MS: UMA_HISTOGRAM_CUSTOM_TIMES( @@ -282,27 +346,6 @@ int TransportConnectJob::DoTransportConnectComplete(int result) { break; } - if (is_ipv4) { - UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency_IPv4_No_Race", - connect_duration, - base::TimeDelta::FromMilliseconds(1), - base::TimeDelta::FromMinutes(10), - 100); - } else { - if (AddressListOnlyContainsIPv6(addresses_)) { - UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency_IPv6_Solo", - connect_duration, - base::TimeDelta::FromMilliseconds(1), - base::TimeDelta::FromMinutes(10), - 100); - } else { - UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency_IPv6_Raceable", - connect_duration, - base::TimeDelta::FromMilliseconds(1), - base::TimeDelta::FromMinutes(10), - 100); - } - } SetSocket(transport_socket_.Pass()); fallback_timer_.Stop(); } else { @@ -317,7 +360,8 @@ int TransportConnectJob::DoTransportConnectComplete(int result) { void TransportConnectJob::DoIPv6FallbackTransportConnect() { // The timer should only fire while we're waiting for the main connect to // succeed. - if (next_state_ != STATE_TRANSPORT_CONNECT_COMPLETE) { + if (helper_.next_state() != + TransportConnectJobHelper::STATE_TRANSPORT_CONNECT_COMPLETE) { NOTREACHED(); return; } @@ -325,10 +369,10 @@ void TransportConnectJob::DoIPv6FallbackTransportConnect() { DCHECK(!fallback_transport_socket_.get()); DCHECK(!fallback_addresses_.get()); - fallback_addresses_.reset(new AddressList(addresses_)); + fallback_addresses_.reset(new AddressList(helper_.addresses())); MakeAddressListStartWithIPv4(fallback_addresses_.get()); fallback_transport_socket_ = - client_socket_factory_->CreateTransportClientSocket( + helper_.client_socket_factory()->CreateTransportClientSocket( *fallback_addresses_, net_log().net_log(), net_log().source()); fallback_connect_start_time_ = base::TimeTicks::Now(); int rv = fallback_transport_socket_->Connect( @@ -341,7 +385,8 @@ void TransportConnectJob::DoIPv6FallbackTransportConnect() { void TransportConnectJob::DoIPv6FallbackTransportConnectComplete(int result) { // This should only happen when we're waiting for the main connect to succeed. - if (next_state_ != STATE_TRANSPORT_CONNECT_COMPLETE) { + if (helper_.next_state() != + TransportConnectJobHelper::STATE_TRANSPORT_CONNECT_COMPLETE) { NOTREACHED(); return; } @@ -352,30 +397,11 @@ void TransportConnectJob::DoIPv6FallbackTransportConnectComplete(int result) { if (result == OK) { DCHECK(!fallback_connect_start_time_.is_null()); - DCHECK(!connect_timing_.dns_start.is_null()); - base::TimeTicks now = base::TimeTicks::Now(); - base::TimeDelta total_duration = now - connect_timing_.dns_start; - UMA_HISTOGRAM_CUSTOM_TIMES( - "Net.DNS_Resolution_And_TCP_Connection_Latency2", - total_duration, - base::TimeDelta::FromMilliseconds(1), - base::TimeDelta::FromMinutes(10), - 100); - - base::TimeDelta connect_duration = now - fallback_connect_start_time_; - UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency", - connect_duration, - base::TimeDelta::FromMilliseconds(1), - base::TimeDelta::FromMinutes(10), - 100); - - UMA_HISTOGRAM_CUSTOM_TIMES("Net.TCP_Connection_Latency_IPv4_Wins_Race", - connect_duration, - base::TimeDelta::FromMilliseconds(1), - base::TimeDelta::FromMinutes(10), - 100); + connect_timing_.connect_start = fallback_connect_start_time_; + helper_.HistogramDuration( + TransportConnectJobHelper::CONNECTION_LATENCY_IPV4_WINS_RACE); SetSocket(fallback_transport_socket_.Pass()); - next_state_ = STATE_NONE; + helper_.set_next_state(TransportConnectJobHelper::STATE_NONE); transport_socket_.reset(); } else { // Be a bit paranoid and kill off the fallback members to prevent reuse. @@ -386,12 +412,11 @@ void TransportConnectJob::DoIPv6FallbackTransportConnectComplete(int result) { } int TransportConnectJob::ConnectInternal() { - next_state_ = STATE_RESOLVE_HOST; - return DoLoop(OK); + return helper_.DoConnectInternal(this); } scoped_ptr<ConnectJob> - TransportClientSocketPool::TransportConnectJobFactory::NewConnectJob( +TransportClientSocketPool::TransportConnectJobFactory::NewConnectJob( const std::string& group_name, const PoolBase::Request& request, ConnectJob::Delegate* delegate) const { @@ -439,6 +464,15 @@ int TransportClientSocketPool::RequestSocket( const scoped_refptr<TransportSocketParams>* casted_params = static_cast<const scoped_refptr<TransportSocketParams>*>(params); + NetLogTcpClientSocketPoolRequestedSocket(net_log, casted_params); + + return base_.RequestSocket(group_name, *casted_params, priority, handle, + callback, net_log); +} + +void TransportClientSocketPool::NetLogTcpClientSocketPoolRequestedSocket( + const BoundNetLog& net_log, + const scoped_refptr<TransportSocketParams>* casted_params) { if (net_log.IsLogging()) { // TODO(eroman): Split out the host and port parameters. net_log.AddEvent( @@ -446,9 +480,6 @@ int TransportClientSocketPool::RequestSocket( CreateNetLogHostPortPairCallback( &casted_params->get()->destination().host_port_pair())); } - - return base_.RequestSocket(group_name, *casted_params, priority, handle, - callback, net_log); } void TransportClientSocketPool::RequestSockets( diff --git a/chromium/net/socket/transport_client_socket_pool.h b/chromium/net/socket/transport_client_socket_pool.h index 1c22bf29ec3..15cef5c02ac 100644 --- a/chromium/net/socket/transport_client_socket_pool.h +++ b/chromium/net/socket/transport_client_socket_pool.h @@ -29,14 +29,30 @@ OnHostResolutionCallback; class NET_EXPORT_PRIVATE TransportSocketParams : public base::RefCounted<TransportSocketParams> { public: + // CombineConnectAndWrite currently translates to using TCP FastOpen. + // TCP FastOpen should not be used if the first write to the socket may + // be non-idempotent, as the underlying socket could retransmit the data + // on failure of the first transmission. + // NOTE: Currently, COMBINE_CONNECT_AND_WRITE_DESIRED is used if the data in + // the write is known to be idempotent, and COMBINE_CONNECT_AND_WRITE_DEFAULT + // is used as a default for other cases (including non-idempotent writes). + enum CombineConnectAndWritePolicy { + COMBINE_CONNECT_AND_WRITE_DEFAULT, // Default policy, implemented in + // TransportSocketParams constructor. + COMBINE_CONNECT_AND_WRITE_DESIRED, // Combine if supported by socket. + COMBINE_CONNECT_AND_WRITE_PROHIBITED // Do not combine. + }; + // |host_resolution_callback| will be invoked after the the hostname is // resolved. If |host_resolution_callback| does not return OK, then the - // connection will be aborted with that value. + // connection will be aborted with that value. |combine_connect_and_write| + // defines the policy for use of TCP FastOpen on this socket. TransportSocketParams( const HostPortPair& host_port_pair, bool disable_resolver_cache, bool ignore_limits, - const OnHostResolutionCallback& host_resolution_callback); + const OnHostResolutionCallback& host_resolution_callback, + CombineConnectAndWritePolicy combine_connect_and_write); const HostResolver::RequestInfo& destination() const { return destination_; } bool ignore_limits() const { return ignore_limits_; } @@ -44,6 +60,10 @@ class NET_EXPORT_PRIVATE TransportSocketParams return host_resolution_callback_; } + CombineConnectAndWritePolicy combine_connect_and_write() const { + return combine_connect_and_write_; + } + private: friend class base::RefCounted<TransportSocketParams>; ~TransportSocketParams(); @@ -51,10 +71,81 @@ class NET_EXPORT_PRIVATE TransportSocketParams HostResolver::RequestInfo destination_; bool ignore_limits_; const OnHostResolutionCallback host_resolution_callback_; + CombineConnectAndWritePolicy combine_connect_and_write_; DISALLOW_COPY_AND_ASSIGN(TransportSocketParams); }; +// Common data and logic shared between TransportConnectJob and +// WebSocketTransportConnectJob. +class NET_EXPORT_PRIVATE TransportConnectJobHelper { + public: + enum State { + STATE_RESOLVE_HOST, + STATE_RESOLVE_HOST_COMPLETE, + STATE_TRANSPORT_CONNECT, + STATE_TRANSPORT_CONNECT_COMPLETE, + STATE_NONE, + }; + + // For recording the connection time in the appropriate bucket. + enum ConnectionLatencyHistogram { + CONNECTION_LATENCY_UNKNOWN, + CONNECTION_LATENCY_IPV4_WINS_RACE, + CONNECTION_LATENCY_IPV4_NO_RACE, + CONNECTION_LATENCY_IPV6_RACEABLE, + CONNECTION_LATENCY_IPV6_SOLO, + }; + + TransportConnectJobHelper(const scoped_refptr<TransportSocketParams>& params, + ClientSocketFactory* client_socket_factory, + HostResolver* host_resolver, + LoadTimingInfo::ConnectTiming* connect_timing); + ~TransportConnectJobHelper(); + + ClientSocketFactory* client_socket_factory() { + return client_socket_factory_; + } + + const AddressList& addresses() const { return addresses_; } + State next_state() const { return next_state_; } + void set_next_state(State next_state) { next_state_ = next_state; } + CompletionCallback on_io_complete() const { return on_io_complete_; } + const TransportSocketParams* params() { return params_.get(); } + + int DoResolveHost(RequestPriority priority, const BoundNetLog& net_log); + int DoResolveHostComplete(int result, const BoundNetLog& net_log); + + template <class T> + int DoConnectInternal(T* job); + + template <class T> + void SetOnIOComplete(T* job); + + template <class T> + void OnIOComplete(T* job, int result); + + // Record the histograms Net.DNS_Resolution_And_TCP_Connection_Latency2 and + // Net.TCP_Connection_Latency and return the connect duration. + base::TimeDelta HistogramDuration(ConnectionLatencyHistogram race_result); + + static const int kIPv6FallbackTimerInMs; + + private: + template <class T> + int DoLoop(T* job, int result); + + scoped_refptr<TransportSocketParams> params_; + ClientSocketFactory* const client_socket_factory_; + SingleRequestHostResolver resolver_; + AddressList addresses_; + State next_state_; + CompletionCallback on_io_complete_; + LoadTimingInfo::ConnectTiming* connect_timing_; + + DISALLOW_COPY_AND_ASSIGN(TransportConnectJobHelper); +}; + // TransportConnectJob handles the host resolution necessary for socket creation // and the transport (likely TCP) connect. TransportConnectJob also has fallback // logic for IPv6 connect() timeouts (which may happen due to networks / routers @@ -73,36 +164,23 @@ class NET_EXPORT_PRIVATE TransportConnectJob : public ConnectJob { HostResolver* host_resolver, Delegate* delegate, NetLog* net_log); - virtual ~TransportConnectJob(); + ~TransportConnectJob() override; // ConnectJob methods. - virtual LoadState GetLoadState() const OVERRIDE; + LoadState GetLoadState() const override; // Rolls |addrlist| forward until the first IPv4 address, if any. // WARNING: this method should only be used to implement the prefer-IPv4 hack. static void MakeAddressListStartWithIPv4(AddressList* addrlist); - static const int kIPv6FallbackTimerInMs; - private: - enum State { - STATE_RESOLVE_HOST, - STATE_RESOLVE_HOST_COMPLETE, - STATE_TRANSPORT_CONNECT, - STATE_TRANSPORT_CONNECT_COMPLETE, - STATE_NONE, - }; - enum ConnectInterval { CONNECT_INTERVAL_LE_10MS, CONNECT_INTERVAL_LE_20MS, CONNECT_INTERVAL_GT_20MS, }; - void OnIOComplete(int result); - - // Runs the state transition loop. - int DoLoop(int result); + friend class TransportConnectJobHelper; int DoResolveHost(); int DoResolveHostComplete(int result); @@ -116,13 +194,9 @@ class NET_EXPORT_PRIVATE TransportConnectJob : public ConnectJob { // Begins the host resolution and the TCP connect. Returns OK on success // and ERR_IO_PENDING if it cannot immediately service the request. // Otherwise, it returns a net error code. - virtual int ConnectInternal() OVERRIDE; + int ConnectInternal() override; - scoped_refptr<TransportSocketParams> params_; - ClientSocketFactory* const client_socket_factory_; - SingleRequestHostResolver resolver_; - AddressList addresses_; - State next_state_; + TransportConnectJobHelper helper_; scoped_ptr<StreamSocket> transport_socket_; @@ -149,43 +223,47 @@ class NET_EXPORT_PRIVATE TransportClientSocketPool : public ClientSocketPool { ClientSocketFactory* client_socket_factory, NetLog* net_log); - virtual ~TransportClientSocketPool(); + ~TransportClientSocketPool() override; // ClientSocketPool implementation. - virtual int RequestSocket(const std::string& group_name, - const void* resolve_info, - RequestPriority priority, - ClientSocketHandle* handle, - const CompletionCallback& callback, - const BoundNetLog& net_log) OVERRIDE; - virtual void RequestSockets(const std::string& group_name, - const void* params, - int num_sockets, - const BoundNetLog& net_log) OVERRIDE; - virtual void CancelRequest(const std::string& group_name, - ClientSocketHandle* handle) OVERRIDE; - virtual void ReleaseSocket(const std::string& group_name, - scoped_ptr<StreamSocket> socket, - int id) OVERRIDE; - virtual void FlushWithError(int error) OVERRIDE; - virtual void CloseIdleSockets() OVERRIDE; - virtual int IdleSocketCount() const OVERRIDE; - virtual int IdleSocketCountInGroup( - const std::string& group_name) const OVERRIDE; - virtual LoadState GetLoadState( - const std::string& group_name, - const ClientSocketHandle* handle) const OVERRIDE; - virtual base::DictionaryValue* GetInfoAsValue( + int RequestSocket(const std::string& group_name, + const void* resolve_info, + RequestPriority priority, + ClientSocketHandle* handle, + const CompletionCallback& callback, + const BoundNetLog& net_log) override; + void RequestSockets(const std::string& group_name, + const void* params, + int num_sockets, + const BoundNetLog& net_log) override; + void CancelRequest(const std::string& group_name, + ClientSocketHandle* handle) override; + void ReleaseSocket(const std::string& group_name, + scoped_ptr<StreamSocket> socket, + int id) override; + void FlushWithError(int error) override; + void CloseIdleSockets() override; + int IdleSocketCount() const override; + int IdleSocketCountInGroup(const std::string& group_name) const override; + LoadState GetLoadState(const std::string& group_name, + const ClientSocketHandle* handle) const override; + base::DictionaryValue* GetInfoAsValue( const std::string& name, const std::string& type, - bool include_nested_pools) const OVERRIDE; - virtual base::TimeDelta ConnectionTimeout() const OVERRIDE; - virtual ClientSocketPoolHistograms* histograms() const OVERRIDE; + bool include_nested_pools) const override; + base::TimeDelta ConnectionTimeout() const override; + ClientSocketPoolHistograms* histograms() const override; // HigherLayeredPool implementation. - virtual bool IsStalled() const OVERRIDE; - virtual void AddHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE; - virtual void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE; + bool IsStalled() const override; + void AddHigherLayeredPool(HigherLayeredPool* higher_pool) override; + void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool) override; + + protected: + // Methods shared with WebSocketTransportClientSocketPool + void NetLogTcpClientSocketPoolRequestedSocket( + const BoundNetLog& net_log, + const scoped_refptr<TransportSocketParams>* casted_params); private: typedef ClientSocketPoolBase<TransportSocketParams> PoolBase; @@ -200,16 +278,16 @@ class NET_EXPORT_PRIVATE TransportClientSocketPool : public ClientSocketPool { host_resolver_(host_resolver), net_log_(net_log) {} - virtual ~TransportConnectJobFactory() {} + ~TransportConnectJobFactory() override {} // ClientSocketPoolBase::ConnectJobFactory methods. - virtual scoped_ptr<ConnectJob> NewConnectJob( + scoped_ptr<ConnectJob> NewConnectJob( const std::string& group_name, const PoolBase::Request& request, - ConnectJob::Delegate* delegate) const OVERRIDE; + ConnectJob::Delegate* delegate) const override; - virtual base::TimeDelta ConnectionTimeout() const OVERRIDE; + base::TimeDelta ConnectionTimeout() const override; private: ClientSocketFactory* const client_socket_factory_; @@ -224,6 +302,61 @@ class NET_EXPORT_PRIVATE TransportClientSocketPool : public ClientSocketPool { DISALLOW_COPY_AND_ASSIGN(TransportClientSocketPool); }; +template <class T> +int TransportConnectJobHelper::DoConnectInternal(T* job) { + next_state_ = STATE_RESOLVE_HOST; + return this->DoLoop(job, OK); +} + +template <class T> +void TransportConnectJobHelper::SetOnIOComplete(T* job) { + // These usages of base::Unretained() are safe because IO callbacks are + // guaranteed not to be called after the object is destroyed. + on_io_complete_ = base::Bind(&TransportConnectJobHelper::OnIOComplete<T>, + base::Unretained(this), + base::Unretained(job)); +} + +template <class T> +void TransportConnectJobHelper::OnIOComplete(T* job, int result) { + result = this->DoLoop(job, result); + if (result != ERR_IO_PENDING) + job->NotifyDelegateOfCompletion(result); // Deletes |job| and |this| +} + +template <class T> +int TransportConnectJobHelper::DoLoop(T* job, int result) { + DCHECK_NE(next_state_, STATE_NONE); + + int rv = result; + do { + State state = next_state_; + next_state_ = STATE_NONE; + switch (state) { + case STATE_RESOLVE_HOST: + DCHECK_EQ(OK, rv); + rv = job->DoResolveHost(); + break; + case STATE_RESOLVE_HOST_COMPLETE: + rv = job->DoResolveHostComplete(rv); + break; + case STATE_TRANSPORT_CONNECT: + DCHECK_EQ(OK, rv); + rv = job->DoTransportConnect(); + break; + case STATE_TRANSPORT_CONNECT_COMPLETE: + rv = job->DoTransportConnectComplete(rv); + break; + default: + NOTREACHED(); + rv = ERR_FAILED; + break; + } + } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE); + + return rv; +} + } // namespace net #endif // NET_SOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_ diff --git a/chromium/net/socket/transport_client_socket_pool_test_util.cc b/chromium/net/socket/transport_client_socket_pool_test_util.cc new file mode 100644 index 00000000000..82ed8e6a78e --- /dev/null +++ b/chromium/net/socket/transport_client_socket_pool_test_util.cc @@ -0,0 +1,424 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/transport_client_socket_pool_test_util.h" + +#include <string> + +#include "base/logging.h" +#include "base/memory/weak_ptr.h" +#include "base/run_loop.h" +#include "net/base/ip_endpoint.h" +#include "net/base/load_timing_info.h" +#include "net/base/load_timing_info_test_util.h" +#include "net/base/net_util.h" +#include "net/socket/client_socket_handle.h" +#include "net/socket/ssl_client_socket.h" +#include "net/udp/datagram_client_socket.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +namespace { + +IPAddressNumber ParseIP(const std::string& ip) { + IPAddressNumber number; + CHECK(ParseIPLiteralToNumber(ip, &number)); + return number; +} + +// A StreamSocket which connects synchronously and successfully. +class MockConnectClientSocket : public StreamSocket { + public: + MockConnectClientSocket(const AddressList& addrlist, net::NetLog* net_log) + : connected_(false), + addrlist_(addrlist), + net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)), + use_tcp_fastopen_(false) {} + + // StreamSocket implementation. + int Connect(const CompletionCallback& callback) override { + connected_ = true; + return OK; + } + void Disconnect() override { connected_ = false; } + bool IsConnected() const override { return connected_; } + bool IsConnectedAndIdle() const override { return connected_; } + + int GetPeerAddress(IPEndPoint* address) const override { + *address = addrlist_.front(); + return OK; + } + int GetLocalAddress(IPEndPoint* address) const override { + if (!connected_) + return ERR_SOCKET_NOT_CONNECTED; + if (addrlist_.front().GetFamily() == ADDRESS_FAMILY_IPV4) + SetIPv4Address(address); + else + SetIPv6Address(address); + return OK; + } + const BoundNetLog& NetLog() const override { return net_log_; } + + void SetSubresourceSpeculation() override {} + void SetOmniboxSpeculation() override {} + bool WasEverUsed() const override { return false; } + void EnableTCPFastOpenIfSupported() override { use_tcp_fastopen_ = true; } + bool UsingTCPFastOpen() const override { return use_tcp_fastopen_; } + bool WasNpnNegotiated() const override { return false; } + NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; } + bool GetSSLInfo(SSLInfo* ssl_info) override { return false; } + + // Socket implementation. + int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override { + return ERR_FAILED; + } + int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override { + return ERR_FAILED; + } + int SetReceiveBufferSize(int32 size) override { return OK; } + int SetSendBufferSize(int32 size) override { return OK; } + + private: + bool connected_; + const AddressList addrlist_; + BoundNetLog net_log_; + bool use_tcp_fastopen_; + + DISALLOW_COPY_AND_ASSIGN(MockConnectClientSocket); +}; + +class MockFailingClientSocket : public StreamSocket { + public: + MockFailingClientSocket(const AddressList& addrlist, net::NetLog* net_log) + : addrlist_(addrlist), + net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)), + use_tcp_fastopen_(false) {} + + // StreamSocket implementation. + int Connect(const CompletionCallback& callback) override { + return ERR_CONNECTION_FAILED; + } + + void Disconnect() override {} + + bool IsConnected() const override { return false; } + bool IsConnectedAndIdle() const override { return false; } + int GetPeerAddress(IPEndPoint* address) const override { + return ERR_UNEXPECTED; + } + int GetLocalAddress(IPEndPoint* address) const override { + return ERR_UNEXPECTED; + } + const BoundNetLog& NetLog() const override { return net_log_; } + + void SetSubresourceSpeculation() override {} + void SetOmniboxSpeculation() override {} + bool WasEverUsed() const override { return false; } + void EnableTCPFastOpenIfSupported() override { use_tcp_fastopen_ = true; } + bool UsingTCPFastOpen() const override { return use_tcp_fastopen_; } + bool WasNpnNegotiated() const override { return false; } + NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; } + bool GetSSLInfo(SSLInfo* ssl_info) override { return false; } + + // Socket implementation. + int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override { + return ERR_FAILED; + } + + int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override { + return ERR_FAILED; + } + int SetReceiveBufferSize(int32 size) override { return OK; } + int SetSendBufferSize(int32 size) override { return OK; } + + private: + const AddressList addrlist_; + BoundNetLog net_log_; + bool use_tcp_fastopen_; + + DISALLOW_COPY_AND_ASSIGN(MockFailingClientSocket); +}; + +class MockTriggerableClientSocket : public StreamSocket { + public: + // |should_connect| indicates whether the socket should successfully complete + // or fail. + MockTriggerableClientSocket(const AddressList& addrlist, + bool should_connect, + net::NetLog* net_log) + : should_connect_(should_connect), + is_connected_(false), + addrlist_(addrlist), + net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)), + use_tcp_fastopen_(false), + weak_factory_(this) {} + + // Call this method to get a closure which will trigger the connect callback + // when called. The closure can be called even after the socket is deleted; it + // will safely do nothing. + base::Closure GetConnectCallback() { + return base::Bind(&MockTriggerableClientSocket::DoCallback, + weak_factory_.GetWeakPtr()); + } + + static scoped_ptr<StreamSocket> MakeMockPendingClientSocket( + const AddressList& addrlist, + bool should_connect, + net::NetLog* net_log) { + scoped_ptr<MockTriggerableClientSocket> socket( + new MockTriggerableClientSocket(addrlist, should_connect, net_log)); + base::MessageLoop::current()->PostTask(FROM_HERE, + socket->GetConnectCallback()); + return socket.Pass(); + } + + static scoped_ptr<StreamSocket> MakeMockDelayedClientSocket( + const AddressList& addrlist, + bool should_connect, + const base::TimeDelta& delay, + net::NetLog* net_log) { + scoped_ptr<MockTriggerableClientSocket> socket( + new MockTriggerableClientSocket(addrlist, should_connect, net_log)); + base::MessageLoop::current()->PostDelayedTask( + FROM_HERE, socket->GetConnectCallback(), delay); + return socket.Pass(); + } + + static scoped_ptr<StreamSocket> MakeMockStalledClientSocket( + const AddressList& addrlist, + net::NetLog* net_log) { + scoped_ptr<MockTriggerableClientSocket> socket( + new MockTriggerableClientSocket(addrlist, true, net_log)); + return socket.Pass(); + } + + // StreamSocket implementation. + int Connect(const CompletionCallback& callback) override { + DCHECK(callback_.is_null()); + callback_ = callback; + return ERR_IO_PENDING; + } + + void Disconnect() override {} + + bool IsConnected() const override { return is_connected_; } + bool IsConnectedAndIdle() const override { return is_connected_; } + int GetPeerAddress(IPEndPoint* address) const override { + *address = addrlist_.front(); + return OK; + } + int GetLocalAddress(IPEndPoint* address) const override { + if (!is_connected_) + return ERR_SOCKET_NOT_CONNECTED; + if (addrlist_.front().GetFamily() == ADDRESS_FAMILY_IPV4) + SetIPv4Address(address); + else + SetIPv6Address(address); + return OK; + } + const BoundNetLog& NetLog() const override { return net_log_; } + + void SetSubresourceSpeculation() override {} + void SetOmniboxSpeculation() override {} + bool WasEverUsed() const override { return false; } + void EnableTCPFastOpenIfSupported() override { use_tcp_fastopen_ = true; } + bool UsingTCPFastOpen() const override { return use_tcp_fastopen_; } + bool WasNpnNegotiated() const override { return false; } + NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; } + bool GetSSLInfo(SSLInfo* ssl_info) override { return false; } + + // Socket implementation. + int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override { + return ERR_FAILED; + } + + int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override { + return ERR_FAILED; + } + int SetReceiveBufferSize(int32 size) override { return OK; } + int SetSendBufferSize(int32 size) override { return OK; } + + private: + void DoCallback() { + is_connected_ = should_connect_; + callback_.Run(is_connected_ ? OK : ERR_CONNECTION_FAILED); + } + + bool should_connect_; + bool is_connected_; + const AddressList addrlist_; + BoundNetLog net_log_; + CompletionCallback callback_; + bool use_tcp_fastopen_; + + base::WeakPtrFactory<MockTriggerableClientSocket> weak_factory_; + + DISALLOW_COPY_AND_ASSIGN(MockTriggerableClientSocket); +}; + +} // namespace + +void TestLoadTimingInfoConnectedReused(const ClientSocketHandle& handle) { + LoadTimingInfo load_timing_info; + // Only pass true in as |is_reused|, as in general, HttpStream types should + // have stricter concepts of reuse than socket pools. + EXPECT_TRUE(handle.GetLoadTimingInfo(true, &load_timing_info)); + + EXPECT_TRUE(load_timing_info.socket_reused); + EXPECT_NE(NetLog::Source::kInvalidId, load_timing_info.socket_log_id); + + ExpectConnectTimingHasNoTimes(load_timing_info.connect_timing); + ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info); +} + +void TestLoadTimingInfoConnectedNotReused(const ClientSocketHandle& handle) { + EXPECT_FALSE(handle.is_reused()); + + LoadTimingInfo load_timing_info; + EXPECT_TRUE(handle.GetLoadTimingInfo(false, &load_timing_info)); + + EXPECT_FALSE(load_timing_info.socket_reused); + EXPECT_NE(NetLog::Source::kInvalidId, load_timing_info.socket_log_id); + + ExpectConnectTimingHasTimes(load_timing_info.connect_timing, + CONNECT_TIMING_HAS_DNS_TIMES); + ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info); + + TestLoadTimingInfoConnectedReused(handle); +} + +void SetIPv4Address(IPEndPoint* address) { + *address = IPEndPoint(ParseIP("1.1.1.1"), 80); +} + +void SetIPv6Address(IPEndPoint* address) { + *address = IPEndPoint(ParseIP("1:abcd::3:4:ff"), 80); +} + +MockTransportClientSocketFactory::MockTransportClientSocketFactory( + NetLog* net_log) + : net_log_(net_log), + allocation_count_(0), + client_socket_type_(MOCK_CLIENT_SOCKET), + client_socket_types_(NULL), + client_socket_index_(0), + client_socket_index_max_(0), + delay_(base::TimeDelta::FromMilliseconds( + ClientSocketPool::kMaxConnectRetryIntervalMs)) {} + +MockTransportClientSocketFactory::~MockTransportClientSocketFactory() {} + +scoped_ptr<DatagramClientSocket> +MockTransportClientSocketFactory::CreateDatagramClientSocket( + DatagramSocket::BindType bind_type, + const RandIntCallback& rand_int_cb, + NetLog* net_log, + const NetLog::Source& source) { + NOTREACHED(); + return scoped_ptr<DatagramClientSocket>(); +} + +scoped_ptr<StreamSocket> +MockTransportClientSocketFactory::CreateTransportClientSocket( + const AddressList& addresses, + NetLog* /* net_log */, + const NetLog::Source& /* source */) { + allocation_count_++; + + ClientSocketType type = client_socket_type_; + if (client_socket_types_ && client_socket_index_ < client_socket_index_max_) { + type = client_socket_types_[client_socket_index_++]; + } + + switch (type) { + case MOCK_CLIENT_SOCKET: + return scoped_ptr<StreamSocket>( + new MockConnectClientSocket(addresses, net_log_)); + case MOCK_FAILING_CLIENT_SOCKET: + return scoped_ptr<StreamSocket>( + new MockFailingClientSocket(addresses, net_log_)); + case MOCK_PENDING_CLIENT_SOCKET: + return MockTriggerableClientSocket::MakeMockPendingClientSocket( + addresses, true, net_log_); + case MOCK_PENDING_FAILING_CLIENT_SOCKET: + return MockTriggerableClientSocket::MakeMockPendingClientSocket( + addresses, false, net_log_); + case MOCK_DELAYED_CLIENT_SOCKET: + return MockTriggerableClientSocket::MakeMockDelayedClientSocket( + addresses, true, delay_, net_log_); + case MOCK_DELAYED_FAILING_CLIENT_SOCKET: + return MockTriggerableClientSocket::MakeMockDelayedClientSocket( + addresses, false, delay_, net_log_); + case MOCK_STALLED_CLIENT_SOCKET: + return MockTriggerableClientSocket::MakeMockStalledClientSocket(addresses, + net_log_); + case MOCK_TRIGGERABLE_CLIENT_SOCKET: { + scoped_ptr<MockTriggerableClientSocket> rv( + new MockTriggerableClientSocket(addresses, true, net_log_)); + triggerable_sockets_.push(rv->GetConnectCallback()); + // run_loop_quit_closure_ behaves like a condition variable. It will + // wake up WaitForTriggerableSocketCreation() if it is sleeping. We + // don't need to worry about atomicity because this code is + // single-threaded. + if (!run_loop_quit_closure_.is_null()) + run_loop_quit_closure_.Run(); + return rv.Pass(); + } + default: + NOTREACHED(); + return scoped_ptr<StreamSocket>( + new MockConnectClientSocket(addresses, net_log_)); + } +} + +scoped_ptr<SSLClientSocket> +MockTransportClientSocketFactory::CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, + const HostPortPair& host_and_port, + const SSLConfig& ssl_config, + const SSLClientSocketContext& context) { + NOTIMPLEMENTED(); + return scoped_ptr<SSLClientSocket>(); +} + +void MockTransportClientSocketFactory::ClearSSLSessionCache() { + NOTIMPLEMENTED(); +} + +void MockTransportClientSocketFactory::set_client_socket_types( + ClientSocketType* type_list, + int num_types) { + DCHECK_GT(num_types, 0); + client_socket_types_ = type_list; + client_socket_index_ = 0; + client_socket_index_max_ = num_types; +} + +base::Closure +MockTransportClientSocketFactory::WaitForTriggerableSocketCreation() { + while (triggerable_sockets_.empty()) { + base::RunLoop run_loop; + run_loop_quit_closure_ = run_loop.QuitClosure(); + run_loop.Run(); + run_loop_quit_closure_.Reset(); + } + base::Closure trigger = triggerable_sockets_.front(); + triggerable_sockets_.pop(); + return trigger; +} + +} // namespace net diff --git a/chromium/net/socket/transport_client_socket_pool_test_util.h b/chromium/net/socket/transport_client_socket_pool_test_util.h new file mode 100644 index 00000000000..b375353f06f --- /dev/null +++ b/chromium/net/socket/transport_client_socket_pool_test_util.h @@ -0,0 +1,127 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Test methods and classes common to transport_client_socket_pool_unittest.cc +// and websocket_transport_client_socket_pool_unittest.cc. If you find you need +// to use these for another purpose, consider moving them to socket_test_util.h. + +#ifndef NET_SOCKET_TRANSPORT_CLIENT_SOCKET_POOL_TEST_UTIL_H_ +#define NET_SOCKET_TRANSPORT_CLIENT_SOCKET_POOL_TEST_UTIL_H_ + +#include <queue> + +#include "base/callback.h" +#include "base/compiler_specific.h" +#include "base/macros.h" +#include "base/memory/scoped_ptr.h" +#include "base/time/time.h" +#include "net/base/address_list.h" +#include "net/base/net_log.h" +#include "net/socket/client_socket_factory.h" +#include "net/socket/client_socket_handle.h" +#include "net/socket/stream_socket.h" + +namespace net { + +class ClientSocketHandle; +class IPEndPoint; + +// Make sure |handle| sets load times correctly when it has been assigned a +// reused socket. Uses gtest expectations. +void TestLoadTimingInfoConnectedReused(const ClientSocketHandle& handle); + +// Make sure |handle| sets load times correctly when it has been assigned a +// fresh socket. Also runs TestLoadTimingInfoConnectedReused, since the owner +// of a connection where |is_reused| is false may consider the connection +// reused. Uses gtest expectations. +void TestLoadTimingInfoConnectedNotReused(const ClientSocketHandle& handle); + +// Set |address| to 1.1.1.1:80 +void SetIPv4Address(IPEndPoint* address); + +// Set |address| to [1:abcd::3:4:ff]:80 +void SetIPv6Address(IPEndPoint* address); + +// A ClientSocketFactory that produces sockets with the specified connection +// behaviours. +class MockTransportClientSocketFactory : public ClientSocketFactory { + public: + enum ClientSocketType { + // Connects successfully, synchronously. + MOCK_CLIENT_SOCKET, + // Fails to connect, synchronously. + MOCK_FAILING_CLIENT_SOCKET, + // Connects successfully, asynchronously. + MOCK_PENDING_CLIENT_SOCKET, + // Fails to connect, asynchronously. + MOCK_PENDING_FAILING_CLIENT_SOCKET, + // A delayed socket will pause before connecting through the message loop. + MOCK_DELAYED_CLIENT_SOCKET, + // A delayed socket that fails. + MOCK_DELAYED_FAILING_CLIENT_SOCKET, + // A stalled socket that never connects at all. + MOCK_STALLED_CLIENT_SOCKET, + // A socket that can be triggered to connect explicitly, asynchronously. + MOCK_TRIGGERABLE_CLIENT_SOCKET, + }; + + explicit MockTransportClientSocketFactory(NetLog* net_log); + ~MockTransportClientSocketFactory() override; + + scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( + DatagramSocket::BindType bind_type, + const RandIntCallback& rand_int_cb, + NetLog* net_log, + const NetLog::Source& source) override; + + scoped_ptr<StreamSocket> CreateTransportClientSocket( + const AddressList& addresses, + NetLog* /* net_log */, + const NetLog::Source& /* source */) override; + + scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, + const HostPortPair& host_and_port, + const SSLConfig& ssl_config, + const SSLClientSocketContext& context) override; + + void ClearSSLSessionCache() override; + + int allocation_count() const { return allocation_count_; } + + // Set the default ClientSocketType. + void set_default_client_socket_type(ClientSocketType type) { + client_socket_type_ = type; + } + + // Set a list of ClientSocketTypes to be used. + void set_client_socket_types(ClientSocketType* type_list, int num_types); + + void set_delay(base::TimeDelta delay) { delay_ = delay; } + + // If one or more MOCK_TRIGGERABLE_CLIENT_SOCKETs has already been created, + // then returns a Closure that can be called to cause the first + // not-yet-connected one to connect. If no MOCK_TRIGGERABLE_CLIENT_SOCKETs + // have been created yet, wait for one to be created before returning the + // Closure. This method should be called the same number of times as + // MOCK_TRIGGERABLE_CLIENT_SOCKETs are created in the test. + base::Closure WaitForTriggerableSocketCreation(); + + private: + NetLog* net_log_; + int allocation_count_; + ClientSocketType client_socket_type_; + ClientSocketType* client_socket_types_; + int client_socket_index_; + int client_socket_index_max_; + base::TimeDelta delay_; + std::queue<base::Closure> triggerable_sockets_; + base::Closure run_loop_quit_closure_; + + DISALLOW_COPY_AND_ASSIGN(MockTransportClientSocketFactory); +}; + +} // namespace net + +#endif // NET_SOCKET_TRANSPORT_CLIENT_SOCKET_POOL_TEST_UTIL_H_ diff --git a/chromium/net/socket/transport_client_socket_pool_unittest.cc b/chromium/net/socket/transport_client_socket_pool_unittest.cc index 425bb8cc421..c0687ef5a43 100644 --- a/chromium/net/socket/transport_client_socket_pool_unittest.cc +++ b/chromium/net/socket/transport_client_socket_pool_unittest.cc @@ -7,8 +7,6 @@ #include "base/bind.h" #include "base/bind_helpers.h" #include "base/callback.h" -#include "base/compiler_specific.h" -#include "base/logging.h" #include "base/message_loop/message_loop.h" #include "base/threading/platform_thread.h" #include "net/base/capturing_net_log.h" @@ -19,12 +17,11 @@ #include "net/base/net_util.h" #include "net/base/test_completion_callback.h" #include "net/dns/mock_host_resolver.h" -#include "net/socket/client_socket_factory.h" #include "net/socket/client_socket_handle.h" #include "net/socket/client_socket_pool_histograms.h" #include "net/socket/socket_test_util.h" -#include "net/socket/ssl_client_socket.h" #include "net/socket/stream_socket.h" +#include "net/socket/transport_client_socket_pool_test_util.h" #include "testing/gtest/include/gtest/gtest.h" namespace net { @@ -35,411 +32,7 @@ namespace { const int kMaxSockets = 32; const int kMaxSocketsPerGroup = 6; -const net::RequestPriority kDefaultPriority = LOW; - -// Make sure |handle| sets load times correctly when it has been assigned a -// reused socket. -void TestLoadTimingInfoConnectedReused(const ClientSocketHandle& handle) { - LoadTimingInfo load_timing_info; - // Only pass true in as |is_reused|, as in general, HttpStream types should - // have stricter concepts of reuse than socket pools. - EXPECT_TRUE(handle.GetLoadTimingInfo(true, &load_timing_info)); - - EXPECT_TRUE(load_timing_info.socket_reused); - EXPECT_NE(NetLog::Source::kInvalidId, load_timing_info.socket_log_id); - - ExpectConnectTimingHasNoTimes(load_timing_info.connect_timing); - ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info); -} - -// Make sure |handle| sets load times correctly when it has been assigned a -// fresh socket. Also runs TestLoadTimingInfoConnectedReused, since the owner -// of a connection where |is_reused| is false may consider the connection -// reused. -void TestLoadTimingInfoConnectedNotReused(const ClientSocketHandle& handle) { - EXPECT_FALSE(handle.is_reused()); - - LoadTimingInfo load_timing_info; - EXPECT_TRUE(handle.GetLoadTimingInfo(false, &load_timing_info)); - - EXPECT_FALSE(load_timing_info.socket_reused); - EXPECT_NE(NetLog::Source::kInvalidId, load_timing_info.socket_log_id); - - ExpectConnectTimingHasTimes(load_timing_info.connect_timing, - CONNECT_TIMING_HAS_DNS_TIMES); - ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info); - - TestLoadTimingInfoConnectedReused(handle); -} - -void SetIPv4Address(IPEndPoint* address) { - IPAddressNumber number; - CHECK(ParseIPLiteralToNumber("1.1.1.1", &number)); - *address = IPEndPoint(number, 80); -} - -void SetIPv6Address(IPEndPoint* address) { - IPAddressNumber number; - CHECK(ParseIPLiteralToNumber("1:abcd::3:4:ff", &number)); - *address = IPEndPoint(number, 80); -} - -class MockClientSocket : public StreamSocket { - public: - MockClientSocket(const AddressList& addrlist, net::NetLog* net_log) - : connected_(false), - addrlist_(addrlist), - net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) { - } - - // StreamSocket implementation. - virtual int Connect(const CompletionCallback& callback) OVERRIDE { - connected_ = true; - return OK; - } - virtual void Disconnect() OVERRIDE { - connected_ = false; - } - virtual bool IsConnected() const OVERRIDE { - return connected_; - } - virtual bool IsConnectedAndIdle() const OVERRIDE { - return connected_; - } - virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE { - return ERR_UNEXPECTED; - } - virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE { - if (!connected_) - return ERR_SOCKET_NOT_CONNECTED; - if (addrlist_.front().GetFamily() == ADDRESS_FAMILY_IPV4) - SetIPv4Address(address); - else - SetIPv6Address(address); - return OK; - } - virtual const BoundNetLog& NetLog() const OVERRIDE { - return net_log_; - } - - virtual void SetSubresourceSpeculation() OVERRIDE {} - virtual void SetOmniboxSpeculation() OVERRIDE {} - virtual bool WasEverUsed() const OVERRIDE { return false; } - virtual bool UsingTCPFastOpen() const OVERRIDE { return false; } - virtual bool WasNpnNegotiated() const OVERRIDE { - return false; - } - virtual NextProto GetNegotiatedProtocol() const OVERRIDE { - return kProtoUnknown; - } - virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { - return false; - } - - // Socket implementation. - virtual int Read(IOBuffer* buf, int buf_len, - const CompletionCallback& callback) OVERRIDE { - return ERR_FAILED; - } - virtual int Write(IOBuffer* buf, int buf_len, - const CompletionCallback& callback) OVERRIDE { - return ERR_FAILED; - } - virtual int SetReceiveBufferSize(int32 size) OVERRIDE { return OK; } - virtual int SetSendBufferSize(int32 size) OVERRIDE { return OK; } - - private: - bool connected_; - const AddressList addrlist_; - BoundNetLog net_log_; - - DISALLOW_COPY_AND_ASSIGN(MockClientSocket); -}; - -class MockFailingClientSocket : public StreamSocket { - public: - MockFailingClientSocket(const AddressList& addrlist, net::NetLog* net_log) - : addrlist_(addrlist), - net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) { - } - - // StreamSocket implementation. - virtual int Connect(const CompletionCallback& callback) OVERRIDE { - return ERR_CONNECTION_FAILED; - } - - virtual void Disconnect() OVERRIDE {} - - virtual bool IsConnected() const OVERRIDE { - return false; - } - virtual bool IsConnectedAndIdle() const OVERRIDE { - return false; - } - virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE { - return ERR_UNEXPECTED; - } - virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE { - return ERR_UNEXPECTED; - } - virtual const BoundNetLog& NetLog() const OVERRIDE { - return net_log_; - } - - virtual void SetSubresourceSpeculation() OVERRIDE {} - virtual void SetOmniboxSpeculation() OVERRIDE {} - virtual bool WasEverUsed() const OVERRIDE { return false; } - virtual bool UsingTCPFastOpen() const OVERRIDE { return false; } - virtual bool WasNpnNegotiated() const OVERRIDE { - return false; - } - virtual NextProto GetNegotiatedProtocol() const OVERRIDE { - return kProtoUnknown; - } - virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { - return false; - } - - // Socket implementation. - virtual int Read(IOBuffer* buf, int buf_len, - const CompletionCallback& callback) OVERRIDE { - return ERR_FAILED; - } - - virtual int Write(IOBuffer* buf, int buf_len, - const CompletionCallback& callback) OVERRIDE { - return ERR_FAILED; - } - virtual int SetReceiveBufferSize(int32 size) OVERRIDE { return OK; } - virtual int SetSendBufferSize(int32 size) OVERRIDE { return OK; } - - private: - const AddressList addrlist_; - BoundNetLog net_log_; - - DISALLOW_COPY_AND_ASSIGN(MockFailingClientSocket); -}; - -class MockPendingClientSocket : public StreamSocket { - public: - // |should_connect| indicates whether the socket should successfully complete - // or fail. - // |should_stall| indicates that this socket should never connect. - // |delay_ms| is the delay, in milliseconds, before simulating a connect. - MockPendingClientSocket( - const AddressList& addrlist, - bool should_connect, - bool should_stall, - base::TimeDelta delay, - net::NetLog* net_log) - : should_connect_(should_connect), - should_stall_(should_stall), - delay_(delay), - is_connected_(false), - addrlist_(addrlist), - net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)), - weak_factory_(this) { - } - - // StreamSocket implementation. - virtual int Connect(const CompletionCallback& callback) OVERRIDE { - base::MessageLoop::current()->PostDelayedTask( - FROM_HERE, - base::Bind(&MockPendingClientSocket::DoCallback, - weak_factory_.GetWeakPtr(), callback), - delay_); - return ERR_IO_PENDING; - } - - virtual void Disconnect() OVERRIDE {} - - virtual bool IsConnected() const OVERRIDE { - return is_connected_; - } - virtual bool IsConnectedAndIdle() const OVERRIDE { - return is_connected_; - } - virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE { - return ERR_UNEXPECTED; - } - virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE { - if (!is_connected_) - return ERR_SOCKET_NOT_CONNECTED; - if (addrlist_.front().GetFamily() == ADDRESS_FAMILY_IPV4) - SetIPv4Address(address); - else - SetIPv6Address(address); - return OK; - } - virtual const BoundNetLog& NetLog() const OVERRIDE { - return net_log_; - } - - virtual void SetSubresourceSpeculation() OVERRIDE {} - virtual void SetOmniboxSpeculation() OVERRIDE {} - virtual bool WasEverUsed() const OVERRIDE { return false; } - virtual bool UsingTCPFastOpen() const OVERRIDE { return false; } - virtual bool WasNpnNegotiated() const OVERRIDE { - return false; - } - virtual NextProto GetNegotiatedProtocol() const OVERRIDE { - return kProtoUnknown; - } - virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { - return false; - } - - // Socket implementation. - virtual int Read(IOBuffer* buf, int buf_len, - const CompletionCallback& callback) OVERRIDE { - return ERR_FAILED; - } - - virtual int Write(IOBuffer* buf, int buf_len, - const CompletionCallback& callback) OVERRIDE { - return ERR_FAILED; - } - virtual int SetReceiveBufferSize(int32 size) OVERRIDE { return OK; } - virtual int SetSendBufferSize(int32 size) OVERRIDE { return OK; } - - private: - void DoCallback(const CompletionCallback& callback) { - if (should_stall_) - return; - - if (should_connect_) { - is_connected_ = true; - callback.Run(OK); - } else { - is_connected_ = false; - callback.Run(ERR_CONNECTION_FAILED); - } - } - - bool should_connect_; - bool should_stall_; - base::TimeDelta delay_; - bool is_connected_; - const AddressList addrlist_; - BoundNetLog net_log_; - - base::WeakPtrFactory<MockPendingClientSocket> weak_factory_; - - DISALLOW_COPY_AND_ASSIGN(MockPendingClientSocket); -}; - -class MockClientSocketFactory : public ClientSocketFactory { - public: - enum ClientSocketType { - MOCK_CLIENT_SOCKET, - MOCK_FAILING_CLIENT_SOCKET, - MOCK_PENDING_CLIENT_SOCKET, - MOCK_PENDING_FAILING_CLIENT_SOCKET, - // A delayed socket will pause before connecting through the message loop. - MOCK_DELAYED_CLIENT_SOCKET, - // A stalled socket that never connects at all. - MOCK_STALLED_CLIENT_SOCKET, - }; - - explicit MockClientSocketFactory(NetLog* net_log) - : net_log_(net_log), allocation_count_(0), - client_socket_type_(MOCK_CLIENT_SOCKET), client_socket_types_(NULL), - client_socket_index_(0), client_socket_index_max_(0), - delay_(base::TimeDelta::FromMilliseconds( - ClientSocketPool::kMaxConnectRetryIntervalMs)) {} - - virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( - DatagramSocket::BindType bind_type, - const RandIntCallback& rand_int_cb, - NetLog* net_log, - const NetLog::Source& source) OVERRIDE { - NOTREACHED(); - return scoped_ptr<DatagramClientSocket>(); - } - - virtual scoped_ptr<StreamSocket> CreateTransportClientSocket( - const AddressList& addresses, - NetLog* /* net_log */, - const NetLog::Source& /* source */) OVERRIDE { - allocation_count_++; - - ClientSocketType type = client_socket_type_; - if (client_socket_types_ && - client_socket_index_ < client_socket_index_max_) { - type = client_socket_types_[client_socket_index_++]; - } - - switch (type) { - case MOCK_CLIENT_SOCKET: - return scoped_ptr<StreamSocket>( - new MockClientSocket(addresses, net_log_)); - case MOCK_FAILING_CLIENT_SOCKET: - return scoped_ptr<StreamSocket>( - new MockFailingClientSocket(addresses, net_log_)); - case MOCK_PENDING_CLIENT_SOCKET: - return scoped_ptr<StreamSocket>( - new MockPendingClientSocket( - addresses, true, false, base::TimeDelta(), net_log_)); - case MOCK_PENDING_FAILING_CLIENT_SOCKET: - return scoped_ptr<StreamSocket>( - new MockPendingClientSocket( - addresses, false, false, base::TimeDelta(), net_log_)); - case MOCK_DELAYED_CLIENT_SOCKET: - return scoped_ptr<StreamSocket>( - new MockPendingClientSocket( - addresses, true, false, delay_, net_log_)); - case MOCK_STALLED_CLIENT_SOCKET: - return scoped_ptr<StreamSocket>( - new MockPendingClientSocket( - addresses, true, true, base::TimeDelta(), net_log_)); - default: - NOTREACHED(); - return scoped_ptr<StreamSocket>( - new MockClientSocket(addresses, net_log_)); - } - } - - virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket( - scoped_ptr<ClientSocketHandle> transport_socket, - const HostPortPair& host_and_port, - const SSLConfig& ssl_config, - const SSLClientSocketContext& context) OVERRIDE { - NOTIMPLEMENTED(); - return scoped_ptr<SSLClientSocket>(); - } - - virtual void ClearSSLSessionCache() OVERRIDE { - NOTIMPLEMENTED(); - } - - int allocation_count() const { return allocation_count_; } - - // Set the default ClientSocketType. - void set_client_socket_type(ClientSocketType type) { - client_socket_type_ = type; - } - - // Set a list of ClientSocketTypes to be used. - void set_client_socket_types(ClientSocketType* type_list, int num_types) { - DCHECK_GT(num_types, 0); - client_socket_types_ = type_list; - client_socket_index_ = 0; - client_socket_index_max_ = num_types; - } - - void set_delay(base::TimeDelta delay) { delay_ = delay; } - - private: - NetLog* net_log_; - int allocation_count_; - ClientSocketType client_socket_type_; - ClientSocketType* client_socket_types_; - int client_socket_index_; - int client_socket_index_max_; - base::TimeDelta delay_; - - DISALLOW_COPY_AND_ASSIGN(MockClientSocketFactory); -}; +const RequestPriority kDefaultPriority = LOW; class TransportClientSocketPoolTest : public testing::Test { protected: @@ -447,9 +40,12 @@ class TransportClientSocketPoolTest : public testing::Test { : connect_backup_jobs_enabled_( ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(true)), params_( - new TransportSocketParams(HostPortPair("www.google.com", 80), - false, false, - OnHostResolutionCallback())), + new TransportSocketParams( + HostPortPair("www.google.com", 80), + false, + false, + OnHostResolutionCallback(), + TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)), histograms_(new ClientSocketPoolHistograms("TCPUnitTest")), host_resolver_(new MockHostResolver), client_socket_factory_(&net_log_), @@ -461,15 +57,22 @@ class TransportClientSocketPoolTest : public testing::Test { NULL) { } - virtual ~TransportClientSocketPoolTest() { + ~TransportClientSocketPoolTest() override { internal::ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled( connect_backup_jobs_enabled_); } + scoped_refptr<TransportSocketParams> CreateParamsForTCPFastOpen() { + return new TransportSocketParams(HostPortPair("www.google.com", 80), + false, false, OnHostResolutionCallback(), + TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DESIRED); + } + int StartRequest(const std::string& group_name, RequestPriority priority) { scoped_refptr<TransportSocketParams> params(new TransportSocketParams( HostPortPair("www.google.com", 80), false, false, - OnHostResolutionCallback())); + OnHostResolutionCallback(), + TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)); return test_base_.StartRequestUsingPool( &pool_, group_name, priority, params); } @@ -494,10 +97,11 @@ class TransportClientSocketPoolTest : public testing::Test { scoped_refptr<TransportSocketParams> params_; scoped_ptr<ClientSocketPoolHistograms> histograms_; scoped_ptr<MockHostResolver> host_resolver_; - MockClientSocketFactory client_socket_factory_; + MockTransportClientSocketFactory client_socket_factory_; TransportClientSocketPool pool_; ClientSocketPoolTest test_base_; + private: DISALLOW_COPY_AND_ASSIGN(TransportClientSocketPoolTest); }; @@ -607,8 +211,8 @@ TEST_F(TransportClientSocketPoolTest, InitHostResolutionFailure) { ClientSocketHandle handle; HostPortPair host_port_pair("unresolvable.host.name", 80); scoped_refptr<TransportSocketParams> dest(new TransportSocketParams( - host_port_pair, false, false, - OnHostResolutionCallback())); + host_port_pair, false, false, OnHostResolutionCallback(), + TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)); EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", dest, kDefaultPriority, callback.callback(), &pool_, BoundNetLog())); @@ -616,8 +220,8 @@ TEST_F(TransportClientSocketPoolTest, InitHostResolutionFailure) { } TEST_F(TransportClientSocketPoolTest, InitConnectionFailure) { - client_socket_factory_.set_client_socket_type( - MockClientSocketFactory::MOCK_FAILING_CLIENT_SOCKET); + client_socket_factory_.set_default_client_socket_type( + MockTransportClientSocketFactory::MOCK_FAILING_CLIENT_SOCKET); TestCompletionCallback callback; ClientSocketHandle handle; EXPECT_EQ(ERR_IO_PENDING, @@ -760,8 +364,8 @@ TEST_F(TransportClientSocketPoolTest, TwoRequestsCancelOne) { } TEST_F(TransportClientSocketPoolTest, ConnectCancelConnect) { - client_socket_factory_.set_client_socket_type( - MockClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET); + client_socket_factory_.set_default_client_socket_type( + MockTransportClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET); ClientSocketHandle handle; TestCompletionCallback callback; EXPECT_EQ(ERR_IO_PENDING, @@ -861,7 +465,7 @@ class RequestSocketCallback : public TestCompletionCallbackBase { base::Unretained(this))) { } - virtual ~RequestSocketCallback() {} + ~RequestSocketCallback() override {} const CompletionCallback& callback() const { return callback_; } @@ -883,7 +487,8 @@ class RequestSocketCallback : public TestCompletionCallbackBase { within_callback_ = true; scoped_refptr<TransportSocketParams> dest(new TransportSocketParams( HostPortPair("www.google.com", 80), false, false, - OnHostResolutionCallback())); + OnHostResolutionCallback(), + TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)); int rv = handle_->Init("a", dest, LOWEST, callback(), pool_, BoundNetLog()); EXPECT_EQ(OK, rv); @@ -903,7 +508,8 @@ TEST_F(TransportClientSocketPoolTest, RequestTwice) { RequestSocketCallback callback(&handle, &pool_); scoped_refptr<TransportSocketParams> dest(new TransportSocketParams( HostPortPair("www.google.com", 80), false, false, - OnHostResolutionCallback())); + OnHostResolutionCallback(), + TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)); int rv = handle.Init("a", dest, LOWEST, callback.callback(), &pool_, BoundNetLog()); ASSERT_EQ(ERR_IO_PENDING, rv); @@ -920,8 +526,8 @@ TEST_F(TransportClientSocketPoolTest, RequestTwice) { // Make sure that pending requests get serviced after active requests get // cancelled. TEST_F(TransportClientSocketPoolTest, CancelActiveRequestWithPendingRequests) { - client_socket_factory_.set_client_socket_type( - MockClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET); + client_socket_factory_.set_default_client_socket_type( + MockTransportClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET); // Queue up all the requests EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); @@ -950,8 +556,8 @@ TEST_F(TransportClientSocketPoolTest, CancelActiveRequestWithPendingRequests) { // Make sure that pending requests get serviced after active requests fail. TEST_F(TransportClientSocketPoolTest, FailingActiveRequestWithPendingRequests) { - client_socket_factory_.set_client_socket_type( - MockClientSocketFactory::MOCK_PENDING_FAILING_CLIENT_SOCKET); + client_socket_factory_.set_default_client_socket_type( + MockTransportClientSocketFactory::MOCK_PENDING_FAILING_CLIENT_SOCKET); const int kNumRequests = 2 * kMaxSocketsPerGroup + 1; ASSERT_LE(kNumRequests, kMaxSockets); // Otherwise the test will hang. @@ -1022,24 +628,24 @@ TEST_F(TransportClientSocketPoolTest, ResetIdleSocketsOnIPAddressChange) { TEST_F(TransportClientSocketPoolTest, BackupSocketConnect) { // Case 1 tests the first socket stalling, and the backup connecting. - MockClientSocketFactory::ClientSocketType case1_types[] = { + MockTransportClientSocketFactory::ClientSocketType case1_types[] = { // The first socket will not connect. - MockClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET, + MockTransportClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET, // The second socket will connect more quickly. - MockClientSocketFactory::MOCK_CLIENT_SOCKET + MockTransportClientSocketFactory::MOCK_CLIENT_SOCKET }; // Case 2 tests the first socket being slow, so that we start the // second connect, but the second connect stalls, and we still // complete the first. - MockClientSocketFactory::ClientSocketType case2_types[] = { + MockTransportClientSocketFactory::ClientSocketType case2_types[] = { // The first socket will connect, although delayed. - MockClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET, + MockTransportClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET, // The second socket will not connect. - MockClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET + MockTransportClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET }; - MockClientSocketFactory::ClientSocketType* cases[2] = { + MockTransportClientSocketFactory::ClientSocketType* cases[2] = { case1_types, case2_types }; @@ -1083,8 +689,8 @@ TEST_F(TransportClientSocketPoolTest, BackupSocketConnect) { // Test the case where a socket took long enough to start the creation // of the backup socket, but then we cancelled the request after that. TEST_F(TransportClientSocketPoolTest, BackupSocketCancel) { - client_socket_factory_.set_client_socket_type( - MockClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET); + client_socket_factory_.set_default_client_socket_type( + MockTransportClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET); enum { CANCEL_BEFORE_WAIT, CANCEL_AFTER_WAIT }; @@ -1126,11 +732,11 @@ TEST_F(TransportClientSocketPoolTest, BackupSocketCancel) { // of the backup socket and never completes, and then the backup // connection fails. TEST_F(TransportClientSocketPoolTest, BackupSocketFailAfterStall) { - MockClientSocketFactory::ClientSocketType case_types[] = { + MockTransportClientSocketFactory::ClientSocketType case_types[] = { // The first socket will not connect. - MockClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET, + MockTransportClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET, // The second socket will fail immediately. - MockClientSocketFactory::MOCK_FAILING_CLIENT_SOCKET + MockTransportClientSocketFactory::MOCK_FAILING_CLIENT_SOCKET }; client_socket_factory_.set_client_socket_types(case_types, 2); @@ -1173,11 +779,11 @@ TEST_F(TransportClientSocketPoolTest, BackupSocketFailAfterStall) { // of the backup socket and eventually completes, but the backup socket // fails. TEST_F(TransportClientSocketPoolTest, BackupSocketFailAfterDelay) { - MockClientSocketFactory::ClientSocketType case_types[] = { + MockTransportClientSocketFactory::ClientSocketType case_types[] = { // The first socket will connect, although delayed. - MockClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET, + MockTransportClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET, // The second socket will not connect. - MockClientSocketFactory::MOCK_FAILING_CLIENT_SOCKET + MockTransportClientSocketFactory::MOCK_FAILING_CLIENT_SOCKET }; client_socket_factory_.set_client_socket_types(case_types, 2); @@ -1228,11 +834,11 @@ TEST_F(TransportClientSocketPoolTest, IPv6FallbackSocketIPv4FinishesFirst) { &client_socket_factory_, NULL); - MockClientSocketFactory::ClientSocketType case_types[] = { + MockTransportClientSocketFactory::ClientSocketType case_types[] = { // This is the IPv6 socket. - MockClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET, + MockTransportClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET, // This is the IPv4 socket. - MockClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET + MockTransportClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET }; client_socket_factory_.set_client_socket_types(case_types, 2); @@ -1271,16 +877,16 @@ TEST_F(TransportClientSocketPoolTest, IPv6FallbackSocketIPv6FinishesFirst) { &client_socket_factory_, NULL); - MockClientSocketFactory::ClientSocketType case_types[] = { + MockTransportClientSocketFactory::ClientSocketType case_types[] = { // This is the IPv6 socket. - MockClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET, + MockTransportClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET, // This is the IPv4 socket. - MockClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET + MockTransportClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET }; client_socket_factory_.set_client_socket_types(case_types, 2); client_socket_factory_.set_delay(base::TimeDelta::FromMilliseconds( - TransportConnectJob::kIPv6FallbackTimerInMs + 50)); + TransportConnectJobHelper::kIPv6FallbackTimerInMs + 50)); // Resolve an AddressList with a IPv6 address first and then a IPv4 address. host_resolver_->rules() @@ -1313,8 +919,8 @@ TEST_F(TransportClientSocketPoolTest, IPv6NoIPv4AddressesToFallbackTo) { &client_socket_factory_, NULL); - client_socket_factory_.set_client_socket_type( - MockClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET); + client_socket_factory_.set_default_client_socket_type( + MockTransportClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET); // Resolve an AddressList with only IPv6 addresses. host_resolver_->rules() @@ -1347,8 +953,8 @@ TEST_F(TransportClientSocketPoolTest, IPv4HasNoFallback) { &client_socket_factory_, NULL); - client_socket_factory_.set_client_socket_type( - MockClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET); + client_socket_factory_.set_default_client_socket_type( + MockTransportClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET); // Resolve an AddressList with only IPv4 addresses. host_resolver_->rules()->AddIPLiteralRule("*", "1.1.1.1", std::string()); @@ -1370,6 +976,137 @@ TEST_F(TransportClientSocketPoolTest, IPv4HasNoFallback) { EXPECT_EQ(1, client_socket_factory_.allocation_count()); } +// Test that if TCP FastOpen is enabled, it is set on the socket +// when we have only an IPv4 address. +TEST_F(TransportClientSocketPoolTest, TCPFastOpenOnIPv4WithNoFallback) { + // Create a pool without backup jobs. + ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(false); + TransportClientSocketPool pool(kMaxSockets, + kMaxSocketsPerGroup, + histograms_.get(), + host_resolver_.get(), + &client_socket_factory_, + NULL); + client_socket_factory_.set_default_client_socket_type( + MockTransportClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET); + // Resolve an AddressList with only IPv4 addresses. + host_resolver_->rules()->AddIPLiteralRule("*", "1.1.1.1", std::string()); + + TestCompletionCallback callback; + ClientSocketHandle handle; + // Enable TCP FastOpen in TransportSocketParams. + scoped_refptr<TransportSocketParams> params = CreateParamsForTCPFastOpen(); + handle.Init("a", params, LOW, callback.callback(), &pool, BoundNetLog()); + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.socket()->UsingTCPFastOpen()); +} + +// Test that if TCP FastOpen is enabled, it is set on the socket +// when we have only IPv6 addresses. +TEST_F(TransportClientSocketPoolTest, TCPFastOpenOnIPv6WithNoFallback) { + // Create a pool without backup jobs. + ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(false); + TransportClientSocketPool pool(kMaxSockets, + kMaxSocketsPerGroup, + histograms_.get(), + host_resolver_.get(), + &client_socket_factory_, + NULL); + client_socket_factory_.set_default_client_socket_type( + MockTransportClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET); + // Resolve an AddressList with only IPv6 addresses. + host_resolver_->rules() + ->AddIPLiteralRule("*", "2:abcd::3:4:ff,3:abcd::3:4:ff", std::string()); + + TestCompletionCallback callback; + ClientSocketHandle handle; + // Enable TCP FastOpen in TransportSocketParams. + scoped_refptr<TransportSocketParams> params = CreateParamsForTCPFastOpen(); + handle.Init("a", params, LOW, callback.callback(), &pool, BoundNetLog()); + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.socket()->UsingTCPFastOpen()); +} + +// Test that if TCP FastOpen is enabled, it does not do anything when there +// is a IPv6 address with fallback to an IPv4 address. This test tests the case +// when the IPv6 connect fails and the IPv4 one succeeds. +TEST_F(TransportClientSocketPoolTest, + NoTCPFastOpenOnIPv6FailureWithIPv4Fallback) { + // Create a pool without backup jobs. + ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(false); + TransportClientSocketPool pool(kMaxSockets, + kMaxSocketsPerGroup, + histograms_.get(), + host_resolver_.get(), + &client_socket_factory_, + NULL); + + MockTransportClientSocketFactory::ClientSocketType case_types[] = { + // This is the IPv6 socket. + MockTransportClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET, + // This is the IPv4 socket. + MockTransportClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET + }; + client_socket_factory_.set_client_socket_types(case_types, 2); + // Resolve an AddressList with a IPv6 address first and then a IPv4 address. + host_resolver_->rules() + ->AddIPLiteralRule("*", "2:abcd::3:4:ff,2.2.2.2", std::string()); + + TestCompletionCallback callback; + ClientSocketHandle handle; + // Enable TCP FastOpen in TransportSocketParams. + scoped_refptr<TransportSocketParams> params = CreateParamsForTCPFastOpen(); + handle.Init("a", params, LOW, callback.callback(), &pool, BoundNetLog()); + EXPECT_EQ(OK, callback.WaitForResult()); + // Verify that the socket used is connected to the fallback IPv4 address. + IPEndPoint endpoint; + handle.socket()->GetLocalAddress(&endpoint); + EXPECT_EQ(kIPv4AddressSize, endpoint.address().size()); + EXPECT_EQ(2, client_socket_factory_.allocation_count()); + // Verify that TCP FastOpen was not turned on for the socket. + EXPECT_FALSE(handle.socket()->UsingTCPFastOpen()); +} + +// Test that if TCP FastOpen is enabled, it does not do anything when there +// is a IPv6 address with fallback to an IPv4 address. This test tests the case +// when the IPv6 connect succeeds. +TEST_F(TransportClientSocketPoolTest, + NoTCPFastOpenOnIPv6SuccessWithIPv4Fallback) { + // Create a pool without backup jobs. + ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(false); + TransportClientSocketPool pool(kMaxSockets, + kMaxSocketsPerGroup, + histograms_.get(), + host_resolver_.get(), + &client_socket_factory_, + NULL); + + MockTransportClientSocketFactory::ClientSocketType case_types[] = { + // This is the IPv6 socket. + MockTransportClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET, + // This is the IPv4 socket. + MockTransportClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET + }; + client_socket_factory_.set_client_socket_types(case_types, 2); + // Resolve an AddressList with a IPv6 address first and then a IPv4 address. + host_resolver_->rules() + ->AddIPLiteralRule("*", "2:abcd::3:4:ff,2.2.2.2", std::string()); + + TestCompletionCallback callback; + ClientSocketHandle handle; + // Enable TCP FastOpen in TransportSocketParams. + scoped_refptr<TransportSocketParams> params = CreateParamsForTCPFastOpen(); + handle.Init("a", params, LOW, callback.callback(), &pool, BoundNetLog()); + EXPECT_EQ(OK, callback.WaitForResult()); + // Verify that the socket used is connected to the IPv6 address. + IPEndPoint endpoint; + handle.socket()->GetLocalAddress(&endpoint); + EXPECT_EQ(kIPv6AddressSize, endpoint.address().size()); + EXPECT_EQ(1, client_socket_factory_.allocation_count()); + // Verify that TCP FastOpen was not turned on for the socket. + EXPECT_FALSE(handle.socket()->UsingTCPFastOpen()); +} + } // namespace } // namespace net diff --git a/chromium/net/socket/transport_client_socket_unittest.cc b/chromium/net/socket/transport_client_socket_unittest.cc index 5548b27b995..d01cbad6dc8 100644 --- a/chromium/net/socket/transport_client_socket_unittest.cc +++ b/chromium/net/socket/transport_client_socket_unittest.cc @@ -47,22 +47,22 @@ class TransportClientSocketTest } // Implement StreamListenSocket::Delegate methods - virtual void DidAccept(StreamListenSocket* server, - scoped_ptr<StreamListenSocket> connection) OVERRIDE { + void DidAccept(StreamListenSocket* server, + scoped_ptr<StreamListenSocket> connection) override { connected_sock_.reset( static_cast<TCPListenSocket*>(connection.release())); } - virtual void DidRead(StreamListenSocket*, const char* str, int len) OVERRIDE { + void DidRead(StreamListenSocket*, const char* str, int len) override { // TODO(dkegel): this might not be long enough to tickle some bugs. connected_sock_->Send(kServerReply, arraysize(kServerReply) - 1, false /* Don't append line feed */); if (close_server_socket_on_next_send_) CloseServerSocket(); } - virtual void DidClose(StreamListenSocket* sock) OVERRIDE {} + void DidClose(StreamListenSocket* sock) override {} // Testcase hooks - virtual void SetUp(); + void SetUp() override; void CloseServerSocket() { // delete the connected_sock_, which will close it. diff --git a/chromium/net/socket/unix_domain_client_socket_posix.cc b/chromium/net/socket/unix_domain_client_socket_posix.cc new file mode 100644 index 00000000000..5adbca9979e --- /dev/null +++ b/chromium/net/socket/unix_domain_client_socket_posix.cc @@ -0,0 +1,171 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/unix_domain_client_socket_posix.h" + +#include <sys/socket.h> +#include <sys/un.h> + +#include "base/logging.h" +#include "base/posix/eintr_wrapper.h" +#include "net/base/net_errors.h" +#include "net/base/net_util.h" +#include "net/socket/socket_libevent.h" + +namespace net { + +UnixDomainClientSocket::UnixDomainClientSocket(const std::string& socket_path, + bool use_abstract_namespace) + : socket_path_(socket_path), + use_abstract_namespace_(use_abstract_namespace) { +} + +UnixDomainClientSocket::UnixDomainClientSocket( + scoped_ptr<SocketLibevent> socket) + : use_abstract_namespace_(false), + socket_(socket.Pass()) { +} + +UnixDomainClientSocket::~UnixDomainClientSocket() { + Disconnect(); +} + +// static +bool UnixDomainClientSocket::FillAddress(const std::string& socket_path, + bool use_abstract_namespace, + SockaddrStorage* address) { + struct sockaddr_un* socket_addr = + reinterpret_cast<struct sockaddr_un*>(address->addr); + size_t path_max = address->addr_len - offsetof(struct sockaddr_un, sun_path); + // Non abstract namespace pathname should be null-terminated. Abstract + // namespace pathname must start with '\0'. So, the size is always greater + // than socket_path size by 1. + size_t path_size = socket_path.size() + 1; + if (path_size > path_max) + return false; + + memset(socket_addr, 0, address->addr_len); + socket_addr->sun_family = AF_UNIX; + address->addr_len = path_size + offsetof(struct sockaddr_un, sun_path); + if (!use_abstract_namespace) { + memcpy(socket_addr->sun_path, socket_path.c_str(), socket_path.size()); + return true; + } + +#if defined(OS_ANDROID) || defined(OS_LINUX) + // Convert the path given into abstract socket name. It must start with + // the '\0' character, so we are adding it. |addr_len| must specify the + // length of the structure exactly, as potentially the socket name may + // have '\0' characters embedded (although we don't support this). + // Note that addr.sun_path is already zero initialized. + memcpy(socket_addr->sun_path + 1, socket_path.c_str(), socket_path.size()); + return true; +#else + return false; +#endif +} + +int UnixDomainClientSocket::Connect(const CompletionCallback& callback) { + DCHECK(!socket_); + + if (socket_path_.empty()) + return ERR_ADDRESS_INVALID; + + SockaddrStorage address; + if (!FillAddress(socket_path_, use_abstract_namespace_, &address)) + return ERR_ADDRESS_INVALID; + + socket_.reset(new SocketLibevent); + int rv = socket_->Open(AF_UNIX); + DCHECK_NE(ERR_IO_PENDING, rv); + if (rv != OK) + return rv; + + return socket_->Connect(address, callback); +} + +void UnixDomainClientSocket::Disconnect() { + socket_.reset(); +} + +bool UnixDomainClientSocket::IsConnected() const { + return socket_ && socket_->IsConnected(); +} + +bool UnixDomainClientSocket::IsConnectedAndIdle() const { + return socket_ && socket_->IsConnectedAndIdle(); +} + +int UnixDomainClientSocket::GetPeerAddress(IPEndPoint* address) const { + NOTIMPLEMENTED(); + return ERR_NOT_IMPLEMENTED; +} + +int UnixDomainClientSocket::GetLocalAddress(IPEndPoint* address) const { + NOTIMPLEMENTED(); + return ERR_NOT_IMPLEMENTED; +} + +const BoundNetLog& UnixDomainClientSocket::NetLog() const { + return net_log_; +} + +void UnixDomainClientSocket::SetSubresourceSpeculation() { +} + +void UnixDomainClientSocket::SetOmniboxSpeculation() { +} + +bool UnixDomainClientSocket::WasEverUsed() const { + return true; // We don't care. +} + +bool UnixDomainClientSocket::UsingTCPFastOpen() const { + return false; +} + +bool UnixDomainClientSocket::WasNpnNegotiated() const { + return false; +} + +NextProto UnixDomainClientSocket::GetNegotiatedProtocol() const { + return kProtoUnknown; +} + +bool UnixDomainClientSocket::GetSSLInfo(SSLInfo* ssl_info) { + return false; +} + +int UnixDomainClientSocket::Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + DCHECK(socket_); + return socket_->Read(buf, buf_len, callback); +} + +int UnixDomainClientSocket::Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + DCHECK(socket_); + return socket_->Write(buf, buf_len, callback); +} + +int UnixDomainClientSocket::SetReceiveBufferSize(int32 size) { + NOTIMPLEMENTED(); + return ERR_NOT_IMPLEMENTED; +} + +int UnixDomainClientSocket::SetSendBufferSize(int32 size) { + NOTIMPLEMENTED(); + return ERR_NOT_IMPLEMENTED; +} + +SocketDescriptor UnixDomainClientSocket::ReleaseConnectedSocket() { + DCHECK(socket_); + DCHECK(socket_->IsConnected()); + + SocketDescriptor socket_fd = socket_->ReleaseConnectedSocket(); + socket_.reset(); + return socket_fd; +} + +} // namespace net diff --git a/chromium/net/socket/unix_domain_client_socket_posix.h b/chromium/net/socket/unix_domain_client_socket_posix.h new file mode 100644 index 00000000000..2a8bdb625c9 --- /dev/null +++ b/chromium/net/socket/unix_domain_client_socket_posix.h @@ -0,0 +1,87 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_UNIX_DOMAIN_CLIENT_SOCKET_POSIX_H_ +#define NET_SOCKET_UNIX_DOMAIN_CLIENT_SOCKET_POSIX_H_ + +#include <string> + +#include "base/basictypes.h" +#include "base/macros.h" +#include "base/memory/scoped_ptr.h" +#include "net/base/completion_callback.h" +#include "net/base/net_export.h" +#include "net/base/net_log.h" +#include "net/socket/socket_descriptor.h" +#include "net/socket/stream_socket.h" + +namespace net { + +class SocketLibevent; +struct SockaddrStorage; + +// A client socket that uses unix domain socket as the transport layer. +class NET_EXPORT UnixDomainClientSocket : public StreamSocket { + public: + // Builds a client socket with |socket_path|. The caller should call Connect() + // to connect to a server socket. + UnixDomainClientSocket(const std::string& socket_path, + bool use_abstract_namespace); + // Builds a client socket with socket libevent which is already connected. + // UnixDomainServerSocket uses this after it accepts a connection. + explicit UnixDomainClientSocket(scoped_ptr<SocketLibevent> socket); + + ~UnixDomainClientSocket() override; + + // Fills |address| with |socket_path| and its length. For Android or Linux + // platform, this supports abstract namespaces. + static bool FillAddress(const std::string& socket_path, + bool use_abstract_namespace, + SockaddrStorage* address); + + // StreamSocket implementation. + int Connect(const CompletionCallback& callback) override; + void Disconnect() override; + bool IsConnected() const override; + bool IsConnectedAndIdle() const override; + int GetPeerAddress(IPEndPoint* address) const override; + int GetLocalAddress(IPEndPoint* address) const override; + const BoundNetLog& NetLog() const override; + void SetSubresourceSpeculation() override; + void SetOmniboxSpeculation() override; + bool WasEverUsed() const override; + bool UsingTCPFastOpen() const override; + bool WasNpnNegotiated() const override; + NextProto GetNegotiatedProtocol() const override; + bool GetSSLInfo(SSLInfo* ssl_info) override; + + // Socket implementation. + int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override; + int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override; + int SetReceiveBufferSize(int32 size) override; + int SetSendBufferSize(int32 size) override; + + // Releases ownership of underlying SocketDescriptor to caller. + // Internal state is reset so that this object can be used again. + // Socket must be connected in order to release it. + SocketDescriptor ReleaseConnectedSocket(); + + private: + const std::string socket_path_; + const bool use_abstract_namespace_; + scoped_ptr<SocketLibevent> socket_; + // This net log is just to comply StreamSocket::NetLog(). It throws away + // everything. + BoundNetLog net_log_; + + DISALLOW_COPY_AND_ASSIGN(UnixDomainClientSocket); +}; + +} // namespace net + +#endif // NET_SOCKET_UNIX_DOMAIN_CLIENT_SOCKET_POSIX_H_ diff --git a/chromium/net/socket/unix_domain_client_socket_posix_unittest.cc b/chromium/net/socket/unix_domain_client_socket_posix_unittest.cc new file mode 100644 index 00000000000..651cd72dd5b --- /dev/null +++ b/chromium/net/socket/unix_domain_client_socket_posix_unittest.cc @@ -0,0 +1,446 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/unix_domain_client_socket_posix.h" + +#include <unistd.h> + +#include "base/bind.h" +#include "base/files/file_path.h" +#include "base/files/scoped_temp_dir.h" +#include "base/memory/scoped_ptr.h" +#include "base/posix/eintr_wrapper.h" +#include "net/base/io_buffer.h" +#include "net/base/net_errors.h" +#include "net/base/test_completion_callback.h" +#include "net/socket/socket_libevent.h" +#include "net/socket/unix_domain_server_socket_posix.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { +namespace { + +const char kSocketFilename[] = "socket_for_testing"; + +bool UserCanConnectCallback( + bool allow_user, const UnixDomainServerSocket::Credentials& credentials) { + // Here peers are running in same process. +#if defined(OS_LINUX) || defined(OS_ANDROID) + EXPECT_EQ(getpid(), credentials.process_id); +#endif + EXPECT_EQ(getuid(), credentials.user_id); + EXPECT_EQ(getgid(), credentials.group_id); + return allow_user; +} + +UnixDomainServerSocket::AuthCallback CreateAuthCallback(bool allow_user) { + return base::Bind(&UserCanConnectCallback, allow_user); +} + +// Connects socket synchronously. +int ConnectSynchronously(StreamSocket* socket) { + TestCompletionCallback connect_callback; + int rv = socket->Connect(connect_callback.callback()); + if (rv == ERR_IO_PENDING) + rv = connect_callback.WaitForResult(); + return rv; +} + +// Reads data from |socket| until it fills |buf| at least up to |min_data_len|. +// Returns length of data read, or a net error. +int ReadSynchronously(StreamSocket* socket, + IOBuffer* buf, + int buf_len, + int min_data_len) { + DCHECK_LE(min_data_len, buf_len); + scoped_refptr<DrainableIOBuffer> read_buf( + new DrainableIOBuffer(buf, buf_len)); + TestCompletionCallback read_callback; + // Iterate reading several times (but not infinite) until it reads at least + // |min_data_len| bytes into |buf|. + for (int retry_count = 10; + retry_count > 0 && (read_buf->BytesConsumed() < min_data_len || + // Try at least once when min_data_len == 0. + min_data_len == 0); + --retry_count) { + int rv = socket->Read( + read_buf.get(), read_buf->BytesRemaining(), read_callback.callback()); + EXPECT_GE(read_buf->BytesRemaining(), rv); + if (rv == ERR_IO_PENDING) { + // If |min_data_len| is 0, returns ERR_IO_PENDING to distinguish the case + // when some data has been read. + if (min_data_len == 0) { + // No data has been read because of for-loop condition. + DCHECK_EQ(0, read_buf->BytesConsumed()); + return ERR_IO_PENDING; + } + rv = read_callback.WaitForResult(); + } + EXPECT_NE(ERR_IO_PENDING, rv); + if (rv < 0) + return rv; + read_buf->DidConsume(rv); + } + EXPECT_LE(0, read_buf->BytesRemaining()); + return read_buf->BytesConsumed(); +} + +// Writes data to |socket| until it completes writing |buf| up to |buf_len|. +// Returns length of data written, or a net error. +int WriteSynchronously(StreamSocket* socket, + IOBuffer* buf, + int buf_len) { + scoped_refptr<DrainableIOBuffer> write_buf( + new DrainableIOBuffer(buf, buf_len)); + TestCompletionCallback write_callback; + // Iterate writing several times (but not infinite) until it writes buf fully. + for (int retry_count = 10; + retry_count > 0 && write_buf->BytesRemaining() > 0; + --retry_count) { + int rv = socket->Write(write_buf.get(), + write_buf->BytesRemaining(), + write_callback.callback()); + EXPECT_GE(write_buf->BytesRemaining(), rv); + if (rv == ERR_IO_PENDING) + rv = write_callback.WaitForResult(); + EXPECT_NE(ERR_IO_PENDING, rv); + if (rv < 0) + return rv; + write_buf->DidConsume(rv); + } + EXPECT_LE(0, write_buf->BytesRemaining()); + return write_buf->BytesConsumed(); +} + +class UnixDomainClientSocketTest : public testing::Test { + protected: + UnixDomainClientSocketTest() { + EXPECT_TRUE(temp_dir_.CreateUniqueTempDir()); + socket_path_ = temp_dir_.path().Append(kSocketFilename).value(); + } + + base::ScopedTempDir temp_dir_; + std::string socket_path_; +}; + +TEST_F(UnixDomainClientSocketTest, Connect) { + const bool kUseAbstractNamespace = false; + + UnixDomainServerSocket server_socket(CreateAuthCallback(true), + kUseAbstractNamespace); + EXPECT_EQ(OK, server_socket.ListenWithAddressAndPort(socket_path_, 0, 1)); + + scoped_ptr<StreamSocket> accepted_socket; + TestCompletionCallback accept_callback; + EXPECT_EQ(ERR_IO_PENDING, + server_socket.Accept(&accepted_socket, accept_callback.callback())); + EXPECT_FALSE(accepted_socket); + + UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace); + EXPECT_FALSE(client_socket.IsConnected()); + + EXPECT_EQ(OK, ConnectSynchronously(&client_socket)); + EXPECT_TRUE(client_socket.IsConnected()); + // Server has not yet been notified of the connection. + EXPECT_FALSE(accepted_socket); + + EXPECT_EQ(OK, accept_callback.WaitForResult()); + EXPECT_TRUE(accepted_socket); + EXPECT_TRUE(accepted_socket->IsConnected()); +} + +TEST_F(UnixDomainClientSocketTest, ConnectWithSocketDescriptor) { + const bool kUseAbstractNamespace = false; + + UnixDomainServerSocket server_socket(CreateAuthCallback(true), + kUseAbstractNamespace); + EXPECT_EQ(OK, server_socket.ListenWithAddressAndPort(socket_path_, 0, 1)); + + SocketDescriptor accepted_socket_fd = kInvalidSocket; + TestCompletionCallback accept_callback; + EXPECT_EQ(ERR_IO_PENDING, + server_socket.AcceptSocketDescriptor(&accepted_socket_fd, + accept_callback.callback())); + EXPECT_EQ(kInvalidSocket, accepted_socket_fd); + + UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace); + EXPECT_FALSE(client_socket.IsConnected()); + + EXPECT_EQ(OK, ConnectSynchronously(&client_socket)); + EXPECT_TRUE(client_socket.IsConnected()); + // Server has not yet been notified of the connection. + EXPECT_EQ(kInvalidSocket, accepted_socket_fd); + + EXPECT_EQ(OK, accept_callback.WaitForResult()); + EXPECT_NE(kInvalidSocket, accepted_socket_fd); + + SocketDescriptor client_socket_fd = client_socket.ReleaseConnectedSocket(); + EXPECT_NE(kInvalidSocket, client_socket_fd); + + // Now, re-wrap client_socket_fd in a UnixDomainClientSocket and try a read + // to be sure it hasn't gotten accidentally closed. + SockaddrStorage addr; + ASSERT_TRUE(UnixDomainClientSocket::FillAddress(socket_path_, false, &addr)); + scoped_ptr<SocketLibevent> adopter(new SocketLibevent); + adopter->AdoptConnectedSocket(client_socket_fd, addr); + UnixDomainClientSocket rewrapped_socket(adopter.Pass()); + EXPECT_TRUE(rewrapped_socket.IsConnected()); + + // Try to read data. + const int kReadDataSize = 10; + scoped_refptr<IOBuffer> read_buffer(new IOBuffer(kReadDataSize)); + TestCompletionCallback read_callback; + EXPECT_EQ(ERR_IO_PENDING, + rewrapped_socket.Read( + read_buffer.get(), kReadDataSize, read_callback.callback())); + + EXPECT_EQ(0, IGNORE_EINTR(close(accepted_socket_fd))); +} + +TEST_F(UnixDomainClientSocketTest, ConnectWithAbstractNamespace) { + const bool kUseAbstractNamespace = true; + + UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace); + EXPECT_FALSE(client_socket.IsConnected()); + +#if defined(OS_ANDROID) || defined(OS_LINUX) + UnixDomainServerSocket server_socket(CreateAuthCallback(true), + kUseAbstractNamespace); + EXPECT_EQ(OK, server_socket.ListenWithAddressAndPort(socket_path_, 0, 1)); + + scoped_ptr<StreamSocket> accepted_socket; + TestCompletionCallback accept_callback; + EXPECT_EQ(ERR_IO_PENDING, + server_socket.Accept(&accepted_socket, accept_callback.callback())); + EXPECT_FALSE(accepted_socket); + + EXPECT_EQ(OK, ConnectSynchronously(&client_socket)); + EXPECT_TRUE(client_socket.IsConnected()); + // Server has not yet beend notified of the connection. + EXPECT_FALSE(accepted_socket); + + EXPECT_EQ(OK, accept_callback.WaitForResult()); + EXPECT_TRUE(accepted_socket); + EXPECT_TRUE(accepted_socket->IsConnected()); +#else + EXPECT_EQ(ERR_ADDRESS_INVALID, ConnectSynchronously(&client_socket)); +#endif +} + +TEST_F(UnixDomainClientSocketTest, ConnectToNonExistentSocket) { + const bool kUseAbstractNamespace = false; + + UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace); + EXPECT_FALSE(client_socket.IsConnected()); + EXPECT_EQ(ERR_FILE_NOT_FOUND, ConnectSynchronously(&client_socket)); +} + +TEST_F(UnixDomainClientSocketTest, + ConnectToNonExistentSocketWithAbstractNamespace) { + const bool kUseAbstractNamespace = true; + + UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace); + EXPECT_FALSE(client_socket.IsConnected()); + + TestCompletionCallback connect_callback; +#if defined(OS_ANDROID) || defined(OS_LINUX) + EXPECT_EQ(ERR_CONNECTION_REFUSED, ConnectSynchronously(&client_socket)); +#else + EXPECT_EQ(ERR_ADDRESS_INVALID, ConnectSynchronously(&client_socket)); +#endif +} + +TEST_F(UnixDomainClientSocketTest, DisconnectFromClient) { + UnixDomainServerSocket server_socket(CreateAuthCallback(true), false); + EXPECT_EQ(OK, server_socket.ListenWithAddressAndPort(socket_path_, 0, 1)); + scoped_ptr<StreamSocket> accepted_socket; + TestCompletionCallback accept_callback; + EXPECT_EQ(ERR_IO_PENDING, + server_socket.Accept(&accepted_socket, accept_callback.callback())); + UnixDomainClientSocket client_socket(socket_path_, false); + EXPECT_EQ(OK, ConnectSynchronously(&client_socket)); + + EXPECT_EQ(OK, accept_callback.WaitForResult()); + EXPECT_TRUE(accepted_socket->IsConnected()); + EXPECT_TRUE(client_socket.IsConnected()); + + // Try to read data. + const int kReadDataSize = 10; + scoped_refptr<IOBuffer> read_buffer(new IOBuffer(kReadDataSize)); + TestCompletionCallback read_callback; + EXPECT_EQ(ERR_IO_PENDING, + accepted_socket->Read( + read_buffer.get(), kReadDataSize, read_callback.callback())); + + // Disconnect from client side. + client_socket.Disconnect(); + EXPECT_FALSE(client_socket.IsConnected()); + EXPECT_FALSE(accepted_socket->IsConnected()); + + // Connection closed by peer. + EXPECT_EQ(0 /* EOF */, read_callback.WaitForResult()); + // Note that read callback won't be called when the connection is closed + // locally before the peer closes it. SocketLibevent just clears callbacks. +} + +TEST_F(UnixDomainClientSocketTest, DisconnectFromServer) { + UnixDomainServerSocket server_socket(CreateAuthCallback(true), false); + EXPECT_EQ(OK, server_socket.ListenWithAddressAndPort(socket_path_, 0, 1)); + scoped_ptr<StreamSocket> accepted_socket; + TestCompletionCallback accept_callback; + EXPECT_EQ(ERR_IO_PENDING, + server_socket.Accept(&accepted_socket, accept_callback.callback())); + UnixDomainClientSocket client_socket(socket_path_, false); + EXPECT_EQ(OK, ConnectSynchronously(&client_socket)); + + EXPECT_EQ(OK, accept_callback.WaitForResult()); + EXPECT_TRUE(accepted_socket->IsConnected()); + EXPECT_TRUE(client_socket.IsConnected()); + + // Try to read data. + const int kReadDataSize = 10; + scoped_refptr<IOBuffer> read_buffer(new IOBuffer(kReadDataSize)); + TestCompletionCallback read_callback; + EXPECT_EQ(ERR_IO_PENDING, + client_socket.Read( + read_buffer.get(), kReadDataSize, read_callback.callback())); + + // Disconnect from server side. + accepted_socket->Disconnect(); + EXPECT_FALSE(accepted_socket->IsConnected()); + EXPECT_FALSE(client_socket.IsConnected()); + + // Connection closed by peer. + EXPECT_EQ(0 /* EOF */, read_callback.WaitForResult()); + // Note that read callback won't be called when the connection is closed + // locally before the peer closes it. SocketLibevent just clears callbacks. +} + +TEST_F(UnixDomainClientSocketTest, ReadAfterWrite) { + UnixDomainServerSocket server_socket(CreateAuthCallback(true), false); + EXPECT_EQ(OK, server_socket.ListenWithAddressAndPort(socket_path_, 0, 1)); + scoped_ptr<StreamSocket> accepted_socket; + TestCompletionCallback accept_callback; + EXPECT_EQ(ERR_IO_PENDING, + server_socket.Accept(&accepted_socket, accept_callback.callback())); + UnixDomainClientSocket client_socket(socket_path_, false); + EXPECT_EQ(OK, ConnectSynchronously(&client_socket)); + + EXPECT_EQ(OK, accept_callback.WaitForResult()); + EXPECT_TRUE(accepted_socket->IsConnected()); + EXPECT_TRUE(client_socket.IsConnected()); + + // Send data from client to server. + const int kWriteDataSize = 10; + scoped_refptr<IOBuffer> write_buffer( + new StringIOBuffer(std::string(kWriteDataSize, 'd'))); + EXPECT_EQ( + kWriteDataSize, + WriteSynchronously(&client_socket, write_buffer.get(), kWriteDataSize)); + + // The buffer is bigger than write data size. + const int kReadBufferSize = kWriteDataSize * 2; + scoped_refptr<IOBuffer> read_buffer(new IOBuffer(kReadBufferSize)); + EXPECT_EQ(kWriteDataSize, + ReadSynchronously(accepted_socket.get(), + read_buffer.get(), + kReadBufferSize, + kWriteDataSize)); + EXPECT_EQ(std::string(write_buffer->data(), kWriteDataSize), + std::string(read_buffer->data(), kWriteDataSize)); + + // Send data from server and client. + EXPECT_EQ(kWriteDataSize, + WriteSynchronously( + accepted_socket.get(), write_buffer.get(), kWriteDataSize)); + + // Read multiple times. + const int kSmallReadBufferSize = kWriteDataSize / 3; + EXPECT_EQ(kSmallReadBufferSize, + ReadSynchronously(&client_socket, + read_buffer.get(), + kSmallReadBufferSize, + kSmallReadBufferSize)); + EXPECT_EQ(std::string(write_buffer->data(), kSmallReadBufferSize), + std::string(read_buffer->data(), kSmallReadBufferSize)); + + EXPECT_EQ(kWriteDataSize - kSmallReadBufferSize, + ReadSynchronously(&client_socket, + read_buffer.get(), + kReadBufferSize, + kWriteDataSize - kSmallReadBufferSize)); + EXPECT_EQ(std::string(write_buffer->data() + kSmallReadBufferSize, + kWriteDataSize - kSmallReadBufferSize), + std::string(read_buffer->data(), + kWriteDataSize - kSmallReadBufferSize)); + + // No more data. + EXPECT_EQ( + ERR_IO_PENDING, + ReadSynchronously(&client_socket, read_buffer.get(), kReadBufferSize, 0)); + + // Disconnect from server side after read-write. + accepted_socket->Disconnect(); + EXPECT_FALSE(accepted_socket->IsConnected()); + EXPECT_FALSE(client_socket.IsConnected()); +} + +TEST_F(UnixDomainClientSocketTest, ReadBeforeWrite) { + UnixDomainServerSocket server_socket(CreateAuthCallback(true), false); + EXPECT_EQ(OK, server_socket.ListenWithAddressAndPort(socket_path_, 0, 1)); + scoped_ptr<StreamSocket> accepted_socket; + TestCompletionCallback accept_callback; + EXPECT_EQ(ERR_IO_PENDING, + server_socket.Accept(&accepted_socket, accept_callback.callback())); + UnixDomainClientSocket client_socket(socket_path_, false); + EXPECT_EQ(OK, ConnectSynchronously(&client_socket)); + + EXPECT_EQ(OK, accept_callback.WaitForResult()); + EXPECT_TRUE(accepted_socket->IsConnected()); + EXPECT_TRUE(client_socket.IsConnected()); + + // Wait for data from client. + const int kWriteDataSize = 10; + const int kReadBufferSize = kWriteDataSize * 2; + const int kSmallReadBufferSize = kWriteDataSize / 3; + // Read smaller than write data size first. + scoped_refptr<IOBuffer> read_buffer(new IOBuffer(kReadBufferSize)); + TestCompletionCallback read_callback; + EXPECT_EQ( + ERR_IO_PENDING, + accepted_socket->Read( + read_buffer.get(), kSmallReadBufferSize, read_callback.callback())); + + scoped_refptr<IOBuffer> write_buffer( + new StringIOBuffer(std::string(kWriteDataSize, 'd'))); + EXPECT_EQ( + kWriteDataSize, + WriteSynchronously(&client_socket, write_buffer.get(), kWriteDataSize)); + + // First read completed. + int rv = read_callback.WaitForResult(); + EXPECT_LT(0, rv); + EXPECT_LE(rv, kSmallReadBufferSize); + + // Read remaining data. + const int kExpectedRemainingDataSize = kWriteDataSize - rv; + EXPECT_LE(0, kExpectedRemainingDataSize); + EXPECT_EQ(kExpectedRemainingDataSize, + ReadSynchronously(accepted_socket.get(), + read_buffer.get(), + kReadBufferSize, + kExpectedRemainingDataSize)); + // No more data. + EXPECT_EQ(ERR_IO_PENDING, + ReadSynchronously( + accepted_socket.get(), read_buffer.get(), kReadBufferSize, 0)); + + // Disconnect from server side after read-write. + accepted_socket->Disconnect(); + EXPECT_FALSE(accepted_socket->IsConnected()); + EXPECT_FALSE(client_socket.IsConnected()); +} + +} // namespace +} // namespace net diff --git a/chromium/net/socket/unix_domain_listen_socket_posix.cc b/chromium/net/socket/unix_domain_listen_socket_posix.cc new file mode 100644 index 00000000000..3e46439c8b5 --- /dev/null +++ b/chromium/net/socket/unix_domain_listen_socket_posix.cc @@ -0,0 +1,167 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/unix_domain_listen_socket_posix.h" + +#include <errno.h> +#include <sys/socket.h> +#include <sys/stat.h> +#include <sys/types.h> +#include <sys/un.h> +#include <unistd.h> + +#include <cstring> +#include <string> + +#include "base/bind.h" +#include "base/callback.h" +#include "base/posix/eintr_wrapper.h" +#include "base/threading/platform_thread.h" +#include "build/build_config.h" +#include "net/base/net_errors.h" +#include "net/base/net_util.h" +#include "net/socket/socket_descriptor.h" +#include "net/socket/unix_domain_client_socket_posix.h" + +namespace net { +namespace deprecated { + +namespace { + +int CreateAndBind(const std::string& socket_path, + bool use_abstract_namespace, + SocketDescriptor* socket_fd) { + DCHECK(socket_fd); + + SockaddrStorage address; + if (!UnixDomainClientSocket::FillAddress(socket_path, + use_abstract_namespace, + &address)) { + return ERR_ADDRESS_INVALID; + } + + SocketDescriptor fd = CreatePlatformSocket(PF_UNIX, SOCK_STREAM, 0); + if (fd == kInvalidSocket) + return errno ? MapSystemError(errno) : ERR_UNEXPECTED; + + if (bind(fd, address.addr, address.addr_len) < 0) { + int rv = MapSystemError(errno); + close(fd); + PLOG(ERROR) << "Could not bind unix domain socket to " << socket_path + << (use_abstract_namespace ? " (with abstract namespace)" : ""); + return rv; + } + + *socket_fd = fd; + return OK; +} + +} // namespace + +// static +scoped_ptr<UnixDomainListenSocket> +UnixDomainListenSocket::CreateAndListenInternal( + const std::string& path, + const std::string& fallback_path, + StreamListenSocket::Delegate* del, + const AuthCallback& auth_callback, + bool use_abstract_namespace) { + SocketDescriptor socket_fd = kInvalidSocket; + int rv = CreateAndBind(path, use_abstract_namespace, &socket_fd); + if (rv != OK && !fallback_path.empty()) + rv = CreateAndBind(fallback_path, use_abstract_namespace, &socket_fd); + if (rv != OK) + return scoped_ptr<UnixDomainListenSocket>(); + scoped_ptr<UnixDomainListenSocket> sock( + new UnixDomainListenSocket(socket_fd, del, auth_callback)); + sock->Listen(); + return sock.Pass(); +} + +// static +scoped_ptr<UnixDomainListenSocket> UnixDomainListenSocket::CreateAndListen( + const std::string& path, + StreamListenSocket::Delegate* del, + const AuthCallback& auth_callback) { + return CreateAndListenInternal(path, "", del, auth_callback, false); +} + +#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED) +// static +scoped_ptr<UnixDomainListenSocket> +UnixDomainListenSocket::CreateAndListenWithAbstractNamespace( + const std::string& path, + const std::string& fallback_path, + StreamListenSocket::Delegate* del, + const AuthCallback& auth_callback) { + return + CreateAndListenInternal(path, fallback_path, del, auth_callback, true); +} +#endif + +UnixDomainListenSocket::UnixDomainListenSocket( + SocketDescriptor s, + StreamListenSocket::Delegate* del, + const AuthCallback& auth_callback) + : StreamListenSocket(s, del), + auth_callback_(auth_callback) {} + +UnixDomainListenSocket::~UnixDomainListenSocket() {} + +void UnixDomainListenSocket::Accept() { + SocketDescriptor conn = StreamListenSocket::AcceptSocket(); + if (conn == kInvalidSocket) + return; + UnixDomainServerSocket::Credentials credentials; + if (!UnixDomainServerSocket::GetPeerCredentials(conn, &credentials) || + !auth_callback_.Run(credentials)) { + if (IGNORE_EINTR(close(conn)) < 0) + LOG(ERROR) << "close() error"; + return; + } + scoped_ptr<UnixDomainListenSocket> sock( + new UnixDomainListenSocket(conn, socket_delegate_, auth_callback_)); + // It's up to the delegate to AddRef if it wants to keep it around. + sock->WatchSocket(WAITING_READ); + socket_delegate_->DidAccept(this, sock.Pass()); +} + +UnixDomainListenSocketFactory::UnixDomainListenSocketFactory( + const std::string& path, + const UnixDomainListenSocket::AuthCallback& auth_callback) + : path_(path), + auth_callback_(auth_callback) {} + +UnixDomainListenSocketFactory::~UnixDomainListenSocketFactory() {} + +scoped_ptr<StreamListenSocket> UnixDomainListenSocketFactory::CreateAndListen( + StreamListenSocket::Delegate* delegate) const { + return UnixDomainListenSocket::CreateAndListen( + path_, delegate, auth_callback_).Pass(); +} + +#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED) + +UnixDomainListenSocketWithAbstractNamespaceFactory:: +UnixDomainListenSocketWithAbstractNamespaceFactory( + const std::string& path, + const std::string& fallback_path, + const UnixDomainListenSocket::AuthCallback& auth_callback) + : UnixDomainListenSocketFactory(path, auth_callback), + fallback_path_(fallback_path) {} + +UnixDomainListenSocketWithAbstractNamespaceFactory:: +~UnixDomainListenSocketWithAbstractNamespaceFactory() {} + +scoped_ptr<StreamListenSocket> +UnixDomainListenSocketWithAbstractNamespaceFactory::CreateAndListen( + StreamListenSocket::Delegate* delegate) const { + return UnixDomainListenSocket::CreateAndListenWithAbstractNamespace( + path_, fallback_path_, delegate, auth_callback_); +} + +#endif + +} // namespace deprecated +} // namespace net diff --git a/chromium/net/socket/unix_domain_listen_socket_posix.h b/chromium/net/socket/unix_domain_listen_socket_posix.h new file mode 100644 index 00000000000..82ec342edaa --- /dev/null +++ b/chromium/net/socket/unix_domain_listen_socket_posix.h @@ -0,0 +1,122 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_UNIX_DOMAIN_LISTEN_SOCKET_POSIX_H_ +#define NET_SOCKET_UNIX_DOMAIN_LISTEN_SOCKET_POSIX_H_ + +#include <string> + +#include "base/basictypes.h" +#include "base/callback_forward.h" +#include "base/compiler_specific.h" +#include "base/macros.h" +#include "build/build_config.h" +#include "net/base/net_export.h" +#include "net/socket/stream_listen_socket.h" +#include "net/socket/unix_domain_server_socket_posix.h" + +#if defined(OS_ANDROID) || defined(OS_LINUX) +// Feature only supported on Linux currently. This lets the Unix Domain Socket +// not be backed by the file system. +#define SOCKET_ABSTRACT_NAMESPACE_SUPPORTED +#endif + +namespace net { +namespace deprecated { + +// Unix Domain Socket Implementation. Supports abstract namespaces on Linux. +class NET_EXPORT UnixDomainListenSocket : public StreamListenSocket { + public: + typedef UnixDomainServerSocket::AuthCallback AuthCallback; + + ~UnixDomainListenSocket() override; + + // Note that the returned UnixDomainListenSocket instance does not take + // ownership of |del|. + static scoped_ptr<UnixDomainListenSocket> CreateAndListen( + const std::string& path, + StreamListenSocket::Delegate* del, + const AuthCallback& auth_callback); + +#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED) + // Same as above except that the created socket uses the abstract namespace + // which is a Linux-only feature. If |fallback_path| is not empty, + // make the second attempt with the provided fallback name. + static scoped_ptr<UnixDomainListenSocket> + CreateAndListenWithAbstractNamespace( + const std::string& path, + const std::string& fallback_path, + StreamListenSocket::Delegate* del, + const AuthCallback& auth_callback); +#endif + + private: + UnixDomainListenSocket(SocketDescriptor s, + StreamListenSocket::Delegate* del, + const AuthCallback& auth_callback); + + static scoped_ptr<UnixDomainListenSocket> CreateAndListenInternal( + const std::string& path, + const std::string& fallback_path, + StreamListenSocket::Delegate* del, + const AuthCallback& auth_callback, + bool use_abstract_namespace); + + // StreamListenSocket: + void Accept() override; + + AuthCallback auth_callback_; + + DISALLOW_COPY_AND_ASSIGN(UnixDomainListenSocket); +}; + +// Factory that can be used to instantiate UnixDomainListenSocket. +class NET_EXPORT UnixDomainListenSocketFactory + : public StreamListenSocketFactory { + public: + // Note that this class does not take ownership of the provided delegate. + UnixDomainListenSocketFactory( + const std::string& path, + const UnixDomainListenSocket::AuthCallback& auth_callback); + ~UnixDomainListenSocketFactory() override; + + // StreamListenSocketFactory: + scoped_ptr<StreamListenSocket> CreateAndListen( + StreamListenSocket::Delegate* delegate) const override; + + protected: + const std::string path_; + const UnixDomainListenSocket::AuthCallback auth_callback_; + + private: + DISALLOW_COPY_AND_ASSIGN(UnixDomainListenSocketFactory); +}; + +#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED) +// Use this factory to instantiate UnixDomainListenSocket using the abstract +// namespace feature (only supported on Linux). +class NET_EXPORT UnixDomainListenSocketWithAbstractNamespaceFactory + : public UnixDomainListenSocketFactory { + public: + UnixDomainListenSocketWithAbstractNamespaceFactory( + const std::string& path, + const std::string& fallback_path, + const UnixDomainListenSocket::AuthCallback& auth_callback); + ~UnixDomainListenSocketWithAbstractNamespaceFactory() override; + + // UnixDomainListenSocketFactory: + scoped_ptr<StreamListenSocket> CreateAndListen( + StreamListenSocket::Delegate* delegate) const override; + + private: + std::string fallback_path_; + + DISALLOW_COPY_AND_ASSIGN(UnixDomainListenSocketWithAbstractNamespaceFactory); +}; +#endif + +} // namespace deprecated +} // namespace net + +#endif // NET_SOCKET_UNIX_DOMAIN_LISTEN_SOCKET_POSIX_H_ diff --git a/chromium/net/socket/unix_domain_socket_posix_unittest.cc b/chromium/net/socket/unix_domain_listen_socket_posix_unittest.cc index b1857e62e0e..aaf362310c3 100644 --- a/chromium/net/socket/unix_domain_socket_posix_unittest.cc +++ b/chromium/net/socket/unix_domain_listen_socket_posix_unittest.cc @@ -1,7 +1,9 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Copyright 2014 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +#include "net/socket/unix_domain_listen_socket_posix.h" + #include <errno.h> #include <fcntl.h> #include <poll.h> @@ -19,8 +21,9 @@ #include "base/bind.h" #include "base/callback.h" #include "base/compiler_specific.h" -#include "base/file_util.h" #include "base/files/file_path.h" +#include "base/files/file_util.h" +#include "base/files/scoped_temp_dir.h" #include "base/memory/ref_counted.h" #include "base/memory/scoped_ptr.h" #include "base/message_loop/message_loop.h" @@ -30,16 +33,16 @@ #include "base/threading/platform_thread.h" #include "base/threading/thread.h" #include "net/socket/socket_descriptor.h" -#include "net/socket/unix_domain_socket_posix.h" #include "testing/gtest/include/gtest/gtest.h" using std::queue; using std::string; namespace net { +namespace deprecated { namespace { -const char kSocketFilename[] = "unix_domain_socket_for_testing"; +const char kSocketFilename[] = "socket_for_testing"; const char kInvalidSocketPath[] = "/invalid/path"; const char kMsg[] = "hello"; @@ -52,16 +55,6 @@ enum EventType { EVENT_READ, }; -string MakeSocketPath(const string& socket_file_name) { - base::FilePath temp_dir; - base::GetTempDir(&temp_dir); - return temp_dir.Append(socket_file_name).value(); -} - -string MakeSocketPath() { - return MakeSocketPath(kSocketFilename); -} - class EventManager : public base::RefCounted<EventManager> { public: EventManager() : condition_(&mutex_) {} @@ -101,16 +94,16 @@ class TestListenSocketDelegate : public StreamListenSocket::Delegate { const scoped_refptr<EventManager>& event_manager) : event_manager_(event_manager) {} - virtual void DidAccept(StreamListenSocket* server, - scoped_ptr<StreamListenSocket> connection) OVERRIDE { + void DidAccept(StreamListenSocket* server, + scoped_ptr<StreamListenSocket> connection) override { LOG(ERROR) << __PRETTY_FUNCTION__; connection_ = connection.Pass(); Notify(EVENT_ACCEPT); } - virtual void DidRead(StreamListenSocket* connection, - const char* data, - int len) OVERRIDE { + void DidRead(StreamListenSocket* connection, + const char* data, + int len) override { { base::AutoLock lock(mutex_); DCHECK(len); @@ -119,9 +112,7 @@ class TestListenSocketDelegate : public StreamListenSocket::Delegate { Notify(EVENT_READ); } - virtual void DidClose(StreamListenSocket* sock) OVERRIDE { - Notify(EVENT_CLOSE); - } + void DidClose(StreamListenSocket* sock) override { Notify(EVENT_CLOSE); } void OnListenCompleted() { Notify(EVENT_LISTEN); @@ -145,47 +136,52 @@ class TestListenSocketDelegate : public StreamListenSocket::Delegate { bool UserCanConnectCallback( bool allow_user, const scoped_refptr<EventManager>& event_manager, - uid_t, gid_t) { + const UnixDomainServerSocket::Credentials&) { event_manager->Notify( allow_user ? EVENT_AUTH_GRANTED : EVENT_AUTH_DENIED); return allow_user; } -class UnixDomainSocketTestHelper : public testing::Test { +class UnixDomainListenSocketTestHelper : public testing::Test { public: void CreateAndListen() { - socket_ = UnixDomainSocket::CreateAndListen( + socket_ = UnixDomainListenSocket::CreateAndListen( file_path_.value(), socket_delegate_.get(), MakeAuthCallback()); socket_delegate_->OnListenCompleted(); } protected: - UnixDomainSocketTestHelper(const string& path, bool allow_user) - : file_path_(path), - allow_user_(allow_user) {} + UnixDomainListenSocketTestHelper(const string& path_str, bool allow_user) + : allow_user_(allow_user) { + file_path_ = base::FilePath(path_str); + if (!file_path_.IsAbsolute()) { + EXPECT_TRUE(temp_dir_.CreateUniqueTempDir()); + file_path_ = GetTempSocketPath(file_path_.value()); + } + // Beware that if path_str is an absolute path, this class doesn't delete + // the file. It must be an invalid path and cannot be created by unittests. + } + + base::FilePath GetTempSocketPath(const std::string socket_name) { + DCHECK(temp_dir_.IsValid()); + return temp_dir_.path().Append(socket_name); + } - virtual void SetUp() OVERRIDE { + void SetUp() override { event_manager_ = new EventManager(); socket_delegate_.reset(new TestListenSocketDelegate(event_manager_)); - DeleteSocketFile(); } - virtual void TearDown() OVERRIDE { - DeleteSocketFile(); + void TearDown() override { socket_.reset(); socket_delegate_.reset(); event_manager_ = NULL; } - UnixDomainSocket::AuthCallback MakeAuthCallback() { + UnixDomainListenSocket::AuthCallback MakeAuthCallback() { return base::Bind(&UserCanConnectCallback, allow_user_, event_manager_); } - void DeleteSocketFile() { - ASSERT_FALSE(file_path_.empty()); - base::DeleteFile(file_path_, false /* not recursive */); - } - SocketDescriptor CreateClientSocket() { const SocketDescriptor sock = CreatePlatformSocket(PF_UNIX, SOCK_STREAM, 0); if (sock < 0) { @@ -199,7 +195,8 @@ class UnixDomainSocketTestHelper : public testing::Test { strncpy(addr.sun_path, file_path_.value().c_str(), sizeof(addr.sun_path)); addr_len = sizeof(sockaddr_un); if (connect(sock, reinterpret_cast<sockaddr*>(&addr), addr_len) != 0) { - LOG(ERROR) << "connect() error"; + LOG(ERROR) << "connect() error: " << strerror(errno) + << ": path=" << file_path_.value(); return kInvalidSocket; } return sock; @@ -212,43 +209,48 @@ class UnixDomainSocketTestHelper : public testing::Test { thread->StartWithOptions(options); thread->message_loop()->PostTask( FROM_HERE, - base::Bind(&UnixDomainSocketTestHelper::CreateAndListen, + base::Bind(&UnixDomainListenSocketTestHelper::CreateAndListen, base::Unretained(this))); return thread.Pass(); } - const base::FilePath file_path_; + base::ScopedTempDir temp_dir_; + base::FilePath file_path_; const bool allow_user_; scoped_refptr<EventManager> event_manager_; scoped_ptr<TestListenSocketDelegate> socket_delegate_; - scoped_ptr<UnixDomainSocket> socket_; + scoped_ptr<UnixDomainListenSocket> socket_; }; -class UnixDomainSocketTest : public UnixDomainSocketTestHelper { +class UnixDomainListenSocketTest : public UnixDomainListenSocketTestHelper { protected: - UnixDomainSocketTest() - : UnixDomainSocketTestHelper(MakeSocketPath(), true /* allow user */) {} + UnixDomainListenSocketTest() + : UnixDomainListenSocketTestHelper(kSocketFilename, + true /* allow user */) {} }; -class UnixDomainSocketTestWithInvalidPath : public UnixDomainSocketTestHelper { +class UnixDomainListenSocketTestWithInvalidPath + : public UnixDomainListenSocketTestHelper { protected: - UnixDomainSocketTestWithInvalidPath() - : UnixDomainSocketTestHelper(kInvalidSocketPath, true) {} + UnixDomainListenSocketTestWithInvalidPath() + : UnixDomainListenSocketTestHelper(kInvalidSocketPath, true) {} }; -class UnixDomainSocketTestWithForbiddenUser - : public UnixDomainSocketTestHelper { +class UnixDomainListenSocketTestWithForbiddenUser + : public UnixDomainListenSocketTestHelper { protected: - UnixDomainSocketTestWithForbiddenUser() - : UnixDomainSocketTestHelper(MakeSocketPath(), false /* forbid user */) {} + UnixDomainListenSocketTestWithForbiddenUser() + : UnixDomainListenSocketTestHelper(kSocketFilename, + false /* forbid user */) {} }; -TEST_F(UnixDomainSocketTest, CreateAndListen) { +TEST_F(UnixDomainListenSocketTest, CreateAndListen) { CreateAndListen(); EXPECT_FALSE(socket_.get() == NULL); } -TEST_F(UnixDomainSocketTestWithInvalidPath, CreateAndListenWithInvalidPath) { +TEST_F(UnixDomainListenSocketTestWithInvalidPath, + CreateAndListenWithInvalidPath) { CreateAndListen(); EXPECT_TRUE(socket_.get() == NULL); } @@ -256,35 +258,35 @@ TEST_F(UnixDomainSocketTestWithInvalidPath, CreateAndListenWithInvalidPath) { #ifdef SOCKET_ABSTRACT_NAMESPACE_SUPPORTED // Test with an invalid path to make sure that the socket is not backed by a // file. -TEST_F(UnixDomainSocketTestWithInvalidPath, +TEST_F(UnixDomainListenSocketTestWithInvalidPath, CreateAndListenWithAbstractNamespace) { - socket_ = UnixDomainSocket::CreateAndListenWithAbstractNamespace( + socket_ = UnixDomainListenSocket::CreateAndListenWithAbstractNamespace( file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback()); EXPECT_FALSE(socket_.get() == NULL); } -TEST_F(UnixDomainSocketTest, TestFallbackName) { - scoped_ptr<UnixDomainSocket> existing_socket = - UnixDomainSocket::CreateAndListenWithAbstractNamespace( +TEST_F(UnixDomainListenSocketTest, TestFallbackName) { + scoped_ptr<UnixDomainListenSocket> existing_socket = + UnixDomainListenSocket::CreateAndListenWithAbstractNamespace( file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback()); EXPECT_FALSE(existing_socket.get() == NULL); // First, try to bind socket with the same name with no fallback name. socket_ = - UnixDomainSocket::CreateAndListenWithAbstractNamespace( + UnixDomainListenSocket::CreateAndListenWithAbstractNamespace( file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback()); EXPECT_TRUE(socket_.get() == NULL); // Now with a fallback name. - const char kFallbackSocketName[] = "unix_domain_socket_for_testing_2"; - socket_ = UnixDomainSocket::CreateAndListenWithAbstractNamespace( + const char kFallbackSocketName[] = "socket_for_testing_2"; + socket_ = UnixDomainListenSocket::CreateAndListenWithAbstractNamespace( file_path_.value(), - MakeSocketPath(kFallbackSocketName), + GetTempSocketPath(kFallbackSocketName).value(), socket_delegate_.get(), MakeAuthCallback()); EXPECT_FALSE(socket_.get() == NULL); } #endif -TEST_F(UnixDomainSocketTest, TestWithClient) { +TEST_F(UnixDomainListenSocketTest, TestWithClient) { const scoped_ptr<base::Thread> server_thread = CreateAndRunServerThread(); EventType event = event_manager_->WaitForEvent(); ASSERT_EQ(EVENT_LISTEN, event); @@ -311,7 +313,7 @@ TEST_F(UnixDomainSocketTest, TestWithClient) { ASSERT_EQ(EVENT_CLOSE, event); } -TEST_F(UnixDomainSocketTestWithForbiddenUser, TestWithForbiddenUser) { +TEST_F(UnixDomainListenSocketTestWithForbiddenUser, TestWithForbiddenUser) { const scoped_ptr<base::Thread> server_thread = CreateAndRunServerThread(); EventType event = event_manager_->WaitForEvent(); ASSERT_EQ(EVENT_LISTEN, event); @@ -335,4 +337,5 @@ TEST_F(UnixDomainSocketTestWithForbiddenUser, TestWithForbiddenUser) { } } // namespace +} // namespace deprecated } // namespace net diff --git a/chromium/net/socket/unix_domain_server_socket_posix.cc b/chromium/net/socket/unix_domain_server_socket_posix.cc new file mode 100644 index 00000000000..4d6328310ff --- /dev/null +++ b/chromium/net/socket/unix_domain_server_socket_posix.cc @@ -0,0 +1,186 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/unix_domain_server_socket_posix.h" + +#include <errno.h> +#include <sys/socket.h> +#include <sys/un.h> +#include <unistd.h> + +#include "base/logging.h" +#include "net/base/net_errors.h" +#include "net/socket/socket_libevent.h" +#include "net/socket/unix_domain_client_socket_posix.h" + +namespace net { + +namespace { + +// Intended for use as SetterCallbacks in Accept() helper methods. +void SetStreamSocket(scoped_ptr<StreamSocket>* socket, + scoped_ptr<SocketLibevent> accepted_socket) { + socket->reset(new UnixDomainClientSocket(accepted_socket.Pass())); +} + +void SetSocketDescriptor(SocketDescriptor* socket, + scoped_ptr<SocketLibevent> accepted_socket) { + *socket = accepted_socket->ReleaseConnectedSocket(); +} + +} // anonymous namespace + +UnixDomainServerSocket::UnixDomainServerSocket( + const AuthCallback& auth_callback, + bool use_abstract_namespace) + : auth_callback_(auth_callback), + use_abstract_namespace_(use_abstract_namespace) { + DCHECK(!auth_callback_.is_null()); +} + +UnixDomainServerSocket::~UnixDomainServerSocket() { +} + +// static +bool UnixDomainServerSocket::GetPeerCredentials(SocketDescriptor socket, + Credentials* credentials) { +#if defined(OS_LINUX) || defined(OS_ANDROID) + struct ucred user_cred; + socklen_t len = sizeof(user_cred); + if (getsockopt(socket, SOL_SOCKET, SO_PEERCRED, &user_cred, &len) < 0) + return false; + credentials->process_id = user_cred.pid; + credentials->user_id = user_cred.uid; + credentials->group_id = user_cred.gid; + return true; +#else + return getpeereid( + socket, &credentials->user_id, &credentials->group_id) == 0; +#endif +} + +int UnixDomainServerSocket::Listen(const IPEndPoint& address, int backlog) { + NOTIMPLEMENTED(); + return ERR_NOT_IMPLEMENTED; +} + +int UnixDomainServerSocket::ListenWithAddressAndPort( + const std::string& unix_domain_path, + int port_unused, + int backlog) { + DCHECK(!listen_socket_); + + SockaddrStorage address; + if (!UnixDomainClientSocket::FillAddress(unix_domain_path, + use_abstract_namespace_, + &address)) { + return ERR_ADDRESS_INVALID; + } + + scoped_ptr<SocketLibevent> socket(new SocketLibevent); + int rv = socket->Open(AF_UNIX); + DCHECK_NE(ERR_IO_PENDING, rv); + if (rv != OK) + return rv; + + rv = socket->Bind(address); + DCHECK_NE(ERR_IO_PENDING, rv); + if (rv != OK) { + PLOG(ERROR) + << "Could not bind unix domain socket to " << unix_domain_path + << (use_abstract_namespace_ ? " (with abstract namespace)" : ""); + return rv; + } + + rv = socket->Listen(backlog); + DCHECK_NE(ERR_IO_PENDING, rv); + if (rv != OK) + return rv; + + listen_socket_.swap(socket); + return rv; +} + +int UnixDomainServerSocket::GetLocalAddress(IPEndPoint* address) const { + NOTIMPLEMENTED(); + return ERR_NOT_IMPLEMENTED; +} + +int UnixDomainServerSocket::Accept(scoped_ptr<StreamSocket>* socket, + const CompletionCallback& callback) { + DCHECK(socket); + + SetterCallback setter_callback = base::Bind(&SetStreamSocket, socket); + return DoAccept(setter_callback, callback); +} + +int UnixDomainServerSocket::AcceptSocketDescriptor( + SocketDescriptor* socket, + const CompletionCallback& callback) { + DCHECK(socket); + + SetterCallback setter_callback = base::Bind(&SetSocketDescriptor, socket); + return DoAccept(setter_callback, callback); +} + +int UnixDomainServerSocket::DoAccept(const SetterCallback& setter_callback, + const CompletionCallback& callback) { + DCHECK(!setter_callback.is_null()); + DCHECK(!callback.is_null()); + DCHECK(listen_socket_); + DCHECK(!accept_socket_); + + while (true) { + int rv = listen_socket_->Accept( + &accept_socket_, + base::Bind(&UnixDomainServerSocket::AcceptCompleted, + base::Unretained(this), + setter_callback, + callback)); + if (rv != OK) + return rv; + if (AuthenticateAndGetStreamSocket(setter_callback)) + return OK; + // Accept another socket because authentication error should be transparent + // to the caller. + } +} + +void UnixDomainServerSocket::AcceptCompleted( + const SetterCallback& setter_callback, + const CompletionCallback& callback, + int rv) { + if (rv != OK) { + callback.Run(rv); + return; + } + + if (AuthenticateAndGetStreamSocket(setter_callback)) { + callback.Run(OK); + return; + } + + // Accept another socket because authentication error should be transparent + // to the caller. + rv = DoAccept(setter_callback, callback); + if (rv != ERR_IO_PENDING) + callback.Run(rv); +} + +bool UnixDomainServerSocket::AuthenticateAndGetStreamSocket( + const SetterCallback& setter_callback) { + DCHECK(accept_socket_); + + Credentials credentials; + if (!GetPeerCredentials(accept_socket_->socket_fd(), &credentials) || + !auth_callback_.Run(credentials)) { + accept_socket_.reset(); + return false; + } + + setter_callback.Run(accept_socket_.Pass()); + return true; +} + +} // namespace net diff --git a/chromium/net/socket/unix_domain_server_socket_posix.h b/chromium/net/socket/unix_domain_server_socket_posix.h new file mode 100644 index 00000000000..0a26eb3d375 --- /dev/null +++ b/chromium/net/socket/unix_domain_server_socket_posix.h @@ -0,0 +1,91 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_UNIX_DOMAIN_SERVER_SOCKET_POSIX_H_ +#define NET_SOCKET_UNIX_DOMAIN_SERVER_SOCKET_POSIX_H_ + +#include <sys/types.h> + +#include <string> + +#include "base/basictypes.h" +#include "base/callback.h" +#include "base/macros.h" +#include "base/memory/scoped_ptr.h" +#include "net/base/net_export.h" +#include "net/socket/server_socket.h" +#include "net/socket/socket_descriptor.h" + +namespace net { + +class SocketLibevent; + +// Unix Domain Server Socket Implementation. Supports abstract namespaces on +// Linux and Android. +class NET_EXPORT UnixDomainServerSocket : public ServerSocket { + public: + // Credentials of a peer process connected to the socket. + struct NET_EXPORT Credentials { +#if defined(OS_LINUX) || defined(OS_ANDROID) + // Linux/Android API provides more information about the connected peer + // than Windows/OS X. It's useful for permission-based authorization on + // Android. + pid_t process_id; +#endif + uid_t user_id; + gid_t group_id; + }; + + // Callback that returns whether the already connected client, identified by + // its credentials, is allowed to keep the connection open. Note that + // the socket is closed immediately in case the callback returns false. + typedef base::Callback<bool (const Credentials&)> AuthCallback; + + UnixDomainServerSocket(const AuthCallback& auth_callack, + bool use_abstract_namespace); + ~UnixDomainServerSocket() override; + + // Gets credentials of peer to check permissions. + static bool GetPeerCredentials(SocketDescriptor socket_fd, + Credentials* credentials); + + // ServerSocket implementation. + int Listen(const IPEndPoint& address, int backlog) override; + int ListenWithAddressAndPort(const std::string& unix_domain_path, + int port_unused, + int backlog) override; + int GetLocalAddress(IPEndPoint* address) const override; + int Accept(scoped_ptr<StreamSocket>* socket, + const CompletionCallback& callback) override; + + // Accepts an incoming connection on |listen_socket_|, but passes back + // a raw SocketDescriptor instead of a StreamSocket. + int AcceptSocketDescriptor(SocketDescriptor* socket_descriptor, + const CompletionCallback& callback); + + private: + // A callback to wrap the setting of the out-parameter to Accept(). + // This allows the internal machinery of that call to be implemented in + // a manner that's agnostic to the caller's desired output. + typedef base::Callback<void(scoped_ptr<SocketLibevent>)> SetterCallback; + + int DoAccept(const SetterCallback& setter_callback, + const CompletionCallback& callback); + void AcceptCompleted(const SetterCallback& setter_callback, + const CompletionCallback& callback, + int rv); + bool AuthenticateAndGetStreamSocket(const SetterCallback& setter_callback); + + scoped_ptr<SocketLibevent> listen_socket_; + const AuthCallback auth_callback_; + const bool use_abstract_namespace_; + + scoped_ptr<SocketLibevent> accept_socket_; + + DISALLOW_COPY_AND_ASSIGN(UnixDomainServerSocket); +}; + +} // namespace net + +#endif // NET_SOCKET_UNIX_DOMAIN_SOCKET_POSIX_H_ diff --git a/chromium/net/socket/unix_domain_server_socket_posix_unittest.cc b/chromium/net/socket/unix_domain_server_socket_posix_unittest.cc new file mode 100644 index 00000000000..bdf1efa29c4 --- /dev/null +++ b/chromium/net/socket/unix_domain_server_socket_posix_unittest.cc @@ -0,0 +1,125 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/unix_domain_server_socket_posix.h" + +#include <vector> + +#include "base/bind.h" +#include "base/files/file_path.h" +#include "base/files/scoped_temp_dir.h" +#include "base/memory/scoped_ptr.h" +#include "base/run_loop.h" +#include "base/stl_util.h" +#include "net/base/io_buffer.h" +#include "net/base/net_errors.h" +#include "net/base/test_completion_callback.h" +#include "net/socket/unix_domain_client_socket_posix.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { +namespace { + +const char kSocketFilename[] = "socket_for_testing"; +const char kInvalidSocketPath[] = "/invalid/path"; + +bool UserCanConnectCallback(bool allow_user, + const UnixDomainServerSocket::Credentials& credentials) { + // Here peers are running in same process. +#if defined(OS_LINUX) || defined(OS_ANDROID) + EXPECT_EQ(getpid(), credentials.process_id); +#endif + EXPECT_EQ(getuid(), credentials.user_id); + EXPECT_EQ(getgid(), credentials.group_id); + return allow_user; +} + +UnixDomainServerSocket::AuthCallback CreateAuthCallback(bool allow_user) { + return base::Bind(&UserCanConnectCallback, allow_user); +} + +class UnixDomainServerSocketTest : public testing::Test { + protected: + UnixDomainServerSocketTest() { + EXPECT_TRUE(temp_dir_.CreateUniqueTempDir()); + socket_path_ = temp_dir_.path().Append(kSocketFilename).value(); + } + + base::ScopedTempDir temp_dir_; + std::string socket_path_; +}; + +TEST_F(UnixDomainServerSocketTest, ListenWithInvalidPath) { + const bool kUseAbstractNamespace = false; + UnixDomainServerSocket server_socket(CreateAuthCallback(true), + kUseAbstractNamespace); + EXPECT_EQ(ERR_FILE_NOT_FOUND, + server_socket.ListenWithAddressAndPort(kInvalidSocketPath, 0, 1)); +} + +TEST_F(UnixDomainServerSocketTest, ListenWithInvalidPathWithAbstractNamespace) { + const bool kUseAbstractNamespace = true; + UnixDomainServerSocket server_socket(CreateAuthCallback(true), + kUseAbstractNamespace); +#if defined(OS_ANDROID) || defined(OS_LINUX) + EXPECT_EQ(OK, + server_socket.ListenWithAddressAndPort(kInvalidSocketPath, 0, 1)); +#else + EXPECT_EQ(ERR_ADDRESS_INVALID, + server_socket.ListenWithAddressAndPort(kInvalidSocketPath, 0, 1)); +#endif +} + +TEST_F(UnixDomainServerSocketTest, ListenAgainAfterFailureWithInvalidPath) { + const bool kUseAbstractNamespace = false; + UnixDomainServerSocket server_socket(CreateAuthCallback(true), + kUseAbstractNamespace); + EXPECT_EQ(ERR_FILE_NOT_FOUND, + server_socket.ListenWithAddressAndPort(kInvalidSocketPath, 0, 1)); + EXPECT_EQ(OK, server_socket.ListenWithAddressAndPort(socket_path_, 0, 1)); +} + +TEST_F(UnixDomainServerSocketTest, AcceptWithForbiddenUser) { + const bool kUseAbstractNamespace = false; + + UnixDomainServerSocket server_socket(CreateAuthCallback(false), + kUseAbstractNamespace); + EXPECT_EQ(OK, server_socket.ListenWithAddressAndPort(socket_path_, 0, 1)); + + scoped_ptr<StreamSocket> accepted_socket; + TestCompletionCallback accept_callback; + EXPECT_EQ(ERR_IO_PENDING, + server_socket.Accept(&accepted_socket, accept_callback.callback())); + EXPECT_FALSE(accepted_socket); + + UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace); + EXPECT_FALSE(client_socket.IsConnected()); + + // Connect() will return OK before the server rejects the connection. + TestCompletionCallback connect_callback; + int rv = connect_callback.GetResult( + client_socket.Connect(connect_callback.callback())); + ASSERT_EQ(OK, rv); + + // Try to read from the socket. + const int read_buffer_size = 10; + scoped_refptr<IOBuffer> read_buffer(new IOBuffer(read_buffer_size)); + TestCompletionCallback read_callback; + rv = read_callback.GetResult(client_socket.Read( + read_buffer.get(), read_buffer_size, read_callback.callback())); + + // The server should have disconnected gracefully, without sending any data. + ASSERT_EQ(0, rv); + EXPECT_FALSE(client_socket.IsConnected()); + + // The server socket should not have called |accept_callback| or modified + // |accepted_socket|. + EXPECT_FALSE(accept_callback.have_result()); + EXPECT_FALSE(accepted_socket); +} + +// Normal cases including read/write are tested by UnixDomainClientSocketTest. + +} // namespace +} // namespace net diff --git a/chromium/net/socket/unix_domain_socket_posix.cc b/chromium/net/socket/unix_domain_socket_posix.cc deleted file mode 100644 index 3141f7166b2..00000000000 --- a/chromium/net/socket/unix_domain_socket_posix.cc +++ /dev/null @@ -1,196 +0,0 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "net/socket/unix_domain_socket_posix.h" - -#include <cstring> -#include <string> - -#include <errno.h> -#include <sys/socket.h> -#include <sys/stat.h> -#include <sys/types.h> -#include <sys/un.h> -#include <unistd.h> - -#include "base/bind.h" -#include "base/callback.h" -#include "base/posix/eintr_wrapper.h" -#include "base/threading/platform_thread.h" -#include "build/build_config.h" -#include "net/base/net_errors.h" -#include "net/base/net_util.h" -#include "net/socket/socket_descriptor.h" - -namespace net { - -namespace { - -bool NoAuthenticationCallback(uid_t, gid_t) { - return true; -} - -bool GetPeerIds(int socket, uid_t* user_id, gid_t* group_id) { -#if defined(OS_LINUX) || defined(OS_ANDROID) - struct ucred user_cred; - socklen_t len = sizeof(user_cred); - if (getsockopt(socket, SOL_SOCKET, SO_PEERCRED, &user_cred, &len) == -1) - return false; - *user_id = user_cred.uid; - *group_id = user_cred.gid; -#else - if (getpeereid(socket, user_id, group_id) == -1) - return false; -#endif - return true; -} - -} // namespace - -// static -UnixDomainSocket::AuthCallback UnixDomainSocket::NoAuthentication() { - return base::Bind(NoAuthenticationCallback); -} - -// static -scoped_ptr<UnixDomainSocket> UnixDomainSocket::CreateAndListenInternal( - const std::string& path, - const std::string& fallback_path, - StreamListenSocket::Delegate* del, - const AuthCallback& auth_callback, - bool use_abstract_namespace) { - SocketDescriptor s = CreateAndBind(path, use_abstract_namespace); - if (s == kInvalidSocket && !fallback_path.empty()) - s = CreateAndBind(fallback_path, use_abstract_namespace); - if (s == kInvalidSocket) - return scoped_ptr<UnixDomainSocket>(); - scoped_ptr<UnixDomainSocket> sock( - new UnixDomainSocket(s, del, auth_callback)); - sock->Listen(); - return sock.Pass(); -} - -// static -scoped_ptr<UnixDomainSocket> UnixDomainSocket::CreateAndListen( - const std::string& path, - StreamListenSocket::Delegate* del, - const AuthCallback& auth_callback) { - return CreateAndListenInternal(path, "", del, auth_callback, false); -} - -#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED) -// static -scoped_ptr<UnixDomainSocket> -UnixDomainSocket::CreateAndListenWithAbstractNamespace( - const std::string& path, - const std::string& fallback_path, - StreamListenSocket::Delegate* del, - const AuthCallback& auth_callback) { - return - CreateAndListenInternal(path, fallback_path, del, auth_callback, true); -} -#endif - -UnixDomainSocket::UnixDomainSocket( - SocketDescriptor s, - StreamListenSocket::Delegate* del, - const AuthCallback& auth_callback) - : StreamListenSocket(s, del), - auth_callback_(auth_callback) {} - -UnixDomainSocket::~UnixDomainSocket() {} - -// static -SocketDescriptor UnixDomainSocket::CreateAndBind(const std::string& path, - bool use_abstract_namespace) { - sockaddr_un addr; - static const size_t kPathMax = sizeof(addr.sun_path); - if (use_abstract_namespace + path.size() + 1 /* '\0' */ > kPathMax) - return kInvalidSocket; - const SocketDescriptor s = CreatePlatformSocket(PF_UNIX, SOCK_STREAM, 0); - if (s == kInvalidSocket) - return kInvalidSocket; - memset(&addr, 0, sizeof(addr)); - addr.sun_family = AF_UNIX; - socklen_t addr_len; - if (use_abstract_namespace) { - // Convert the path given into abstract socket name. It must start with - // the '\0' character, so we are adding it. |addr_len| must specify the - // length of the structure exactly, as potentially the socket name may - // have '\0' characters embedded (although we don't support this). - // Note that addr.sun_path is already zero initialized. - memcpy(addr.sun_path + 1, path.c_str(), path.size()); - addr_len = path.size() + offsetof(struct sockaddr_un, sun_path) + 1; - } else { - memcpy(addr.sun_path, path.c_str(), path.size()); - addr_len = sizeof(sockaddr_un); - } - if (bind(s, reinterpret_cast<sockaddr*>(&addr), addr_len)) { - LOG(ERROR) << "Could not bind unix domain socket to " << path; - if (use_abstract_namespace) - LOG(ERROR) << " (with abstract namespace enabled)"; - if (IGNORE_EINTR(close(s)) < 0) - LOG(ERROR) << "close() error"; - return kInvalidSocket; - } - return s; -} - -void UnixDomainSocket::Accept() { - SocketDescriptor conn = StreamListenSocket::AcceptSocket(); - if (conn == kInvalidSocket) - return; - uid_t user_id; - gid_t group_id; - if (!GetPeerIds(conn, &user_id, &group_id) || - !auth_callback_.Run(user_id, group_id)) { - if (IGNORE_EINTR(close(conn)) < 0) - LOG(ERROR) << "close() error"; - return; - } - scoped_ptr<UnixDomainSocket> sock( - new UnixDomainSocket(conn, socket_delegate_, auth_callback_)); - // It's up to the delegate to AddRef if it wants to keep it around. - sock->WatchSocket(WAITING_READ); - socket_delegate_->DidAccept(this, sock.PassAs<StreamListenSocket>()); -} - -UnixDomainSocketFactory::UnixDomainSocketFactory( - const std::string& path, - const UnixDomainSocket::AuthCallback& auth_callback) - : path_(path), - auth_callback_(auth_callback) {} - -UnixDomainSocketFactory::~UnixDomainSocketFactory() {} - -scoped_ptr<StreamListenSocket> UnixDomainSocketFactory::CreateAndListen( - StreamListenSocket::Delegate* delegate) const { - return UnixDomainSocket::CreateAndListen( - path_, delegate, auth_callback_).PassAs<StreamListenSocket>(); -} - -#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED) - -UnixDomainSocketWithAbstractNamespaceFactory:: -UnixDomainSocketWithAbstractNamespaceFactory( - const std::string& path, - const std::string& fallback_path, - const UnixDomainSocket::AuthCallback& auth_callback) - : UnixDomainSocketFactory(path, auth_callback), - fallback_path_(fallback_path) {} - -UnixDomainSocketWithAbstractNamespaceFactory:: -~UnixDomainSocketWithAbstractNamespaceFactory() {} - -scoped_ptr<StreamListenSocket> -UnixDomainSocketWithAbstractNamespaceFactory::CreateAndListen( - StreamListenSocket::Delegate* delegate) const { - return UnixDomainSocket::CreateAndListenWithAbstractNamespace( - path_, fallback_path_, delegate, auth_callback_) - .PassAs<StreamListenSocket>(); -} - -#endif - -} // namespace net diff --git a/chromium/net/socket/unix_domain_socket_posix.h b/chromium/net/socket/unix_domain_socket_posix.h deleted file mode 100644 index 98d0c11a648..00000000000 --- a/chromium/net/socket/unix_domain_socket_posix.h +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef NET_SOCKET_UNIX_DOMAIN_SOCKET_POSIX_H_ -#define NET_SOCKET_UNIX_DOMAIN_SOCKET_POSIX_H_ - -#include <string> - -#include "base/basictypes.h" -#include "base/callback_forward.h" -#include "base/compiler_specific.h" -#include "build/build_config.h" -#include "net/base/net_export.h" -#include "net/socket/stream_listen_socket.h" - -#if defined(OS_ANDROID) || defined(OS_LINUX) -// Feature only supported on Linux currently. This lets the Unix Domain Socket -// not be backed by the file system. -#define SOCKET_ABSTRACT_NAMESPACE_SUPPORTED -#endif - -namespace net { - -// Unix Domain Socket Implementation. Supports abstract namespaces on Linux. -class NET_EXPORT UnixDomainSocket : public StreamListenSocket { - public: - virtual ~UnixDomainSocket(); - - // Callback that returns whether the already connected client, identified by - // its process |user_id| and |group_id|, is allowed to keep the connection - // open. Note that the socket is closed immediately in case the callback - // returns false. - typedef base::Callback<bool (uid_t user_id, gid_t group_id)> AuthCallback; - - // Returns an authentication callback that always grants access for - // convenience in case you don't want to use authentication. - static AuthCallback NoAuthentication(); - - // Note that the returned UnixDomainSocket instance does not take ownership of - // |del|. - static scoped_ptr<UnixDomainSocket> CreateAndListen( - const std::string& path, - StreamListenSocket::Delegate* del, - const AuthCallback& auth_callback); - -#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED) - // Same as above except that the created socket uses the abstract namespace - // which is a Linux-only feature. If |fallback_path| is not empty, - // make the second attempt with the provided fallback name. - static scoped_ptr<UnixDomainSocket> CreateAndListenWithAbstractNamespace( - const std::string& path, - const std::string& fallback_path, - StreamListenSocket::Delegate* del, - const AuthCallback& auth_callback); -#endif - - private: - UnixDomainSocket(SocketDescriptor s, - StreamListenSocket::Delegate* del, - const AuthCallback& auth_callback); - - static scoped_ptr<UnixDomainSocket> CreateAndListenInternal( - const std::string& path, - const std::string& fallback_path, - StreamListenSocket::Delegate* del, - const AuthCallback& auth_callback, - bool use_abstract_namespace); - - static SocketDescriptor CreateAndBind(const std::string& path, - bool use_abstract_namespace); - - // StreamListenSocket: - virtual void Accept() OVERRIDE; - - AuthCallback auth_callback_; - - DISALLOW_COPY_AND_ASSIGN(UnixDomainSocket); -}; - -// Factory that can be used to instantiate UnixDomainSocket. -class NET_EXPORT UnixDomainSocketFactory : public StreamListenSocketFactory { - public: - // Note that this class does not take ownership of the provided delegate. - UnixDomainSocketFactory(const std::string& path, - const UnixDomainSocket::AuthCallback& auth_callback); - virtual ~UnixDomainSocketFactory(); - - // StreamListenSocketFactory: - virtual scoped_ptr<StreamListenSocket> CreateAndListen( - StreamListenSocket::Delegate* delegate) const OVERRIDE; - - protected: - const std::string path_; - const UnixDomainSocket::AuthCallback auth_callback_; - - private: - DISALLOW_COPY_AND_ASSIGN(UnixDomainSocketFactory); -}; - -#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED) -// Use this factory to instantiate UnixDomainSocket using the abstract -// namespace feature (only supported on Linux). -class NET_EXPORT UnixDomainSocketWithAbstractNamespaceFactory - : public UnixDomainSocketFactory { - public: - UnixDomainSocketWithAbstractNamespaceFactory( - const std::string& path, - const std::string& fallback_path, - const UnixDomainSocket::AuthCallback& auth_callback); - virtual ~UnixDomainSocketWithAbstractNamespaceFactory(); - - // UnixDomainSocketFactory: - virtual scoped_ptr<StreamListenSocket> CreateAndListen( - StreamListenSocket::Delegate* delegate) const OVERRIDE; - - private: - std::string fallback_path_; - - DISALLOW_COPY_AND_ASSIGN(UnixDomainSocketWithAbstractNamespaceFactory); -}; -#endif - -} // namespace net - -#endif // NET_SOCKET_UNIX_DOMAIN_SOCKET_POSIX_H_ diff --git a/chromium/net/socket/websocket_endpoint_lock_manager.cc b/chromium/net/socket/websocket_endpoint_lock_manager.cc new file mode 100644 index 00000000000..e578bb2435b --- /dev/null +++ b/chromium/net/socket/websocket_endpoint_lock_manager.cc @@ -0,0 +1,132 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/websocket_endpoint_lock_manager.h" + +#include <utility> + +#include "base/logging.h" +#include "net/base/net_errors.h" +#include "net/base/net_log.h" + +namespace net { + +WebSocketEndpointLockManager::Waiter::~Waiter() { + if (next()) { + DCHECK(previous()); + RemoveFromList(); + } +} + +WebSocketEndpointLockManager* WebSocketEndpointLockManager::GetInstance() { + return Singleton<WebSocketEndpointLockManager>::get(); +} + +int WebSocketEndpointLockManager::LockEndpoint(const IPEndPoint& endpoint, + Waiter* waiter) { + LockInfoMap::value_type insert_value(endpoint, LockInfo()); + std::pair<LockInfoMap::iterator, bool> rv = + lock_info_map_.insert(insert_value); + LockInfo& lock_info_in_map = rv.first->second; + if (rv.second) { + DVLOG(3) << "Locking endpoint " << endpoint.ToString(); + lock_info_in_map.queue.reset(new LockInfo::WaiterQueue); + return OK; + } + DVLOG(3) << "Waiting for endpoint " << endpoint.ToString(); + lock_info_in_map.queue->Append(waiter); + return ERR_IO_PENDING; +} + +void WebSocketEndpointLockManager::RememberSocket(StreamSocket* socket, + const IPEndPoint& endpoint) { + LockInfoMap::iterator lock_info_it = lock_info_map_.find(endpoint); + CHECK(lock_info_it != lock_info_map_.end()); + bool inserted = + socket_lock_info_map_.insert(SocketLockInfoMap::value_type( + socket, lock_info_it)).second; + DCHECK(inserted); + DCHECK(!lock_info_it->second.socket); + lock_info_it->second.socket = socket; + DVLOG(3) << "Remembered (StreamSocket*)" << socket << " for " + << endpoint.ToString() << " (" << socket_lock_info_map_.size() + << " socket(s) remembered)"; +} + +void WebSocketEndpointLockManager::UnlockSocket(StreamSocket* socket) { + SocketLockInfoMap::iterator socket_it = socket_lock_info_map_.find(socket); + if (socket_it == socket_lock_info_map_.end()) + return; + + LockInfoMap::iterator lock_info_it = socket_it->second; + + DVLOG(3) << "Unlocking (StreamSocket*)" << socket << " for " + << lock_info_it->first.ToString() << " (" + << socket_lock_info_map_.size() << " socket(s) left)"; + socket_lock_info_map_.erase(socket_it); + DCHECK(socket == lock_info_it->second.socket); + lock_info_it->second.socket = NULL; + UnlockEndpointByIterator(lock_info_it); +} + +void WebSocketEndpointLockManager::UnlockEndpoint(const IPEndPoint& endpoint) { + LockInfoMap::iterator lock_info_it = lock_info_map_.find(endpoint); + if (lock_info_it == lock_info_map_.end()) + return; + + UnlockEndpointByIterator(lock_info_it); +} + +bool WebSocketEndpointLockManager::IsEmpty() const { + return lock_info_map_.empty() && socket_lock_info_map_.empty(); +} + +WebSocketEndpointLockManager::LockInfo::LockInfo() : socket(NULL) {} +WebSocketEndpointLockManager::LockInfo::~LockInfo() { + DCHECK(!socket); +} + +WebSocketEndpointLockManager::LockInfo::LockInfo(const LockInfo& rhs) + : socket(rhs.socket) { + DCHECK(!rhs.queue); +} + +WebSocketEndpointLockManager::WebSocketEndpointLockManager() {} + +WebSocketEndpointLockManager::~WebSocketEndpointLockManager() { + DCHECK(lock_info_map_.empty()); + DCHECK(socket_lock_info_map_.empty()); +} + +void WebSocketEndpointLockManager::UnlockEndpointByIterator( + LockInfoMap::iterator lock_info_it) { + if (lock_info_it->second.socket) + EraseSocket(lock_info_it); + LockInfo::WaiterQueue* queue = lock_info_it->second.queue.get(); + DCHECK(queue); + if (queue->empty()) { + DVLOG(3) << "Unlocking endpoint " << lock_info_it->first.ToString(); + lock_info_map_.erase(lock_info_it); + return; + } + + DVLOG(3) << "Unlocking endpoint " << lock_info_it->first.ToString() + << " and activating next waiter"; + Waiter* next_job = queue->head()->value(); + next_job->RemoveFromList(); + // This must be last to minimise the excitement caused by re-entrancy. + next_job->GotEndpointLock(); +} + +void WebSocketEndpointLockManager::EraseSocket( + LockInfoMap::iterator lock_info_it) { + DVLOG(3) << "Removing (StreamSocket*)" << lock_info_it->second.socket + << " for " << lock_info_it->first.ToString() << " (" + << socket_lock_info_map_.size() << " socket(s) left)"; + size_t erased = socket_lock_info_map_.erase(lock_info_it->second.socket); + DCHECK_EQ(1U, erased); + lock_info_it->second.socket = NULL; +} + +} // namespace net diff --git a/chromium/net/socket/websocket_endpoint_lock_manager.h b/chromium/net/socket/websocket_endpoint_lock_manager.h new file mode 100644 index 00000000000..d5cad508d6a --- /dev/null +++ b/chromium/net/socket/websocket_endpoint_lock_manager.h @@ -0,0 +1,121 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_WEBSOCKET_ENDPOINT_LOCK_MANAGER_H_ +#define NET_SOCKET_WEBSOCKET_ENDPOINT_LOCK_MANAGER_H_ + +#include <map> + +#include "base/containers/linked_list.h" +#include "base/logging.h" +#include "base/macros.h" +#include "base/memory/singleton.h" +#include "net/base/ip_endpoint.h" +#include "net/base/net_export.h" +#include "net/socket/websocket_transport_client_socket_pool.h" + +namespace net { + +class StreamSocket; + +class NET_EXPORT_PRIVATE WebSocketEndpointLockManager { + public: + class NET_EXPORT_PRIVATE Waiter : public base::LinkNode<Waiter> { + public: + // If the node is in a list, removes it. + virtual ~Waiter(); + + virtual void GotEndpointLock() = 0; + }; + + static WebSocketEndpointLockManager* GetInstance(); + + // Returns OK if lock was acquired immediately, ERR_IO_PENDING if not. If the + // lock was not acquired, then |waiter->GotEndpointLock()| will be called when + // it is. A Waiter automatically removes itself from the list of waiters when + // its destructor is called. + int LockEndpoint(const IPEndPoint& endpoint, Waiter* waiter); + + // Records the IPEndPoint associated with a particular socket. This is + // necessary because TCPClientSocket refuses to return the PeerAddress after + // the connection is disconnected. The association will be forgotten when + // UnlockSocket() or UnlockEndpoint() is called. The |socket| pointer must not + // be deleted between the call to RememberSocket() and the call to + // UnlockSocket(). + void RememberSocket(StreamSocket* socket, const IPEndPoint& endpoint); + + // Releases the lock on the endpoint that was associated with |socket| by + // RememberSocket(). If appropriate, triggers the next socket connection. + // Should be called exactly once for each |socket| that was passed to + // RememberSocket(). Does nothing if UnlockEndpoint() has been called since + // the call to RememberSocket(). + void UnlockSocket(StreamSocket* socket); + + // Releases the lock on |endpoint|. Does nothing if |endpoint| is not locked. + // Removes any socket association that was recorded with RememberSocket(). If + // appropriate, calls |waiter->GotEndpointLock()|. + void UnlockEndpoint(const IPEndPoint& endpoint); + + // Checks that |lock_info_map_| and |socket_lock_info_map_| are empty. For + // tests. + bool IsEmpty() const; + + private: + struct LockInfo { + typedef base::LinkedList<Waiter> WaiterQueue; + + LockInfo(); + ~LockInfo(); + + // This object must be copyable to be placed in the map, but it cannot be + // copied after |queue| has been assigned to. + LockInfo(const LockInfo& rhs); + + // Not used. + LockInfo& operator=(const LockInfo& rhs); + + // Must be NULL to copy this object into the map. Must be set to non-NULL + // after the object is inserted into the map then point to the same list + // until this object is deleted. + scoped_ptr<WaiterQueue> queue; + + // This pointer is only used to identify the last instance of StreamSocket + // that was passed to RememberSocket() for this endpoint. It should only be + // compared with other pointers. It is never dereferenced and not owned. It + // is non-NULL if RememberSocket() has been called for this endpoint since + // the last call to UnlockSocket() or UnlockEndpoint(). + StreamSocket* socket; + }; + + // SocketLockInfoMap requires std::map iterator semantics for LockInfoMap + // (ie. that the iterator will remain valid as long as the entry is not + // deleted). + typedef std::map<IPEndPoint, LockInfo> LockInfoMap; + typedef std::map<StreamSocket*, LockInfoMap::iterator> SocketLockInfoMap; + + WebSocketEndpointLockManager(); + ~WebSocketEndpointLockManager(); + + void UnlockEndpointByIterator(LockInfoMap::iterator lock_info_it); + void EraseSocket(LockInfoMap::iterator lock_info_it); + + // If an entry is present in the map for a particular endpoint, then that + // endpoint is locked. If LockInfo.queue is non-empty, then one or more + // Waiters are waiting for the lock. + LockInfoMap lock_info_map_; + + // Store sockets remembered by RememberSocket() and not yet unlocked by + // UnlockSocket() or UnlockEndpoint(). Every entry in this map always + // references a live entry in lock_info_map_, and the LockInfo::socket member + // is non-NULL if and only if there is an entry in this map for the socket. + SocketLockInfoMap socket_lock_info_map_; + + friend struct DefaultSingletonTraits<WebSocketEndpointLockManager>; + + DISALLOW_COPY_AND_ASSIGN(WebSocketEndpointLockManager); +}; + +} // namespace net + +#endif // NET_SOCKET_WEBSOCKET_ENDPOINT_LOCK_MANAGER_H_ diff --git a/chromium/net/socket/websocket_endpoint_lock_manager_unittest.cc b/chromium/net/socket/websocket_endpoint_lock_manager_unittest.cc new file mode 100644 index 00000000000..1626aa90201 --- /dev/null +++ b/chromium/net/socket/websocket_endpoint_lock_manager_unittest.cc @@ -0,0 +1,224 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/websocket_endpoint_lock_manager.h" + +#include "net/base/net_errors.h" +#include "net/socket/next_proto.h" +#include "net/socket/socket_test_util.h" +#include "net/socket/stream_socket.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +namespace { + +// A StreamSocket implementation with no functionality at all. +// TODO(ricea): If you need to use this in another file, please move it to +// socket_test_util.h. +class FakeStreamSocket : public StreamSocket { + public: + FakeStreamSocket() {} + + // StreamSocket implementation + int Connect(const CompletionCallback& callback) override { + return ERR_FAILED; + } + + void Disconnect() override { return; } + + bool IsConnected() const override { return false; } + + bool IsConnectedAndIdle() const override { return false; } + + int GetPeerAddress(IPEndPoint* address) const override { return ERR_FAILED; } + + int GetLocalAddress(IPEndPoint* address) const override { return ERR_FAILED; } + + const BoundNetLog& NetLog() const override { return bound_net_log_; } + + void SetSubresourceSpeculation() override { return; } + void SetOmniboxSpeculation() override { return; } + + bool WasEverUsed() const override { return false; } + + bool UsingTCPFastOpen() const override { return false; } + + bool WasNpnNegotiated() const override { return false; } + + NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; } + + bool GetSSLInfo(SSLInfo* ssl_info) override { return false; } + + // Socket implementation + int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override { + return ERR_FAILED; + } + + int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override { + return ERR_FAILED; + } + + int SetReceiveBufferSize(int32 size) override { return ERR_FAILED; } + + int SetSendBufferSize(int32 size) override { return ERR_FAILED; } + + private: + BoundNetLog bound_net_log_; + + DISALLOW_COPY_AND_ASSIGN(FakeStreamSocket); +}; + +class FakeWaiter : public WebSocketEndpointLockManager::Waiter { + public: + FakeWaiter() : called_(false) {} + + void GotEndpointLock() override { + CHECK(!called_); + called_ = true; + } + + bool called() const { return called_; } + + private: + bool called_; +}; + +class WebSocketEndpointLockManagerTest : public ::testing::Test { + protected: + WebSocketEndpointLockManagerTest() + : instance_(WebSocketEndpointLockManager::GetInstance()) {} + ~WebSocketEndpointLockManagerTest() override { + // If this check fails then subsequent tests may fail. + CHECK(instance_->IsEmpty()); + } + + WebSocketEndpointLockManager* instance() const { return instance_; } + + IPEndPoint DummyEndpoint() { + IPAddressNumber ip_address_number; + CHECK(ParseIPLiteralToNumber("127.0.0.1", &ip_address_number)); + return IPEndPoint(ip_address_number, 80); + } + + void UnlockDummyEndpoint(int times) { + for (int i = 0; i < times; ++i) { + instance()->UnlockEndpoint(DummyEndpoint()); + } + } + + WebSocketEndpointLockManager* const instance_; +}; + +TEST_F(WebSocketEndpointLockManagerTest, GetInstanceWorks) { + // All the work is done by the test framework. +} + +TEST_F(WebSocketEndpointLockManagerTest, LockEndpointReturnsOkOnce) { + FakeWaiter waiters[2]; + EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &waiters[0])); + EXPECT_EQ(ERR_IO_PENDING, + instance()->LockEndpoint(DummyEndpoint(), &waiters[1])); + + UnlockDummyEndpoint(2); +} + +TEST_F(WebSocketEndpointLockManagerTest, GotEndpointLockNotCalledOnOk) { + FakeWaiter waiter; + EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &waiter)); + EXPECT_FALSE(waiter.called()); + + UnlockDummyEndpoint(1); +} + +TEST_F(WebSocketEndpointLockManagerTest, GotEndpointLockNotCalledImmediately) { + FakeWaiter waiters[2]; + EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &waiters[0])); + EXPECT_EQ(ERR_IO_PENDING, + instance()->LockEndpoint(DummyEndpoint(), &waiters[1])); + EXPECT_FALSE(waiters[1].called()); + + UnlockDummyEndpoint(2); +} + +TEST_F(WebSocketEndpointLockManagerTest, GotEndpointLockCalledWhenUnlocked) { + FakeWaiter waiters[2]; + EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &waiters[0])); + EXPECT_EQ(ERR_IO_PENDING, + instance()->LockEndpoint(DummyEndpoint(), &waiters[1])); + instance()->UnlockEndpoint(DummyEndpoint()); + EXPECT_TRUE(waiters[1].called()); + + UnlockDummyEndpoint(1); +} + +TEST_F(WebSocketEndpointLockManagerTest, + EndpointUnlockedIfWaiterAlreadyDeleted) { + FakeWaiter first_lock_holder; + EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &first_lock_holder)); + + { + FakeWaiter short_lived_waiter; + EXPECT_EQ(ERR_IO_PENDING, + instance()->LockEndpoint(DummyEndpoint(), &short_lived_waiter)); + } + + instance()->UnlockEndpoint(DummyEndpoint()); + + FakeWaiter second_lock_holder; + EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &second_lock_holder)); + + UnlockDummyEndpoint(1); +} + +TEST_F(WebSocketEndpointLockManagerTest, RememberSocketWorks) { + FakeWaiter waiters[2]; + FakeStreamSocket dummy_socket; + EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &waiters[0])); + EXPECT_EQ(ERR_IO_PENDING, + instance()->LockEndpoint(DummyEndpoint(), &waiters[1])); + + instance()->RememberSocket(&dummy_socket, DummyEndpoint()); + instance()->UnlockSocket(&dummy_socket); + EXPECT_TRUE(waiters[1].called()); + + UnlockDummyEndpoint(1); +} + +// UnlockEndpoint() should cause any sockets remembered for this endpoint +// to be forgotten. +TEST_F(WebSocketEndpointLockManagerTest, SocketAssociationForgottenOnUnlock) { + FakeWaiter waiter; + FakeStreamSocket dummy_socket; + + EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &waiter)); + instance()->RememberSocket(&dummy_socket, DummyEndpoint()); + instance()->UnlockEndpoint(DummyEndpoint()); + EXPECT_TRUE(instance()->IsEmpty()); +} + +// When ownership of the endpoint is passed to a new waiter, the new waiter can +// call RememberSocket() again. +TEST_F(WebSocketEndpointLockManagerTest, NextWaiterCanCallRememberSocketAgain) { + FakeWaiter waiters[2]; + FakeStreamSocket dummy_sockets[2]; + EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &waiters[0])); + EXPECT_EQ(ERR_IO_PENDING, + instance()->LockEndpoint(DummyEndpoint(), &waiters[1])); + + instance()->RememberSocket(&dummy_sockets[0], DummyEndpoint()); + instance()->UnlockEndpoint(DummyEndpoint()); + EXPECT_TRUE(waiters[1].called()); + instance()->RememberSocket(&dummy_sockets[1], DummyEndpoint()); + + UnlockDummyEndpoint(1); +} + +} // namespace + +} // namespace net diff --git a/chromium/net/socket/websocket_transport_client_socket_pool.cc b/chromium/net/socket/websocket_transport_client_socket_pool.cc new file mode 100644 index 00000000000..15ec028cb18 --- /dev/null +++ b/chromium/net/socket/websocket_transport_client_socket_pool.cc @@ -0,0 +1,651 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/websocket_transport_client_socket_pool.h" + +#include <algorithm> + +#include "base/compiler_specific.h" +#include "base/logging.h" +#include "base/numerics/safe_conversions.h" +#include "base/strings/string_util.h" +#include "base/time/time.h" +#include "base/values.h" +#include "net/base/net_errors.h" +#include "net/base/net_log.h" +#include "net/socket/client_socket_handle.h" +#include "net/socket/client_socket_pool_base.h" +#include "net/socket/websocket_endpoint_lock_manager.h" +#include "net/socket/websocket_transport_connect_sub_job.h" + +namespace net { + +namespace { + +using base::TimeDelta; + +// TODO(ricea): For now, we implement a global timeout for compatability with +// TransportConnectJob. Since WebSocketTransportConnectJob controls the address +// selection process more tightly, it could do something smarter here. +const int kTransportConnectJobTimeoutInSeconds = 240; // 4 minutes. + +} // namespace + +WebSocketTransportConnectJob::WebSocketTransportConnectJob( + const std::string& group_name, + RequestPriority priority, + const scoped_refptr<TransportSocketParams>& params, + TimeDelta timeout_duration, + const CompletionCallback& callback, + ClientSocketFactory* client_socket_factory, + HostResolver* host_resolver, + ClientSocketHandle* handle, + Delegate* delegate, + NetLog* pool_net_log, + const BoundNetLog& request_net_log) + : ConnectJob(group_name, + timeout_duration, + priority, + delegate, + BoundNetLog::Make(pool_net_log, NetLog::SOURCE_CONNECT_JOB)), + helper_(params, client_socket_factory, host_resolver, &connect_timing_), + race_result_(TransportConnectJobHelper::CONNECTION_LATENCY_UNKNOWN), + handle_(handle), + callback_(callback), + request_net_log_(request_net_log), + had_ipv4_(false), + had_ipv6_(false) { + helper_.SetOnIOComplete(this); +} + +WebSocketTransportConnectJob::~WebSocketTransportConnectJob() {} + +LoadState WebSocketTransportConnectJob::GetLoadState() const { + LoadState load_state = LOAD_STATE_RESOLVING_HOST; + if (ipv6_job_) + load_state = ipv6_job_->GetLoadState(); + // This method should return LOAD_STATE_CONNECTING in preference to + // LOAD_STATE_WAITING_FOR_AVAILABLE_SOCKET when possible because "waiting for + // available socket" implies that nothing is happening. + if (ipv4_job_ && load_state != LOAD_STATE_CONNECTING) + load_state = ipv4_job_->GetLoadState(); + return load_state; +} + +int WebSocketTransportConnectJob::DoResolveHost() { + return helper_.DoResolveHost(priority(), net_log()); +} + +int WebSocketTransportConnectJob::DoResolveHostComplete(int result) { + return helper_.DoResolveHostComplete(result, net_log()); +} + +int WebSocketTransportConnectJob::DoTransportConnect() { + AddressList ipv4_addresses; + AddressList ipv6_addresses; + int result = ERR_UNEXPECTED; + helper_.set_next_state( + TransportConnectJobHelper::STATE_TRANSPORT_CONNECT_COMPLETE); + + for (AddressList::const_iterator it = helper_.addresses().begin(); + it != helper_.addresses().end(); + ++it) { + switch (it->GetFamily()) { + case ADDRESS_FAMILY_IPV4: + ipv4_addresses.push_back(*it); + break; + + case ADDRESS_FAMILY_IPV6: + ipv6_addresses.push_back(*it); + break; + + default: + DVLOG(1) << "Unexpected ADDRESS_FAMILY: " << it->GetFamily(); + break; + } + } + + if (!ipv4_addresses.empty()) { + had_ipv4_ = true; + ipv4_job_.reset(new WebSocketTransportConnectSubJob( + ipv4_addresses, this, SUB_JOB_IPV4)); + } + + if (!ipv6_addresses.empty()) { + had_ipv6_ = true; + ipv6_job_.reset(new WebSocketTransportConnectSubJob( + ipv6_addresses, this, SUB_JOB_IPV6)); + result = ipv6_job_->Start(); + switch (result) { + case OK: + SetSocket(ipv6_job_->PassSocket()); + race_result_ = + had_ipv4_ + ? TransportConnectJobHelper::CONNECTION_LATENCY_IPV6_RACEABLE + : TransportConnectJobHelper::CONNECTION_LATENCY_IPV6_SOLO; + return result; + + case ERR_IO_PENDING: + if (ipv4_job_) { + // This use of base::Unretained is safe because |fallback_timer_| is + // owned by this object. + fallback_timer_.Start( + FROM_HERE, + TimeDelta::FromMilliseconds( + TransportConnectJobHelper::kIPv6FallbackTimerInMs), + base::Bind(&WebSocketTransportConnectJob::StartIPv4JobAsync, + base::Unretained(this))); + } + return result; + + default: + ipv6_job_.reset(); + } + } + + DCHECK(!ipv6_job_); + if (ipv4_job_) { + result = ipv4_job_->Start(); + if (result == OK) { + SetSocket(ipv4_job_->PassSocket()); + race_result_ = + had_ipv6_ + ? TransportConnectJobHelper::CONNECTION_LATENCY_IPV4_WINS_RACE + : TransportConnectJobHelper::CONNECTION_LATENCY_IPV4_NO_RACE; + } + } + + return result; +} + +int WebSocketTransportConnectJob::DoTransportConnectComplete(int result) { + if (result == OK) + helper_.HistogramDuration(race_result_); + return result; +} + +void WebSocketTransportConnectJob::OnSubJobComplete( + int result, + WebSocketTransportConnectSubJob* job) { + if (result == OK) { + switch (job->type()) { + case SUB_JOB_IPV4: + race_result_ = + had_ipv6_ + ? TransportConnectJobHelper::CONNECTION_LATENCY_IPV4_WINS_RACE + : TransportConnectJobHelper::CONNECTION_LATENCY_IPV4_NO_RACE; + break; + + case SUB_JOB_IPV6: + race_result_ = + had_ipv4_ + ? TransportConnectJobHelper::CONNECTION_LATENCY_IPV6_RACEABLE + : TransportConnectJobHelper::CONNECTION_LATENCY_IPV6_SOLO; + break; + } + SetSocket(job->PassSocket()); + + // Make sure all connections are cancelled even if this object fails to be + // deleted. + ipv4_job_.reset(); + ipv6_job_.reset(); + } else { + switch (job->type()) { + case SUB_JOB_IPV4: + ipv4_job_.reset(); + break; + + case SUB_JOB_IPV6: + ipv6_job_.reset(); + if (ipv4_job_ && !ipv4_job_->started()) { + fallback_timer_.Stop(); + result = ipv4_job_->Start(); + if (result != ERR_IO_PENDING) { + OnSubJobComplete(result, ipv4_job_.get()); + return; + } + } + break; + } + if (ipv4_job_ || ipv6_job_) + return; + } + helper_.OnIOComplete(this, result); +} + +void WebSocketTransportConnectJob::StartIPv4JobAsync() { + DCHECK(ipv4_job_); + int result = ipv4_job_->Start(); + if (result != ERR_IO_PENDING) + OnSubJobComplete(result, ipv4_job_.get()); +} + +int WebSocketTransportConnectJob::ConnectInternal() { + return helper_.DoConnectInternal(this); +} + +WebSocketTransportClientSocketPool::WebSocketTransportClientSocketPool( + int max_sockets, + int max_sockets_per_group, + ClientSocketPoolHistograms* histograms, + HostResolver* host_resolver, + ClientSocketFactory* client_socket_factory, + NetLog* net_log) + : TransportClientSocketPool(max_sockets, + max_sockets_per_group, + histograms, + host_resolver, + client_socket_factory, + net_log), + connect_job_delegate_(this), + histograms_(histograms), + pool_net_log_(net_log), + client_socket_factory_(client_socket_factory), + host_resolver_(host_resolver), + max_sockets_(max_sockets), + handed_out_socket_count_(0), + flushing_(false), + weak_factory_(this) {} + +WebSocketTransportClientSocketPool::~WebSocketTransportClientSocketPool() { + // Clean up any pending connect jobs. + FlushWithError(ERR_ABORTED); + DCHECK(pending_connects_.empty()); + DCHECK_EQ(0, handed_out_socket_count_); + DCHECK(stalled_request_queue_.empty()); + DCHECK(stalled_request_map_.empty()); +} + +// static +void WebSocketTransportClientSocketPool::UnlockEndpoint( + ClientSocketHandle* handle) { + DCHECK(handle->is_initialized()); + DCHECK(handle->socket()); + IPEndPoint address; + if (handle->socket()->GetPeerAddress(&address) == OK) + WebSocketEndpointLockManager::GetInstance()->UnlockEndpoint(address); +} + +int WebSocketTransportClientSocketPool::RequestSocket( + const std::string& group_name, + const void* params, + RequestPriority priority, + ClientSocketHandle* handle, + const CompletionCallback& callback, + const BoundNetLog& request_net_log) { + DCHECK(params); + const scoped_refptr<TransportSocketParams>& casted_params = + *static_cast<const scoped_refptr<TransportSocketParams>*>(params); + + NetLogTcpClientSocketPoolRequestedSocket(request_net_log, &casted_params); + + CHECK(!callback.is_null()); + CHECK(handle); + + request_net_log.BeginEvent(NetLog::TYPE_SOCKET_POOL); + + if (ReachedMaxSocketsLimit() && !casted_params->ignore_limits()) { + request_net_log.AddEvent(NetLog::TYPE_SOCKET_POOL_STALLED_MAX_SOCKETS); + // TODO(ricea): Use emplace_back when C++11 becomes allowed. + StalledRequest request( + casted_params, priority, handle, callback, request_net_log); + stalled_request_queue_.push_back(request); + StalledRequestQueue::iterator iterator = stalled_request_queue_.end(); + --iterator; + DCHECK_EQ(handle, iterator->handle); + // Because StalledRequestQueue is a std::list, its iterators are guaranteed + // to remain valid as long as the elements are not removed. As long as + // stalled_request_queue_ and stalled_request_map_ are updated in sync, it + // is safe to dereference an iterator in stalled_request_map_ to find the + // corresponding list element. + stalled_request_map_.insert( + StalledRequestMap::value_type(handle, iterator)); + return ERR_IO_PENDING; + } + + scoped_ptr<WebSocketTransportConnectJob> connect_job( + new WebSocketTransportConnectJob(group_name, + priority, + casted_params, + ConnectionTimeout(), + callback, + client_socket_factory_, + host_resolver_, + handle, + &connect_job_delegate_, + pool_net_log_, + request_net_log)); + + int rv = connect_job->Connect(); + // Regardless of the outcome of |connect_job|, it will always be bound to + // |handle|, since this pool uses early-binding. So the binding is logged + // here, without waiting for the result. + request_net_log.AddEvent( + NetLog::TYPE_SOCKET_POOL_BOUND_TO_CONNECT_JOB, + connect_job->net_log().source().ToEventParametersCallback()); + if (rv == OK) { + HandOutSocket(connect_job->PassSocket(), + connect_job->connect_timing(), + handle, + request_net_log); + request_net_log.EndEvent(NetLog::TYPE_SOCKET_POOL); + } else if (rv == ERR_IO_PENDING) { + // TODO(ricea): Implement backup job timer? + AddJob(handle, connect_job.Pass()); + } else { + scoped_ptr<StreamSocket> error_socket; + connect_job->GetAdditionalErrorState(handle); + error_socket = connect_job->PassSocket(); + if (error_socket) { + HandOutSocket(error_socket.Pass(), + connect_job->connect_timing(), + handle, + request_net_log); + } + } + + if (rv != ERR_IO_PENDING) { + request_net_log.EndEventWithNetErrorCode(NetLog::TYPE_SOCKET_POOL, rv); + } + + return rv; +} + +void WebSocketTransportClientSocketPool::RequestSockets( + const std::string& group_name, + const void* params, + int num_sockets, + const BoundNetLog& net_log) { + NOTIMPLEMENTED(); +} + +void WebSocketTransportClientSocketPool::CancelRequest( + const std::string& group_name, + ClientSocketHandle* handle) { + DCHECK(!handle->is_initialized()); + if (DeleteStalledRequest(handle)) + return; + scoped_ptr<StreamSocket> socket = handle->PassSocket(); + if (socket) + ReleaseSocket(handle->group_name(), socket.Pass(), handle->id()); + if (!DeleteJob(handle)) + pending_callbacks_.erase(handle); + if (!ReachedMaxSocketsLimit() && !stalled_request_queue_.empty()) + ActivateStalledRequest(); +} + +void WebSocketTransportClientSocketPool::ReleaseSocket( + const std::string& group_name, + scoped_ptr<StreamSocket> socket, + int id) { + WebSocketEndpointLockManager::GetInstance()->UnlockSocket(socket.get()); + CHECK_GT(handed_out_socket_count_, 0); + --handed_out_socket_count_; + if (!ReachedMaxSocketsLimit() && !stalled_request_queue_.empty()) + ActivateStalledRequest(); +} + +void WebSocketTransportClientSocketPool::FlushWithError(int error) { + // Sockets which are in LOAD_STATE_CONNECTING are in danger of unlocking + // sockets waiting for the endpoint lock. If they connected synchronously, + // then OnConnectJobComplete(). The |flushing_| flag tells this object to + // ignore spurious calls to OnConnectJobComplete(). It is safe to ignore those + // calls because this method will delete the jobs and call their callbacks + // anyway. + flushing_ = true; + for (PendingConnectsMap::iterator it = pending_connects_.begin(); + it != pending_connects_.end(); + ++it) { + InvokeUserCallbackLater( + it->second->handle(), it->second->callback(), error); + delete it->second, it->second = NULL; + } + pending_connects_.clear(); + for (StalledRequestQueue::iterator it = stalled_request_queue_.begin(); + it != stalled_request_queue_.end(); + ++it) { + InvokeUserCallbackLater(it->handle, it->callback, error); + } + stalled_request_map_.clear(); + stalled_request_queue_.clear(); + flushing_ = false; +} + +void WebSocketTransportClientSocketPool::CloseIdleSockets() { + // We have no idle sockets. +} + +int WebSocketTransportClientSocketPool::IdleSocketCount() const { + return 0; +} + +int WebSocketTransportClientSocketPool::IdleSocketCountInGroup( + const std::string& group_name) const { + return 0; +} + +LoadState WebSocketTransportClientSocketPool::GetLoadState( + const std::string& group_name, + const ClientSocketHandle* handle) const { + if (stalled_request_map_.find(handle) != stalled_request_map_.end()) + return LOAD_STATE_WAITING_FOR_AVAILABLE_SOCKET; + if (pending_callbacks_.count(handle)) + return LOAD_STATE_CONNECTING; + return LookupConnectJob(handle)->GetLoadState(); +} + +base::DictionaryValue* WebSocketTransportClientSocketPool::GetInfoAsValue( + const std::string& name, + const std::string& type, + bool include_nested_pools) const { + base::DictionaryValue* dict = new base::DictionaryValue(); + dict->SetString("name", name); + dict->SetString("type", type); + dict->SetInteger("handed_out_socket_count", handed_out_socket_count_); + dict->SetInteger("connecting_socket_count", pending_connects_.size()); + dict->SetInteger("idle_socket_count", 0); + dict->SetInteger("max_socket_count", max_sockets_); + dict->SetInteger("max_sockets_per_group", max_sockets_); + dict->SetInteger("pool_generation_number", 0); + return dict; +} + +TimeDelta WebSocketTransportClientSocketPool::ConnectionTimeout() const { + return TimeDelta::FromSeconds(kTransportConnectJobTimeoutInSeconds); +} + +ClientSocketPoolHistograms* WebSocketTransportClientSocketPool::histograms() + const { + return histograms_; +} + +bool WebSocketTransportClientSocketPool::IsStalled() const { + return !stalled_request_queue_.empty(); +} + +void WebSocketTransportClientSocketPool::OnConnectJobComplete( + int result, + WebSocketTransportConnectJob* job) { + DCHECK_NE(ERR_IO_PENDING, result); + + scoped_ptr<StreamSocket> socket = job->PassSocket(); + + // See comment in FlushWithError. + if (flushing_) { + WebSocketEndpointLockManager::GetInstance()->UnlockSocket(socket.get()); + return; + } + + BoundNetLog request_net_log = job->request_net_log(); + CompletionCallback callback = job->callback(); + LoadTimingInfo::ConnectTiming connect_timing = job->connect_timing(); + + ClientSocketHandle* const handle = job->handle(); + bool handed_out_socket = false; + + if (result == OK) { + DCHECK(socket.get()); + handed_out_socket = true; + HandOutSocket(socket.Pass(), connect_timing, handle, request_net_log); + request_net_log.EndEvent(NetLog::TYPE_SOCKET_POOL); + } else { + // If we got a socket, it must contain error information so pass that + // up so that the caller can retrieve it. + job->GetAdditionalErrorState(handle); + if (socket.get()) { + handed_out_socket = true; + HandOutSocket(socket.Pass(), connect_timing, handle, request_net_log); + } + request_net_log.EndEventWithNetErrorCode(NetLog::TYPE_SOCKET_POOL, result); + } + bool delete_succeeded = DeleteJob(handle); + DCHECK(delete_succeeded); + if (!handed_out_socket && !stalled_request_queue_.empty() && + !ReachedMaxSocketsLimit()) + ActivateStalledRequest(); + InvokeUserCallbackLater(handle, callback, result); +} + +void WebSocketTransportClientSocketPool::InvokeUserCallbackLater( + ClientSocketHandle* handle, + const CompletionCallback& callback, + int rv) { + DCHECK(!pending_callbacks_.count(handle)); + pending_callbacks_.insert(handle); + base::MessageLoop::current()->PostTask( + FROM_HERE, + base::Bind(&WebSocketTransportClientSocketPool::InvokeUserCallback, + weak_factory_.GetWeakPtr(), + handle, + callback, + rv)); +} + +void WebSocketTransportClientSocketPool::InvokeUserCallback( + ClientSocketHandle* handle, + const CompletionCallback& callback, + int rv) { + if (pending_callbacks_.erase(handle)) + callback.Run(rv); +} + +bool WebSocketTransportClientSocketPool::ReachedMaxSocketsLimit() const { + return handed_out_socket_count_ >= max_sockets_ || + base::checked_cast<int>(pending_connects_.size()) >= + max_sockets_ - handed_out_socket_count_; +} + +void WebSocketTransportClientSocketPool::HandOutSocket( + scoped_ptr<StreamSocket> socket, + const LoadTimingInfo::ConnectTiming& connect_timing, + ClientSocketHandle* handle, + const BoundNetLog& net_log) { + DCHECK(socket); + handle->SetSocket(socket.Pass()); + DCHECK_EQ(ClientSocketHandle::UNUSED, handle->reuse_type()); + DCHECK_EQ(0, handle->idle_time().InMicroseconds()); + handle->set_pool_id(0); + handle->set_connect_timing(connect_timing); + + net_log.AddEvent( + NetLog::TYPE_SOCKET_POOL_BOUND_TO_SOCKET, + handle->socket()->NetLog().source().ToEventParametersCallback()); + + ++handed_out_socket_count_; +} + +void WebSocketTransportClientSocketPool::AddJob( + ClientSocketHandle* handle, + scoped_ptr<WebSocketTransportConnectJob> connect_job) { + bool inserted = + pending_connects_.insert(PendingConnectsMap::value_type( + handle, connect_job.release())).second; + DCHECK(inserted); +} + +bool WebSocketTransportClientSocketPool::DeleteJob(ClientSocketHandle* handle) { + PendingConnectsMap::iterator it = pending_connects_.find(handle); + if (it == pending_connects_.end()) + return false; + // Deleting a ConnectJob which holds an endpoint lock can lead to a different + // ConnectJob proceeding to connect. If the connect proceeds synchronously + // (usually because of a failure) then it can trigger that job to be + // deleted. |it| remains valid because std::map guarantees that erase() does + // not invalid iterators to other entries. + delete it->second, it->second = NULL; + DCHECK(pending_connects_.find(handle) == it); + pending_connects_.erase(it); + return true; +} + +const WebSocketTransportConnectJob* +WebSocketTransportClientSocketPool::LookupConnectJob( + const ClientSocketHandle* handle) const { + PendingConnectsMap::const_iterator it = pending_connects_.find(handle); + CHECK(it != pending_connects_.end()); + return it->second; +} + +void WebSocketTransportClientSocketPool::ActivateStalledRequest() { + DCHECK(!stalled_request_queue_.empty()); + DCHECK(!ReachedMaxSocketsLimit()); + // Usually we will only be able to activate one stalled request at a time, + // however if all the connects fail synchronously for some reason, we may be + // able to clear the whole queue at once. + while (!stalled_request_queue_.empty() && !ReachedMaxSocketsLimit()) { + StalledRequest request(stalled_request_queue_.front()); + stalled_request_queue_.pop_front(); + stalled_request_map_.erase(request.handle); + int rv = RequestSocket("ignored", + &request.params, + request.priority, + request.handle, + request.callback, + request.net_log); + // ActivateStalledRequest() never returns synchronously, so it is never + // called re-entrantly. + if (rv != ERR_IO_PENDING) + InvokeUserCallbackLater(request.handle, request.callback, rv); + } +} + +bool WebSocketTransportClientSocketPool::DeleteStalledRequest( + ClientSocketHandle* handle) { + StalledRequestMap::iterator it = stalled_request_map_.find(handle); + if (it == stalled_request_map_.end()) + return false; + stalled_request_queue_.erase(it->second); + stalled_request_map_.erase(it); + return true; +} + +WebSocketTransportClientSocketPool::ConnectJobDelegate::ConnectJobDelegate( + WebSocketTransportClientSocketPool* owner) + : owner_(owner) {} + +WebSocketTransportClientSocketPool::ConnectJobDelegate::~ConnectJobDelegate() {} + +void +WebSocketTransportClientSocketPool::ConnectJobDelegate::OnConnectJobComplete( + int result, + ConnectJob* job) { + owner_->OnConnectJobComplete(result, + static_cast<WebSocketTransportConnectJob*>(job)); +} + +WebSocketTransportClientSocketPool::StalledRequest::StalledRequest( + const scoped_refptr<TransportSocketParams>& params, + RequestPriority priority, + ClientSocketHandle* handle, + const CompletionCallback& callback, + const BoundNetLog& net_log) + : params(params), + priority(priority), + handle(handle), + callback(callback), + net_log(net_log) {} + +WebSocketTransportClientSocketPool::StalledRequest::~StalledRequest() {} + +} // namespace net diff --git a/chromium/net/socket/websocket_transport_client_socket_pool.h b/chromium/net/socket/websocket_transport_client_socket_pool.h new file mode 100644 index 00000000000..f0a94be417f --- /dev/null +++ b/chromium/net/socket/websocket_transport_client_socket_pool.h @@ -0,0 +1,246 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_WEBSOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_ +#define NET_SOCKET_WEBSOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_ + +#include <list> +#include <map> +#include <set> +#include <string> + +#include "base/basictypes.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "base/memory/weak_ptr.h" +#include "base/time/time.h" +#include "base/timer/timer.h" +#include "net/base/net_export.h" +#include "net/base/net_log.h" +#include "net/socket/client_socket_pool.h" +#include "net/socket/client_socket_pool_base.h" +#include "net/socket/transport_client_socket_pool.h" + +namespace net { + +class ClientSocketFactory; +class ClientSocketPoolHistograms; +class HostResolver; +class NetLog; +class WebSocketEndpointLockManager; +class WebSocketTransportConnectSubJob; + +// WebSocketTransportConnectJob handles the host resolution necessary for socket +// creation and the TCP connect. WebSocketTransportConnectJob also has fallback +// logic for IPv6 connect() timeouts (which may happen due to networks / routers +// with broken IPv6 support). Those timeouts take 20s, so rather than make the +// user wait 20s for the timeout to fire, we use a fallback timer +// (kIPv6FallbackTimerInMs) and start a connect() to an IPv4 address if the +// timer fires. Then we race the IPv4 connect(s) against the IPv6 connect(s) and +// use the socket that completes successfully first or fails last. +class NET_EXPORT_PRIVATE WebSocketTransportConnectJob : public ConnectJob { + public: + WebSocketTransportConnectJob( + const std::string& group_name, + RequestPriority priority, + const scoped_refptr<TransportSocketParams>& params, + base::TimeDelta timeout_duration, + const CompletionCallback& callback, + ClientSocketFactory* client_socket_factory, + HostResolver* host_resolver, + ClientSocketHandle* handle, + Delegate* delegate, + NetLog* pool_net_log, + const BoundNetLog& request_net_log); + ~WebSocketTransportConnectJob() override; + + // Unlike normal socket pools, the WebSocketTransportClientPool uses + // early-binding of sockets. + ClientSocketHandle* handle() const { return handle_; } + + // Stash the callback from RequestSocket() here for convenience. + const CompletionCallback& callback() const { return callback_; } + + const BoundNetLog& request_net_log() const { return request_net_log_; } + + // ConnectJob methods. + LoadState GetLoadState() const override; + + private: + friend class WebSocketTransportConnectSubJob; + friend class TransportConnectJobHelper; + friend class WebSocketEndpointLockManager; + + // Although it is not strictly necessary, it makes the code simpler if each + // subjob knows what type it is. + enum SubJobType { SUB_JOB_IPV4, SUB_JOB_IPV6 }; + + int DoResolveHost(); + int DoResolveHostComplete(int result); + int DoTransportConnect(); + int DoTransportConnectComplete(int result); + + // Called back from a SubJob when it completes. + void OnSubJobComplete(int result, WebSocketTransportConnectSubJob* job); + + // Called from |fallback_timer_|. + void StartIPv4JobAsync(); + + // Begins the host resolution and the TCP connect. Returns OK on success + // and ERR_IO_PENDING if it cannot immediately service the request. + // Otherwise, it returns a net error code. + int ConnectInternal() override; + + TransportConnectJobHelper helper_; + + // The addresses are divided into IPv4 and IPv6, which are performed partially + // in parallel. If the list of IPv6 addresses is non-empty, then the IPv6 jobs + // go first, followed after |kIPv6FallbackTimerInMs| by the IPv4 + // addresses. First sub-job to establish a connection wins. + scoped_ptr<WebSocketTransportConnectSubJob> ipv4_job_; + scoped_ptr<WebSocketTransportConnectSubJob> ipv6_job_; + + base::OneShotTimer<WebSocketTransportConnectJob> fallback_timer_; + TransportConnectJobHelper::ConnectionLatencyHistogram race_result_; + ClientSocketHandle* const handle_; + CompletionCallback callback_; + BoundNetLog request_net_log_; + + bool had_ipv4_; + bool had_ipv6_; + + DISALLOW_COPY_AND_ASSIGN(WebSocketTransportConnectJob); +}; + +class NET_EXPORT_PRIVATE WebSocketTransportClientSocketPool + : public TransportClientSocketPool { + public: + WebSocketTransportClientSocketPool(int max_sockets, + int max_sockets_per_group, + ClientSocketPoolHistograms* histograms, + HostResolver* host_resolver, + ClientSocketFactory* client_socket_factory, + NetLog* net_log); + + ~WebSocketTransportClientSocketPool() override; + + // Allow another connection to be started to the IPEndPoint that this |handle| + // is connected to. Used when the WebSocket handshake completes successfully. + // This only works if the socket is connected, however the caller does not + // need to explicitly check for this. Instead, ensure that dead sockets are + // returned to ReleaseSocket() in a timely fashion. + static void UnlockEndpoint(ClientSocketHandle* handle); + + // ClientSocketPool implementation. + int RequestSocket(const std::string& group_name, + const void* resolve_info, + RequestPriority priority, + ClientSocketHandle* handle, + const CompletionCallback& callback, + const BoundNetLog& net_log) override; + void RequestSockets(const std::string& group_name, + const void* params, + int num_sockets, + const BoundNetLog& net_log) override; + void CancelRequest(const std::string& group_name, + ClientSocketHandle* handle) override; + void ReleaseSocket(const std::string& group_name, + scoped_ptr<StreamSocket> socket, + int id) override; + void FlushWithError(int error) override; + void CloseIdleSockets() override; + int IdleSocketCount() const override; + int IdleSocketCountInGroup(const std::string& group_name) const override; + LoadState GetLoadState(const std::string& group_name, + const ClientSocketHandle* handle) const override; + base::DictionaryValue* GetInfoAsValue( + const std::string& name, + const std::string& type, + bool include_nested_pools) const override; + base::TimeDelta ConnectionTimeout() const override; + ClientSocketPoolHistograms* histograms() const override; + + // HigherLayeredPool implementation. + bool IsStalled() const override; + + private: + class ConnectJobDelegate : public ConnectJob::Delegate { + public: + explicit ConnectJobDelegate(WebSocketTransportClientSocketPool* owner); + ~ConnectJobDelegate() override; + + void OnConnectJobComplete(int result, ConnectJob* job) override; + + private: + WebSocketTransportClientSocketPool* owner_; + + DISALLOW_COPY_AND_ASSIGN(ConnectJobDelegate); + }; + + // Store the arguments from a call to RequestSocket() that has stalled so we + // can replay it when there are available socket slots. + struct StalledRequest { + StalledRequest(const scoped_refptr<TransportSocketParams>& params, + RequestPriority priority, + ClientSocketHandle* handle, + const CompletionCallback& callback, + const BoundNetLog& net_log); + ~StalledRequest(); + const scoped_refptr<TransportSocketParams> params; + const RequestPriority priority; + ClientSocketHandle* const handle; + const CompletionCallback callback; + const BoundNetLog net_log; + }; + friend class ConnectJobDelegate; + typedef std::map<const ClientSocketHandle*, WebSocketTransportConnectJob*> + PendingConnectsMap; + // This is a list so that we can remove requests from the middle, and also + // so that iterators are not invalidated unless the corresponding request is + // removed. + typedef std::list<StalledRequest> StalledRequestQueue; + typedef std::map<const ClientSocketHandle*, StalledRequestQueue::iterator> + StalledRequestMap; + + void OnConnectJobComplete(int result, WebSocketTransportConnectJob* job); + void InvokeUserCallbackLater(ClientSocketHandle* handle, + const CompletionCallback& callback, + int rv); + void InvokeUserCallback(ClientSocketHandle* handle, + const CompletionCallback& callback, + int rv); + bool ReachedMaxSocketsLimit() const; + void HandOutSocket(scoped_ptr<StreamSocket> socket, + const LoadTimingInfo::ConnectTiming& connect_timing, + ClientSocketHandle* handle, + const BoundNetLog& net_log); + void AddJob(ClientSocketHandle* handle, + scoped_ptr<WebSocketTransportConnectJob> connect_job); + bool DeleteJob(ClientSocketHandle* handle); + const WebSocketTransportConnectJob* LookupConnectJob( + const ClientSocketHandle* handle) const; + void ActivateStalledRequest(); + bool DeleteStalledRequest(ClientSocketHandle* handle); + + ConnectJobDelegate connect_job_delegate_; + std::set<const ClientSocketHandle*> pending_callbacks_; + PendingConnectsMap pending_connects_; + StalledRequestQueue stalled_request_queue_; + StalledRequestMap stalled_request_map_; + ClientSocketPoolHistograms* const histograms_; + NetLog* const pool_net_log_; + ClientSocketFactory* const client_socket_factory_; + HostResolver* const host_resolver_; + const int max_sockets_; + int handed_out_socket_count_; + bool flushing_; + + base::WeakPtrFactory<WebSocketTransportClientSocketPool> weak_factory_; + + DISALLOW_COPY_AND_ASSIGN(WebSocketTransportClientSocketPool); +}; + +} // namespace net + +#endif // NET_SOCKET_WEBSOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_ diff --git a/chromium/net/socket/websocket_transport_client_socket_pool_unittest.cc b/chromium/net/socket/websocket_transport_client_socket_pool_unittest.cc new file mode 100644 index 00000000000..2189181b9fc --- /dev/null +++ b/chromium/net/socket/websocket_transport_client_socket_pool_unittest.cc @@ -0,0 +1,1143 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/websocket_transport_client_socket_pool.h" + +#include <queue> +#include <vector> + +#include "base/bind.h" +#include "base/bind_helpers.h" +#include "base/callback.h" +#include "base/macros.h" +#include "base/message_loop/message_loop.h" +#include "base/run_loop.h" +#include "base/strings/stringprintf.h" +#include "base/time/time.h" +#include "net/base/capturing_net_log.h" +#include "net/base/ip_endpoint.h" +#include "net/base/load_timing_info.h" +#include "net/base/load_timing_info_test_util.h" +#include "net/base/net_errors.h" +#include "net/base/net_util.h" +#include "net/base/test_completion_callback.h" +#include "net/dns/mock_host_resolver.h" +#include "net/socket/client_socket_handle.h" +#include "net/socket/client_socket_pool_histograms.h" +#include "net/socket/socket_test_util.h" +#include "net/socket/stream_socket.h" +#include "net/socket/transport_client_socket_pool_test_util.h" +#include "net/socket/websocket_endpoint_lock_manager.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +namespace { + +const int kMaxSockets = 32; +const int kMaxSocketsPerGroup = 6; +const RequestPriority kDefaultPriority = LOW; + +// RunLoop doesn't support this natively but it is easy to emulate. +void RunLoopForTimePeriod(base::TimeDelta period) { + base::RunLoop run_loop; + base::Closure quit_closure(run_loop.QuitClosure()); + base::MessageLoop::current()->PostDelayedTask( + FROM_HERE, quit_closure, period); + run_loop.Run(); +} + +class WebSocketTransportClientSocketPoolTest : public testing::Test { + protected: + WebSocketTransportClientSocketPoolTest() + : params_(new TransportSocketParams( + HostPortPair("www.google.com", 80), + false, + false, + OnHostResolutionCallback(), + TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)), + histograms_(new ClientSocketPoolHistograms("TCPUnitTest")), + host_resolver_(new MockHostResolver), + client_socket_factory_(&net_log_), + pool_(kMaxSockets, + kMaxSocketsPerGroup, + histograms_.get(), + host_resolver_.get(), + &client_socket_factory_, + NULL) {} + + ~WebSocketTransportClientSocketPoolTest() override { + ReleaseAllConnections(ClientSocketPoolTest::NO_KEEP_ALIVE); + EXPECT_TRUE(WebSocketEndpointLockManager::GetInstance()->IsEmpty()); + } + + int StartRequest(const std::string& group_name, RequestPriority priority) { + scoped_refptr<TransportSocketParams> params( + new TransportSocketParams( + HostPortPair("www.google.com", 80), + false, + false, + OnHostResolutionCallback(), + TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)); + return test_base_.StartRequestUsingPool( + &pool_, group_name, priority, params); + } + + int GetOrderOfRequest(size_t index) { + return test_base_.GetOrderOfRequest(index); + } + + bool ReleaseOneConnection(ClientSocketPoolTest::KeepAlive keep_alive) { + return test_base_.ReleaseOneConnection(keep_alive); + } + + void ReleaseAllConnections(ClientSocketPoolTest::KeepAlive keep_alive) { + test_base_.ReleaseAllConnections(keep_alive); + } + + TestSocketRequest* request(int i) { return test_base_.request(i); } + + ScopedVector<TestSocketRequest>* requests() { return test_base_.requests(); } + size_t completion_count() const { return test_base_.completion_count(); } + + CapturingNetLog net_log_; + scoped_refptr<TransportSocketParams> params_; + scoped_ptr<ClientSocketPoolHistograms> histograms_; + scoped_ptr<MockHostResolver> host_resolver_; + MockTransportClientSocketFactory client_socket_factory_; + WebSocketTransportClientSocketPool pool_; + ClientSocketPoolTest test_base_; + + private: + DISALLOW_COPY_AND_ASSIGN(WebSocketTransportClientSocketPoolTest); +}; + +TEST_F(WebSocketTransportClientSocketPoolTest, Basic) { + TestCompletionCallback callback; + ClientSocketHandle handle; + int rv = handle.Init( + "a", params_, LOW, callback.callback(), &pool_, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); + TestLoadTimingInfoConnectedNotReused(handle); +} + +// Make sure that WebSocketTransportConnectJob passes on its priority to its +// HostResolver request on Init. +TEST_F(WebSocketTransportClientSocketPoolTest, SetResolvePriorityOnInit) { + for (int i = MINIMUM_PRIORITY; i <= MAXIMUM_PRIORITY; ++i) { + RequestPriority priority = static_cast<RequestPriority>(i); + TestCompletionCallback callback; + ClientSocketHandle handle; + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", + params_, + priority, + callback.callback(), + &pool_, + BoundNetLog())); + EXPECT_EQ(priority, host_resolver_->last_request_priority()); + } +} + +TEST_F(WebSocketTransportClientSocketPoolTest, InitHostResolutionFailure) { + host_resolver_->rules()->AddSimulatedFailure("unresolvable.host.name"); + TestCompletionCallback callback; + ClientSocketHandle handle; + HostPortPair host_port_pair("unresolvable.host.name", 80); + scoped_refptr<TransportSocketParams> dest(new TransportSocketParams( + host_port_pair, false, false, OnHostResolutionCallback(), + TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)); + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", + dest, + kDefaultPriority, + callback.callback(), + &pool_, + BoundNetLog())); + EXPECT_EQ(ERR_NAME_NOT_RESOLVED, callback.WaitForResult()); +} + +TEST_F(WebSocketTransportClientSocketPoolTest, InitConnectionFailure) { + client_socket_factory_.set_default_client_socket_type( + MockTransportClientSocketFactory::MOCK_FAILING_CLIENT_SOCKET); + TestCompletionCallback callback; + ClientSocketHandle handle; + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", + params_, + kDefaultPriority, + callback.callback(), + &pool_, + BoundNetLog())); + EXPECT_EQ(ERR_CONNECTION_FAILED, callback.WaitForResult()); + + // Make the host resolutions complete synchronously this time. + host_resolver_->set_synchronous_mode(true); + EXPECT_EQ(ERR_CONNECTION_FAILED, + handle.Init("a", + params_, + kDefaultPriority, + callback.callback(), + &pool_, + BoundNetLog())); +} + +TEST_F(WebSocketTransportClientSocketPoolTest, PendingRequestsFinishFifo) { + // First request finishes asynchronously. + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, request(0)->WaitForResult()); + + // Make all subsequent host resolutions complete synchronously. + host_resolver_->set_synchronous_mode(true); + + // Rest of them wait for the first socket to be released. + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + + ReleaseAllConnections(ClientSocketPoolTest::KEEP_ALIVE); + + EXPECT_EQ(6, client_socket_factory_.allocation_count()); + + // One initial asynchronous request and then 5 pending requests. + EXPECT_EQ(6U, completion_count()); + + // The requests finish in FIFO order. + EXPECT_EQ(1, GetOrderOfRequest(1)); + EXPECT_EQ(2, GetOrderOfRequest(2)); + EXPECT_EQ(3, GetOrderOfRequest(3)); + EXPECT_EQ(4, GetOrderOfRequest(4)); + EXPECT_EQ(5, GetOrderOfRequest(5)); + EXPECT_EQ(6, GetOrderOfRequest(6)); + + // Make sure we test order of all requests made. + EXPECT_EQ(ClientSocketPoolTest::kIndexOutOfBounds, GetOrderOfRequest(7)); +} + +TEST_F(WebSocketTransportClientSocketPoolTest, PendingRequests_NoKeepAlive) { + // First request finishes asynchronously. + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, request(0)->WaitForResult()); + + // Make all subsequent host resolutions complete synchronously. + host_resolver_->set_synchronous_mode(true); + + // Rest of them wait for the first socket to be released. + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + + ReleaseAllConnections(ClientSocketPoolTest::NO_KEEP_ALIVE); + + // The pending requests should finish successfully. + EXPECT_EQ(OK, request(1)->WaitForResult()); + EXPECT_EQ(OK, request(2)->WaitForResult()); + EXPECT_EQ(OK, request(3)->WaitForResult()); + EXPECT_EQ(OK, request(4)->WaitForResult()); + EXPECT_EQ(OK, request(5)->WaitForResult()); + + EXPECT_EQ(static_cast<int>(requests()->size()), + client_socket_factory_.allocation_count()); + + // First asynchronous request, and then last 5 pending requests. + EXPECT_EQ(6U, completion_count()); +} + +// This test will start up a RequestSocket() and then immediately Cancel() it. +// The pending host resolution will eventually complete, and destroy the +// ClientSocketPool which will crash if the group was not cleared properly. +TEST_F(WebSocketTransportClientSocketPoolTest, CancelRequestClearGroup) { + TestCompletionCallback callback; + ClientSocketHandle handle; + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", + params_, + kDefaultPriority, + callback.callback(), + &pool_, + BoundNetLog())); + handle.Reset(); +} + +TEST_F(WebSocketTransportClientSocketPoolTest, TwoRequestsCancelOne) { + ClientSocketHandle handle; + TestCompletionCallback callback; + ClientSocketHandle handle2; + TestCompletionCallback callback2; + + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", + params_, + kDefaultPriority, + callback.callback(), + &pool_, + BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle2.Init("a", + params_, + kDefaultPriority, + callback2.callback(), + &pool_, + BoundNetLog())); + + handle.Reset(); + + EXPECT_EQ(OK, callback2.WaitForResult()); + handle2.Reset(); +} + +TEST_F(WebSocketTransportClientSocketPoolTest, ConnectCancelConnect) { + client_socket_factory_.set_default_client_socket_type( + MockTransportClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET); + ClientSocketHandle handle; + TestCompletionCallback callback; + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", + params_, + kDefaultPriority, + callback.callback(), + &pool_, + BoundNetLog())); + + handle.Reset(); + + TestCompletionCallback callback2; + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", + params_, + kDefaultPriority, + callback2.callback(), + &pool_, + BoundNetLog())); + + host_resolver_->set_synchronous_mode(true); + // At this point, handle has two ConnectingSockets out for it. Due to the + // setting the mock resolver into synchronous mode, the host resolution for + // both will return in the same loop of the MessageLoop. The client socket + // is a pending socket, so the Connect() will asynchronously complete on the + // next loop of the MessageLoop. That means that the first + // ConnectingSocket will enter OnIOComplete, and then the second one will. + // If the first one is not cancelled, it will advance the load state, and + // then the second one will crash. + + EXPECT_EQ(OK, callback2.WaitForResult()); + EXPECT_FALSE(callback.have_result()); + + handle.Reset(); +} + +TEST_F(WebSocketTransportClientSocketPoolTest, CancelRequest) { + // First request finishes asynchronously. + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, request(0)->WaitForResult()); + + // Make all subsequent host resolutions complete synchronously. + host_resolver_->set_synchronous_mode(true); + + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + + // Cancel a request. + const size_t index_to_cancel = 2; + EXPECT_FALSE(request(index_to_cancel)->handle()->is_initialized()); + request(index_to_cancel)->handle()->Reset(); + + ReleaseAllConnections(ClientSocketPoolTest::KEEP_ALIVE); + + EXPECT_EQ(5, client_socket_factory_.allocation_count()); + + EXPECT_EQ(1, GetOrderOfRequest(1)); + EXPECT_EQ(2, GetOrderOfRequest(2)); + EXPECT_EQ(ClientSocketPoolTest::kRequestNotFound, + GetOrderOfRequest(3)); // Canceled request. + EXPECT_EQ(3, GetOrderOfRequest(4)); + EXPECT_EQ(4, GetOrderOfRequest(5)); + EXPECT_EQ(5, GetOrderOfRequest(6)); + + // Make sure we test order of all requests made. + EXPECT_EQ(ClientSocketPoolTest::kIndexOutOfBounds, GetOrderOfRequest(7)); +} + +class RequestSocketCallback : public TestCompletionCallbackBase { + public: + RequestSocketCallback(ClientSocketHandle* handle, + WebSocketTransportClientSocketPool* pool) + : handle_(handle), + pool_(pool), + within_callback_(false), + callback_(base::Bind(&RequestSocketCallback::OnComplete, + base::Unretained(this))) {} + + ~RequestSocketCallback() override {} + + const CompletionCallback& callback() const { return callback_; } + + private: + void OnComplete(int result) { + SetResult(result); + ASSERT_EQ(OK, result); + + if (!within_callback_) { + // Don't allow reuse of the socket. Disconnect it and then release it and + // run through the MessageLoop once to get it completely released. + handle_->socket()->Disconnect(); + handle_->Reset(); + { + base::MessageLoop::ScopedNestableTaskAllower allow( + base::MessageLoop::current()); + base::MessageLoop::current()->RunUntilIdle(); + } + within_callback_ = true; + scoped_refptr<TransportSocketParams> dest( + new TransportSocketParams( + HostPortPair("www.google.com", 80), + false, + false, + OnHostResolutionCallback(), + TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)); + int rv = + handle_->Init("a", dest, LOWEST, callback(), pool_, BoundNetLog()); + EXPECT_EQ(OK, rv); + } + } + + ClientSocketHandle* const handle_; + WebSocketTransportClientSocketPool* const pool_; + bool within_callback_; + CompletionCallback callback_; + + DISALLOW_COPY_AND_ASSIGN(RequestSocketCallback); +}; + +TEST_F(WebSocketTransportClientSocketPoolTest, RequestTwice) { + ClientSocketHandle handle; + RequestSocketCallback callback(&handle, &pool_); + scoped_refptr<TransportSocketParams> dest( + new TransportSocketParams( + HostPortPair("www.google.com", 80), + false, + false, + OnHostResolutionCallback(), + TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)); + int rv = handle.Init( + "a", dest, LOWEST, callback.callback(), &pool_, BoundNetLog()); + ASSERT_EQ(ERR_IO_PENDING, rv); + + // The callback is going to request "www.google.com". We want it to complete + // synchronously this time. + host_resolver_->set_synchronous_mode(true); + + EXPECT_EQ(OK, callback.WaitForResult()); + + handle.Reset(); +} + +// Make sure that pending requests get serviced after active requests get +// cancelled. +TEST_F(WebSocketTransportClientSocketPoolTest, + CancelActiveRequestWithPendingRequests) { + client_socket_factory_.set_default_client_socket_type( + MockTransportClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET); + + // Queue up all the requests + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + + // Now, kMaxSocketsPerGroup requests should be active. Let's cancel them. + ASSERT_LE(kMaxSocketsPerGroup, static_cast<int>(requests()->size())); + for (int i = 0; i < kMaxSocketsPerGroup; i++) + request(i)->handle()->Reset(); + + // Let's wait for the rest to complete now. + for (size_t i = kMaxSocketsPerGroup; i < requests()->size(); ++i) { + EXPECT_EQ(OK, request(i)->WaitForResult()); + request(i)->handle()->Reset(); + } + + EXPECT_EQ(requests()->size() - kMaxSocketsPerGroup, completion_count()); +} + +// Make sure that pending requests get serviced after active requests fail. +TEST_F(WebSocketTransportClientSocketPoolTest, + FailingActiveRequestWithPendingRequests) { + client_socket_factory_.set_default_client_socket_type( + MockTransportClientSocketFactory::MOCK_PENDING_FAILING_CLIENT_SOCKET); + + const int kNumRequests = 2 * kMaxSocketsPerGroup + 1; + ASSERT_LE(kNumRequests, kMaxSockets); // Otherwise the test will hang. + + // Queue up all the requests + for (int i = 0; i < kNumRequests; i++) + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + + for (int i = 0; i < kNumRequests; i++) + EXPECT_EQ(ERR_CONNECTION_FAILED, request(i)->WaitForResult()); +} + +// The lock on the endpoint is released when a ClientSocketHandle is reset. +TEST_F(WebSocketTransportClientSocketPoolTest, LockReleasedOnHandleReset) { + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, request(0)->WaitForResult()); + EXPECT_FALSE(request(1)->handle()->is_initialized()); + request(0)->handle()->Reset(); + base::RunLoop().RunUntilIdle(); + EXPECT_TRUE(request(1)->handle()->is_initialized()); +} + +// The lock on the endpoint is released when a ClientSocketHandle is deleted. +TEST_F(WebSocketTransportClientSocketPoolTest, LockReleasedOnHandleDelete) { + TestCompletionCallback callback; + scoped_ptr<ClientSocketHandle> handle(new ClientSocketHandle); + int rv = handle->Init( + "a", params_, LOW, callback.callback(), &pool_, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_FALSE(request(0)->handle()->is_initialized()); + handle.reset(); + base::RunLoop().RunUntilIdle(); + EXPECT_TRUE(request(0)->handle()->is_initialized()); +} + +// A new connection is performed when the lock on the previous connection is +// explicitly released. +TEST_F(WebSocketTransportClientSocketPoolTest, + ConnectionProceedsOnExplicitRelease) { + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(OK, request(0)->WaitForResult()); + EXPECT_FALSE(request(1)->handle()->is_initialized()); + WebSocketTransportClientSocketPool::UnlockEndpoint(request(0)->handle()); + base::RunLoop().RunUntilIdle(); + EXPECT_TRUE(request(1)->handle()->is_initialized()); +} + +// A connection which is cancelled before completion does not block subsequent +// connections. +TEST_F(WebSocketTransportClientSocketPoolTest, + CancelDuringConnectionReleasesLock) { + MockTransportClientSocketFactory::ClientSocketType case_types[] = { + MockTransportClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET, + MockTransportClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET}; + + client_socket_factory_.set_client_socket_types(case_types, + arraysize(case_types)); + + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + base::RunLoop().RunUntilIdle(); + pool_.CancelRequest("a", request(0)->handle()); + EXPECT_EQ(OK, request(1)->WaitForResult()); +} + +// Test the case of the IPv6 address stalling, and falling back to the IPv4 +// socket which finishes first. +TEST_F(WebSocketTransportClientSocketPoolTest, + IPv6FallbackSocketIPv4FinishesFirst) { + WebSocketTransportClientSocketPool pool(kMaxSockets, + kMaxSocketsPerGroup, + histograms_.get(), + host_resolver_.get(), + &client_socket_factory_, + NULL); + + MockTransportClientSocketFactory::ClientSocketType case_types[] = { + // This is the IPv6 socket. + MockTransportClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET, + // This is the IPv4 socket. + MockTransportClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET}; + + client_socket_factory_.set_client_socket_types(case_types, 2); + + // Resolve an AddressList with an IPv6 address first and then an IPv4 address. + host_resolver_->rules()->AddIPLiteralRule( + "*", "2:abcd::3:4:ff,2.2.2.2", std::string()); + + TestCompletionCallback callback; + ClientSocketHandle handle; + int rv = + handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); + IPEndPoint endpoint; + handle.socket()->GetLocalAddress(&endpoint); + EXPECT_EQ(kIPv4AddressSize, endpoint.address().size()); + EXPECT_EQ(2, client_socket_factory_.allocation_count()); +} + +// Test the case of the IPv6 address being slow, thus falling back to trying to +// connect to the IPv4 address, but having the connect to the IPv6 address +// finish first. +TEST_F(WebSocketTransportClientSocketPoolTest, + IPv6FallbackSocketIPv6FinishesFirst) { + WebSocketTransportClientSocketPool pool(kMaxSockets, + kMaxSocketsPerGroup, + histograms_.get(), + host_resolver_.get(), + &client_socket_factory_, + NULL); + + MockTransportClientSocketFactory::ClientSocketType case_types[] = { + // This is the IPv6 socket. + MockTransportClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET, + // This is the IPv4 socket. + MockTransportClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET}; + + client_socket_factory_.set_client_socket_types(case_types, 2); + client_socket_factory_.set_delay(base::TimeDelta::FromMilliseconds( + TransportConnectJobHelper::kIPv6FallbackTimerInMs + 50)); + + // Resolve an AddressList with an IPv6 address first and then an IPv4 address. + host_resolver_->rules()->AddIPLiteralRule( + "*", "2:abcd::3:4:ff,2.2.2.2", std::string()); + + TestCompletionCallback callback; + ClientSocketHandle handle; + int rv = + handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); + IPEndPoint endpoint; + handle.socket()->GetLocalAddress(&endpoint); + EXPECT_EQ(kIPv6AddressSize, endpoint.address().size()); + EXPECT_EQ(2, client_socket_factory_.allocation_count()); +} + +TEST_F(WebSocketTransportClientSocketPoolTest, + IPv6NoIPv4AddressesToFallbackTo) { + WebSocketTransportClientSocketPool pool(kMaxSockets, + kMaxSocketsPerGroup, + histograms_.get(), + host_resolver_.get(), + &client_socket_factory_, + NULL); + + client_socket_factory_.set_default_client_socket_type( + MockTransportClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET); + + // Resolve an AddressList with only IPv6 addresses. + host_resolver_->rules()->AddIPLiteralRule( + "*", "2:abcd::3:4:ff,3:abcd::3:4:ff", std::string()); + + TestCompletionCallback callback; + ClientSocketHandle handle; + int rv = + handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); + IPEndPoint endpoint; + handle.socket()->GetLocalAddress(&endpoint); + EXPECT_EQ(kIPv6AddressSize, endpoint.address().size()); + EXPECT_EQ(1, client_socket_factory_.allocation_count()); +} + +TEST_F(WebSocketTransportClientSocketPoolTest, IPv4HasNoFallback) { + WebSocketTransportClientSocketPool pool(kMaxSockets, + kMaxSocketsPerGroup, + histograms_.get(), + host_resolver_.get(), + &client_socket_factory_, + NULL); + + client_socket_factory_.set_default_client_socket_type( + MockTransportClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET); + + // Resolve an AddressList with only IPv4 addresses. + host_resolver_->rules()->AddIPLiteralRule("*", "1.1.1.1", std::string()); + + TestCompletionCallback callback; + ClientSocketHandle handle; + int rv = + handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.is_initialized()); + EXPECT_FALSE(handle.socket()); + + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(handle.is_initialized()); + EXPECT_TRUE(handle.socket()); + IPEndPoint endpoint; + handle.socket()->GetLocalAddress(&endpoint); + EXPECT_EQ(kIPv4AddressSize, endpoint.address().size()); + EXPECT_EQ(1, client_socket_factory_.allocation_count()); +} + +// If all IPv6 addresses fail to connect synchronously, then IPv4 connections +// proceeed immediately. +TEST_F(WebSocketTransportClientSocketPoolTest, IPv6InstantFail) { + WebSocketTransportClientSocketPool pool(kMaxSockets, + kMaxSocketsPerGroup, + histograms_.get(), + host_resolver_.get(), + &client_socket_factory_, + NULL); + + MockTransportClientSocketFactory::ClientSocketType case_types[] = { + // First IPv6 socket. + MockTransportClientSocketFactory::MOCK_FAILING_CLIENT_SOCKET, + // Second IPv6 socket. + MockTransportClientSocketFactory::MOCK_FAILING_CLIENT_SOCKET, + // This is the IPv4 socket. + MockTransportClientSocketFactory::MOCK_CLIENT_SOCKET}; + + client_socket_factory_.set_client_socket_types(case_types, + arraysize(case_types)); + + // Resolve an AddressList with two IPv6 addresses and then an IPv4 address. + host_resolver_->rules()->AddIPLiteralRule( + "*", "2:abcd::3:4:ff,2:abcd::3:5:ff,2.2.2.2", std::string()); + host_resolver_->set_synchronous_mode(true); + TestCompletionCallback callback; + ClientSocketHandle handle; + int rv = + handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog()); + EXPECT_EQ(OK, rv); + ASSERT_TRUE(handle.socket()); + + IPEndPoint endpoint; + handle.socket()->GetPeerAddress(&endpoint); + EXPECT_EQ("2.2.2.2", endpoint.ToStringWithoutPort()); +} + +// If all IPv6 addresses fail before the IPv4 fallback timeout, then the IPv4 +// connections proceed immediately. +TEST_F(WebSocketTransportClientSocketPoolTest, IPv6RapidFail) { + WebSocketTransportClientSocketPool pool(kMaxSockets, + kMaxSocketsPerGroup, + histograms_.get(), + host_resolver_.get(), + &client_socket_factory_, + NULL); + + MockTransportClientSocketFactory::ClientSocketType case_types[] = { + // First IPv6 socket. + MockTransportClientSocketFactory::MOCK_PENDING_FAILING_CLIENT_SOCKET, + // Second IPv6 socket. + MockTransportClientSocketFactory::MOCK_PENDING_FAILING_CLIENT_SOCKET, + // This is the IPv4 socket. + MockTransportClientSocketFactory::MOCK_CLIENT_SOCKET}; + + client_socket_factory_.set_client_socket_types(case_types, + arraysize(case_types)); + + // Resolve an AddressList with two IPv6 addresses and then an IPv4 address. + host_resolver_->rules()->AddIPLiteralRule( + "*", "2:abcd::3:4:ff,2:abcd::3:5:ff,2.2.2.2", std::string()); + + TestCompletionCallback callback; + ClientSocketHandle handle; + int rv = + handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle.socket()); + + base::Time start(base::Time::NowFromSystemTime()); + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_LT(base::Time::NowFromSystemTime() - start, + base::TimeDelta::FromMilliseconds( + TransportConnectJobHelper::kIPv6FallbackTimerInMs)); + ASSERT_TRUE(handle.socket()); + + IPEndPoint endpoint; + handle.socket()->GetPeerAddress(&endpoint); + EXPECT_EQ("2.2.2.2", endpoint.ToStringWithoutPort()); +} + +// If two sockets connect successfully, the one which connected first wins (this +// can only happen if the sockets are different types, since sockets of the same +// type do not race). +TEST_F(WebSocketTransportClientSocketPoolTest, FirstSuccessWins) { + WebSocketTransportClientSocketPool pool(kMaxSockets, + kMaxSocketsPerGroup, + histograms_.get(), + host_resolver_.get(), + &client_socket_factory_, + NULL); + + client_socket_factory_.set_default_client_socket_type( + MockTransportClientSocketFactory::MOCK_TRIGGERABLE_CLIENT_SOCKET); + + // Resolve an AddressList with an IPv6 addresses and an IPv4 address. + host_resolver_->rules()->AddIPLiteralRule( + "*", "2:abcd::3:4:ff,2.2.2.2", std::string()); + + TestCompletionCallback callback; + ClientSocketHandle handle; + int rv = + handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + ASSERT_FALSE(handle.socket()); + + base::Closure ipv6_connect_trigger = + client_socket_factory_.WaitForTriggerableSocketCreation(); + base::Closure ipv4_connect_trigger = + client_socket_factory_.WaitForTriggerableSocketCreation(); + + ipv4_connect_trigger.Run(); + ipv6_connect_trigger.Run(); + + EXPECT_EQ(OK, callback.WaitForResult()); + ASSERT_TRUE(handle.socket()); + + IPEndPoint endpoint; + handle.socket()->GetPeerAddress(&endpoint); + EXPECT_EQ("2.2.2.2", endpoint.ToStringWithoutPort()); +} + +// We should not report failure until all connections have failed. +TEST_F(WebSocketTransportClientSocketPoolTest, LastFailureWins) { + WebSocketTransportClientSocketPool pool(kMaxSockets, + kMaxSocketsPerGroup, + histograms_.get(), + host_resolver_.get(), + &client_socket_factory_, + NULL); + + client_socket_factory_.set_default_client_socket_type( + MockTransportClientSocketFactory::MOCK_DELAYED_FAILING_CLIENT_SOCKET); + base::TimeDelta delay = base::TimeDelta::FromMilliseconds( + TransportConnectJobHelper::kIPv6FallbackTimerInMs / 3); + client_socket_factory_.set_delay(delay); + + // Resolve an AddressList with 4 IPv6 addresses and 2 IPv4 addresses. + host_resolver_->rules()->AddIPLiteralRule("*", + "1:abcd::3:4:ff,2:abcd::3:4:ff," + "3:abcd::3:4:ff,4:abcd::3:4:ff," + "1.1.1.1,2.2.2.2", + std::string()); + + // Expected order of events: + // After 100ms: Connect to 1:abcd::3:4:ff times out + // After 200ms: Connect to 2:abcd::3:4:ff times out + // After 300ms: Connect to 3:abcd::3:4:ff times out, IPv4 fallback starts + // After 400ms: Connect to 4:abcd::3:4:ff and 1.1.1.1 time out + // After 500ms: Connect to 2.2.2.2 times out + + TestCompletionCallback callback; + ClientSocketHandle handle; + base::Time start(base::Time::NowFromSystemTime()); + int rv = + handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + + EXPECT_EQ(ERR_CONNECTION_FAILED, callback.WaitForResult()); + + EXPECT_GE(base::Time::NowFromSystemTime() - start, delay * 5); +} + +// Global timeout for all connects applies. This test is disabled by default +// because it takes 4 minutes. Run with --gtest_also_run_disabled_tests if you +// want to run it. +TEST_F(WebSocketTransportClientSocketPoolTest, DISABLED_OverallTimeoutApplies) { + WebSocketTransportClientSocketPool pool(kMaxSockets, + kMaxSocketsPerGroup, + histograms_.get(), + host_resolver_.get(), + &client_socket_factory_, + NULL); + const base::TimeDelta connect_job_timeout = pool.ConnectionTimeout(); + + client_socket_factory_.set_default_client_socket_type( + MockTransportClientSocketFactory::MOCK_DELAYED_FAILING_CLIENT_SOCKET); + client_socket_factory_.set_delay(base::TimeDelta::FromSeconds(1) + + connect_job_timeout / 6); + + // Resolve an AddressList with 6 IPv6 addresses and 6 IPv4 addresses. + host_resolver_->rules()->AddIPLiteralRule("*", + "1:abcd::3:4:ff,2:abcd::3:4:ff," + "3:abcd::3:4:ff,4:abcd::3:4:ff," + "5:abcd::3:4:ff,6:abcd::3:4:ff," + "1.1.1.1,2.2.2.2,3.3.3.3," + "4.4.4.4,5.5.5.5,6.6.6.6", + std::string()); + + TestCompletionCallback callback; + ClientSocketHandle handle; + + int rv = + handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + + EXPECT_EQ(ERR_TIMED_OUT, callback.WaitForResult()); +} + +TEST_F(WebSocketTransportClientSocketPoolTest, MaxSocketsEnforced) { + host_resolver_->set_synchronous_mode(true); + for (int i = 0; i < kMaxSockets; ++i) { + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + WebSocketTransportClientSocketPool::UnlockEndpoint(request(i)->handle()); + } + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); +} + +TEST_F(WebSocketTransportClientSocketPoolTest, MaxSocketsEnforcedWhenPending) { + for (int i = 0; i < kMaxSockets + 1; ++i) { + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + } + // Now there are 32 sockets waiting to connect, and one stalled. + for (int i = 0; i < kMaxSockets; ++i) { + base::RunLoop().RunUntilIdle(); + EXPECT_TRUE(request(i)->handle()->is_initialized()); + EXPECT_TRUE(request(i)->handle()->socket()); + WebSocketTransportClientSocketPool::UnlockEndpoint(request(i)->handle()); + } + // Now there are 32 sockets connected, and one stalled. + base::RunLoop().RunUntilIdle(); + EXPECT_FALSE(request(kMaxSockets)->handle()->is_initialized()); + EXPECT_FALSE(request(kMaxSockets)->handle()->socket()); +} + +TEST_F(WebSocketTransportClientSocketPoolTest, StalledSocketReleased) { + host_resolver_->set_synchronous_mode(true); + for (int i = 0; i < kMaxSockets; ++i) { + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + WebSocketTransportClientSocketPool::UnlockEndpoint(request(i)->handle()); + } + + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + ReleaseOneConnection(ClientSocketPoolTest::NO_KEEP_ALIVE); + EXPECT_TRUE(request(kMaxSockets)->handle()->is_initialized()); + EXPECT_TRUE(request(kMaxSockets)->handle()->socket()); +} + +TEST_F(WebSocketTransportClientSocketPoolTest, IsStalledTrueWhenStalled) { + for (int i = 0; i < kMaxSockets + 1; ++i) { + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + } + EXPECT_EQ(OK, request(0)->WaitForResult()); + EXPECT_TRUE(pool_.IsStalled()); +} + +TEST_F(WebSocketTransportClientSocketPoolTest, + CancellingPendingSocketUnstallsStalledSocket) { + for (int i = 0; i < kMaxSockets + 1; ++i) { + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + } + EXPECT_EQ(OK, request(0)->WaitForResult()); + request(1)->handle()->Reset(); + base::RunLoop().RunUntilIdle(); + EXPECT_FALSE(pool_.IsStalled()); +} + +TEST_F(WebSocketTransportClientSocketPoolTest, + LoadStateOfStalledSocketIsWaitingForAvailableSocket) { + for (int i = 0; i < kMaxSockets + 1; ++i) { + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + } + EXPECT_EQ(LOAD_STATE_WAITING_FOR_AVAILABLE_SOCKET, + pool_.GetLoadState("a", request(kMaxSockets)->handle())); +} + +TEST_F(WebSocketTransportClientSocketPoolTest, + CancellingStalledSocketUnstallsPool) { + for (int i = 0; i < kMaxSockets + 1; ++i) { + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + } + request(kMaxSockets)->handle()->Reset(); + EXPECT_FALSE(pool_.IsStalled()); +} + +TEST_F(WebSocketTransportClientSocketPoolTest, + FlushWithErrorFlushesPendingConnections) { + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + pool_.FlushWithError(ERR_FAILED); + EXPECT_EQ(ERR_FAILED, request(0)->WaitForResult()); +} + +TEST_F(WebSocketTransportClientSocketPoolTest, + FlushWithErrorFlushesStalledConnections) { + for (int i = 0; i < kMaxSockets + 1; ++i) { + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + } + pool_.FlushWithError(ERR_FAILED); + EXPECT_EQ(ERR_FAILED, request(kMaxSockets)->WaitForResult()); +} + +TEST_F(WebSocketTransportClientSocketPoolTest, + AfterFlushWithErrorCanMakeNewConnections) { + for (int i = 0; i < kMaxSockets + 1; ++i) { + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + } + pool_.FlushWithError(ERR_FAILED); + host_resolver_->set_synchronous_mode(true); + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); +} + +// Deleting pending connections can release the lock on the endpoint, which can +// in principle lead to other pending connections succeeding. However, when we +// call FlushWithError(), everything should fail. +TEST_F(WebSocketTransportClientSocketPoolTest, + FlushWithErrorDoesNotCauseSuccessfulConnections) { + host_resolver_->set_synchronous_mode(true); + MockTransportClientSocketFactory::ClientSocketType first_type[] = { + // First socket + MockTransportClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET + }; + client_socket_factory_.set_client_socket_types(first_type, + arraysize(first_type)); + // The rest of the sockets will connect synchronously. + client_socket_factory_.set_default_client_socket_type( + MockTransportClientSocketFactory::MOCK_CLIENT_SOCKET); + for (int i = 0; i < kMaxSockets; ++i) { + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + } + // Now we have one socket in STATE_TRANSPORT_CONNECT and the rest in + // STATE_OBTAIN_LOCK. If any of the sockets in STATE_OBTAIN_LOCK is given the + // lock, they will synchronously connect. + pool_.FlushWithError(ERR_FAILED); + for (int i = 0; i < kMaxSockets; ++i) { + EXPECT_EQ(ERR_FAILED, request(i)->WaitForResult()); + } +} + +// This is a regression test for the first attempted fix for +// FlushWithErrorDoesNotCauseSuccessfulConnections. Because a ConnectJob can +// have both IPv4 and IPv6 subjobs, it can be both connecting and waiting for +// the lock at the same time. +TEST_F(WebSocketTransportClientSocketPoolTest, + FlushWithErrorDoesNotCauseSuccessfulConnectionsMultipleAddressTypes) { + host_resolver_->set_synchronous_mode(true); + // The first |kMaxSockets| sockets to connect will be IPv6. Then we will have + // one IPv4. + std::vector<MockTransportClientSocketFactory::ClientSocketType> socket_types( + kMaxSockets + 1, + MockTransportClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET); + client_socket_factory_.set_client_socket_types(&socket_types[0], + socket_types.size()); + // The rest of the sockets will connect synchronously. + client_socket_factory_.set_default_client_socket_type( + MockTransportClientSocketFactory::MOCK_CLIENT_SOCKET); + for (int i = 0; i < kMaxSockets; ++i) { + host_resolver_->rules()->ClearRules(); + // Each connect job has a different IPv6 address but the same IPv4 address. + // So the IPv6 connections happen in parallel but the IPv4 ones are + // serialised. + host_resolver_->rules()->AddIPLiteralRule("*", + base::StringPrintf( + "%x:abcd::3:4:ff," + "1.1.1.1", + i + 1), + std::string()); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + } + // Now we have |kMaxSockets| IPv6 sockets stalled in connect. No IPv4 sockets + // are started yet. + RunLoopForTimePeriod(base::TimeDelta::FromMilliseconds( + TransportConnectJobHelper::kIPv6FallbackTimerInMs)); + // Now we have |kMaxSockets| IPv6 sockets and one IPv4 socket stalled in + // connect, and |kMaxSockets - 1| IPv4 sockets waiting for the endpoint lock. + pool_.FlushWithError(ERR_FAILED); + for (int i = 0; i < kMaxSockets; ++i) { + EXPECT_EQ(ERR_FAILED, request(i)->WaitForResult()); + } +} + +// Sockets that have had ownership transferred to a ClientSocketHandle should +// not be affected by FlushWithError. +TEST_F(WebSocketTransportClientSocketPoolTest, + FlushWithErrorDoesNotAffectHandedOutSockets) { + host_resolver_->set_synchronous_mode(true); + MockTransportClientSocketFactory::ClientSocketType socket_types[] = { + MockTransportClientSocketFactory::MOCK_CLIENT_SOCKET, + MockTransportClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET}; + client_socket_factory_.set_client_socket_types(socket_types, + arraysize(socket_types)); + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + // Socket has been "handed out". + EXPECT_TRUE(request(0)->handle()->socket()); + + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + // Now we have one socket handed out, and one pending. + pool_.FlushWithError(ERR_FAILED); + EXPECT_EQ(ERR_FAILED, request(1)->WaitForResult()); + // Socket owned by ClientSocketHandle is unaffected: + EXPECT_TRUE(request(0)->handle()->socket()); + // Return it to the pool (which deletes it). + request(0)->handle()->Reset(); +} + +// Sockets should not be leaked if CancelRequest() is called in between +// SetSocket() being called on the ClientSocketHandle and InvokeUserCallback(). +TEST_F(WebSocketTransportClientSocketPoolTest, CancelRequestReclaimsSockets) { + host_resolver_->set_synchronous_mode(true); + MockTransportClientSocketFactory::ClientSocketType socket_types[] = { + MockTransportClientSocketFactory::MOCK_TRIGGERABLE_CLIENT_SOCKET, + MockTransportClientSocketFactory::MOCK_CLIENT_SOCKET}; + + client_socket_factory_.set_client_socket_types(socket_types, + arraysize(socket_types)); + + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + + base::Closure connect_trigger = + client_socket_factory_.WaitForTriggerableSocketCreation(); + + connect_trigger.Run(); // Calls InvokeUserCallbackLater() + + request(0)->handle()->Reset(); // calls CancelRequest() + + // We should now be able to create a new connection without blocking on the + // endpoint lock. + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); +} + +// A handshake completing and then the WebSocket closing should only release one +// Endpoint, not two. +TEST_F(WebSocketTransportClientSocketPoolTest, EndpointLockIsOnlyReleasedOnce) { + host_resolver_->set_synchronous_mode(true); + EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); + // First socket completes handshake. + WebSocketTransportClientSocketPool::UnlockEndpoint(request(0)->handle()); + // First socket is closed. + request(0)->handle()->Reset(); + // Second socket should have been released. + EXPECT_EQ(OK, request(1)->WaitForResult()); + // Third socket should still be waiting for endpoint. + ASSERT_FALSE(request(2)->handle()->is_initialized()); + EXPECT_EQ(LOAD_STATE_WAITING_FOR_AVAILABLE_SOCKET, + request(2)->handle()->GetLoadState()); +} + +} // namespace + +} // namespace net diff --git a/chromium/net/socket/websocket_transport_connect_sub_job.cc b/chromium/net/socket/websocket_transport_connect_sub_job.cc new file mode 100644 index 00000000000..fbe8bbcc82c --- /dev/null +++ b/chromium/net/socket/websocket_transport_connect_sub_job.cc @@ -0,0 +1,170 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/websocket_transport_connect_sub_job.h" + +#include "base/logging.h" +#include "net/base/ip_endpoint.h" +#include "net/base/net_errors.h" +#include "net/base/net_log.h" +#include "net/socket/client_socket_factory.h" +#include "net/socket/websocket_endpoint_lock_manager.h" + +namespace net { + +WebSocketTransportConnectSubJob::WebSocketTransportConnectSubJob( + const AddressList& addresses, + WebSocketTransportConnectJob* parent_job, + SubJobType type) + : parent_job_(parent_job), + addresses_(addresses), + current_address_index_(0), + next_state_(STATE_NONE), + type_(type) {} + +WebSocketTransportConnectSubJob::~WebSocketTransportConnectSubJob() { + // We don't worry about cancelling the TCP connect, since ~StreamSocket will + // take care of it. + if (next()) { + DCHECK_EQ(STATE_OBTAIN_LOCK_COMPLETE, next_state_); + // The ~Waiter destructor will remove this object from the waiting list. + } else if (next_state_ == STATE_TRANSPORT_CONNECT_COMPLETE) { + WebSocketEndpointLockManager::GetInstance()->UnlockEndpoint( + CurrentAddress()); + } +} + +// Start connecting. +int WebSocketTransportConnectSubJob::Start() { + DCHECK_EQ(STATE_NONE, next_state_); + next_state_ = STATE_OBTAIN_LOCK; + return DoLoop(OK); +} + +// Called by WebSocketEndpointLockManager when the lock becomes available. +void WebSocketTransportConnectSubJob::GotEndpointLock() { + DCHECK_EQ(STATE_OBTAIN_LOCK_COMPLETE, next_state_); + OnIOComplete(OK); +} + +LoadState WebSocketTransportConnectSubJob::GetLoadState() const { + switch (next_state_) { + case STATE_OBTAIN_LOCK: + case STATE_OBTAIN_LOCK_COMPLETE: + // TODO(ricea): Add a WebSocket-specific LOAD_STATE ? + return LOAD_STATE_WAITING_FOR_AVAILABLE_SOCKET; + case STATE_TRANSPORT_CONNECT: + case STATE_TRANSPORT_CONNECT_COMPLETE: + case STATE_DONE: + return LOAD_STATE_CONNECTING; + case STATE_NONE: + return LOAD_STATE_IDLE; + } + NOTREACHED(); + return LOAD_STATE_IDLE; +} + +ClientSocketFactory* WebSocketTransportConnectSubJob::client_socket_factory() + const { + return parent_job_->helper_.client_socket_factory(); +} + +const BoundNetLog& WebSocketTransportConnectSubJob::net_log() const { + return parent_job_->net_log(); +} + +const IPEndPoint& WebSocketTransportConnectSubJob::CurrentAddress() const { + DCHECK_LT(current_address_index_, addresses_.size()); + return addresses_[current_address_index_]; +} + +void WebSocketTransportConnectSubJob::OnIOComplete(int result) { + int rv = DoLoop(result); + if (rv != ERR_IO_PENDING) + parent_job_->OnSubJobComplete(rv, this); // |this| deleted +} + +int WebSocketTransportConnectSubJob::DoLoop(int result) { + DCHECK_NE(next_state_, STATE_NONE); + + int rv = result; + do { + State state = next_state_; + next_state_ = STATE_NONE; + switch (state) { + case STATE_OBTAIN_LOCK: + DCHECK_EQ(OK, rv); + rv = DoEndpointLock(); + break; + case STATE_OBTAIN_LOCK_COMPLETE: + DCHECK_EQ(OK, rv); + rv = DoEndpointLockComplete(); + break; + case STATE_TRANSPORT_CONNECT: + DCHECK_EQ(OK, rv); + rv = DoTransportConnect(); + break; + case STATE_TRANSPORT_CONNECT_COMPLETE: + rv = DoTransportConnectComplete(rv); + break; + default: + NOTREACHED(); + rv = ERR_FAILED; + break; + } + } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE && + next_state_ != STATE_DONE); + + return rv; +} + +int WebSocketTransportConnectSubJob::DoEndpointLock() { + int rv = WebSocketEndpointLockManager::GetInstance()->LockEndpoint( + CurrentAddress(), this); + next_state_ = STATE_OBTAIN_LOCK_COMPLETE; + return rv; +} + +int WebSocketTransportConnectSubJob::DoEndpointLockComplete() { + next_state_ = STATE_TRANSPORT_CONNECT; + return OK; +} + +int WebSocketTransportConnectSubJob::DoTransportConnect() { + // TODO(ricea): Update global g_last_connect_time and report + // ConnectInterval. + next_state_ = STATE_TRANSPORT_CONNECT_COMPLETE; + AddressList one_address(CurrentAddress()); + transport_socket_ = client_socket_factory()->CreateTransportClientSocket( + one_address, net_log().net_log(), net_log().source()); + // This use of base::Unretained() is safe because transport_socket_ is + // destroyed in the destructor. + return transport_socket_->Connect(base::Bind( + &WebSocketTransportConnectSubJob::OnIOComplete, base::Unretained(this))); +} + +int WebSocketTransportConnectSubJob::DoTransportConnectComplete(int result) { + next_state_ = STATE_DONE; + WebSocketEndpointLockManager* endpoint_lock_manager = + WebSocketEndpointLockManager::GetInstance(); + if (result != OK) { + endpoint_lock_manager->UnlockEndpoint(CurrentAddress()); + + if (current_address_index_ + 1 < addresses_.size()) { + // Try falling back to the next address in the list. + next_state_ = STATE_OBTAIN_LOCK; + ++current_address_index_; + result = OK; + } + + return result; + } + + endpoint_lock_manager->RememberSocket(transport_socket_.get(), + CurrentAddress()); + + return result; +} + +} // namespace net diff --git a/chromium/net/socket/websocket_transport_connect_sub_job.h b/chromium/net/socket/websocket_transport_connect_sub_job.h new file mode 100644 index 00000000000..5709a461caf --- /dev/null +++ b/chromium/net/socket/websocket_transport_connect_sub_job.h @@ -0,0 +1,90 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_WEBSOCKET_TRANSPORT_CONNECT_SUB_JOB_H_ +#define NET_SOCKET_WEBSOCKET_TRANSPORT_CONNECT_SUB_JOB_H_ + +#include "base/compiler_specific.h" +#include "base/macros.h" +#include "base/memory/scoped_ptr.h" +#include "net/base/address_list.h" +#include "net/base/load_states.h" +#include "net/socket/websocket_endpoint_lock_manager.h" +#include "net/socket/websocket_transport_client_socket_pool.h" + +namespace net { + +class BoundNetLog; +class ClientSocketFactory; +class IPEndPoint; +class StreamSocket; + +// Attempts to connect to a subset of the addresses required by a +// WebSocketTransportConnectJob, specifically either the IPv4 or IPv6 +// addresses. Each address is tried in turn, and parent_job->OnSubJobComplete() +// is called when the first address succeeds or the last address fails. +class WebSocketTransportConnectSubJob + : public WebSocketEndpointLockManager::Waiter { + public: + typedef WebSocketTransportConnectJob::SubJobType SubJobType; + + WebSocketTransportConnectSubJob(const AddressList& addresses, + WebSocketTransportConnectJob* parent_job, + SubJobType type); + + ~WebSocketTransportConnectSubJob() override; + + // Start connecting. + int Start(); + + bool started() { return next_state_ != STATE_NONE; } + + LoadState GetLoadState() const; + + SubJobType type() const { return type_; } + + scoped_ptr<StreamSocket> PassSocket() { return transport_socket_.Pass(); } + + // Implementation of WebSocketEndpointLockManager::EndpointWaiter. + void GotEndpointLock() override; + + private: + enum State { + STATE_NONE, + STATE_OBTAIN_LOCK, + STATE_OBTAIN_LOCK_COMPLETE, + STATE_TRANSPORT_CONNECT, + STATE_TRANSPORT_CONNECT_COMPLETE, + STATE_DONE, + }; + + ClientSocketFactory* client_socket_factory() const; + + const BoundNetLog& net_log() const; + + const IPEndPoint& CurrentAddress() const; + + void OnIOComplete(int result); + int DoLoop(int result); + int DoEndpointLock(); + int DoEndpointLockComplete(); + int DoTransportConnect(); + int DoTransportConnectComplete(int result); + + WebSocketTransportConnectJob* const parent_job_; + + const AddressList addresses_; + size_t current_address_index_; + + State next_state_; + const SubJobType type_; + + scoped_ptr<StreamSocket> transport_socket_; + + DISALLOW_COPY_AND_ASSIGN(WebSocketTransportConnectSubJob); +}; + +} // namespace net + +#endif // NET_SOCKET_WEBSOCKET_TRANSPORT_CONNECT_SUB_JOB_H_ |