diff options
author | Allan Sandfeld Jensen <allan.jensen@theqtcompany.com> | 2015-06-18 14:10:49 +0200 |
---|---|---|
committer | Oswald Buddenhagen <oswald.buddenhagen@theqtcompany.com> | 2015-06-18 13:53:24 +0000 |
commit | 813fbf95af77a531c57a8c497345ad2c61d475b3 (patch) | |
tree | 821b2c8de8365f21b6c9ba17a236fb3006a1d506 /chromium/net/socket | |
parent | af6588f8d723931a298c995fa97259bb7f7deb55 (diff) | |
download | qtwebengine-chromium-813fbf95af77a531c57a8c497345ad2c61d475b3.tar.gz |
BASELINE: Update chromium to 44.0.2403.47
Change-Id: Ie056fedba95cf5e5c76b30c4b2c80fca4764aa2f
Reviewed-by: Oswald Buddenhagen <oswald.buddenhagen@theqtcompany.com>
Diffstat (limited to 'chromium/net/socket')
93 files changed, 4824 insertions, 4952 deletions
diff --git a/chromium/net/socket/buffered_write_stream_socket.cc b/chromium/net/socket/buffered_write_stream_socket.cc deleted file mode 100644 index e69de29bb2d..00000000000 --- a/chromium/net/socket/buffered_write_stream_socket.cc +++ /dev/null diff --git a/chromium/net/socket/buffered_write_stream_socket.h b/chromium/net/socket/buffered_write_stream_socket.h deleted file mode 100644 index e69de29bb2d..00000000000 --- a/chromium/net/socket/buffered_write_stream_socket.h +++ /dev/null diff --git a/chromium/net/socket/buffered_write_stream_socket_unittest.cc b/chromium/net/socket/buffered_write_stream_socket_unittest.cc deleted file mode 100644 index e69de29bb2d..00000000000 --- a/chromium/net/socket/buffered_write_stream_socket_unittest.cc +++ /dev/null diff --git a/chromium/net/socket/client_socket_factory.cc b/chromium/net/socket/client_socket_factory.cc index 51aea715f4d..cb5d8510486 100644 --- a/chromium/net/socket/client_socket_factory.cc +++ b/chromium/net/socket/client_socket_factory.cc @@ -12,7 +12,7 @@ #include "net/socket/client_socket_handle.h" #if defined(USE_OPENSSL) #include "net/socket/ssl_client_socket_openssl.h" -#elif defined(USE_NSS) || defined(OS_MACOSX) || defined(OS_WIN) +#elif defined(USE_NSS_CERTS) || defined(OS_MACOSX) || defined(OS_WIN) #include "net/socket/ssl_client_socket_nss.h" #endif #include "net/socket/tcp_client_socket.h" @@ -107,7 +107,7 @@ class DefaultClientSocketFactory : public ClientSocketFactory, return scoped_ptr<SSLClientSocket>( new SSLClientSocketOpenSSL(transport_socket.Pass(), host_and_port, ssl_config, context)); -#elif defined(USE_NSS) || defined(OS_MACOSX) || defined(OS_WIN) +#elif defined(USE_NSS_CERTS) || defined(OS_MACOSX) || defined(OS_WIN) return scoped_ptr<SSLClientSocket>( new SSLClientSocketNSS(nss_task_runner.get(), transport_socket.Pass(), diff --git a/chromium/net/socket/client_socket_factory.h b/chromium/net/socket/client_socket_factory.h index 6cb5949f0b3..a1ad503480d 100644 --- a/chromium/net/socket/client_socket_factory.h +++ b/chromium/net/socket/client_socket_factory.h @@ -10,8 +10,8 @@ #include "base/basictypes.h" #include "base/memory/scoped_ptr.h" #include "net/base/net_export.h" -#include "net/base/net_log.h" #include "net/base/rand_callback.h" +#include "net/log/net_log.h" #include "net/udp/datagram_socket.h" namespace net { diff --git a/chromium/net/socket/client_socket_handle.cc b/chromium/net/socket/client_socket_handle.cc index 53bcd77499a..d5c17005ad4 100644 --- a/chromium/net/socket/client_socket_handle.cc +++ b/chromium/net/socket/client_socket_handle.cc @@ -11,7 +11,6 @@ #include "base/logging.h" #include "net/base/net_errors.h" #include "net/socket/client_socket_pool.h" -#include "net/socket/client_socket_pool_histograms.h" namespace net { @@ -22,7 +21,9 @@ ClientSocketHandle::ClientSocketHandle() reuse_type_(ClientSocketHandle::UNUSED), callback_(base::Bind(&ClientSocketHandle::OnIOComplete, base::Unretained(this))), - is_ssl_error_(false) {} + is_ssl_error_(false), + ssl_failure_state_(SSL_FAILURE_NONE) { +} ClientSocketHandle::~ClientSocketHandle() { Reset(); @@ -73,6 +74,7 @@ void ClientSocketHandle::ResetInternal(bool cancel) { void ClientSocketHandle::ResetErrorState() { is_ssl_error_ = false; ssl_error_response_info_ = HttpResponseInfo(); + ssl_failure_state_ = SSL_FAILURE_NONE; pending_http_proxy_connection_.reset(); } @@ -149,8 +151,6 @@ scoped_ptr<StreamSocket> ClientSocketHandle::PassSocket() { void ClientSocketHandle::HandleInitCompletion(int result) { CHECK_NE(ERR_IO_PENDING, result); - ClientSocketPoolHistograms* histograms = pool_->histograms(); - histograms->AddErrorCode(result); if (result != OK) { if (!socket_.get()) ResetInternal(false); // Nothing to cancel since the request failed. @@ -162,22 +162,6 @@ void ClientSocketHandle::HandleInitCompletion(int result) { CHECK_NE(-1, pool_id_) << "Pool should have set |pool_id_| to a valid value."; setup_time_ = base::TimeTicks::Now() - init_time_; - histograms->AddSocketType(reuse_type()); - switch (reuse_type()) { - case ClientSocketHandle::UNUSED: - histograms->AddRequestTime(setup_time()); - break; - case ClientSocketHandle::UNUSED_IDLE: - histograms->AddUnusedIdleTime(idle_time()); - break; - case ClientSocketHandle::REUSED_IDLE: - histograms->AddReusedIdleTime(idle_time()); - break; - default: - NOTREACHED(); - break; - } - // Broadcast that the socket has been acquired. // TODO(eroman): This logging is not complete, in particular set_socket() and // release() socket. It ends up working though, since those methods are being diff --git a/chromium/net/socket/client_socket_handle.h b/chromium/net/socket/client_socket_handle.h index 0899d9a9bb0..a4f7befbcac 100644 --- a/chromium/net/socket/client_socket_handle.h +++ b/chromium/net/socket/client_socket_handle.h @@ -12,15 +12,18 @@ #include "base/memory/scoped_ptr.h" #include "base/time/time.h" #include "net/base/completion_callback.h" +#include "net/base/ip_endpoint.h" #include "net/base/load_states.h" #include "net/base/load_timing_info.h" #include "net/base/net_errors.h" #include "net/base/net_export.h" -#include "net/base/net_log.h" #include "net/base/request_priority.h" #include "net/http/http_response_info.h" +#include "net/log/net_log.h" #include "net/socket/client_socket_pool.h" +#include "net/socket/connection_attempts.h" #include "net/socket/stream_socket.h" +#include "net/ssl/ssl_failure_state.h" namespace net { @@ -133,9 +136,15 @@ class NET_EXPORT ClientSocketHandle { void set_ssl_error_response_info(const HttpResponseInfo& ssl_error_state) { ssl_error_response_info_ = ssl_error_state; } + void set_ssl_failure_state(SSLFailureState ssl_failure_state) { + ssl_failure_state_ = ssl_failure_state; + } void set_pending_http_proxy_connection(ClientSocketHandle* connection) { pending_http_proxy_connection_.reset(connection); } + void set_connection_attempts(const ConnectionAttempts& attempts) { + connection_attempts_ = attempts; + } // Only valid if there is no |socket_|. bool is_ssl_error() const { @@ -148,9 +157,16 @@ class NET_EXPORT ClientSocketHandle { const HttpResponseInfo& ssl_error_response_info() const { return ssl_error_response_info_; } + SSLFailureState ssl_failure_state() const { return ssl_failure_state_; } ClientSocketHandle* release_pending_http_proxy_connection() { return pending_http_proxy_connection_.release(); } + // If the connection failed, returns the connection attempts made. (If it + // succeeded, they will be returned through the socket instead; see + // |StreamSocket::GetConnectionAttempts|.) + const ConnectionAttempts& connection_attempts() { + return connection_attempts_; + } StreamSocket* socket() { return socket_.get(); } @@ -199,7 +215,9 @@ class NET_EXPORT ClientSocketHandle { int pool_id_; // See ClientSocketPool::ReleaseSocket() for an explanation. bool is_ssl_error_; HttpResponseInfo ssl_error_response_info_; + SSLFailureState ssl_failure_state_; scoped_ptr<ClientSocketHandle> pending_http_proxy_connection_; + std::vector<ConnectionAttempt> connection_attempts_; base::TimeTicks init_time_; base::TimeDelta setup_time_; diff --git a/chromium/net/socket/client_socket_pool.h b/chromium/net/socket/client_socket_pool.h index 2a2be36c8cd..5404db6b59f 100644 --- a/chromium/net/socket/client_socket_pool.h +++ b/chromium/net/socket/client_socket_pool.h @@ -26,7 +26,6 @@ class DictionaryValue; namespace net { class ClientSocketHandle; -class ClientSocketPoolHistograms; class StreamSocket; // ClientSocketPools are layered. This defines an interface for lower level @@ -170,10 +169,6 @@ class NET_EXPORT ClientSocketPool : public LowerLayeredPool { // Returns the maximum amount of time to wait before retrying a connect. static const int kMaxConnectRetryIntervalMs = 250; - // The set of histograms specific to this pool. We can't use the standard - // UMA_HISTOGRAM_* macros because they are callsite static. - virtual ClientSocketPoolHistograms* histograms() const = 0; - static base::TimeDelta unused_idle_socket_timeout(); static void set_unused_idle_socket_timeout(base::TimeDelta timeout); diff --git a/chromium/net/socket/client_socket_pool_base.cc b/chromium/net/socket/client_socket_pool_base.cc index 9e1abf482a5..c210e69ae9a 100644 --- a/chromium/net/socket/client_socket_pool_base.cc +++ b/chromium/net/socket/client_socket_pool_base.cc @@ -4,17 +4,18 @@ #include "net/socket/client_socket_pool_base.h" +#include <algorithm> + #include "base/compiler_specific.h" #include "base/format_macros.h" #include "base/logging.h" #include "base/message_loop/message_loop.h" -#include "base/metrics/stats_counters.h" #include "base/stl_util.h" #include "base/strings/string_util.h" #include "base/time/time.h" #include "base/values.h" #include "net/base/net_errors.h" -#include "net/base/net_log.h" +#include "net/log/net_log.h" using base::TimeDelta; @@ -144,7 +145,13 @@ ClientSocketPoolBaseHelper::Request::Request( DCHECK_EQ(priority_, MAXIMUM_PRIORITY); } -ClientSocketPoolBaseHelper::Request::~Request() {} +ClientSocketPoolBaseHelper::Request::~Request() { + liveness_ = DEAD; +} + +void ClientSocketPoolBaseHelper::Request::CrashIfInvalid() const { + CHECK_EQ(liveness_, ALIVE); +} ClientSocketPoolBaseHelper::ClientSocketPoolBaseHelper( HigherLayeredPool* pool, @@ -226,7 +233,7 @@ bool ClientSocketPoolBaseHelper::IsStalled() const { // which does not count.) for (GroupMap::const_iterator it = group_map_.begin(); it != group_map_.end(); ++it) { - if (it->second->IsStalledOnPoolMaxSockets(max_sockets_per_group_)) + if (it->second->CanUseAdditionalSocketSlot(max_sockets_per_group_)) return true; } return false; @@ -278,7 +285,7 @@ int ClientSocketPoolBaseHelper::RequestSocket( // call back in to |this|, which will cause all sorts of fun and exciting // re-entrancy issues if the socket pool is doing something else at the // time. - if (group->IsStalledOnPoolMaxSockets(max_sockets_per_group_)) { + if (group->CanUseAdditionalSocketSlot(max_sockets_per_group_)) { base::MessageLoop::current()->PostTask( FROM_HERE, base::Bind( @@ -482,6 +489,12 @@ bool ClientSocketPoolBaseHelper::AssignIdleSocketToRequest( idle_socket.socket->WasEverUsed() ? ClientSocketHandle::REUSED_IDLE : ClientSocketHandle::UNUSED_IDLE; + + // If this socket took multiple attempts to obtain, don't report those + // every time it's reused, just to the first user. + if (idle_socket.socket->WasEverUsed()) + idle_socket.socket->ClearConnectionAttempts(); + HandOutSocket( scoped_ptr<StreamSocket>(idle_socket.socket), reuse_type, @@ -562,27 +575,22 @@ LoadState ClientSocketPoolBaseHelper::GetLoadState( if (ContainsKey(pending_callback_map_, handle)) return LOAD_STATE_CONNECTING; - if (!ContainsKey(group_map_, group_name)) { - NOTREACHED() << "ClientSocketPool does not contain group: " << group_name - << " for handle: " << handle; + GroupMap::const_iterator group_it = group_map_.find(group_name); + if (group_it == group_map_.end()) { + // TODO(mmenke): This is actually reached in the wild, for unknown reasons. + // Would be great to understand why, and if it's a bug, fix it. If not, + // should have a test for that case. + NOTREACHED(); return LOAD_STATE_IDLE; } - // Can't use operator[] since it is non-const. - const Group& group = *group_map_.find(group_name)->second; - + const Group& group = *group_it->second; if (group.HasConnectJobForHandle(handle)) { - // Just return the state of the farthest along ConnectJob for the first - // group.jobs().size() pending requests. - LoadState max_state = LOAD_STATE_IDLE; - for (ConnectJobSet::const_iterator job_it = group.jobs().begin(); - job_it != group.jobs().end(); ++job_it) { - max_state = std::max(max_state, (*job_it)->GetLoadState()); - } - return max_state; + // Just return the state of the oldest ConnectJob. + return (*group.jobs().begin())->GetLoadState(); } - if (group.IsStalledOnPoolMaxSockets(max_sockets_per_group_)) + if (group.CanUseAdditionalSocketSlot(max_sockets_per_group_)) return LOAD_STATE_WAITING_FOR_STALLED_SOCKET_POOL; return LOAD_STATE_WAITING_FOR_AVAILABLE_SOCKET; } @@ -629,16 +637,15 @@ base::DictionaryValue* ClientSocketPoolBaseHelper::GetInfoAsValue( group_dict->Set("idle_sockets", idle_socket_list); base::ListValue* connect_jobs_list = new base::ListValue(); - std::set<ConnectJob*>::const_iterator job = group->jobs().begin(); + std::list<ConnectJob*>::const_iterator job = group->jobs().begin(); for (job = group->jobs().begin(); job != group->jobs().end(); job++) { int source_id = (*job)->net_log().source().id; connect_jobs_list->Append(new base::FundamentalValue(source_id)); } group_dict->Set("connect_jobs", connect_jobs_list); - group_dict->SetBoolean("is_stalled", - group->IsStalledOnPoolMaxSockets( - max_sockets_per_group_)); + group_dict->SetBoolean("is_stalled", group->CanUseAdditionalSocketSlot( + max_sockets_per_group_)); group_dict->SetBoolean("backup_job_timer_is_running", group->BackupJobTimerIsRunning()); @@ -840,7 +847,7 @@ bool ClientSocketPoolBaseHelper::FindTopStalledGroup( Group* curr_group = i->second; if (!curr_group->has_pending_requests()) continue; - if (curr_group->IsStalledOnPoolMaxSockets(max_sockets_per_group_)) { + if (curr_group->CanUseAdditionalSocketSlot(max_sockets_per_group_)) { if (!group) return true; has_stalled_group = true; @@ -959,6 +966,15 @@ void ClientSocketPoolBaseHelper::ProcessPendingRequest( const std::string& group_name, Group* group) { const Request* next_request = group->GetNextPendingRequest(); DCHECK(next_request); + + // If the group has no idle sockets, and can't make use of an additional slot, + // either because it's at the limit or because it's at the socket per group + // limit, then there's nothing to do. + if (group->idle_sockets().empty() && + !group->CanUseAdditionalSocketSlot(max_sockets_per_group_)) { + return; + } + int rv = RequestSocketInternal(group_name, *next_request); if (rv != ERR_IO_PENDING) { scoped_ptr<const Request> request = group->PopNextPendingRequest(); @@ -1185,19 +1201,16 @@ void ClientSocketPoolBaseHelper::Group::AddJob(scoped_ptr<ConnectJob> job, if (is_preconnect) ++unassigned_job_count_; - jobs_.insert(job.release()); + jobs_.push_back(job.release()); } void ClientSocketPoolBaseHelper::Group::RemoveJob(ConnectJob* job) { scoped_ptr<ConnectJob> owned_job(job); SanityCheck(); - std::set<ConnectJob*>::iterator it = jobs_.find(job); - if (it != jobs_.end()) { - jobs_.erase(it); - } else { - NOTREACHED(); - } + // Check that |job| is in the list. + DCHECK_EQ(*std::find(jobs_.begin(), jobs_.end(), job), job); + jobs_.remove(job); size_t job_count = jobs_.size(); if (job_count < unassigned_job_count_) unassigned_job_count_ = job_count; @@ -1234,7 +1247,6 @@ void ClientSocketPoolBaseHelper::Group::OnBackupJobTimerFired( pool->connect_job_factory_->NewConnectJob( group_name, *pending_requests_.FirstMax().value(), pool); backup_job->net_log().AddEvent(NetLog::TYPE_BACKUP_CONNECT_JOB_CREATED); - SIMPLE_STATS_COUNTER("socket.backup_created"); int rv = backup_job->Connect(); pool->connecting_socket_count_++; ConnectJob* raw_backup_job = backup_job.get(); @@ -1318,11 +1330,14 @@ ClientSocketPoolBaseHelper::Group::FindAndRemovePendingRequest( scoped_ptr<const ClientSocketPoolBaseHelper::Request> ClientSocketPoolBaseHelper::Group::RemovePendingRequest( const RequestQueue::Pointer& pointer) { + // TODO(eroman): Temporary for debugging http://crbug.com/467797. + CHECK(!pointer.is_null()); scoped_ptr<const Request> request(pointer.value()); pending_requests_.Erase(pointer); // If there are no more requests, kill the backup timer. if (pending_requests_.empty()) backup_job_timer_.Stop(); + request->CrashIfInvalid(); return request.Pass(); } diff --git a/chromium/net/socket/client_socket_pool_base.h b/chromium/net/socket/client_socket_pool_base.h index ec4e33cc0ed..7686715597a 100644 --- a/chromium/net/socket/client_socket_pool_base.h +++ b/chromium/net/socket/client_socket_pool_base.h @@ -42,10 +42,10 @@ #include "net/base/load_timing_info.h" #include "net/base/net_errors.h" #include "net/base/net_export.h" -#include "net/base/net_log.h" #include "net/base/network_change_notifier.h" #include "net/base/priority_queue.h" #include "net/base/request_priority.h" +#include "net/log/net_log.h" #include "net/socket/client_socket_handle.h" #include "net/socket/client_socket_pool.h" #include "net/socket/stream_socket.h" @@ -183,7 +183,16 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper Flags flags() const { return flags_; } const BoundNetLog& net_log() const { return net_log_; } + // TODO(eroman): Temporary until crbug.com/467797 is solved. + void CrashIfInvalid() const; + private: + // TODO(eroman): Temporary until crbug.com/467797 is solved. + enum Liveness { + ALIVE = 0xCA11AB13, + DEAD = 0xDEADBEEF, + }; + ClientSocketHandle* const handle_; const CompletionCallback callback_; // TODO(akalin): Support reprioritization. @@ -192,6 +201,9 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper const Flags flags_; const BoundNetLog net_log_; + // TODO(eroman): Temporary until crbug.com/467797 is solved. + Liveness liveness_ = ALIVE; + DISALLOW_COPY_AND_ASSIGN(Request); }; @@ -383,7 +395,9 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper static_cast<int>(idle_sockets_.size()); } - bool IsStalledOnPoolMaxSockets(int max_sockets_per_group) const { + // Returns true if the group could make use of an additional socket slot, if + // it were given one. + bool CanUseAdditionalSocketSlot(int max_sockets_per_group) const { return HasAvailableSocketSlot(max_sockets_per_group) && pending_requests_.size() > jobs_.size(); } @@ -448,7 +462,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper void DecrementActiveSocketCount() { active_socket_count_--; } int unassigned_job_count() const { return unassigned_job_count_; } - const std::set<ConnectJob*>& jobs() const { return jobs_; } + const std::list<ConnectJob*>& jobs() const { return jobs_; } const std::list<IdleSocket>& idle_sockets() const { return idle_sockets_; } int active_socket_count() const { return active_socket_count_; } std::list<IdleSocket>* mutable_idle_sockets() { return &idle_sockets_; } @@ -477,7 +491,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper size_t unassigned_job_count_; std::list<IdleSocket> idle_sockets_; - std::set<ConnectJob*> jobs_; + std::list<ConnectJob*> jobs_; RequestQueue pending_requests_; int active_socket_count_; // number of active sockets used by clients // A timer for when to start the backup job. @@ -700,17 +714,17 @@ class ClientSocketPoolBase { // long to leave an unused idle socket open before closing it. // |used_idle_socket_timeout| specifies how long to leave a previously used // idle socket open before closing it. - ClientSocketPoolBase( - HigherLayeredPool* self, - int max_sockets, - int max_sockets_per_group, - ClientSocketPoolHistograms* histograms, - base::TimeDelta unused_idle_socket_timeout, - base::TimeDelta used_idle_socket_timeout, - ConnectJobFactory* connect_job_factory) - : histograms_(histograms), - helper_(self, max_sockets, max_sockets_per_group, - unused_idle_socket_timeout, used_idle_socket_timeout, + ClientSocketPoolBase(HigherLayeredPool* self, + int max_sockets, + int max_sockets_per_group, + base::TimeDelta unused_idle_socket_timeout, + base::TimeDelta used_idle_socket_timeout, + ConnectJobFactory* connect_job_factory) + : helper_(self, + max_sockets, + max_sockets_per_group, + unused_idle_socket_timeout, + used_idle_socket_timeout, new ConnectJobFactoryAdaptor(connect_job_factory)) {} virtual ~ClientSocketPoolBase() {} @@ -822,10 +836,6 @@ class ClientSocketPoolBase { return helper_.ConnectionTimeout(); } - ClientSocketPoolHistograms* histograms() const { - return histograms_; - } - void EnableConnectBackupJobs() { helper_.EnableConnectBackupJobs(); } bool CloseOneIdleSocket() { return helper_.CloseOneIdleSocket(); } @@ -848,9 +858,9 @@ class ClientSocketPoolBase { explicit ConnectJobFactoryAdaptor(ConnectJobFactory* connect_job_factory) : connect_job_factory_(connect_job_factory) {} - virtual ~ConnectJobFactoryAdaptor() {} + ~ConnectJobFactoryAdaptor() override {} - virtual scoped_ptr<ConnectJob> NewConnectJob( + scoped_ptr<ConnectJob> NewConnectJob( const std::string& group_name, const internal::ClientSocketPoolBaseHelper::Request& request, ConnectJob::Delegate* delegate) const override { @@ -859,15 +869,13 @@ class ClientSocketPoolBase { group_name, casted_request, delegate); } - virtual base::TimeDelta ConnectionTimeout() const { + base::TimeDelta ConnectionTimeout() const override { return connect_job_factory_->ConnectionTimeout(); } const scoped_ptr<ConnectJobFactory> connect_job_factory_; }; - // Histograms for the pool - ClientSocketPoolHistograms* const histograms_; internal::ClientSocketPoolBaseHelper helper_; DISALLOW_COPY_AND_ASSIGN(ClientSocketPoolBase); diff --git a/chromium/net/socket/client_socket_pool_base_unittest.cc b/chromium/net/socket/client_socket_pool_base_unittest.cc index c4a28459a1e..cfa9c6e115f 100644 --- a/chromium/net/socket/client_socket_pool_base_unittest.cc +++ b/chromium/net/socket/client_socket_pool_base_unittest.cc @@ -21,14 +21,15 @@ #include "net/base/load_timing_info.h" #include "net/base/load_timing_info_test_util.h" #include "net/base/net_errors.h" -#include "net/base/net_log.h" -#include "net/base/net_log_unittest.h" #include "net/base/request_priority.h" #include "net/base/test_completion_callback.h" #include "net/http/http_response_headers.h" +#include "net/log/net_log.h" +#include "net/log/test_net_log.h" +#include "net/log/test_net_log_entry.h" +#include "net/log/test_net_log_util.h" #include "net/socket/client_socket_factory.h" #include "net/socket/client_socket_handle.h" -#include "net/socket/client_socket_pool_histograms.h" #include "net/socket/socket_test_util.h" #include "net/socket/ssl_client_socket.h" #include "net/socket/stream_socket.h" @@ -117,9 +118,8 @@ class MockClientSocket : public StreamSocket { explicit MockClientSocket(net::NetLog* net_log) : connected_(false), has_unread_data_(false), - net_log_(BoundNetLog::Make(net_log, net::NetLog::SOURCE_SOCKET)), - was_used_to_convey_data_(false) { - } + net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)), + was_used_to_convey_data_(false) {} // Sets whether the socket has unread data. If true, the next call to Read() // will return 1 byte and IsConnectedAndIdle() will return false. @@ -177,6 +177,11 @@ class MockClientSocket : public StreamSocket { bool WasNpnNegotiated() const override { return false; } NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; } bool GetSSLInfo(SSLInfo* ssl_info) override { return false; } + void GetConnectionAttempts(ConnectionAttempts* out) const override { + out->clear(); + } + void ClearConnectionAttempts() override {} + void AddConnectionAttempts(const ConnectionAttempts& attempts) override {} private: bool connected_; @@ -296,8 +301,8 @@ class TestConnectJob : public ConnectJob { int ConnectInternal() override { AddressList ignored; - client_socket_factory_->CreateTransportClientSocket( - ignored, NULL, net::NetLog::Source()); + client_socket_factory_->CreateTransportClientSocket(ignored, NULL, + NetLog::Source()); SetSocket( scoped_ptr<StreamSocket>(new MockClientSocket(net_log().net_log()))); switch (job_type_) { @@ -482,19 +487,21 @@ class TestClientSocketPool : public ClientSocketPool { TestClientSocketPool( int max_sockets, int max_sockets_per_group, - ClientSocketPoolHistograms* histograms, base::TimeDelta unused_idle_socket_timeout, base::TimeDelta used_idle_socket_timeout, TestClientSocketPoolBase::ConnectJobFactory* connect_job_factory) - : base_(NULL, max_sockets, max_sockets_per_group, histograms, - unused_idle_socket_timeout, used_idle_socket_timeout, + : base_(NULL, + max_sockets, + max_sockets_per_group, + unused_idle_socket_timeout, + used_idle_socket_timeout, connect_job_factory) {} ~TestClientSocketPool() override {} int RequestSocket(const std::string& group_name, const void* params, - net::RequestPriority priority, + RequestPriority priority, ClientSocketHandle* handle, const CompletionCallback& callback, const BoundNetLog& net_log) override { @@ -561,10 +568,6 @@ class TestClientSocketPool : public ClientSocketPool { return base_.ConnectionTimeout(); } - ClientSocketPoolHistograms* histograms() const override { - return base_.histograms(); - } - const TestClientSocketPoolBase* base() const { return &base_; } int NumUnassignedConnectJobsInGroup(const std::string& group_name) const { @@ -658,8 +661,7 @@ class TestConnectJobDelegate : public ConnectJob::Delegate { class ClientSocketPoolBaseTest : public testing::Test { protected: ClientSocketPoolBaseTest() - : params_(new TestSocketParams(false /* ignore_limits */)), - histograms_("ClientSocketPoolTest") { + : params_(new TestSocketParams(false /* ignore_limits */)) { connect_backup_jobs_enabled_ = internal::ClientSocketPoolBaseHelper::connect_backup_jobs_enabled(); internal::ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(true); @@ -691,7 +693,6 @@ class ClientSocketPoolBaseTest : public testing::Test { &net_log_); pool_.reset(new TestClientSocketPool(max_sockets, max_sockets_per_group, - &histograms_, unused_idle_socket_timeout, used_idle_socket_timeout, connect_job_factory_)); @@ -726,13 +727,12 @@ class ClientSocketPoolBaseTest : public testing::Test { ScopedVector<TestSocketRequest>* requests() { return test_base_.requests(); } size_t completion_count() const { return test_base_.completion_count(); } - CapturingNetLog net_log_; + TestNetLog net_log_; bool connect_backup_jobs_enabled_; bool cleanup_timer_enabled_; MockClientSocketFactory client_socket_factory_; TestConnectJobFactory* connect_job_factory_; scoped_refptr<TestSocketParams> params_; - ClientSocketPoolHistograms histograms_; scoped_ptr<TestClientSocketPool> pool_; ClientSocketPoolTest test_base_; }; @@ -760,7 +760,7 @@ TEST_F(ClientSocketPoolBaseTest, ConnectJob_NoTimeoutOnSynchronousCompletion) { TEST_F(ClientSocketPoolBaseTest, ConnectJob_TimedOut) { TestConnectJobDelegate delegate; ClientSocketHandle ignored; - CapturingNetLog log; + TestNetLog log; TestClientSocketPoolBase::Request request( &ignored, CompletionCallback(), DEFAULT_PRIORITY, @@ -779,7 +779,7 @@ TEST_F(ClientSocketPoolBaseTest, ConnectJob_TimedOut) { base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(1)); EXPECT_EQ(ERR_TIMED_OUT, delegate.WaitForResult()); - CapturingNetLog::CapturedEntryList entries; + TestNetLogEntry::List entries; log.GetEntries(&entries); EXPECT_EQ(6u, entries.size()); @@ -804,7 +804,7 @@ TEST_F(ClientSocketPoolBaseTest, BasicSynchronous) { TestCompletionCallback callback; ClientSocketHandle handle; - CapturingBoundNetLog log; + BoundTestNetLog log; TestLoadTimingInfoNotConnected(handle); EXPECT_EQ(OK, @@ -821,7 +821,7 @@ TEST_F(ClientSocketPoolBaseTest, BasicSynchronous) { handle.Reset(); TestLoadTimingInfoNotConnected(handle); - CapturingNetLog::CapturedEntryList entries; + TestNetLogEntry::List entries; log.GetEntries(&entries); EXPECT_EQ(4u, entries.size()); @@ -841,7 +841,7 @@ TEST_F(ClientSocketPoolBaseTest, InitConnectionFailure) { CreatePool(kDefaultMaxSockets, kDefaultMaxSocketsPerGroup); connect_job_factory_->set_job_type(TestConnectJob::kMockFailingJob); - CapturingBoundNetLog log; + BoundTestNetLog log; ClientSocketHandle handle; TestCompletionCallback callback; @@ -862,7 +862,7 @@ TEST_F(ClientSocketPoolBaseTest, InitConnectionFailure) { EXPECT_TRUE(handle.ssl_error_response_info().headers.get() == NULL); TestLoadTimingInfoNotConnected(handle); - CapturingNetLog::CapturedEntryList entries; + TestNetLogEntry::List entries; log.GetEntries(&entries); EXPECT_EQ(3u, entries.size()); @@ -1678,7 +1678,7 @@ TEST_F(ClientSocketPoolBaseTest, BasicAsynchronous) { connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); ClientSocketHandle handle; TestCompletionCallback callback; - CapturingBoundNetLog log; + BoundTestNetLog log; int rv = handle.Init("a", params_, LOWEST, @@ -1697,7 +1697,7 @@ TEST_F(ClientSocketPoolBaseTest, BasicAsynchronous) { handle.Reset(); TestLoadTimingInfoNotConnected(handle); - CapturingNetLog::CapturedEntryList entries; + TestNetLogEntry::List entries; log.GetEntries(&entries); EXPECT_EQ(4u, entries.size()); @@ -1720,7 +1720,7 @@ TEST_F(ClientSocketPoolBaseTest, connect_job_factory_->set_job_type(TestConnectJob::kMockPendingFailingJob); ClientSocketHandle handle; TestCompletionCallback callback; - CapturingBoundNetLog log; + BoundTestNetLog log; // Set the additional error state members to ensure that they get cleared. handle.set_is_ssl_error(true); HttpResponseInfo info; @@ -1737,7 +1737,7 @@ TEST_F(ClientSocketPoolBaseTest, EXPECT_FALSE(handle.is_ssl_error()); EXPECT_TRUE(handle.ssl_error_response_info().headers.get() == NULL); - CapturingNetLog::CapturedEntryList entries; + TestNetLogEntry::List entries; log.GetEntries(&entries); EXPECT_EQ(3u, entries.size()); @@ -1750,6 +1750,22 @@ TEST_F(ClientSocketPoolBaseTest, entries, 2, NetLog::TYPE_SOCKET_POOL)); } +// Check that an async ConnectJob failure does not result in creation of a new +// ConnectJob when there's another pending request also waiting on its own +// ConnectJob. See http://crbug.com/463960. +TEST_F(ClientSocketPoolBaseTest, AsyncFailureWithPendingRequestWithJob) { + CreatePool(2, 2); + connect_job_factory_->set_job_type(TestConnectJob::kMockPendingFailingJob); + + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", DEFAULT_PRIORITY)); + EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", DEFAULT_PRIORITY)); + + EXPECT_EQ(ERR_CONNECTION_FAILED, request(0)->WaitForResult()); + EXPECT_EQ(ERR_CONNECTION_FAILED, request(1)->WaitForResult()); + + EXPECT_EQ(2, client_socket_factory_.allocation_count()); +} + TEST_F(ClientSocketPoolBaseTest, TwoRequestsCancelOne) { // TODO(eroman): Add back the log expectations! Removed them because the // ordering is difficult, and some may fire during destructor. @@ -1768,7 +1784,7 @@ TEST_F(ClientSocketPoolBaseTest, TwoRequestsCancelOne) { callback.callback(), pool_.get(), BoundNetLog())); - CapturingBoundNetLog log2; + BoundTestNetLog log2; EXPECT_EQ(ERR_IO_PENDING, handle2.Init("a", params_, @@ -1945,48 +1961,63 @@ TEST_F(ClientSocketPoolBaseTest, LoadStateOneRequest) { } // Test GetLoadState in the case there are two socket requests. +// Only the first connection in the pool should affect the pool's load status. TEST_F(ClientSocketPoolBaseTest, LoadStateTwoRequests) { CreatePool(2, 2); connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob); ClientSocketHandle handle; TestCompletionCallback callback; - int rv = handle.Init("a", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog()); + int rv = handle.Init("a", params_, DEFAULT_PRIORITY, callback.callback(), + pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); + client_socket_factory_.SetJobLoadState(0, LOAD_STATE_RESOLVING_HOST); ClientSocketHandle handle2; TestCompletionCallback callback2; - rv = handle2.Init("a", - params_, - DEFAULT_PRIORITY, - callback2.callback(), - pool_.get(), - BoundNetLog()); + rv = handle2.Init("a", params_, DEFAULT_PRIORITY, callback2.callback(), + pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); + client_socket_factory_.SetJobLoadState(1, LOAD_STATE_RESOLVING_HOST); - // If the first Job is in an earlier state than the second, the state of - // the second job should be used for both handles. - client_socket_factory_.SetJobLoadState(0, LOAD_STATE_RESOLVING_HOST); + // Check that both handles report the state of the first job. + EXPECT_EQ(LOAD_STATE_RESOLVING_HOST, handle.GetLoadState()); + EXPECT_EQ(LOAD_STATE_RESOLVING_HOST, handle2.GetLoadState()); + + client_socket_factory_.SetJobLoadState(0, LOAD_STATE_CONNECTING); + + // Check that both handles change to LOAD_STATE_CONNECTING. EXPECT_EQ(LOAD_STATE_CONNECTING, handle.GetLoadState()); EXPECT_EQ(LOAD_STATE_CONNECTING, handle2.GetLoadState()); +} - // If the second Job is in an earlier state than the second, the state of - // the first job should be used for both handles. - client_socket_factory_.SetJobLoadState(0, LOAD_STATE_SSL_HANDSHAKE); - // One request is farther - EXPECT_EQ(LOAD_STATE_SSL_HANDSHAKE, handle.GetLoadState()); - EXPECT_EQ(LOAD_STATE_SSL_HANDSHAKE, handle2.GetLoadState()); +// Test that the second connection request does not affect the pool's load +// status. +TEST_F(ClientSocketPoolBaseTest, LoadStateTwoRequestsChangeSecondRequestState) { + CreatePool(2, 2); + connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob); - // Farthest along job connects and the first request gets the socket. The + ClientSocketHandle handle; + TestCompletionCallback callback; + int rv = handle.Init("a", params_, DEFAULT_PRIORITY, callback.callback(), + pool_.get(), BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + + ClientSocketHandle handle2; + TestCompletionCallback callback2; + rv = handle2.Init("a", params_, DEFAULT_PRIORITY, callback2.callback(), + pool_.get(), BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + client_socket_factory_.SetJobLoadState(1, LOAD_STATE_RESOLVING_HOST); + + EXPECT_EQ(LOAD_STATE_CONNECTING, handle.GetLoadState()); + EXPECT_EQ(LOAD_STATE_CONNECTING, handle2.GetLoadState()); + + // First job connects and the first request gets the socket. The // second handle switches to the state of the remaining ConnectJob. client_socket_factory_.SignalJob(0); EXPECT_EQ(OK, callback.WaitForResult()); - EXPECT_EQ(LOAD_STATE_CONNECTING, handle2.GetLoadState()); + EXPECT_EQ(LOAD_STATE_RESOLVING_HOST, handle2.GetLoadState()); } // Test GetLoadState in the case the per-group limit is reached. @@ -2205,7 +2236,7 @@ TEST_F(ClientSocketPoolBaseTest, DisableCleanupTimerReuse) { // Request a new socket. This should reuse the old socket and complete // synchronously. - CapturingBoundNetLog log; + BoundTestNetLog log; rv = handle.Init("a", params_, LOWEST, @@ -2220,7 +2251,7 @@ TEST_F(ClientSocketPoolBaseTest, DisableCleanupTimerReuse) { EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a")); EXPECT_EQ(1, pool_->NumActiveSocketsInGroup("a")); - CapturingNetLog::CapturedEntryList entries; + TestNetLogEntry::List entries; log.GetEntries(&entries); EXPECT_TRUE(LogContainsEntryWithType( entries, 1, NetLog::TYPE_SOCKET_POOL_REUSED_AN_EXISTING_SOCKET)); @@ -2285,7 +2316,7 @@ TEST_F(ClientSocketPoolBaseTest, DisableCleanupTimerNoReuse) { // Request a new socket. This should cleanup the unused and timed out ones. // A new socket will be created rather than reusing the idle one. - CapturingBoundNetLog log; + BoundTestNetLog log; TestCompletionCallback callback3; rv = handle.Init("a", params_, @@ -2302,7 +2333,7 @@ TEST_F(ClientSocketPoolBaseTest, DisableCleanupTimerNoReuse) { EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a")); EXPECT_EQ(1, pool_->NumActiveSocketsInGroup("a")); - CapturingNetLog::CapturedEntryList entries; + TestNetLogEntry::List entries; log.GetEntries(&entries); EXPECT_FALSE(LogContainsEntryWithType( entries, 1, NetLog::TYPE_SOCKET_POOL_REUSED_AN_EXISTING_SOCKET)); @@ -2364,7 +2395,7 @@ TEST_F(ClientSocketPoolBaseTest, CleanupTimedOutIdleSockets) { // used socket. Request it to make sure that it's used. pool_->CleanupTimedOutIdleSockets(); - CapturingBoundNetLog log; + BoundTestNetLog log; rv = handle.Init("a", params_, LOWEST, @@ -2374,7 +2405,7 @@ TEST_F(ClientSocketPoolBaseTest, CleanupTimedOutIdleSockets) { EXPECT_EQ(OK, rv); EXPECT_TRUE(handle.is_reused()); - CapturingNetLog::CapturedEntryList entries; + TestNetLogEntry::List entries; log.GetEntries(&entries); EXPECT_TRUE(LogContainsEntryWithType( entries, 1, NetLog::TYPE_SOCKET_POOL_REUSED_AN_EXISTING_SOCKET)); diff --git a/chromium/net/socket/client_socket_pool_histograms.cc b/chromium/net/socket/client_socket_pool_histograms.cc deleted file mode 100644 index 9af8649c48b..00000000000 --- a/chromium/net/socket/client_socket_pool_histograms.cc +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright (c) 2011 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "net/socket/client_socket_pool_histograms.h" - -#include <string> - -#include "base/metrics/field_trial.h" -#include "base/metrics/histogram.h" -#include "net/base/net_errors.h" -#include "net/socket/client_socket_handle.h" - -namespace net { - -using base::Histogram; -using base::HistogramBase; -using base::LinearHistogram; -using base::CustomHistogram; - -ClientSocketPoolHistograms::ClientSocketPoolHistograms( - const std::string& pool_name) - : is_http_proxy_connection_(false), - is_socks_connection_(false) { - // UMA_HISTOGRAM_ENUMERATION - socket_type_ = LinearHistogram::FactoryGet("Net.SocketType_" + pool_name, 1, - ClientSocketHandle::NUM_TYPES, ClientSocketHandle::NUM_TYPES + 1, - HistogramBase::kUmaTargetedHistogramFlag); - // UMA_HISTOGRAM_CUSTOM_TIMES - request_time_ = Histogram::FactoryTimeGet( - "Net.SocketRequestTime_" + pool_name, - base::TimeDelta::FromMilliseconds(1), - base::TimeDelta::FromMinutes(10), - 100, HistogramBase::kUmaTargetedHistogramFlag); - // UMA_HISTOGRAM_CUSTOM_TIMES - unused_idle_time_ = Histogram::FactoryTimeGet( - "Net.SocketIdleTimeBeforeNextUse_UnusedSocket_" + pool_name, - base::TimeDelta::FromMilliseconds(1), - base::TimeDelta::FromMinutes(6), - 100, HistogramBase::kUmaTargetedHistogramFlag); - // UMA_HISTOGRAM_CUSTOM_TIMES - reused_idle_time_ = Histogram::FactoryTimeGet( - "Net.SocketIdleTimeBeforeNextUse_ReusedSocket_" + pool_name, - base::TimeDelta::FromMilliseconds(1), - base::TimeDelta::FromMinutes(6), - 100, HistogramBase::kUmaTargetedHistogramFlag); - // UMA_HISTOGRAM_CUSTOM_ENUMERATION - error_code_ = CustomHistogram::FactoryGet( - "Net.SocketInitErrorCodes_" + pool_name, - GetAllErrorCodesForUma(), - HistogramBase::kUmaTargetedHistogramFlag); - - if (pool_name == "HTTPProxy") - is_http_proxy_connection_ = true; - else if (pool_name == "SOCK") - is_socks_connection_ = true; -} - -ClientSocketPoolHistograms::~ClientSocketPoolHistograms() { -} - -void ClientSocketPoolHistograms::AddSocketType(int type) const { - socket_type_->Add(type); -} - -void ClientSocketPoolHistograms::AddRequestTime(base::TimeDelta time) const { - request_time_->AddTime(time); -} - -void ClientSocketPoolHistograms::AddUnusedIdleTime(base::TimeDelta time) const { - unused_idle_time_->AddTime(time); -} - -void ClientSocketPoolHistograms::AddReusedIdleTime(base::TimeDelta time) const { - reused_idle_time_->AddTime(time); -} - -void ClientSocketPoolHistograms::AddErrorCode(int error_code) const { - // Error codes are positive (since histograms expect positive sample values). - error_code_->Add(-error_code); -} - -} // namespace net diff --git a/chromium/net/socket/client_socket_pool_histograms.h b/chromium/net/socket/client_socket_pool_histograms.h deleted file mode 100644 index 26a406362bd..00000000000 --- a/chromium/net/socket/client_socket_pool_histograms.h +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) 2011 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef NET_SOCKET_CLIENT_SOCKET_POOL_HISTOGRAMS_H_ -#define NET_SOCKET_CLIENT_SOCKET_POOL_HISTOGRAMS_H_ - -#include <string> - -#include "base/memory/ref_counted.h" -#include "base/time/time.h" -#include "net/base/net_export.h" - -namespace base { -class HistogramBase; -} - -namespace net { - -class NET_EXPORT_PRIVATE ClientSocketPoolHistograms { - public: - ClientSocketPoolHistograms(const std::string& pool_name); - ~ClientSocketPoolHistograms(); - - void AddSocketType(int socket_reuse_type) const; - void AddRequestTime(base::TimeDelta time) const; - void AddUnusedIdleTime(base::TimeDelta time) const; - void AddReusedIdleTime(base::TimeDelta time) const; - void AddErrorCode(int error_code) const; - - private: - base::HistogramBase* socket_type_; - base::HistogramBase* request_time_; - base::HistogramBase* unused_idle_time_; - base::HistogramBase* reused_idle_time_; - base::HistogramBase* error_code_; - - bool is_http_proxy_connection_; - bool is_socks_connection_; - - DISALLOW_COPY_AND_ASSIGN(ClientSocketPoolHistograms); -}; - -} // namespace net - -#endif // NET_SOCKET_CLIENT_SOCKET_POOL_HISTOGRAMS_H_ diff --git a/chromium/net/socket/client_socket_pool_manager.cc b/chromium/net/socket/client_socket_pool_manager.cc index b99612718e4..4af9c9b37a2 100644 --- a/chromium/net/socket/client_socket_pool_manager.cc +++ b/chromium/net/socket/client_socket_pool_manager.cc @@ -29,9 +29,9 @@ int g_max_sockets_per_pool[] = { 256 // WEBSOCKET_SOCKET_POOL }; -COMPILE_ASSERT(arraysize(g_max_sockets_per_pool) == - HttpNetworkSession::NUM_SOCKET_POOL_TYPES, - max_sockets_per_pool_length_mismatch); +static_assert(arraysize(g_max_sockets_per_pool) == + HttpNetworkSession::NUM_SOCKET_POOL_TYPES, + "max sockets per pool length mismatch"); // Default to allow up to 6 connections per host. Experiment and tuning may // try other values (greater than 0). Too large may cause many problems, such @@ -48,9 +48,9 @@ int g_max_sockets_per_group[] = { 30 // WEBSOCKET_SOCKET_POOL }; -COMPILE_ASSERT(arraysize(g_max_sockets_per_group) == - HttpNetworkSession::NUM_SOCKET_POOL_TYPES, - max_sockets_per_group_length_mismatch); +static_assert(arraysize(g_max_sockets_per_group) == + HttpNetworkSession::NUM_SOCKET_POOL_TYPES, + "max sockets per group length mismatch"); // The max number of sockets to allow per proxy server. This applies both to // http and SOCKS proxies. See http://crbug.com/12066 and @@ -60,19 +60,19 @@ int g_max_sockets_per_proxy_server[] = { kDefaultMaxSocketsPerProxyServer // WEBSOCKET_SOCKET_POOL }; -COMPILE_ASSERT(arraysize(g_max_sockets_per_proxy_server) == - HttpNetworkSession::NUM_SOCKET_POOL_TYPES, - max_sockets_per_proxy_server_length_mismatch); +static_assert(arraysize(g_max_sockets_per_proxy_server) == + HttpNetworkSession::NUM_SOCKET_POOL_TYPES, + "max sockets per proxy server length mismatch"); // The meat of the implementation for the InitSocketHandleForHttpRequest, // InitSocketHandleForRawConnect and PreconnectSocketsForHttpRequest methods. -int InitSocketPoolHelper(const GURL& request_url, +int InitSocketPoolHelper(ClientSocketPoolManager::SocketGroupType group_type, + const HostPortPair& endpoint, const HttpRequestHeaders& request_extra_headers, int request_load_flags, RequestPriority request_priority, HttpNetworkSession* session, const ProxyInfo& proxy_info, - bool force_spdy_over_ssl, bool want_spdy_over_npn, const SSLConfig& ssl_config_for_origin, const SSLConfig& ssl_config_for_proxy, @@ -88,12 +88,8 @@ int InitSocketPoolHelper(const GURL& request_url, scoped_refptr<SOCKSSocketParams> socks_params; scoped_ptr<HostPortPair> proxy_host_port; - bool using_ssl = request_url.SchemeIs("https") || - request_url.SchemeIs("wss") || force_spdy_over_ssl; - - HostPortPair origin_host_port = - HostPortPair(request_url.HostNoBrackets(), - request_url.EffectiveIntPort()); + bool using_ssl = group_type == ClientSocketPoolManager::SSL_GROUP; + HostPortPair origin_host_port = endpoint; if (!using_ssl && session->params().testing_fixed_http_port != 0) { origin_host_port.set_port(session->params().testing_fixed_http_port); @@ -114,7 +110,7 @@ int InitSocketPoolHelper(const GURL& request_url, // Determine the host and port to connect to. std::string connection_group = origin_host_port.ToString(); DCHECK(!connection_group.empty()); - if (request_url.SchemeIs("ftp")) { + if (group_type == ClientSocketPoolManager::FTP_GROUP) { // Combining FTP with forced SPDY over SSL would be a "path to madness". // Make sure we never do that. DCHECK(!using_ssl); @@ -131,7 +127,8 @@ int InitSocketPoolHelper(const GURL& request_url, // should be the same for all connections, whereas version_max may // change for version fallbacks. std::string prefix = "ssl/"; - if (ssl_config_for_origin.version_max != kDefaultSSLVersionMax) { + if (ssl_config_for_origin.version_max != + SSLClientSocket::GetMaxSupportedSSLVersion()) { switch (ssl_config_for_origin.version_max) { case SSL_PROTOCOL_VERSION_TLS1_2: prefix = "ssl(max:3.3)/"; @@ -142,14 +139,15 @@ int InitSocketPoolHelper(const GURL& request_url, case SSL_PROTOCOL_VERSION_TLS1: prefix = "ssl(max:3.1)/"; break; - case SSL_PROTOCOL_VERSION_SSL3: - prefix = "sslv3/"; - break; default: CHECK(false); break; } } + // Place sockets with and without deprecated ciphers into separate + // connection groups. + if (ssl_config_for_origin.enable_deprecated_cipher_suites) + prefix += "deprecatedciphers/"; connection_group = prefix + connection_group; } @@ -191,7 +189,6 @@ int InitSocketPoolHelper(const GURL& request_url, ssl_config_for_proxy, PRIVACY_MODE_DISABLED, load_flags, - force_spdy_over_ssl, want_spdy_over_npn); proxy_tcp_params = NULL; } @@ -199,7 +196,6 @@ int InitSocketPoolHelper(const GURL& request_url, http_proxy_params = new HttpProxySocketParams(proxy_tcp_params, ssl_params, - request_url, user_agent, origin_host_port, session->http_auth_cache(), @@ -252,7 +248,6 @@ int InitSocketPoolHelper(const GURL& request_url, ssl_config_for_origin, privacy_mode, load_flags, - force_spdy_over_ssl, want_spdy_over_npn); SSLClientSocketPool* ssl_pool = NULL; if (proxy_info.is_direct()) { @@ -392,13 +387,13 @@ void ClientSocketPoolManager::set_max_sockets_per_proxy_server( } int InitSocketHandleForHttpRequest( - const GURL& request_url, + ClientSocketPoolManager::SocketGroupType group_type, + const HostPortPair& endpoint, const HttpRequestHeaders& request_extra_headers, int request_load_flags, RequestPriority request_priority, HttpNetworkSession* session, const ProxyInfo& proxy_info, - bool force_spdy_over_ssl, bool want_spdy_over_npn, const SSLConfig& ssl_config_for_origin, const SSLConfig& ssl_config_for_proxy, @@ -409,21 +404,21 @@ int InitSocketHandleForHttpRequest( const CompletionCallback& callback) { DCHECK(socket_handle); return InitSocketPoolHelper( - request_url, request_extra_headers, request_load_flags, request_priority, - session, proxy_info, force_spdy_over_ssl, want_spdy_over_npn, - ssl_config_for_origin, ssl_config_for_proxy, false, privacy_mode, net_log, - 0, socket_handle, HttpNetworkSession::NORMAL_SOCKET_POOL, - resolution_callback, callback); + group_type, endpoint, request_extra_headers, request_load_flags, + request_priority, session, proxy_info, want_spdy_over_npn, + ssl_config_for_origin, ssl_config_for_proxy, /*force_tunnel=*/false, + privacy_mode, net_log, 0, socket_handle, + HttpNetworkSession::NORMAL_SOCKET_POOL, resolution_callback, callback); } int InitSocketHandleForWebSocketRequest( - const GURL& request_url, + ClientSocketPoolManager::SocketGroupType group_type, + const HostPortPair& endpoint, const HttpRequestHeaders& request_extra_headers, int request_load_flags, RequestPriority request_priority, HttpNetworkSession* session, const ProxyInfo& proxy_info, - bool force_spdy_over_ssl, bool want_spdy_over_npn, const SSLConfig& ssl_config_for_origin, const SSLConfig& ssl_config_for_proxy, @@ -434,11 +429,11 @@ int InitSocketHandleForWebSocketRequest( const CompletionCallback& callback) { DCHECK(socket_handle); return InitSocketPoolHelper( - request_url, request_extra_headers, request_load_flags, request_priority, - session, proxy_info, force_spdy_over_ssl, want_spdy_over_npn, - ssl_config_for_origin, ssl_config_for_proxy, true, privacy_mode, net_log, - 0, socket_handle, HttpNetworkSession::WEBSOCKET_SOCKET_POOL, - resolution_callback, callback); + group_type, endpoint, request_extra_headers, request_load_flags, + request_priority, session, proxy_info, want_spdy_over_npn, + ssl_config_for_origin, ssl_config_for_proxy, /*force_tunnel=*/true, + privacy_mode, net_log, 0, socket_handle, + HttpNetworkSession::WEBSOCKET_SOCKET_POOL, resolution_callback, callback); } int InitSocketHandleForRawConnect( @@ -452,53 +447,48 @@ int InitSocketHandleForRawConnect( ClientSocketHandle* socket_handle, const CompletionCallback& callback) { DCHECK(socket_handle); - // Synthesize an HttpRequestInfo. - GURL request_url = GURL("http://" + host_port_pair.ToString()); HttpRequestHeaders request_extra_headers; int request_load_flags = 0; RequestPriority request_priority = MEDIUM; - return InitSocketPoolHelper( - request_url, request_extra_headers, request_load_flags, request_priority, - session, proxy_info, false, false, ssl_config_for_origin, - ssl_config_for_proxy, true, privacy_mode, net_log, 0, socket_handle, + ClientSocketPoolManager::NORMAL_GROUP, host_port_pair, + request_extra_headers, request_load_flags, request_priority, session, + proxy_info, false, ssl_config_for_origin, ssl_config_for_proxy, + /*force_tunnel=*/true, privacy_mode, net_log, 0, socket_handle, HttpNetworkSession::NORMAL_SOCKET_POOL, OnHostResolutionCallback(), callback); } -int InitSocketHandleForTlsConnect( - const HostPortPair& host_port_pair, - HttpNetworkSession* session, - const ProxyInfo& proxy_info, - const SSLConfig& ssl_config_for_origin, - const SSLConfig& ssl_config_for_proxy, - PrivacyMode privacy_mode, - const BoundNetLog& net_log, - ClientSocketHandle* socket_handle, - const CompletionCallback& callback) { +int InitSocketHandleForTlsConnect(const HostPortPair& endpoint, + HttpNetworkSession* session, + const ProxyInfo& proxy_info, + const SSLConfig& ssl_config_for_origin, + const SSLConfig& ssl_config_for_proxy, + PrivacyMode privacy_mode, + const BoundNetLog& net_log, + ClientSocketHandle* socket_handle, + const CompletionCallback& callback) { DCHECK(socket_handle); - // Synthesize an HttpRequestInfo. - GURL request_url = GURL("https://" + host_port_pair.ToString()); HttpRequestHeaders request_extra_headers; int request_load_flags = 0; RequestPriority request_priority = MEDIUM; - return InitSocketPoolHelper( - request_url, request_extra_headers, request_load_flags, request_priority, - session, proxy_info, false, false, ssl_config_for_origin, - ssl_config_for_proxy, true, privacy_mode, net_log, 0, socket_handle, + ClientSocketPoolManager::SSL_GROUP, endpoint, request_extra_headers, + request_load_flags, request_priority, session, proxy_info, + /*want_spdy_over_npn=*/false, ssl_config_for_origin, ssl_config_for_proxy, + /*force_tunnel=*/true, privacy_mode, net_log, 0, socket_handle, HttpNetworkSession::NORMAL_SOCKET_POOL, OnHostResolutionCallback(), callback); } int PreconnectSocketsForHttpRequest( - const GURL& request_url, + ClientSocketPoolManager::SocketGroupType group_type, + const HostPortPair& endpoint, const HttpRequestHeaders& request_extra_headers, int request_load_flags, RequestPriority request_priority, HttpNetworkSession* session, const ProxyInfo& proxy_info, - bool force_spdy_over_ssl, bool want_spdy_over_npn, const SSLConfig& ssl_config_for_origin, const SSLConfig& ssl_config_for_proxy, @@ -506,11 +496,12 @@ int PreconnectSocketsForHttpRequest( const BoundNetLog& net_log, int num_preconnect_streams) { return InitSocketPoolHelper( - request_url, request_extra_headers, request_load_flags, request_priority, - session, proxy_info, force_spdy_over_ssl, want_spdy_over_npn, - ssl_config_for_origin, ssl_config_for_proxy, false, privacy_mode, net_log, - num_preconnect_streams, NULL, HttpNetworkSession::NORMAL_SOCKET_POOL, - OnHostResolutionCallback(), CompletionCallback()); + group_type, endpoint, request_extra_headers, request_load_flags, + request_priority, session, proxy_info, want_spdy_over_npn, + ssl_config_for_origin, ssl_config_for_proxy, /*force_tunnel=*/false, + privacy_mode, net_log, num_preconnect_streams, NULL, + HttpNetworkSession::NORMAL_SOCKET_POOL, OnHostResolutionCallback(), + CompletionCallback()); } } // namespace net diff --git a/chromium/net/socket/client_socket_pool_manager.h b/chromium/net/socket/client_socket_pool_manager.h index 12154809870..69d22f5a197 100644 --- a/chromium/net/socket/client_socket_pool_manager.h +++ b/chromium/net/socket/client_socket_pool_manager.h @@ -44,6 +44,12 @@ enum DefaultMaxValues { kDefaultMaxSocketsPerProxyServer = 32 }; class NET_EXPORT_PRIVATE ClientSocketPoolManager { public: + enum SocketGroupType { + SSL_GROUP, // For all TLS sockets. + NORMAL_GROUP, // For normal HTTP sockets. + FTP_GROUP // For FTP sockets (over an HTTP proxy). + }; + ClientSocketPoolManager(); virtual ~ClientSocketPoolManager(); @@ -89,14 +95,16 @@ class NET_EXPORT_PRIVATE ClientSocketPoolManager { // |resolution_callback| will be invoked after the the hostname is // resolved. If |resolution_callback| does not return OK, then the // connection will be aborted with that value. +// If |want_spdy_over_ssl| is true, then after the SSL handshake is complete, +// SPDY must have been negotiated or else it will be considered an error. int InitSocketHandleForHttpRequest( - const GURL& request_url, + ClientSocketPoolManager::SocketGroupType group_type, + const HostPortPair& endpoint, const HttpRequestHeaders& request_extra_headers, int request_load_flags, RequestPriority request_priority, HttpNetworkSession* session, const ProxyInfo& proxy_info, - bool force_spdy_over_ssl, bool want_spdy_over_npn, const SSLConfig& ssl_config_for_origin, const SSLConfig& ssl_config_for_proxy, @@ -116,13 +124,13 @@ int InitSocketHandleForHttpRequest( // connection will be aborted with that value. // This function uses WEBSOCKET_SOCKET_POOL socket pools. int InitSocketHandleForWebSocketRequest( - const GURL& request_url, + ClientSocketPoolManager::SocketGroupType group_type, + const HostPortPair& endpoint, const HttpRequestHeaders& request_extra_headers, int request_load_flags, RequestPriority request_priority, HttpNetworkSession* session, const ProxyInfo& proxy_info, - bool force_spdy_over_ssl, bool want_spdy_over_npn, const SSLConfig& ssl_config_for_origin, const SSLConfig& ssl_config_for_proxy, @@ -165,13 +173,13 @@ NET_EXPORT int InitSocketHandleForTlsConnect( // Similar to InitSocketHandleForHttpRequest except that it initiates the // desired number of preconnect streams from the relevant socket pool. int PreconnectSocketsForHttpRequest( - const GURL& request_url, + ClientSocketPoolManager::SocketGroupType group_type, + const HostPortPair& endpoint, const HttpRequestHeaders& request_extra_headers, int request_load_flags, RequestPriority request_priority, HttpNetworkSession* session, const ProxyInfo& proxy_info, - bool force_spdy_over_ssl, bool want_spdy_over_npn, const SSLConfig& ssl_config_for_origin, const SSLConfig& ssl_config_for_proxy, diff --git a/chromium/net/socket/client_socket_pool_manager_impl.cc b/chromium/net/socket/client_socket_pool_manager_impl.cc index 5ed31fc9a25..c43678d6eb3 100644 --- a/chromium/net/socket/client_socket_pool_manager_impl.cc +++ b/chromium/net/socket/client_socket_pool_manager_impl.cc @@ -42,11 +42,9 @@ ClientSocketPoolManagerImpl::ClientSocketPoolManagerImpl( ChannelIDService* channel_id_service, TransportSecurityState* transport_security_state, CTVerifier* cert_transparency_verifier, + CertPolicyEnforcer* cert_policy_enforcer, const std::string& ssl_session_cache_shard, - ProxyService* proxy_service, SSLConfigService* ssl_config_service, - bool enable_ssl_connect_job_waiting, - ProxyDelegate* proxy_delegate, HttpNetworkSession::SocketPoolType pool_type) : net_log_(net_log), socket_factory_(socket_factory), @@ -55,52 +53,37 @@ ClientSocketPoolManagerImpl::ClientSocketPoolManagerImpl( channel_id_service_(channel_id_service), transport_security_state_(transport_security_state), cert_transparency_verifier_(cert_transparency_verifier), + cert_policy_enforcer_(cert_policy_enforcer), ssl_session_cache_shard_(ssl_session_cache_shard), - proxy_service_(proxy_service), ssl_config_service_(ssl_config_service), - enable_ssl_connect_job_waiting_(enable_ssl_connect_job_waiting), pool_type_(pool_type), - transport_pool_histograms_("TCP"), transport_socket_pool_( pool_type == HttpNetworkSession::WEBSOCKET_SOCKET_POOL ? new WebSocketTransportClientSocketPool( max_sockets_per_pool(pool_type), max_sockets_per_group(pool_type), - &transport_pool_histograms_, host_resolver, socket_factory_, net_log) : new TransportClientSocketPool(max_sockets_per_pool(pool_type), max_sockets_per_group(pool_type), - &transport_pool_histograms_, host_resolver, socket_factory_, net_log)), - ssl_pool_histograms_("SSL2"), ssl_socket_pool_(new SSLClientSocketPool(max_sockets_per_pool(pool_type), max_sockets_per_group(pool_type), - &ssl_pool_histograms_, - host_resolver, cert_verifier, channel_id_service, transport_security_state, cert_transparency_verifier, + cert_policy_enforcer, ssl_session_cache_shard, socket_factory, transport_socket_pool_.get(), NULL /* no socks proxy */, NULL /* no http proxy */, ssl_config_service, - enable_ssl_connect_job_waiting, - net_log)), - transport_for_socks_pool_histograms_("TCPforSOCKS"), - socks_pool_histograms_("SOCK"), - transport_for_http_proxy_pool_histograms_("TCPforHTTPProxy"), - transport_for_https_proxy_pool_histograms_("TCPforHTTPSProxy"), - ssl_for_https_proxy_pool_histograms_("SSLforHTTPSProxy"), - http_proxy_pool_histograms_("HTTPProxy"), - ssl_socket_pool_for_proxies_histograms_("SSLForProxies"), - proxy_delegate_(proxy_delegate) { + net_log)) { CertDatabase::GetInstance()->AddObserver(this); } @@ -233,7 +216,6 @@ SOCKSClientSocketPool* ClientSocketPoolManagerImpl::GetSocketPoolForSOCKSProxy( new TransportClientSocketPool( max_sockets_per_proxy_server(pool_type_), max_sockets_per_group(pool_type_), - &transport_for_socks_pool_histograms_, host_resolver_, socket_factory_, net_log_))); @@ -244,7 +226,6 @@ SOCKSClientSocketPool* ClientSocketPoolManagerImpl::GetSocketPoolForSOCKSProxy( std::make_pair(socks_proxy, new SOCKSClientSocketPool( max_sockets_per_proxy_server(pool_type_), max_sockets_per_group(pool_type_), - &socks_pool_histograms_, host_resolver_, tcp_ret.first->second, net_log_))); @@ -275,7 +256,6 @@ ClientSocketPoolManagerImpl::GetSocketPoolForHTTPProxy( new TransportClientSocketPool( max_sockets_per_proxy_server(pool_type_), max_sockets_per_group(pool_type_), - &transport_for_http_proxy_pool_histograms_, host_resolver_, socket_factory_, net_log_))); @@ -288,7 +268,6 @@ ClientSocketPoolManagerImpl::GetSocketPoolForHTTPProxy( new TransportClientSocketPool( max_sockets_per_proxy_server(pool_type_), max_sockets_per_group(pool_type_), - &transport_for_https_proxy_pool_histograms_, host_resolver_, socket_factory_, net_log_))); @@ -296,23 +275,15 @@ ClientSocketPoolManagerImpl::GetSocketPoolForHTTPProxy( std::pair<SSLSocketPoolMap::iterator, bool> ssl_https_ret = ssl_socket_pools_for_https_proxies_.insert(std::make_pair( - http_proxy, - new SSLClientSocketPool(max_sockets_per_proxy_server(pool_type_), - max_sockets_per_group(pool_type_), - &ssl_for_https_proxy_pool_histograms_, - host_resolver_, - cert_verifier_, - channel_id_service_, - transport_security_state_, - cert_transparency_verifier_, - ssl_session_cache_shard_, - socket_factory_, - tcp_https_ret.first->second /* https proxy */, - NULL /* no socks proxy */, - NULL /* no http proxy */, - ssl_config_service_.get(), - enable_ssl_connect_job_waiting_, - net_log_))); + http_proxy, new SSLClientSocketPool( + max_sockets_per_proxy_server(pool_type_), + max_sockets_per_group(pool_type_), cert_verifier_, + channel_id_service_, transport_security_state_, + cert_transparency_verifier_, cert_policy_enforcer_, + ssl_session_cache_shard_, socket_factory_, + tcp_https_ret.first->second /* https proxy */, + NULL /* no socks proxy */, NULL /* no http proxy */, + ssl_config_service_.get(), net_log_))); DCHECK(tcp_https_ret.second); std::pair<HTTPProxySocketPoolMap::iterator, bool> ret = @@ -322,11 +293,8 @@ ClientSocketPoolManagerImpl::GetSocketPoolForHTTPProxy( new HttpProxyClientSocketPool( max_sockets_per_proxy_server(pool_type_), max_sockets_per_group(pool_type_), - &http_proxy_pool_histograms_, - host_resolver_, tcp_http_ret.first->second, ssl_https_ret.first->second, - proxy_delegate_, net_log_))); return ret.first->second; @@ -341,20 +309,12 @@ SSLClientSocketPool* ClientSocketPoolManagerImpl::GetSocketPoolForSSLWithProxy( SSLClientSocketPool* new_pool = new SSLClientSocketPool( max_sockets_per_proxy_server(pool_type_), - max_sockets_per_group(pool_type_), - &ssl_pool_histograms_, - host_resolver_, - cert_verifier_, - channel_id_service_, - transport_security_state_, - cert_transparency_verifier_, - ssl_session_cache_shard_, - socket_factory_, + max_sockets_per_group(pool_type_), cert_verifier_, channel_id_service_, + transport_security_state_, cert_transparency_verifier_, + cert_policy_enforcer_, ssl_session_cache_shard_, socket_factory_, NULL, /* no tcp pool, we always go through a proxy */ GetSocketPoolForSOCKSProxy(proxy_server), - GetSocketPoolForHTTPProxy(proxy_server), - ssl_config_service_.get(), - enable_ssl_connect_job_waiting_, + GetSocketPoolForHTTPProxy(proxy_server), ssl_config_service_.get(), net_log_); std::pair<SSLSocketPoolMap::iterator, bool> ret = diff --git a/chromium/net/socket/client_socket_pool_manager_impl.h b/chromium/net/socket/client_socket_pool_manager_impl.h index f9f8d3b8f15..a38cae1a476 100644 --- a/chromium/net/socket/client_socket_pool_manager_impl.h +++ b/chromium/net/socket/client_socket_pool_manager_impl.h @@ -15,7 +15,6 @@ #include "base/threading/non_thread_safe.h" #include "net/cert/cert_database.h" #include "net/http/http_network_session.h" -#include "net/socket/client_socket_pool_histograms.h" #include "net/socket/client_socket_pool_manager.h" namespace net { @@ -23,13 +22,10 @@ namespace net { class CertVerifier; class ChannelIDService; class ClientSocketFactory; -class ClientSocketPoolHistograms; class CTVerifier; class HttpProxyClientSocketPool; class HostResolver; class NetLog; -class ProxyDelegate; -class ProxyService; class SOCKSClientSocketPool; class SSLClientSocketPool; class SSLConfigService; @@ -43,8 +39,7 @@ template <typename Key, typename Value> class OwnedPoolMap : public std::map<Key, Value> { public: OwnedPoolMap() { - COMPILE_ASSERT(base::is_pointer<Value>::value, - value_must_be_a_pointer); + static_assert(base::is_pointer<Value>::value, "value must be a pointer"); } ~OwnedPoolMap() { @@ -65,11 +60,9 @@ class ClientSocketPoolManagerImpl : public base::NonThreadSafe, ChannelIDService* channel_id_service, TransportSecurityState* transport_security_state, CTVerifier* cert_transparency_verifier, + CertPolicyEnforcer* cert_policy_enforcer, const std::string& ssl_session_cache_shard, - ProxyService* proxy_service, SSLConfigService* ssl_config_service, - bool enable_ssl_connect_job_waiting, - ProxyDelegate* proxy_delegate, HttpNetworkSession::SocketPoolType pool_type); ~ClientSocketPoolManagerImpl() override; @@ -114,43 +107,23 @@ class ClientSocketPoolManagerImpl : public base::NonThreadSafe, ChannelIDService* const channel_id_service_; TransportSecurityState* const transport_security_state_; CTVerifier* const cert_transparency_verifier_; + CertPolicyEnforcer* const cert_policy_enforcer_; const std::string ssl_session_cache_shard_; - ProxyService* const proxy_service_; const scoped_refptr<SSLConfigService> ssl_config_service_; - bool enable_ssl_connect_job_waiting_; const HttpNetworkSession::SocketPoolType pool_type_; // Note: this ordering is important. - ClientSocketPoolHistograms transport_pool_histograms_; scoped_ptr<TransportClientSocketPool> transport_socket_pool_; - - ClientSocketPoolHistograms ssl_pool_histograms_; scoped_ptr<SSLClientSocketPool> ssl_socket_pool_; - - ClientSocketPoolHistograms transport_for_socks_pool_histograms_; TransportSocketPoolMap transport_socket_pools_for_socks_proxies_; - - ClientSocketPoolHistograms socks_pool_histograms_; SOCKSSocketPoolMap socks_socket_pools_; - - ClientSocketPoolHistograms transport_for_http_proxy_pool_histograms_; TransportSocketPoolMap transport_socket_pools_for_http_proxies_; - - ClientSocketPoolHistograms transport_for_https_proxy_pool_histograms_; TransportSocketPoolMap transport_socket_pools_for_https_proxies_; - - ClientSocketPoolHistograms ssl_for_https_proxy_pool_histograms_; SSLSocketPoolMap ssl_socket_pools_for_https_proxies_; - - ClientSocketPoolHistograms http_proxy_pool_histograms_; HTTPProxySocketPoolMap http_proxy_socket_pools_; - - ClientSocketPoolHistograms ssl_socket_pool_for_proxies_histograms_; SSLSocketPoolMap ssl_socket_pools_for_proxies_; - const ProxyDelegate* proxy_delegate_; - DISALLOW_COPY_AND_ASSIGN(ClientSocketPoolManagerImpl); }; diff --git a/chromium/net/socket/connection_attempts.h b/chromium/net/socket/connection_attempts.h new file mode 100644 index 00000000000..185defa221f --- /dev/null +++ b/chromium/net/socket/connection_attempts.h @@ -0,0 +1,31 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_CONNECTION_ATTEMPTS_H_ +#define NET_SOCKET_CONNECTION_ATTEMPTS_H_ + +#include "net/base/ip_endpoint.h" + +namespace net { + +// A record of an connection attempt made to connect to a host. Includes TCP +// and SSL errors, but not proxy connections. +struct ConnectionAttempt { + ConnectionAttempt(const IPEndPoint endpoint, int result) + : endpoint(endpoint), result(result) {} + + // Address and port the socket layer attempted to connect to. + IPEndPoint endpoint; + + // Net error indicating the result of that attempt. + int result; +}; + +// Multiple connection attempts, as might be tracked in an HttpTransaction or a +// URLRequest. Order is insignificant. +typedef std::vector<ConnectionAttempt> ConnectionAttempts; + +} // namespace net + +#endif // NET_SOCKET_CONNECTION_ATTEMPTS_H_ diff --git a/chromium/net/socket/deterministic_socket_data_unittest.cc b/chromium/net/socket/deterministic_socket_data_unittest.cc index bdeba2bef55..9a95e1e17b0 100644 --- a/chromium/net/socket/deterministic_socket_data_unittest.cc +++ b/chromium/net/socket/deterministic_socket_data_unittest.cc @@ -58,7 +58,6 @@ class DeterministicSocketDataTest : public PlatformTest { HostPortPair endpoint_; scoped_refptr<TransportSocketParams> tcp_params_; - ClientSocketPoolHistograms histograms_; DeterministicMockClientSocketFactory socket_factory_; MockTransportClientSocketPool socket_pool_; ClientSocketHandle connection_; @@ -72,13 +71,13 @@ DeterministicSocketDataTest::DeterministicSocketDataTest() connect_data_(SYNCHRONOUS, OK), endpoint_("www.google.com", 443), tcp_params_(new TransportSocketParams( - endpoint_, - false, - false, - OnHostResolutionCallback(), - TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)), - histograms_(std::string()), - socket_pool_(10, 10, &histograms_, &socket_factory_) {} + endpoint_, + false, + false, + OnHostResolutionCallback(), + TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)), + socket_pool_(10, 10, &socket_factory_) { +} void DeterministicSocketDataTest::TearDown() { // Empty the current queue. diff --git a/chromium/net/socket/next_proto.cc b/chromium/net/socket/next_proto.cc index 1dcfb5d58f3..cfc6578284a 100644 --- a/chromium/net/socket/next_proto.cc +++ b/chromium/net/socket/next_proto.cc @@ -2,20 +2,16 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "next_proto.h" +#include "net/socket/next_proto.h" namespace net { -NextProtoVector NextProtosHttpOnly() { - NextProtoVector next_protos; - next_protos.push_back(kProtoHTTP11); - return next_protos; -} - NextProtoVector NextProtosDefaults() { NextProtoVector next_protos; next_protos.push_back(kProtoHTTP11); next_protos.push_back(kProtoSPDY31); + next_protos.push_back(kProtoSPDY4_14); + next_protos.push_back(kProtoSPDY4); return next_protos; } @@ -27,6 +23,8 @@ NextProtoVector NextProtosWithSpdyAndQuic(bool spdy_enabled, next_protos.push_back(kProtoQUIC1SPDY3); if (spdy_enabled) { next_protos.push_back(kProtoSPDY31); + next_protos.push_back(kProtoSPDY4_14); + next_protos.push_back(kProtoSPDY4); } return next_protos; } @@ -39,13 +37,9 @@ NextProtoVector NextProtosSpdy31() { return next_protos; } -NextProtoVector NextProtosSpdy4Http2() { - NextProtoVector next_protos; - next_protos.push_back(kProtoHTTP11); - next_protos.push_back(kProtoQUIC1SPDY3); - next_protos.push_back(kProtoSPDY31); - next_protos.push_back(kProtoSPDY4); - return next_protos; +bool NextProtoIsSPDY(NextProto next_proto) { + return next_proto >= kProtoSPDYMinimumVersion && + next_proto <= kProtoSPDYMaximumVersion; } } // namespace net diff --git a/chromium/net/socket/next_proto.h b/chromium/net/socket/next_proto.h index 4df6e9b9cd5..72ee0bb9888 100644 --- a/chromium/net/socket/next_proto.h +++ b/chromium/net/socket/next_proto.h @@ -25,10 +25,17 @@ enum NextProto { kProtoDeprecatedSPDY2 = 100, kProtoSPDYMinimumVersion = kProtoDeprecatedSPDY2, + kProtoSPDYHistogramOffset = kProtoDeprecatedSPDY2, kProtoSPDY3 = 101, kProtoSPDY31 = 102, - kProtoSPDY4 = 103, // SPDY4 is HTTP/2. - kProtoSPDYMaximumVersion = kProtoSPDY4, + kProtoSPDY4_14 = 103, // HTTP/2 draft-14, designated implementation draft. + kProtoSPDY4MinimumVersion = kProtoSPDY4_14, + // kProtoSPDY4_15 = 104, // HTTP/2 draft-15 + // kProtoSPDY4_16 = 105, // HTTP/2 draft-16 + // kProtoSPDY4_17 = 106, // HTTP/2 draft-17 + kProtoSPDY4 = 107, // HTTP/2. TODO(bnc): Add RFC number when published. + kProtoSPDY4MaximumVersion = kProtoSPDY4, + kProtoSPDYMaximumVersion = kProtoSPDY4MaximumVersion, kProtoQUIC1SPDY3 = 200, @@ -40,18 +47,18 @@ typedef std::vector<NextProto> NextProtoVector; // Convenience functions to create NextProtoVector. -NET_EXPORT NextProtoVector NextProtosHttpOnly(); - -// Default values, which are subject to change over time. Currently just -// SPDY 3 and 3.1. +// Default values, which are subject to change over time. NET_EXPORT NextProtoVector NextProtosDefaults(); +// Enable SPDY/3.1 and QUIC, but not HTTP/2. +NET_EXPORT NextProtoVector NextProtosSpdy31(); + +// Control SPDY/3.1 and HTTP/2 separately. NET_EXPORT NextProtoVector NextProtosWithSpdyAndQuic(bool spdy_enabled, bool quic_enabled); -// All of these also enable QUIC. -NET_EXPORT NextProtoVector NextProtosSpdy31(); -NET_EXPORT NextProtoVector NextProtosSpdy4Http2(); +// Returns true if |next_proto| is a version of SPDY or HTTP/2. +bool NextProtoIsSPDY(NextProto next_proto); } // namespace net diff --git a/chromium/net/socket/nss_ssl_util.cc b/chromium/net/socket/nss_ssl_util.cc index a238a25d2d4..e98193b5533 100644 --- a/chromium/net/socket/nss_ssl_util.cc +++ b/chromium/net/socket/nss_ssl_util.cc @@ -22,8 +22,8 @@ #include "build/build_config.h" #include "crypto/nss_util.h" #include "net/base/net_errors.h" -#include "net/base/net_log.h" #include "net/base/nss_memio.h" +#include "net/log/net_log.h" #if defined(OS_WIN) #include "base/win/windows_version.h" @@ -81,7 +81,7 @@ size_t CiphersCopy(const uint16* in, uint16* out) { base::Value* NetLogSSLErrorCallback(int net_error, int ssl_lib_error, - NetLog::LogLevel /* log_level */) { + NetLogCaptureMode /* capture_mode */) { base::DictionaryValue* dict = new base::DictionaryValue(); dict->SetInteger("net_error", net_error); if (ssl_lib_error) @@ -108,7 +108,7 @@ class NSSSSLInitSingleton { disableECDSA = true; #endif - // Explicitly enable exactly those ciphers with keys of at least 80 bits + // Explicitly enable exactly those ciphers with keys of at least 80 bits. for (int i = 0; i < num_ciphers; i++) { SSLCipherSuiteInfo info; if (SSL_GetCipherSuiteInfo(ssl_ciphers[i], &info, @@ -130,10 +130,6 @@ class NSSSSLInitSingleton { enabled = false; } - if (ssl_ciphers[i] == TLS_DHE_DSS_WITH_AES_128_CBC_SHA) { - // Enabled to allow servers with only a DSA certificate to function. - enabled = true; - } SSL_CipherPrefSetDefault(ssl_ciphers[i], enabled); } } @@ -389,7 +385,7 @@ base::Value* NetLogSSLFailedNSSFunctionCallback( const char* function, const char* param, int ssl_lib_error, - NetLog::LogLevel /* log_level */) { + NetLogCaptureMode /* capture_mode */) { base::DictionaryValue* dict = new base::DictionaryValue(); dict->SetString("function", function); if (param[0] != '\0') diff --git a/chromium/net/socket/nss_ssl_util.h b/chromium/net/socket/nss_ssl_util.h index 7b046ffd282..5d9ec7e04e8 100644 --- a/chromium/net/socket/nss_ssl_util.h +++ b/chromium/net/socket/nss_ssl_util.h @@ -12,7 +12,7 @@ #include <prio.h> #include "net/base/net_export.h" -#include "net/base/net_log.h" +#include "net/log/net_log.h" namespace net { diff --git a/chromium/net/socket/sequenced_socket_data_unittest.cc b/chromium/net/socket/sequenced_socket_data_unittest.cc new file mode 100644 index 00000000000..e0ed3201d10 --- /dev/null +++ b/chromium/net/socket/sequenced_socket_data_unittest.cc @@ -0,0 +1,1075 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <string> + +#include "base/macros.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "base/run_loop.h" +#include "net/base/io_buffer.h" +#include "net/base/test_completion_callback.h" +#include "net/socket/client_socket_handle.h" +#include "net/socket/socket_test_util.h" +#include "net/socket/transport_client_socket_pool.h" +#include "testing/gtest/include/gtest/gtest-spi.h" +#include "testing/gtest/include/gtest/gtest.h" +#include "testing/platform_test.h" + +//----------------------------------------------------------------------------- + +namespace net { + +namespace { + +const char kMsg1[] = "\0hello!\xff"; +const int kLen1 = arraysize(kMsg1); +const char kMsg2[] = "\0a2345678\0"; +const int kLen2 = arraysize(kMsg2); +const char kMsg3[] = "bye!"; +const int kLen3 = arraysize(kMsg3); +const char kMsg4[] = "supercalifragilisticexpialidocious"; +const int kLen4 = arraysize(kMsg4); + +// Helper class for starting the next operation operation reentrantly after the +// previous operation completed asynchronously. When OnIOComplete is called, +// it will first verify that the previous operation behaved as expected. This is +// specified by either SetExpectedRead or SetExpectedWrite. It will then invoke +// a read or write operation specified by SetInvokeRead or SetInvokeWrite. +class ReentrantHelper { + public: + ReentrantHelper(StreamSocket* socket) + : socket_(socket), + verify_read_(false), + first_read_data_(nullptr), + first_len_(-1), + second_read_(false), + second_write_data_(nullptr), + second_len_(-1) {} + + // Expect that the previous operation will return |first_len| and will fill + // |first_read_data_| with |first_read_data|. + void SetExpectedRead(const char* first_read_data, int first_len) { + verify_read_ = true; + first_read_buf_ = new IOBuffer(first_len); + first_read_data_ = first_read_data; + first_len_ = first_len; + } + + // Expect that the previous operation will return |first_len|. + void SetExpectedWrite(int first_len) { + verify_read_ = false; + first_len_ = first_len; + } + + // After verifying expectations, invoke a read of |read_len| bytes into + // |read_buf|, notifying |callback| when complete. + void SetInvokeRead(scoped_refptr<IOBuffer> read_buf, + int read_len, + int second_rv, + CompletionCallback callback) { + second_read_ = true; + second_read_buf_ = read_buf; + second_rv_ = second_rv; + second_callback_ = callback; + second_len_ = read_len; + } + + // After verifying expectations, invoke a write of |write_len| bytes from + // |write_data|, notifying |callback| when complete. + void SetInvokeWrite(const char* write_data, + int write_len, + int second_rv, + CompletionCallback callback) { + second_read_ = false; + second_rv_ = second_rv; + second_write_data_ = write_data; + second_callback_ = callback; + second_len_ = write_len; + } + + // Returns the OnIOComplete callback for this helper. + CompletionCallback callback() { + return base::Bind(&ReentrantHelper::OnIOComplete, base::Unretained(this)); + } + + // Retuns the buffer where data is expected to have been written, + // when checked by SetExpectRead() + scoped_refptr<IOBuffer> read_buf() { return first_read_buf_; } + + private: + void OnIOComplete(int rv) { + CHECK_NE(-1, first_len_) << "Expectation not set."; + CHECK_NE(-1, second_len_) << "Invocation not set."; + ASSERT_EQ(first_len_, rv); + if (verify_read_) { + ASSERT_EQ(std::string(first_read_data_, first_len_), + std::string(first_read_buf_->data(), rv)); + } + + if (second_read_) { + ASSERT_EQ(second_rv_, socket_->Read(second_read_buf_.get(), second_len_, + second_callback_)); + } else { + scoped_refptr<IOBuffer> write_buf = new IOBuffer(second_len_); + memcpy(write_buf->data(), second_write_data_, second_len_); + ASSERT_EQ(second_rv_, + socket_->Write(write_buf.get(), second_len_, second_callback_)); + } + } + + StreamSocket* socket_; + + bool verify_read_; + scoped_refptr<IOBuffer> first_read_buf_; + const char* first_read_data_; + int first_len_; + + CompletionCallback second_callback_; + bool second_read_; + int second_rv_; + scoped_refptr<IOBuffer> second_read_buf_; + const char* second_write_data_; + int second_len_; + + DISALLOW_COPY_AND_ASSIGN(ReentrantHelper); +}; + +class SequencedSocketDataTest : public testing::Test { + public: + SequencedSocketDataTest(); + ~SequencedSocketDataTest() override; + + // This method is used as the completion callback for an async read + // operation and when invoked, it verifies that the correct data was read, + // then reads from the socket and verifies that that it returns the correct + // value. + void ReentrantReadCallback(const char* data, + int len1, + int len2, + int expected_rv2, + int rv); + + // This method is used at the completion callback for an async operation. + // When executed, verifies that |rv| equals |expected_rv| and then + // attempts an aync read from the socket into |read_buf_| (initialized + // to |read_buf_len|) using |callback|. + void ReentrantAsyncReadCallback(int len1, int len2, int rv); + + // This method is used as the completion callback for an async write + // operation and when invoked, it verifies that the write returned correctly, + // then + // attempts to write to the socket and verifies that that it returns the + // correct value. + void ReentrantWriteCallback(int expected_rv1, + const char* data, + int len, + int expected_rv2, + int rv); + + // This method is used at the completion callback for an async operation. + // When executed, verifies that |rv| equals |expected_rv| and then + // attempts an aync write of |data| with |callback| + void ReentrantAsyncWriteCallback(const char* data, + int len, + CompletionCallback callback, + int expected_rv, + int rv); + + // Callback which adds a failure if it's ever called. + void FailingCompletionCallback(int rv); + + protected: + void Initialize(MockRead* reads, + size_t reads_count, + MockWrite* writes, + size_t writes_count); + + void AssertSyncReadEquals(const char* data, int len); + void AssertAsyncReadEquals(const char* data, int len); + void AssertReadReturns(int len, int rv); + void AssertReadBufferEquals(const char* data, int len); + + void AssertSyncWriteEquals(const char* data, int len); + void AssertAsyncWriteEquals(const char* data, int len); + void AssertWriteReturns(const char* data, int len, int rv); + + // When a given test completes, data_.at_eof() is expected to + // match the value specified here. Most test should consume all + // reads and writes, but some tests verify error handling behavior + // do not consume all data. + void set_expect_eof(bool expect_eof) { expect_eof_ = expect_eof; } + + TestCompletionCallback read_callback_; + scoped_refptr<IOBuffer> read_buf_; + TestCompletionCallback write_callback_; + CompletionCallback failing_callback_; + StreamSocket* sock_; + + private: + MockConnect connect_data_; + scoped_ptr<SequencedSocketData> data_; + + const HostPortPair endpoint_; + scoped_refptr<TransportSocketParams> tcp_params_; + MockClientSocketFactory socket_factory_; + MockTransportClientSocketPool socket_pool_; + ClientSocketHandle connection_; + bool expect_eof_; + + DISALLOW_COPY_AND_ASSIGN(SequencedSocketDataTest); +}; + +SequencedSocketDataTest::SequencedSocketDataTest() + : failing_callback_( + base::Bind(&SequencedSocketDataTest::FailingCompletionCallback, + base::Unretained(this))), + sock_(nullptr), + connect_data_(SYNCHRONOUS, OK), + endpoint_("www.google.com", 443), + tcp_params_(new TransportSocketParams( + endpoint_, + false, + false, + OnHostResolutionCallback(), + TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)), + socket_pool_(10, 10, &socket_factory_), + expect_eof_(true) { +} + +SequencedSocketDataTest::~SequencedSocketDataTest() { + // Make sure no unexpected pending tasks will cause a failure. + base::RunLoop().RunUntilIdle(); + if (expect_eof_) { + EXPECT_EQ(expect_eof_, data_->at_read_eof()); + EXPECT_EQ(expect_eof_, data_->at_write_eof()); + } +} + +void SequencedSocketDataTest::Initialize(MockRead* reads, + size_t reads_count, + MockWrite* writes, + size_t writes_count) { + data_.reset( + new SequencedSocketData(reads, reads_count, writes, writes_count)); + data_->set_connect_data(connect_data_); + socket_factory_.AddSocketDataProvider(data_.get()); + + EXPECT_EQ(OK, + connection_.Init( + endpoint_.ToString(), tcp_params_, LOWEST, CompletionCallback(), + reinterpret_cast<TransportClientSocketPool*>(&socket_pool_), + BoundNetLog())); + sock_ = connection_.socket(); +} + +void SequencedSocketDataTest::AssertSyncReadEquals(const char* data, int len) { + // Issue the read, which will complete immediately. + AssertReadReturns(len, len); + AssertReadBufferEquals(data, len); +} + +void SequencedSocketDataTest::AssertAsyncReadEquals(const char* data, int len) { + // Issue the read, which will be completed asynchronously. + AssertReadReturns(len, ERR_IO_PENDING); + + EXPECT_TRUE(sock_->IsConnected()); + + // Now the read should complete. + ASSERT_EQ(len, read_callback_.WaitForResult()); + AssertReadBufferEquals(data, len); +} + +void SequencedSocketDataTest::AssertReadReturns(int len, int rv) { + read_buf_ = new IOBuffer(len); + if (rv == ERR_IO_PENDING) { + ASSERT_EQ(rv, sock_->Read(read_buf_.get(), len, read_callback_.callback())); + ASSERT_FALSE(read_callback_.have_result()); + } else { + ASSERT_EQ(rv, sock_->Read(read_buf_.get(), len, failing_callback_)); + } +} + +void SequencedSocketDataTest::AssertReadBufferEquals(const char* data, + int len) { + ASSERT_EQ(std::string(data, len), std::string(read_buf_->data(), len)); +} + +void SequencedSocketDataTest::AssertSyncWriteEquals(const char* data, int len) { + // Issue the write, which should be complete immediately. + AssertWriteReturns(data, len, len); + ASSERT_FALSE(write_callback_.have_result()); +} + +void SequencedSocketDataTest::AssertAsyncWriteEquals(const char* data, + int len) { + // Issue the read, which should be completed asynchronously. + AssertWriteReturns(data, len, ERR_IO_PENDING); + + EXPECT_FALSE(read_callback_.have_result()); + EXPECT_TRUE(sock_->IsConnected()); + + ASSERT_EQ(len, write_callback_.WaitForResult()); +} + +void SequencedSocketDataTest::AssertWriteReturns(const char* data, + int len, + int rv) { + scoped_refptr<IOBuffer> buf(new IOBuffer(len)); + memcpy(buf->data(), data, len); + + if (rv == ERR_IO_PENDING) { + ASSERT_EQ(rv, sock_->Write(buf.get(), len, write_callback_.callback())); + ASSERT_FALSE(write_callback_.have_result()); + } else { + ASSERT_EQ(rv, sock_->Write(buf.get(), len, failing_callback_)); + } +} + +void SequencedSocketDataTest::ReentrantReadCallback(const char* data, + int len1, + int len2, + int expected_rv2, + int rv) { + ASSERT_EQ(len1, rv); + AssertReadBufferEquals(data, len1); + + AssertReadReturns(len2, expected_rv2); +} + +void SequencedSocketDataTest::ReentrantAsyncReadCallback(int expected_rv, + int len, + int rv) { + ASSERT_EQ(expected_rv, rv); + + AssertReadReturns(len, ERR_IO_PENDING); +} + +void SequencedSocketDataTest::ReentrantWriteCallback(int expected_rv1, + const char* data, + int len, + int expected_rv2, + int rv) { + ASSERT_EQ(expected_rv1, rv); + + AssertWriteReturns(data, len, expected_rv2); +} + +void SequencedSocketDataTest::ReentrantAsyncWriteCallback( + const char* data, + int len, + CompletionCallback callback, + int expected_rv, + int rv) { + EXPECT_EQ(expected_rv, rv); + scoped_refptr<IOBuffer> write_buf(new IOBuffer(len)); + memcpy(write_buf->data(), data, len); + EXPECT_EQ(ERR_IO_PENDING, sock_->Write(write_buf.get(), len, callback)); +} + +void SequencedSocketDataTest::FailingCompletionCallback(int rv) { + ADD_FAILURE() << "Callback should not have been invoked"; +} + +// ----------- Read + +TEST_F(SequencedSocketDataTest, SingleSyncRead) { + MockRead reads[] = { + MockRead(SYNCHRONOUS, kMsg1, kLen1, 0), + }; + + Initialize(reads, arraysize(reads), nullptr, 0); + AssertSyncReadEquals(kMsg1, kLen1); +} + +TEST_F(SequencedSocketDataTest, MultipleSyncReads) { + MockRead reads[] = { + MockRead(SYNCHRONOUS, kMsg1, kLen1, 0), + MockRead(SYNCHRONOUS, kMsg2, kLen2, 1), + MockRead(SYNCHRONOUS, kMsg3, kLen3, 2), + MockRead(SYNCHRONOUS, kMsg3, kLen3, 3), + MockRead(SYNCHRONOUS, kMsg2, kLen2, 4), + MockRead(SYNCHRONOUS, kMsg3, kLen3, 5), + MockRead(SYNCHRONOUS, kMsg1, kLen1, 6), + }; + + Initialize(reads, arraysize(reads), nullptr, 0); + + AssertSyncReadEquals(kMsg1, kLen1); + AssertSyncReadEquals(kMsg2, kLen2); + AssertSyncReadEquals(kMsg3, kLen3); + AssertSyncReadEquals(kMsg3, kLen3); + AssertSyncReadEquals(kMsg2, kLen2); + AssertSyncReadEquals(kMsg3, kLen3); + AssertSyncReadEquals(kMsg1, kLen1); +} + +TEST_F(SequencedSocketDataTest, SingleAsyncRead) { + MockRead reads[] = { + MockRead(ASYNC, kMsg1, kLen1, 0), + }; + + Initialize(reads, arraysize(reads), nullptr, 0); + + AssertAsyncReadEquals(kMsg1, kLen1); +} + +TEST_F(SequencedSocketDataTest, MultipleAsyncReads) { + MockRead reads[] = { + MockRead(ASYNC, kMsg1, kLen1, 0), + MockRead(ASYNC, kMsg2, kLen2, 1), + MockRead(ASYNC, kMsg3, kLen3, 2), + MockRead(ASYNC, kMsg3, kLen3, 3), + MockRead(ASYNC, kMsg2, kLen2, 4), + MockRead(ASYNC, kMsg3, kLen3, 5), + MockRead(ASYNC, kMsg1, kLen1, 6), + }; + + Initialize(reads, arraysize(reads), nullptr, 0); + + AssertAsyncReadEquals(kMsg1, kLen1); + AssertAsyncReadEquals(kMsg2, kLen2); + AssertAsyncReadEquals(kMsg3, kLen3); + AssertAsyncReadEquals(kMsg3, kLen3); + AssertAsyncReadEquals(kMsg2, kLen2); + AssertAsyncReadEquals(kMsg3, kLen3); + AssertAsyncReadEquals(kMsg1, kLen1); +} + +TEST_F(SequencedSocketDataTest, MixedReads) { + MockRead reads[] = { + MockRead(SYNCHRONOUS, kMsg1, kLen1, 0), + MockRead(ASYNC, kMsg2, kLen2, 1), + MockRead(SYNCHRONOUS, kMsg3, kLen3, 2), + MockRead(ASYNC, kMsg3, kLen3, 3), + MockRead(SYNCHRONOUS, kMsg2, kLen2, 4), + MockRead(ASYNC, kMsg3, kLen3, 5), + MockRead(SYNCHRONOUS, kMsg1, kLen1, 6), + }; + + Initialize(reads, arraysize(reads), nullptr, 0); + + AssertSyncReadEquals(kMsg1, kLen1); + AssertAsyncReadEquals(kMsg2, kLen2); + AssertSyncReadEquals(kMsg3, kLen3); + AssertAsyncReadEquals(kMsg3, kLen3); + AssertSyncReadEquals(kMsg2, kLen2); + AssertAsyncReadEquals(kMsg3, kLen3); + AssertSyncReadEquals(kMsg1, kLen1); +} + +TEST_F(SequencedSocketDataTest, SyncReadFromCompletionCallback) { + MockRead reads[] = { + MockRead(ASYNC, kMsg1, kLen1, 0), MockRead(SYNCHRONOUS, kMsg2, kLen2, 1), + }; + + Initialize(reads, arraysize(reads), nullptr, 0); + + read_buf_ = new IOBuffer(kLen1); + ASSERT_EQ( + ERR_IO_PENDING, + sock_->Read( + read_buf_.get(), kLen1, + base::Bind(&SequencedSocketDataTest::ReentrantReadCallback, + base::Unretained(this), kMsg1, kLen1, kLen2, kLen2))); + + base::MessageLoop::current()->RunUntilIdle(); + AssertReadBufferEquals(kMsg2, kLen2); +} + +TEST_F(SequencedSocketDataTest, ManyReentrantReads) { + MockRead reads[] = { + MockRead(ASYNC, kMsg1, kLen1, 0), + MockRead(ASYNC, kMsg2, kLen2, 1), + MockRead(ASYNC, kMsg3, kLen3, 2), + MockRead(ASYNC, kMsg4, kLen4, 3), + }; + + Initialize(reads, arraysize(reads), nullptr, 0); + + read_buf_ = new IOBuffer(kLen4); + + ReentrantHelper helper3(sock_); + helper3.SetExpectedRead(kMsg3, kLen3); + helper3.SetInvokeRead(read_buf_, kLen4, ERR_IO_PENDING, + read_callback_.callback()); + + ReentrantHelper helper2(sock_); + helper2.SetExpectedRead(kMsg2, kLen2); + helper2.SetInvokeRead(helper3.read_buf(), kLen3, ERR_IO_PENDING, + helper3.callback()); + + ReentrantHelper helper(sock_); + helper.SetExpectedRead(kMsg1, kLen1); + helper.SetInvokeRead(helper2.read_buf(), kLen2, ERR_IO_PENDING, + helper2.callback()); + + sock_->Read(helper.read_buf().get(), kLen1, helper.callback()); + + ASSERT_EQ(kLen4, read_callback_.WaitForResult()); + AssertReadBufferEquals(kMsg4, kLen4); +} + +TEST_F(SequencedSocketDataTest, AsyncReadFromCompletionCallback) { + MockRead reads[] = { + MockRead(ASYNC, kMsg1, kLen1, 0), MockRead(ASYNC, kMsg2, kLen2, 1), + }; + + Initialize(reads, arraysize(reads), nullptr, 0); + + read_buf_ = new IOBuffer(kLen1); + ASSERT_EQ( + ERR_IO_PENDING, + sock_->Read(read_buf_.get(), kLen1, + base::Bind(&SequencedSocketDataTest::ReentrantReadCallback, + base::Unretained(this), kMsg1, kLen1, kLen2, + ERR_IO_PENDING))); + + ASSERT_FALSE(read_callback_.have_result()); + ASSERT_EQ(kLen2, read_callback_.WaitForResult()); + AssertReadBufferEquals(kMsg2, kLen2); +} + +TEST_F(SequencedSocketDataTest, SingleSyncReadTooEarly) { + MockRead reads[] = { + MockRead(SYNCHRONOUS, kMsg1, kLen1, 1), + }; + + MockWrite writes[] = {MockWrite(SYNCHRONOUS, 0, 0)}; + + Initialize(reads, arraysize(reads), writes, arraysize(writes)); + + EXPECT_NONFATAL_FAILURE(AssertReadReturns(kLen1, ERR_UNEXPECTED), + "Unable to perform synchronous IO while stopped"); + set_expect_eof(false); +} + +TEST_F(SequencedSocketDataTest, SingleSyncReadSmallBuffer) { + MockRead reads[] = { + MockRead(SYNCHRONOUS, kMsg1, kLen1, 0), + }; + + Initialize(reads, arraysize(reads), nullptr, 0); + + // Read the first chunk. + AssertReadReturns(kLen1 - 1, kLen1 - 1); + AssertReadBufferEquals(kMsg1, kLen1 - 1); + // Then read the second chunk. + AssertReadReturns(1, 1); + AssertReadBufferEquals(kMsg1 + kLen1 - 1, 1); +} + +TEST_F(SequencedSocketDataTest, SingleSyncReadLargeBuffer) { + MockRead reads[] = { + MockRead(SYNCHRONOUS, kMsg1, kLen1, 0), + }; + + Initialize(reads, arraysize(reads), nullptr, 0); + scoped_refptr<IOBuffer> read_buf(new IOBuffer(2 * kLen1)); + ASSERT_EQ(kLen1, sock_->Read(read_buf.get(), 2 * kLen1, failing_callback_)); + ASSERT_EQ(std::string(kMsg1, kLen1), std::string(read_buf->data(), kLen1)); +} + +TEST_F(SequencedSocketDataTest, SingleAsyncReadLargeBuffer) { + MockRead reads[] = { + MockRead(ASYNC, kMsg1, kLen1, 0), + }; + + Initialize(reads, arraysize(reads), nullptr, 0); + + scoped_refptr<IOBuffer> read_buf(new IOBuffer(2 * kLen1)); + ASSERT_EQ(ERR_IO_PENDING, + sock_->Read(read_buf.get(), 2 * kLen1, read_callback_.callback())); + ASSERT_EQ(kLen1, read_callback_.WaitForResult()); + ASSERT_EQ(std::string(kMsg1, kLen1), std::string(read_buf->data(), kLen1)); +} + +TEST_F(SequencedSocketDataTest, HangingRead) { + MockRead reads[] = { + MockRead(ASYNC, ERR_IO_PENDING, 0), + }; + + Initialize(reads, arraysize(reads), nullptr, 0); + + scoped_refptr<IOBuffer> read_buf(new IOBuffer(1)); + ASSERT_EQ(ERR_IO_PENDING, + sock_->Read(read_buf.get(), 1, read_callback_.callback())); + ASSERT_FALSE(read_callback_.have_result()); + + // Even though the read is scheduled to complete at sequence number 0, + // verify that the read callback in never called. + base::MessageLoop::current()->RunUntilIdle(); + ASSERT_FALSE(read_callback_.have_result()); +} + +// ----------- Write + +TEST_F(SequencedSocketDataTest, SingleSyncWriteTooEarly) { + MockWrite writes[] = { + MockWrite(SYNCHRONOUS, kMsg1, kLen1, 1), + }; + + MockRead reads[] = {MockRead(SYNCHRONOUS, 0, 0)}; + + Initialize(reads, arraysize(reads), writes, arraysize(writes)); + + EXPECT_NONFATAL_FAILURE(AssertWriteReturns(kMsg1, kLen1, ERR_UNEXPECTED), + "Unable to perform synchronous IO while stopped"); + + set_expect_eof(false); +} + +TEST_F(SequencedSocketDataTest, DISABLED_SingleSyncWriteTooSmall) { + MockWrite writes[] = { + MockWrite(SYNCHRONOUS, kMsg1, kLen1, 0), + }; + + Initialize(nullptr, 0, writes, arraysize(writes)); + + // Attempt to write all of the message, but only some will be written. + EXPECT_NONFATAL_FAILURE(AssertSyncWriteEquals(kMsg1, kLen1 - 1), ""); +} + +TEST_F(SequencedSocketDataTest, SingleSyncPartialWrite) { + MockWrite writes[] = { + MockWrite(SYNCHRONOUS, kMsg1, kLen1 - 1, 0), + MockWrite(SYNCHRONOUS, kMsg1 + kLen1 - 1, 1, 1), + }; + + Initialize(nullptr, 0, writes, arraysize(writes)); + + // Attempt to write all of the message, but only some will be written. + AssertSyncWriteEquals(kMsg1, kLen1 - 1); + // Write the rest of the message. + AssertSyncWriteEquals(kMsg1 + kLen1 - 1, 1); +} + +TEST_F(SequencedSocketDataTest, SingleSyncWrite) { + MockWrite writes[] = { + MockWrite(SYNCHRONOUS, kMsg1, kLen1, 0), + }; + + Initialize(nullptr, 0, writes, arraysize(writes)); + + AssertSyncWriteEquals(kMsg1, kLen1); +} + +TEST_F(SequencedSocketDataTest, MultipleSyncWrites) { + MockWrite writes[] = { + MockWrite(SYNCHRONOUS, kMsg1, kLen1, 0), + MockWrite(SYNCHRONOUS, kMsg2, kLen2, 1), + MockWrite(SYNCHRONOUS, kMsg3, kLen3, 2), + MockWrite(SYNCHRONOUS, kMsg3, kLen3, 3), + MockWrite(SYNCHRONOUS, kMsg2, kLen2, 4), + MockWrite(SYNCHRONOUS, kMsg3, kLen3, 5), + MockWrite(SYNCHRONOUS, kMsg1, kLen1, 6), + }; + + Initialize(nullptr, 0, writes, arraysize(writes)); + + AssertSyncWriteEquals(kMsg1, kLen1); + AssertSyncWriteEquals(kMsg2, kLen2); + AssertSyncWriteEquals(kMsg3, kLen3); + AssertSyncWriteEquals(kMsg3, kLen3); + AssertSyncWriteEquals(kMsg2, kLen2); + AssertSyncWriteEquals(kMsg3, kLen3); + AssertSyncWriteEquals(kMsg1, kLen1); +} + +TEST_F(SequencedSocketDataTest, SingleAsyncWrite) { + MockWrite writes[] = { + MockWrite(ASYNC, kMsg1, kLen1, 0), + }; + + Initialize(nullptr, 0, writes, arraysize(writes)); + + AssertAsyncWriteEquals(kMsg1, kLen1); +} + +TEST_F(SequencedSocketDataTest, MultipleAsyncWrites) { + MockWrite writes[] = { + MockWrite(ASYNC, kMsg1, kLen1, 0), + MockWrite(ASYNC, kMsg2, kLen2, 1), + MockWrite(ASYNC, kMsg3, kLen3, 2), + MockWrite(ASYNC, kMsg3, kLen3, 3), + MockWrite(ASYNC, kMsg2, kLen2, 4), + MockWrite(ASYNC, kMsg3, kLen3, 5), + MockWrite(ASYNC, kMsg1, kLen1, 6), + }; + + Initialize(nullptr, 0, writes, arraysize(writes)); + + AssertAsyncWriteEquals(kMsg1, kLen1); + AssertAsyncWriteEquals(kMsg2, kLen2); + AssertAsyncWriteEquals(kMsg3, kLen3); + AssertAsyncWriteEquals(kMsg3, kLen3); + AssertAsyncWriteEquals(kMsg2, kLen2); + AssertAsyncWriteEquals(kMsg3, kLen3); + AssertAsyncWriteEquals(kMsg1, kLen1); +} + +TEST_F(SequencedSocketDataTest, MixedWrites) { + MockWrite writes[] = { + MockWrite(SYNCHRONOUS, kMsg1, kLen1, 0), + MockWrite(ASYNC, kMsg2, kLen2, 1), + MockWrite(SYNCHRONOUS, kMsg3, kLen3, 2), + MockWrite(ASYNC, kMsg3, kLen3, 3), + MockWrite(SYNCHRONOUS, kMsg2, kLen2, 4), + MockWrite(ASYNC, kMsg3, kLen3, 5), + MockWrite(SYNCHRONOUS, kMsg1, kLen1, 6), + }; + + Initialize(nullptr, 0, writes, arraysize(writes)); + + AssertSyncWriteEquals(kMsg1, kLen1); + AssertAsyncWriteEquals(kMsg2, kLen2); + AssertSyncWriteEquals(kMsg3, kLen3); + AssertAsyncWriteEquals(kMsg3, kLen3); + AssertSyncWriteEquals(kMsg2, kLen2); + AssertAsyncWriteEquals(kMsg3, kLen3); + AssertSyncWriteEquals(kMsg1, kLen1); +} + +TEST_F(SequencedSocketDataTest, SyncWriteFromCompletionCallback) { + MockWrite writes[] = { + MockWrite(ASYNC, kMsg1, kLen1, 0), + MockWrite(SYNCHRONOUS, kMsg2, kLen2, 1), + }; + + Initialize(nullptr, 0, writes, arraysize(writes)); + + scoped_refptr<IOBuffer> write_buf(new IOBuffer(kLen1)); + memcpy(write_buf->data(), kMsg1, kLen1); + ASSERT_EQ( + ERR_IO_PENDING, + sock_->Write( + write_buf.get(), kLen1, + base::Bind(&SequencedSocketDataTest::ReentrantWriteCallback, + base::Unretained(this), kLen1, kMsg2, kLen2, kLen2))); + + base::MessageLoop::current()->RunUntilIdle(); +} + +TEST_F(SequencedSocketDataTest, AsyncWriteFromCompletionCallback) { + MockWrite writes[] = { + MockWrite(ASYNC, kMsg1, kLen1, 0), MockWrite(ASYNC, kMsg2, kLen2, 1), + }; + + Initialize(nullptr, 0, writes, arraysize(writes)); + + scoped_refptr<IOBuffer> write_buf(new IOBuffer(kLen1)); + memcpy(write_buf->data(), kMsg1, kLen1); + ASSERT_EQ( + ERR_IO_PENDING, + sock_->Write(write_buf.get(), kLen1, + base::Bind(&SequencedSocketDataTest::ReentrantWriteCallback, + base::Unretained(this), kLen1, kMsg2, kLen2, + ERR_IO_PENDING))); + + ASSERT_FALSE(write_callback_.have_result()); + ASSERT_EQ(kLen2, write_callback_.WaitForResult()); +} + +TEST_F(SequencedSocketDataTest, ManyReentrantWrites) { + MockWrite writes[] = { + MockWrite(ASYNC, kMsg1, kLen1, 0), + MockWrite(ASYNC, kMsg2, kLen2, 1), + MockWrite(ASYNC, kMsg3, kLen3, 2), + MockWrite(ASYNC, kMsg4, kLen4, 3), + }; + + Initialize(nullptr, 0, writes, arraysize(writes)); + + ReentrantHelper helper3(sock_); + helper3.SetExpectedWrite(kLen3); + helper3.SetInvokeWrite(kMsg4, kLen4, ERR_IO_PENDING, + write_callback_.callback()); + + ReentrantHelper helper2(sock_); + helper2.SetExpectedWrite(kLen2); + helper2.SetInvokeWrite(kMsg3, kLen3, ERR_IO_PENDING, helper3.callback()); + + ReentrantHelper helper(sock_); + helper.SetExpectedWrite(kLen1); + helper.SetInvokeWrite(kMsg2, kLen2, ERR_IO_PENDING, helper2.callback()); + + scoped_refptr<IOBuffer> write_buf(new IOBuffer(kLen1)); + memcpy(write_buf->data(), kMsg1, kLen1); + sock_->Write(write_buf.get(), kLen1, helper.callback()); + + ASSERT_EQ(kLen4, write_callback_.WaitForResult()); +} + +// ----------- Mixed Reads and Writes + +TEST_F(SequencedSocketDataTest, MixedSyncOperations) { + MockRead reads[] = { + MockRead(SYNCHRONOUS, kMsg1, kLen1, 0), + MockRead(SYNCHRONOUS, kMsg2, kLen2, 3), + }; + + MockWrite writes[] = { + MockWrite(SYNCHRONOUS, kMsg2, kLen2, 1), + MockWrite(SYNCHRONOUS, kMsg3, kLen3, 2), + }; + + Initialize(reads, arraysize(reads), writes, arraysize(writes)); + + AssertSyncReadEquals(kMsg1, kLen1); + AssertSyncWriteEquals(kMsg2, kLen2); + AssertSyncWriteEquals(kMsg3, kLen3); + AssertSyncReadEquals(kMsg2, kLen2); +} + +TEST_F(SequencedSocketDataTest, MixedAsyncOperations) { + MockRead reads[] = { + MockRead(ASYNC, kMsg1, kLen1, 0), MockRead(ASYNC, kMsg2, kLen2, 3), + }; + + MockWrite writes[] = { + MockWrite(ASYNC, kMsg2, kLen2, 1), MockWrite(ASYNC, kMsg3, kLen3, 2), + }; + + Initialize(reads, arraysize(reads), writes, arraysize(writes)); + + AssertAsyncReadEquals(kMsg1, kLen1); + AssertAsyncWriteEquals(kMsg2, kLen2); + AssertAsyncWriteEquals(kMsg3, kLen3); + AssertAsyncReadEquals(kMsg2, kLen2); +} + +TEST_F(SequencedSocketDataTest, InterleavedAsyncOperations) { + // Order of completion is read, write, write, read. + MockRead reads[] = { + MockRead(ASYNC, kMsg1, kLen1, 0), MockRead(ASYNC, kMsg2, kLen2, 3), + }; + + MockWrite writes[] = { + MockWrite(ASYNC, kMsg2, kLen2, 1), MockWrite(ASYNC, kMsg3, kLen3, 2), + }; + + Initialize(reads, arraysize(reads), writes, arraysize(writes)); + + // Issue the write, which will block until the read completes. + AssertWriteReturns(kMsg2, kLen2, ERR_IO_PENDING); + + // Issue the read which will return first. + AssertReadReturns(kLen1, ERR_IO_PENDING); + + ASSERT_EQ(kLen1, read_callback_.WaitForResult()); + AssertReadBufferEquals(kMsg1, kLen1); + + ASSERT_TRUE(write_callback_.have_result()); + ASSERT_EQ(kLen2, write_callback_.WaitForResult()); + + // Issue the read, which will block until the write completes. + AssertReadReturns(kLen2, ERR_IO_PENDING); + + // Issue the writes which will return first. + AssertWriteReturns(kMsg3, kLen3, ERR_IO_PENDING); + ASSERT_EQ(kLen3, write_callback_.WaitForResult()); + + ASSERT_EQ(kLen2, read_callback_.WaitForResult()); + AssertReadBufferEquals(kMsg2, kLen2); +} + +TEST_F(SequencedSocketDataTest, InterleavedMixedOperations) { + // Order of completion is read, write, write, read. + MockRead reads[] = { + MockRead(SYNCHRONOUS, kMsg1, kLen1, 0), + MockRead(ASYNC, kMsg2, kLen2, 3), + MockRead(ASYNC, kMsg3, kLen3, 5), + }; + + MockWrite writes[] = { + MockWrite(ASYNC, kMsg2, kLen2, 1), + MockWrite(SYNCHRONOUS, kMsg3, kLen3, 2), + MockWrite(SYNCHRONOUS, kMsg1, kLen1, 4), + }; + + Initialize(reads, arraysize(reads), writes, arraysize(writes)); + + // Issue the write, which will block until the read completes. + AssertWriteReturns(kMsg2, kLen2, ERR_IO_PENDING); + + // Issue the writes which will complete immediately. + AssertSyncReadEquals(kMsg1, kLen1); + + ASSERT_FALSE(write_callback_.have_result()); + ASSERT_EQ(kLen2, write_callback_.WaitForResult()); + + // Issue the read, which will block until the write completes. + AssertReadReturns(kLen2, ERR_IO_PENDING); + + // Issue the writes which will complete immediately. + AssertSyncWriteEquals(kMsg3, kLen3); + + ASSERT_FALSE(read_callback_.have_result()); + ASSERT_EQ(kLen2, read_callback_.WaitForResult()); + AssertReadBufferEquals(kMsg2, kLen2); + + // Issue the read, which will block until the write completes. + AssertReadReturns(kLen2, ERR_IO_PENDING); + + // Issue the writes which will complete immediately. + AssertSyncWriteEquals(kMsg1, kLen1); + + ASSERT_FALSE(read_callback_.have_result()); + ASSERT_EQ(kLen3, read_callback_.WaitForResult()); + AssertReadBufferEquals(kMsg3, kLen3); +} + +TEST_F(SequencedSocketDataTest, AsyncReadFromWriteCompletionCallback) { + MockWrite writes[] = { + MockWrite(ASYNC, kMsg1, kLen1, 0), + }; + + MockRead reads[] = { + MockRead(ASYNC, kMsg2, kLen2, 1), + }; + + Initialize(reads, arraysize(reads), writes, arraysize(writes)); + + scoped_refptr<IOBuffer> write_buf(new IOBuffer(kLen1)); + memcpy(write_buf->data(), kMsg1, kLen1); + ASSERT_EQ(ERR_IO_PENDING, + sock_->Write( + write_buf.get(), kLen1, + base::Bind(&SequencedSocketDataTest::ReentrantAsyncReadCallback, + base::Unretained(this), kLen1, kLen2))); + + ASSERT_FALSE(read_callback_.have_result()); + ASSERT_EQ(kLen2, read_callback_.WaitForResult()); + AssertReadBufferEquals(kMsg2, kLen2); +} + +TEST_F(SequencedSocketDataTest, AsyncWriteFromReadCompletionCallback) { + MockWrite writes[] = { + MockWrite(ASYNC, kMsg2, kLen2, 1), + }; + + MockRead reads[] = { + MockRead(ASYNC, kMsg1, kLen1, 0), + }; + + Initialize(reads, arraysize(reads), writes, arraysize(writes)); + + scoped_refptr<IOBuffer> read_buf(new IOBuffer(kLen1)); + ASSERT_EQ( + ERR_IO_PENDING, + sock_->Read( + read_buf.get(), kLen1, + base::Bind(&SequencedSocketDataTest::ReentrantAsyncWriteCallback, + base::Unretained(this), kMsg2, kLen2, + write_callback_.callback(), kLen1))); + + ASSERT_FALSE(write_callback_.have_result()); + ASSERT_EQ(kLen2, write_callback_.WaitForResult()); +} + +TEST_F(SequencedSocketDataTest, MixedReentrantOperations) { + MockWrite writes[] = { + MockWrite(ASYNC, kMsg1, kLen1, 0), MockWrite(ASYNC, kMsg3, kLen3, 2), + }; + + MockRead reads[] = { + MockRead(ASYNC, kMsg2, kLen2, 1), MockRead(ASYNC, kMsg4, kLen4, 3), + }; + + Initialize(reads, arraysize(reads), writes, arraysize(writes)); + + read_buf_ = new IOBuffer(kLen4); + + ReentrantHelper helper3(sock_); + helper3.SetExpectedWrite(kLen3); + helper3.SetInvokeRead(read_buf_, kLen4, ERR_IO_PENDING, + read_callback_.callback()); + + ReentrantHelper helper2(sock_); + helper2.SetExpectedRead(kMsg2, kLen2); + helper2.SetInvokeWrite(kMsg3, kLen3, ERR_IO_PENDING, helper3.callback()); + + ReentrantHelper helper(sock_); + helper.SetExpectedWrite(kLen1); + helper.SetInvokeRead(helper2.read_buf(), kLen2, ERR_IO_PENDING, + helper2.callback()); + + scoped_refptr<IOBuffer> write_buf(new IOBuffer(kLen1)); + memcpy(write_buf->data(), kMsg1, kLen1); + sock_->Write(write_buf.get(), kLen1, helper.callback()); + + ASSERT_EQ(kLen4, read_callback_.WaitForResult()); +} + +TEST_F(SequencedSocketDataTest, MixedReentrantOperationsThenSynchronousRead) { + MockWrite writes[] = { + MockWrite(ASYNC, kMsg1, kLen1, 0), MockWrite(ASYNC, kMsg3, kLen3, 2), + }; + + MockRead reads[] = { + MockRead(ASYNC, kMsg2, kLen2, 1), MockRead(SYNCHRONOUS, kMsg4, kLen4, 3), + }; + + Initialize(reads, arraysize(reads), writes, arraysize(writes)); + + read_buf_ = new IOBuffer(kLen4); + + ReentrantHelper helper3(sock_); + helper3.SetExpectedWrite(kLen3); + helper3.SetInvokeRead(read_buf_, kLen4, kLen4, failing_callback_); + + ReentrantHelper helper2(sock_); + helper2.SetExpectedRead(kMsg2, kLen2); + helper2.SetInvokeWrite(kMsg3, kLen3, ERR_IO_PENDING, helper3.callback()); + + ReentrantHelper helper(sock_); + helper.SetExpectedWrite(kLen1); + helper.SetInvokeRead(helper2.read_buf(), kLen2, ERR_IO_PENDING, + helper2.callback()); + + scoped_refptr<IOBuffer> write_buf(new IOBuffer(kLen1)); + memcpy(write_buf->data(), kMsg1, kLen1); + ASSERT_EQ(ERR_IO_PENDING, + sock_->Write(write_buf.get(), kLen1, helper.callback())); + + base::MessageLoop::current()->RunUntilIdle(); + AssertReadBufferEquals(kMsg4, kLen4); +} + +TEST_F(SequencedSocketDataTest, MixedReentrantOperationsThenSynchronousWrite) { + MockWrite writes[] = { + MockWrite(ASYNC, kMsg2, kLen2, 1), + MockWrite(SYNCHRONOUS, kMsg4, kLen4, 3), + }; + + MockRead reads[] = { + MockRead(ASYNC, kMsg1, kLen1, 0), MockRead(ASYNC, kMsg3, kLen3, 2), + }; + + Initialize(reads, arraysize(reads), writes, arraysize(writes)); + + read_buf_ = new IOBuffer(kLen4); + + ReentrantHelper helper3(sock_); + helper3.SetExpectedRead(kMsg3, kLen3); + helper3.SetInvokeWrite(kMsg4, kLen4, kLen4, failing_callback_); + + ReentrantHelper helper2(sock_); + helper2.SetExpectedWrite(kLen2); + helper2.SetInvokeRead(helper3.read_buf(), kLen3, ERR_IO_PENDING, + helper3.callback()); + + ReentrantHelper helper(sock_); + helper.SetExpectedRead(kMsg1, kLen1); + helper.SetInvokeWrite(kMsg2, kLen2, ERR_IO_PENDING, helper2.callback()); + + ASSERT_EQ(ERR_IO_PENDING, + sock_->Read(helper.read_buf().get(), kLen1, helper.callback())); + + base::MessageLoop::current()->RunUntilIdle(); +} + +} // namespace + +} // namespace net diff --git a/chromium/net/socket/server_socket.cc b/chromium/net/socket/server_socket.cc index da89b4645f8..7cf6c64133c 100644 --- a/chromium/net/socket/server_socket.cc +++ b/chromium/net/socket/server_socket.cc @@ -17,7 +17,7 @@ ServerSocket::~ServerSocket() { } int ServerSocket::ListenWithAddressAndPort(const std::string& address_string, - int port, + uint16 port, int backlog) { IPAddressNumber address_number; if (!ParseIPLiteralToNumber(address_string, &address_number)) { diff --git a/chromium/net/socket/server_socket.h b/chromium/net/socket/server_socket.h index 4b9ca8e39cf..828b399c7a8 100644 --- a/chromium/net/socket/server_socket.h +++ b/chromium/net/socket/server_socket.h @@ -7,6 +7,7 @@ #include <string> +#include "base/basictypes.h" #include "base/memory/scoped_ptr.h" #include "net/base/completion_callback.h" #include "net/base/net_export.h" @@ -30,7 +31,7 @@ class NET_EXPORT ServerSocket { // Subclasses may override this function if |address_string| is in a different // format, for example, unix domain socket path. virtual int ListenWithAddressAndPort(const std::string& address_string, - int port, + uint16 port, int backlog); // Gets current address the socket is bound to. diff --git a/chromium/net/socket/socket_net_log_params.cc b/chromium/net/socket/socket_net_log_params.cc index bcc12c86bcf..3dd6595525d 100644 --- a/chromium/net/socket/socket_net_log_params.cc +++ b/chromium/net/socket/socket_net_log_params.cc @@ -16,7 +16,7 @@ namespace { base::Value* NetLogSocketErrorCallback(int net_error, int os_error, - NetLog::LogLevel /* log_level */) { + NetLogCaptureMode /* capture_mode */) { base::DictionaryValue* dict = new base::DictionaryValue(); dict->SetInteger("net_error", net_error); dict->SetInteger("os_error", os_error); @@ -24,14 +24,14 @@ base::Value* NetLogSocketErrorCallback(int net_error, } base::Value* NetLogHostPortPairCallback(const HostPortPair* host_and_port, - NetLog::LogLevel /* log_level */) { + NetLogCaptureMode /* capture_mode */) { base::DictionaryValue* dict = new base::DictionaryValue(); dict->SetString("host_and_port", host_and_port->ToString()); return dict; } base::Value* NetLogIPEndPointCallback(const IPEndPoint* address, - NetLog::LogLevel /* log_level */) { + NetLogCaptureMode /* capture_mode */) { base::DictionaryValue* dict = new base::DictionaryValue(); dict->SetString("address", address->ToString()); return dict; @@ -39,7 +39,7 @@ base::Value* NetLogIPEndPointCallback(const IPEndPoint* address, base::Value* NetLogSourceAddressCallback(const struct sockaddr* net_address, socklen_t address_len, - NetLog::LogLevel /* log_level */) { + NetLogCaptureMode /* capture_mode */) { base::DictionaryValue* dict = new base::DictionaryValue(); dict->SetString("source_address", NetAddressToStringWithPort(net_address, address_len)); diff --git a/chromium/net/socket/socket_net_log_params.h b/chromium/net/socket/socket_net_log_params.h index f5fe652d125..b432667d4a6 100644 --- a/chromium/net/socket/socket_net_log_params.h +++ b/chromium/net/socket/socket_net_log_params.h @@ -5,8 +5,8 @@ #ifndef NET_SOCKET_SOCKET_NET_LOG_PARAMS_H_ #define NET_SOCKET_SOCKET_NET_LOG_PARAMS_H_ -#include "net/base/net_log.h" #include "net/base/sys_addrinfo.h" +#include "net/log/net_log.h" namespace net { diff --git a/chromium/net/socket/socket_test_util.cc b/chromium/net/socket/socket_test_util.cc index 5ae4eeeedb8..920c5263eba 100644 --- a/chromium/net/socket/socket_test_util.cc +++ b/chromium/net/socket/socket_test_util.cc @@ -10,8 +10,8 @@ #include "base/basictypes.h" #include "base/bind.h" #include "base/bind_helpers.h" -#include "base/callback_helpers.h" #include "base/compiler_specific.h" +#include "base/logging.h" #include "base/message_loop/message_loop.h" #include "base/run_loop.h" #include "base/time/time.h" @@ -22,20 +22,15 @@ #include "net/http/http_network_session.h" #include "net/http/http_request_headers.h" #include "net/http/http_response_headers.h" -#include "net/socket/client_socket_pool_histograms.h" #include "net/socket/socket.h" +#include "net/socket/websocket_endpoint_lock_manager.h" #include "net/ssl/ssl_cert_request_info.h" #include "net/ssl/ssl_connection_status_flags.h" +#include "net/ssl/ssl_failure_state.h" #include "net/ssl/ssl_info.h" #include "testing/gtest/include/gtest/gtest.h" -// Socket events are easier to debug if you log individual reads and writes. -// Enable these if locally debugging, but they are too noisy for the waterfall. -#if 0 -#define NET_TRACE(level, s) DLOG(level) << s << __FUNCTION__ << "() " -#else -#define NET_TRACE(level, s) EAT_STREAM_PARAMETERS -#endif +#define NET_TRACE(level, s) VLOG(level) << s << __FUNCTION__ << "() " namespace net { @@ -121,8 +116,7 @@ void DumpMockReadWrite(const MockReadWrite<type>& r) { << "\nResult: " << r.result; DumpData(r.data, r.data_len); const char* stop = (r.sequence_number & MockRead::STOPLOOP) ? " (STOP)" : ""; - DVLOG(1) << "Stage: " << (r.sequence_number & ~MockRead::STOPLOOP) << stop - << "\nTime: " << r.time_stamp.ToInternalValue(); + DVLOG(1) << "Stage: " << (r.sequence_number & ~MockRead::STOPLOOP) << stop; } } // namespace @@ -147,19 +141,10 @@ MockConnect::MockConnect(IoMode io_mode, int r, IPEndPoint addr) : MockConnect::~MockConnect() {} -StaticSocketDataProvider::StaticSocketDataProvider() - : reads_(NULL), - read_index_(0), - read_count_(0), - writes_(NULL), - write_index_(0), - write_count_(0) { -} - -StaticSocketDataProvider::StaticSocketDataProvider(MockRead* reads, - size_t reads_count, - MockWrite* writes, - size_t writes_count) +StaticSocketDataHelper::StaticSocketDataHelper(MockRead* reads, + size_t reads_count, + MockWrite* writes, + size_t writes_count) : reads_(reads), read_index_(0), read_count_(reads_count), @@ -168,41 +153,80 @@ StaticSocketDataProvider::StaticSocketDataProvider(MockRead* reads, write_count_(writes_count) { } -StaticSocketDataProvider::~StaticSocketDataProvider() {} +StaticSocketDataHelper::~StaticSocketDataHelper() { +} -const MockRead& StaticSocketDataProvider::PeekRead() const { +const MockRead& StaticSocketDataHelper::PeekRead() const { CHECK(!at_read_eof()); return reads_[read_index_]; } -const MockWrite& StaticSocketDataProvider::PeekWrite() const { +const MockWrite& StaticSocketDataHelper::PeekWrite() const { CHECK(!at_write_eof()); return writes_[write_index_]; } -const MockRead& StaticSocketDataProvider::PeekRead(size_t index) const { - CHECK_LT(index, read_count_); - return reads_[index]; +const MockRead& StaticSocketDataHelper::AdvanceRead() { + CHECK(!at_read_eof()); + return reads_[read_index_++]; } -const MockWrite& StaticSocketDataProvider::PeekWrite(size_t index) const { - CHECK_LT(index, write_count_); - return writes_[index]; +const MockWrite& StaticSocketDataHelper::AdvanceWrite() { + CHECK(!at_write_eof()); + return writes_[write_index_++]; } -MockRead StaticSocketDataProvider::GetNextRead() { - CHECK(!at_read_eof()); - reads_[read_index_].time_stamp = base::Time::Now(); - return reads_[read_index_++]; +bool StaticSocketDataHelper::VerifyWriteData(const std::string& data) { + CHECK(!at_write_eof()); + // Check that what the actual data matches the expectations. + const MockWrite& next_write = PeekWrite(); + if (!next_write.data) + return true; + + // Note: Partial writes are supported here. If the expected data + // is a match, but shorter than the write actually written, that is legal. + // Example: + // Application writes "foobarbaz" (9 bytes) + // Expected write was "foo" (3 bytes) + // This is a success, and the function returns true. + std::string expected_data(next_write.data, next_write.data_len); + std::string actual_data(data.substr(0, next_write.data_len)); + EXPECT_GE(data.length(), expected_data.length()); + EXPECT_EQ(expected_data, actual_data); + return expected_data == actual_data; +} + +void StaticSocketDataHelper::Reset() { + read_index_ = 0; + write_index_ = 0; +} + +StaticSocketDataProvider::StaticSocketDataProvider() + : StaticSocketDataProvider(nullptr, 0, nullptr, 0) { +} + +StaticSocketDataProvider::StaticSocketDataProvider(MockRead* reads, + size_t reads_count, + MockWrite* writes, + size_t writes_count) + : helper_(reads, reads_count, writes, writes_count) { +} + +StaticSocketDataProvider::~StaticSocketDataProvider() { +} + +MockRead StaticSocketDataProvider::OnRead() { + CHECK(!helper_.at_read_eof()); + return helper_.AdvanceRead(); } MockWriteResult StaticSocketDataProvider::OnWrite(const std::string& data) { - if (!writes_) { + if (helper_.write_count() == 0) { // Not using mock writes; succeed synchronously. return MockWriteResult(SYNCHRONOUS, data.length()); } - EXPECT_FALSE(at_write_eof()); - if (at_write_eof()) { + EXPECT_FALSE(helper_.at_write_eof()); + if (helper_.at_write_eof()) { // Show what the extra write actually consists of. EXPECT_EQ("<unexpected write>", data); return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED); @@ -210,31 +234,27 @@ MockWriteResult StaticSocketDataProvider::OnWrite(const std::string& data) { // Check that what we are writing matches the expectation. // Then give the mocked return value. - MockWrite* w = &writes_[write_index_++]; - w->time_stamp = base::Time::Now(); - int result = w->result; - if (w->data) { - // Note - we can simulate a partial write here. If the expected data - // is a match, but shorter than the write actually written, that is legal. - // Example: - // Application writes "foobarbaz" (9 bytes) - // Expected write was "foo" (3 bytes) - // This is a success, and we return 3 to the application. - std::string expected_data(w->data, w->data_len); - EXPECT_GE(data.length(), expected_data.length()); - std::string actual_data(data.substr(0, w->data_len)); - EXPECT_EQ(expected_data, actual_data); - if (expected_data != actual_data) - return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED); - if (result == OK) - result = w->data_len; - } - return MockWriteResult(w->mode, result); + if (!helper_.VerifyWriteData(data)) + return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED); + + const MockWrite& next_write = helper_.AdvanceWrite(); + // In the case that the write was successful, return the number of bytes + // written. Otherwise return the error code. + int result = + next_write.result == OK ? next_write.data_len : next_write.result; + return MockWriteResult(next_write.mode, result); } void StaticSocketDataProvider::Reset() { - read_index_ = 0; - write_index_ = 0; + helper_.Reset(); +} + +bool StaticSocketDataProvider::AllReadDataConsumed() const { + return helper_.at_read_eof(); +} + +bool StaticSocketDataProvider::AllWriteDataConsumed() const { + return helper_.at_write_eof(); } DynamicSocketDataProvider::DynamicSocketDataProvider() @@ -244,7 +264,7 @@ DynamicSocketDataProvider::DynamicSocketDataProvider() DynamicSocketDataProvider::~DynamicSocketDataProvider() {} -MockRead DynamicSocketDataProvider::GetNextRead() { +MockRead DynamicSocketDataProvider::OnRead() { if (reads_.empty()) return MockRead(SYNCHRONOUS, ERR_UNEXPECTED); MockRead result = reads_.front(); @@ -273,14 +293,10 @@ void DynamicSocketDataProvider::SimulateRead(const char* data, SSLSocketDataProvider::SSLSocketDataProvider(IoMode mode, int result) : connect(mode, result), next_proto_status(SSLClientSocket::kNextProtoUnsupported), - was_npn_negotiated(false), - protocol_negotiated(kProtoUnknown), client_cert_sent(false), cert_request_info(NULL), channel_id_sent(false), - connection_status(0), - should_pause_on_connect(false), - is_in_session_cache(false) { + connection_status(0) { SSLConnectionStatusSetVersion(SSL_CONNECTION_VERSION_TLS1_2, &connection_status); // Set to TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305 @@ -291,9 +307,7 @@ SSLSocketDataProvider::~SSLSocketDataProvider() { } void SSLSocketDataProvider::SetNextProto(NextProto proto) { - was_npn_negotiated = true; next_proto_status = SSLClientSocket::kNextProtoNegotiated; - protocol_negotiated = proto; next_proto = SSLClientSocket::NextProtoToString(proto); } @@ -327,10 +341,10 @@ void DelayedSocketData::ForceNextRead() { CompleteRead(); } -MockRead DelayedSocketData::GetNextRead() { +MockRead DelayedSocketData::OnRead() { MockRead out = MockRead(ASYNC, ERR_IO_PENDING); if (write_delay_ <= 0) - out = StaticSocketDataProvider::GetNextRead(); + out = StaticSocketDataProvider::OnRead(); read_in_progress_ = (out.result == ERR_IO_PENDING); return out; } @@ -356,7 +370,7 @@ void DelayedSocketData::Reset() { void DelayedSocketData::CompleteRead() { if (socket() && read_in_progress_) - socket()->OnReadComplete(GetNextRead()); + socket()->OnReadComplete(OnRead()); } OrderedSocketData::OrderedSocketData( @@ -379,40 +393,39 @@ OrderedSocketData::OrderedSocketData( void OrderedSocketData::EndLoop() { // If we've already stopped the loop, don't do it again until we've advanced // to the next sequence_number. - NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ << ": EndLoop()"; + NET_TRACE(1, " *** ") << "Stage " << sequence_number_ << ": EndLoop()"; if (loop_stop_stage_ > 0) { - const MockRead& next_read = StaticSocketDataProvider::PeekRead(); + const MockRead& next_read = helper()->PeekRead(); if ((next_read.sequence_number & ~MockRead::STOPLOOP) > loop_stop_stage_) { - NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ - << ": Clearing stop index"; + NET_TRACE(1, " *** ") << "Stage " << sequence_number_ + << ": Clearing stop index"; loop_stop_stage_ = 0; } else { return; } } // Record the sequence_number at which we stopped the loop. - NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ - << ": Posting Quit at read " << read_index(); + NET_TRACE(1, " *** ") << "Stage " << sequence_number_ + << ": Posting Quit at read " << read_index(); loop_stop_stage_ = sequence_number_; } -MockRead OrderedSocketData::GetNextRead() { +MockRead OrderedSocketData::OnRead() { weak_factory_.InvalidateWeakPtrs(); blocked_ = false; - const MockRead& next_read = StaticSocketDataProvider::PeekRead(); + const MockRead& next_read = helper()->PeekRead(); if (next_read.sequence_number & MockRead::STOPLOOP) EndLoop(); if ((next_read.sequence_number & ~MockRead::STOPLOOP) <= sequence_number_++) { - NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ - 1 - << ": Read " << read_index(); + NET_TRACE(1, " *** ") << "Stage " << sequence_number_ - 1 << ": Read " + << read_index(); DumpMockReadWrite(next_read); blocked_ = (next_read.result == ERR_IO_PENDING); - return StaticSocketDataProvider::GetNextRead(); + return StaticSocketDataProvider::OnRead(); } - NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ - 1 - << ": I/O Pending"; + NET_TRACE(1, " *** ") << "Stage " << sequence_number_ - 1 << ": I/O Pending"; MockRead result = MockRead(ASYNC, ERR_IO_PENDING); DumpMockReadWrite(result); blocked_ = true; @@ -420,9 +433,9 @@ MockRead OrderedSocketData::GetNextRead() { } MockWriteResult OrderedSocketData::OnWrite(const std::string& data) { - NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ - << ": Write " << write_index(); - DumpMockReadWrite(PeekWrite()); + NET_TRACE(1, " *** ") << "Stage " << sequence_number_ << ": Write " + << write_index(); + DumpMockReadWrite(helper()->PeekWrite()); ++sequence_number_; if (blocked_) { // TODO(willchan): This 100ms delay seems to work around some weirdness. We @@ -440,8 +453,7 @@ MockWriteResult OrderedSocketData::OnWrite(const std::string& data) { } void OrderedSocketData::Reset() { - NET_TRACE(INFO, " *** ") << "Stage " - << sequence_number_ << ": Reset()"; + NET_TRACE(1, " *** ") << "Stage " << sequence_number_ << ": Reset()"; sequence_number_ = 0; loop_stop_stage_ = 0; set_socket(NULL); @@ -451,13 +463,259 @@ void OrderedSocketData::Reset() { void OrderedSocketData::CompleteRead() { if (socket() && blocked_) { - NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_; - socket()->OnReadComplete(GetNextRead()); + NET_TRACE(1, " *** ") << "Stage " << sequence_number_; + socket()->OnReadComplete(OnRead()); } } OrderedSocketData::~OrderedSocketData() {} +SequencedSocketData::SequencedSocketData(MockRead* reads, + size_t reads_count, + MockWrite* writes, + size_t writes_count) + : helper_(reads, reads_count, writes, writes_count), + sequence_number_(0), + read_state_(IDLE), + write_state_(IDLE), + weak_factory_(this) { + // Check that reads and writes have a contiguous set of sequence numbers + // starting from 0 and working their way up, with no repeats and skipping + // no values. + size_t next_read = 0; + size_t next_write = 0; + int next_sequence_number = 0; + while (next_read < reads_count || next_write < writes_count) { + if (next_read < reads_count && + reads[next_read].sequence_number == next_sequence_number) { + ++next_read; + ++next_sequence_number; + continue; + } + if (next_write < writes_count && + writes[next_write].sequence_number == next_sequence_number) { + ++next_write; + ++next_sequence_number; + continue; + } + CHECK(false) << "Sequence number not found where expected: " + << next_sequence_number; + return; + } + CHECK_EQ(next_read, reads_count); + CHECK_EQ(next_write, writes_count); +} + +SequencedSocketData::SequencedSocketData(const MockConnect& connect, + MockRead* reads, + size_t reads_count, + MockWrite* writes, + size_t writes_count) + : SequencedSocketData(reads, reads_count, writes, writes_count) { + set_connect_data(connect); +} + +MockRead SequencedSocketData::OnRead() { + CHECK_EQ(IDLE, read_state_); + CHECK(!helper_.at_read_eof()); + + NET_TRACE(1, " *** ") << "sequence_number: " << sequence_number_; + const MockRead& next_read = helper_.PeekRead(); + NET_TRACE(1, " *** ") << "next_read: " << next_read.sequence_number; + CHECK_GE(next_read.sequence_number, sequence_number_); + + // Special case handling for hanging reads. + if (next_read.mode == ASYNC && next_read.result == ERR_IO_PENDING) { + NET_TRACE(1, " *** ") << "Hanging read"; + helper_.AdvanceRead(); + ++sequence_number_; + CHECK(helper_.at_read_eof()); + return MockRead(SYNCHRONOUS, ERR_IO_PENDING); + } + + if (next_read.sequence_number <= sequence_number_) { + if (next_read.mode == SYNCHRONOUS) { + NET_TRACE(1, " *** ") << "Returning synchronously"; + DumpMockReadWrite(next_read); + helper_.AdvanceRead(); + ++sequence_number_; + MaybePostWriteCompleteTask(); + return next_read; + } + + base::MessageLoop::current()->PostTask( + FROM_HERE, base::Bind(&SequencedSocketData::OnReadComplete, + weak_factory_.GetWeakPtr())); + CHECK_NE(COMPLETING, write_state_); + read_state_ = COMPLETING; + } else if (next_read.mode == SYNCHRONOUS) { + ADD_FAILURE() << "Unable to perform synchronous IO while stopped"; + return MockRead(SYNCHRONOUS, ERR_UNEXPECTED); + } else { + NET_TRACE(1, " *** ") << "Waiting for write to trigger read"; + read_state_ = PENDING; + } + + return MockRead(SYNCHRONOUS, ERR_IO_PENDING); +} + +MockWriteResult SequencedSocketData::OnWrite(const std::string& data) { + CHECK_EQ(IDLE, write_state_); + CHECK(!helper_.at_write_eof()); + + NET_TRACE(1, " *** ") << "sequence_number: " << sequence_number_; + const MockWrite& next_write = helper_.PeekWrite(); + NET_TRACE(1, " *** ") << "next_write: " << next_write.sequence_number; + CHECK_GE(next_write.sequence_number, sequence_number_); + + if (!helper_.VerifyWriteData(data)) + return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED); + + if (next_write.sequence_number <= sequence_number_) { + if (next_write.mode == SYNCHRONOUS) { + helper_.AdvanceWrite(); + ++sequence_number_; + MaybePostReadCompleteTask(); + // In the case that the write was successful, return the number of bytes + // written. Otherwise return the error code. + int rv = + next_write.result != OK ? next_write.result : next_write.data_len; + NET_TRACE(1, " *** ") << "Returning synchronously"; + return MockWriteResult(SYNCHRONOUS, rv); + } + + NET_TRACE(1, " *** ") << "Posting task to complete write"; + base::MessageLoop::current()->PostTask( + FROM_HERE, base::Bind(&SequencedSocketData::OnWriteComplete, + weak_factory_.GetWeakPtr())); + CHECK_NE(COMPLETING, read_state_); + write_state_ = COMPLETING; + } else if (next_write.mode == SYNCHRONOUS) { + ADD_FAILURE() << "Unable to perform synchronous IO while stopped"; + return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED); + } else { + NET_TRACE(1, " *** ") << "Waiting for read to trigger write"; + write_state_ = PENDING; + } + + return MockWriteResult(SYNCHRONOUS, ERR_IO_PENDING); +} + +void SequencedSocketData::Reset() { + helper_.Reset(); + sequence_number_ = 0; + read_state_ = IDLE; + write_state_ = IDLE; + weak_factory_.InvalidateWeakPtrs(); +} + +bool SequencedSocketData::AllReadDataConsumed() const { + return helper_.at_read_eof(); +} + +bool SequencedSocketData::AllWriteDataConsumed() const { + return helper_.at_write_eof(); +} + +bool SequencedSocketData::at_read_eof() const { + return helper_.at_read_eof(); +} + +bool SequencedSocketData::at_write_eof() const { + return helper_.at_read_eof(); +} + +void SequencedSocketData::MaybePostReadCompleteTask() { + NET_TRACE(1, " ****** ") << " current: " << sequence_number_; + // Only trigger the next read to complete if there is already a read pending + // which should complete at the current sequence number. + if (read_state_ != PENDING || + helper_.PeekRead().sequence_number != sequence_number_) { + return; + } + + NET_TRACE(1, " ****** ") << "Posting task to complete read: " + << sequence_number_; + base::MessageLoop::current()->PostTask( + FROM_HERE, base::Bind(&SequencedSocketData::OnReadComplete, + weak_factory_.GetWeakPtr())); + CHECK_NE(COMPLETING, write_state_); + read_state_ = COMPLETING; +} + +void SequencedSocketData::MaybePostWriteCompleteTask() { + NET_TRACE(1, " ****** ") << " current: " << sequence_number_; + // Only trigger the next write to complete if there is already a write pending + // which should complete at the current sequence number. + if (write_state_ != PENDING || + helper_.PeekWrite().sequence_number != sequence_number_) { + return; + } + + NET_TRACE(1, " ****** ") << "Posting task to complete write: " + << sequence_number_; + base::MessageLoop::current()->PostTask( + FROM_HERE, base::Bind(&SequencedSocketData::OnWriteComplete, + weak_factory_.GetWeakPtr())); + CHECK_NE(COMPLETING, read_state_); + write_state_ = COMPLETING; +} + +void SequencedSocketData::OnReadComplete() { + CHECK_EQ(COMPLETING, read_state_); + NET_TRACE(1, " *** ") << "Completing read for: " << sequence_number_; + if (!socket()) { + NET_TRACE(1, " *** ") << "No socket available to complete read"; + return; + } + + MockRead data = helper_.AdvanceRead(); + DCHECK_EQ(sequence_number_, data.sequence_number); + sequence_number_++; + read_state_ = IDLE; + + // The result of this read completing might trigger the completion + // of a pending write. If so, post a task to complete the write later. + // Since the socket may call back into the SequencedSocketData + // from socket()->OnReadComplete(), trigger the write task to be posted + // before calling that. + MaybePostWriteCompleteTask(); + + NET_TRACE(1, " *** ") << "Completing socket read for: " << sequence_number_; + DumpMockReadWrite(data); + socket()->OnReadComplete(data); + NET_TRACE(1, " *** ") << "Done"; +} + +void SequencedSocketData::OnWriteComplete() { + CHECK_EQ(COMPLETING, write_state_); + NET_TRACE(1, " *** ") << " Completing write for: " << sequence_number_; + if (!socket()) { + NET_TRACE(1, " *** ") << "No socket available to complete write."; + return; + } + + const MockWrite& data = helper_.AdvanceWrite(); + DCHECK_EQ(sequence_number_, data.sequence_number); + sequence_number_++; + write_state_ = IDLE; + int rv = data.result == OK ? data.data_len : data.result; + + // The result of this write completing might trigger the completion + // of a pending read. If so, post a task to complete the read later. + // Since the socket may call back into the SequencedSocketData + // from socket()->OnWriteComplete(), trigger the write task to be posted + // before calling that. + MaybePostReadCompleteTask(); + + NET_TRACE(1, " *** ") << " Completing socket write for: " << sequence_number_; + socket()->OnWriteComplete(rv); + NET_TRACE(1, " *** ") << "Done"; +} + +SequencedSocketData::~SequencedSocketData() { +} + DeterministicSocketData::DeterministicSocketData(MockRead* reads, size_t reads_count, MockWrite* writes, size_t writes_count) : StaticSocketDataProvider(reads, reads_count, writes, writes_count), @@ -519,8 +777,8 @@ void DeterministicSocketData::StopAfter(int seq) { SetStop(sequence_number_ + seq); } -MockRead DeterministicSocketData::GetNextRead() { - current_read_ = StaticSocketDataProvider::PeekRead(); +MockRead DeterministicSocketData::OnRead() { + current_read_ = helper()->PeekRead(); // Synchronous read while stopped is an error if (stopped() && current_read_.mode == SYNCHRONOUS) { @@ -530,8 +788,7 @@ MockRead DeterministicSocketData::GetNextRead() { // Async read which will be called back in a future step. if (sequence_number_ < current_read_.sequence_number) { - NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ - << ": I/O Pending"; + NET_TRACE(1, " *** ") << "Stage " << sequence_number_ << ": I/O Pending"; MockRead result = MockRead(SYNCHRONOUS, ERR_IO_PENDING); if (current_read_.mode == SYNCHRONOUS) { LOG(ERROR) << "Unable to perform synchronous read: " @@ -544,8 +801,8 @@ MockRead DeterministicSocketData::GetNextRead() { return result; } - NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ - << ": Read " << read_index(); + NET_TRACE(1, " *** ") << "Stage " << sequence_number_ << ": Read " + << read_index(); if (print_debug_) DumpMockReadWrite(current_read_); @@ -554,13 +811,13 @@ MockRead DeterministicSocketData::GetNextRead() { NextStep(); DCHECK_NE(ERR_IO_PENDING, current_read_.result); - StaticSocketDataProvider::GetNextRead(); + StaticSocketDataProvider::OnRead(); return current_read_; } MockWriteResult DeterministicSocketData::OnWrite(const std::string& data) { - const MockWrite& next_write = StaticSocketDataProvider::PeekWrite(); + const MockWrite& next_write = helper()->PeekWrite(); current_write_ = next_write; // Synchronous write while stopped is an error @@ -571,16 +828,15 @@ MockWriteResult DeterministicSocketData::OnWrite(const std::string& data) { // Async write which will be called back in a future step. if (sequence_number_ < next_write.sequence_number) { - NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ - << ": I/O Pending"; + NET_TRACE(1, " *** ") << "Stage " << sequence_number_ << ": I/O Pending"; if (next_write.mode == SYNCHRONOUS) { LOG(ERROR) << "Unable to perform synchronous write: " << next_write.sequence_number << " at stage: " << sequence_number_; return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED); } } else { - NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ - << ": Write " << write_index(); + NET_TRACE(1, " *** ") << "Stage " << sequence_number_ << ": Write " + << write_index(); } if (print_debug_) @@ -596,8 +852,7 @@ MockWriteResult DeterministicSocketData::OnWrite(const std::string& data) { } void DeterministicSocketData::Reset() { - NET_TRACE(INFO, " *** ") << "Stage " - << sequence_number_ << ": Reset()"; + NET_TRACE(1, " *** ") << "Stage " << sequence_number_ << ": Reset()"; sequence_number_ = 0; StaticSocketDataProvider::Reset(); NOTREACHED(); @@ -675,21 +930,21 @@ scoped_ptr<DatagramClientSocket> MockClientSocketFactory::CreateDatagramClientSocket( DatagramSocket::BindType bind_type, const RandIntCallback& rand_int_cb, - net::NetLog* net_log, - const net::NetLog::Source& source) { + NetLog* net_log, + const NetLog::Source& source) { SocketDataProvider* data_provider = mock_data_.GetNext(); scoped_ptr<MockUDPClientSocket> socket( new MockUDPClientSocket(data_provider, net_log)); data_provider->set_socket(socket.get()); if (bind_type == DatagramSocket::RANDOM_BIND) - socket->set_source_port(rand_int_cb.Run(1025, 65535)); + socket->set_source_port(static_cast<uint16>(rand_int_cb.Run(1025, 65535))); return socket.Pass(); } scoped_ptr<StreamSocket> MockClientSocketFactory::CreateTransportClientSocket( const AddressList& addresses, - net::NetLog* net_log, - const net::NetLog::Source& source) { + NetLog* net_log, + const NetLog::Source& source) { SocketDataProvider* data_provider = mock_data_.GetNext(); scoped_ptr<MockTCPClientSocket> socket( new MockTCPClientSocket(addresses, net_log, data_provider)); @@ -702,13 +957,17 @@ scoped_ptr<SSLClientSocket> MockClientSocketFactory::CreateSSLClientSocket( const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context) { - scoped_ptr<MockSSLClientSocket> socket( - new MockSSLClientSocket(transport_socket.Pass(), - host_and_port, - ssl_config, - mock_ssl_data_.GetNext())); - ssl_client_sockets_.push_back(socket.get()); - return socket.Pass(); + SSLSocketDataProvider* next_ssl_data = mock_ssl_data_.GetNext(); + if (!next_ssl_data->next_protos_expected_in_ssl_config.empty()) { + EXPECT_EQ(next_ssl_data->next_protos_expected_in_ssl_config.size(), + ssl_config.next_protos.size()); + EXPECT_TRUE( + std::equal(next_ssl_data->next_protos_expected_in_ssl_config.begin(), + next_ssl_data->next_protos_expected_in_ssl_config.end(), + ssl_config.next_protos.begin())); + } + return scoped_ptr<SSLClientSocket>(new MockSSLClientSocket( + transport_socket.Pass(), host_and_port, ssl_config, next_ssl_data)); } void MockClientSocketFactory::ClearSSLSessionCache() { @@ -764,18 +1023,8 @@ const BoundNetLog& MockClientSocket::NetLog() const { return net_log_; } -std::string MockClientSocket::GetSessionCacheKey() const { - NOTIMPLEMENTED(); - return std::string(); -} - -bool MockClientSocket::InSessionCache() const { - NOTIMPLEMENTED(); - return false; -} - -void MockClientSocket::SetHandshakeCompletionCallback(const base::Closure& cb) { - NOTIMPLEMENTED(); +void MockClientSocket::GetConnectionAttempts(ConnectionAttempts* out) const { + out->clear(); } void MockClientSocket::GetSSLCertRequestInfo( @@ -801,8 +1050,12 @@ ChannelIDService* MockClientSocket::GetChannelIDService() const { return NULL; } -SSLClientSocket::NextProtoStatus -MockClientSocket::GetNextProto(std::string* proto) { +SSLFailureState MockClientSocket::GetSSLFailureState() const { + return IsConnected() ? SSL_FAILURE_NONE : SSL_FAILURE_UNKNOWN; +} + +SSLClientSocket::NextProtoStatus MockClientSocket::GetNextProto( + std::string* proto) const { proto->clear(); return SSLClientSocket::kNextProtoUnsupported; } @@ -825,7 +1078,7 @@ void MockClientSocket::RunCallbackAsync(const CompletionCallback& callback, result)); } -void MockClientSocket::RunCallback(const net::CompletionCallback& callback, +void MockClientSocket::RunCallback(const CompletionCallback& callback, int result) { if (!callback.is_null()) callback.Run(result); @@ -834,15 +1087,15 @@ void MockClientSocket::RunCallback(const net::CompletionCallback& callback, MockTCPClientSocket::MockTCPClientSocket(const AddressList& addresses, net::NetLog* net_log, SocketDataProvider* data) - : MockClientSocket(BoundNetLog::Make(net_log, net::NetLog::SOURCE_NONE)), + : MockClientSocket(BoundNetLog::Make(net_log, NetLog::SOURCE_NONE)), addresses_(addresses), data_(data), read_offset_(0), read_data_(SYNCHRONOUS, ERR_UNEXPECTED), need_read_data_(true), peer_closed_connection_(false), - pending_buf_(NULL), - pending_buf_len_(0), + pending_read_buf_(NULL), + pending_read_buf_len_(0), was_used_to_convey_data_(false) { DCHECK(data_); peer_addr_ = data->connect_data().peer_addr; @@ -857,15 +1110,15 @@ int MockTCPClientSocket::Read(IOBuffer* buf, int buf_len, return ERR_UNEXPECTED; // If the buffer is already in use, a read is already in progress! - DCHECK(pending_buf_.get() == NULL); + DCHECK(pending_read_buf_.get() == NULL); // Store our async IO data. - pending_buf_ = buf; - pending_buf_len_ = buf_len; - pending_callback_ = callback; + pending_read_buf_ = buf; + pending_read_buf_len_ = buf_len; + pending_read_callback_ = callback; if (need_read_data_) { - read_data_ = data_->GetNextRead(); + read_data_ = data_->OnRead(); if (read_data_.result == ERR_CONNECTION_CLOSED) { // This MockRead is just a marker to instruct us to set // peer_closed_connection_. @@ -874,7 +1127,7 @@ int MockTCPClientSocket::Read(IOBuffer* buf, int buf_len, if (read_data_.result == ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) { // This MockRead is just a marker to instruct us to set // peer_closed_connection_. Skip it and get the next one. - read_data_ = data_->GetNextRead(); + read_data_ = data_->OnRead(); peer_closed_connection_ = true; } // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility @@ -903,6 +1156,15 @@ int MockTCPClientSocket::Write(IOBuffer* buf, int buf_len, was_used_to_convey_data_ = true; + // ERR_IO_PENDING is a signal that the socket data will call back + // asynchronously later. + if (write_result.result == ERR_IO_PENDING) { + pending_write_callback_ = callback; + return ERR_IO_PENDING; + } + + // TODO(rch): remove this once OrderedSocketData and DelayedSocketData + // have been removed. if (write_result.mode == ASYNC) { RunCallbackAsync(callback, write_result.result); return ERR_IO_PENDING; @@ -911,6 +1173,22 @@ int MockTCPClientSocket::Write(IOBuffer* buf, int buf_len, return write_result.result; } +void MockTCPClientSocket::GetConnectionAttempts(ConnectionAttempts* out) const { + int connect_result = data_->connect_data().result; + + out->clear(); + if (connected_ && connect_result != OK) + out->push_back(ConnectionAttempt(addresses_[0], connect_result)); +} + +void MockTCPClientSocket::ClearConnectionAttempts() { + NOTIMPLEMENTED(); +} + +void MockTCPClientSocket::AddConnectionAttempts(const ConnectionAttempts& in) { + NOTIMPLEMENTED(); +} + int MockTCPClientSocket::Connect(const CompletionCallback& callback) { if (connected_) return OK; @@ -918,7 +1196,7 @@ int MockTCPClientSocket::Connect(const CompletionCallback& callback) { peer_closed_connection_ = false; if (data_->connect_data().mode == ASYNC) { if (data_->connect_data().result == ERR_IO_PENDING) - pending_callback_ = callback; + pending_read_callback_ = callback; else RunCallbackAsync(callback, data_->connect_data().result); return ERR_IO_PENDING; @@ -928,7 +1206,7 @@ int MockTCPClientSocket::Connect(const CompletionCallback& callback) { void MockTCPClientSocket::Disconnect() { MockClientSocket::Disconnect(); - pending_callback_.Reset(); + pending_read_callback_.Reset(); } bool MockTCPClientSocket::IsConnected() const { @@ -965,7 +1243,7 @@ bool MockTCPClientSocket::GetSSLInfo(SSLInfo* ssl_info) { void MockTCPClientSocket::OnReadComplete(const MockRead& data) { // There must be a read pending. - DCHECK(pending_buf_.get()); + DCHECK(pending_read_buf_.get()); // You can't complete a read with another ERR_IO_PENDING status code. DCHECK_NE(ERR_IO_PENDING, data.result); // Since we've been waiting for data, need_read_data_ should be true. @@ -978,29 +1256,36 @@ void MockTCPClientSocket::OnReadComplete(const MockRead& data) { // let CompleteRead() schedule a callback. read_data_.mode = SYNCHRONOUS; - CompletionCallback callback = pending_callback_; + CompletionCallback callback = pending_read_callback_; int rv = CompleteRead(); RunCallback(callback, rv); } +void MockTCPClientSocket::OnWriteComplete(int rv) { + // There must be a read pending. + DCHECK(!pending_write_callback_.is_null()); + CompletionCallback callback = pending_write_callback_; + RunCallback(callback, rv); +} + void MockTCPClientSocket::OnConnectComplete(const MockConnect& data) { - CompletionCallback callback = pending_callback_; + CompletionCallback callback = pending_read_callback_; RunCallback(callback, data.result); } int MockTCPClientSocket::CompleteRead() { - DCHECK(pending_buf_.get()); - DCHECK(pending_buf_len_ > 0); + DCHECK(pending_read_buf_.get()); + DCHECK(pending_read_buf_len_ > 0); was_used_to_convey_data_ = true; // Save the pending async IO data and reset our |pending_| state. - scoped_refptr<IOBuffer> buf = pending_buf_; - int buf_len = pending_buf_len_; - CompletionCallback callback = pending_callback_; - pending_buf_ = NULL; - pending_buf_len_ = 0; - pending_callback_.Reset(); + scoped_refptr<IOBuffer> buf = pending_read_buf_; + int buf_len = pending_read_buf_len_; + CompletionCallback callback = pending_read_callback_; + pending_read_buf_ = NULL; + pending_read_buf_len_ = 0; + pending_read_callback_.Reset(); int result = read_data_.result; DCHECK(result != ERR_IO_PENDING); @@ -1028,7 +1313,7 @@ int MockTCPClientSocket::CompleteRead() { } DeterministicSocketHelper::DeterministicSocketHelper( - net::NetLog* net_log, + NetLog* net_log, DeterministicSocketData* data) : write_pending_(false), write_result_(0), @@ -1039,7 +1324,7 @@ DeterministicSocketHelper::DeterministicSocketHelper( data_(data), was_used_to_convey_data_(false), peer_closed_connection_(false), - net_log_(BoundNetLog::Make(net_log, net::NetLog::SOURCE_NONE)) { + net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_NONE)) { } DeterministicSocketHelper::~DeterministicSocketHelper() {} @@ -1058,7 +1343,7 @@ int DeterministicSocketHelper::CompleteRead() { was_used_to_convey_data_ = true; if (read_data_.result == ERR_IO_PENDING) - read_data_ = data_->GetNextRead(); + read_data_ = data_->OnRead(); DCHECK_NE(ERR_IO_PENDING, read_data_.result); // If read_data_.mode is ASYNC, we do not need to wait, since this is already // the callback. Therefore we don't even bother to check it. @@ -1101,8 +1386,7 @@ int DeterministicSocketHelper::Write( int DeterministicSocketHelper::Read( IOBuffer* buf, int buf_len, const CompletionCallback& callback) { - - read_data_ = data_->GetNextRead(); + read_data_ = data_->OnRead(); // The buffer should always be big enough to contain all the MockRead data. To // use small buffers, split the data into multiple MockReads. DCHECK_LE(read_data_.data_len, buf_len); @@ -1115,7 +1399,7 @@ int DeterministicSocketHelper::Read( if (read_data_.result == ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) { // This MockRead is just a marker to instruct us to set // peer_closed_connection_. Skip it and get the next one. - read_data_ = data_->GetNextRead(); + read_data_ = data_->OnRead(); peer_closed_connection_ = true; } @@ -1220,6 +1504,9 @@ const BoundNetLog& DeterministicMockUDPClientSocket::NetLog() const { void DeterministicMockUDPClientSocket::OnReadComplete(const MockRead& data) {} +void DeterministicMockUDPClientSocket::OnWriteComplete(int rv) { +} + void DeterministicMockUDPClientSocket::OnConnectComplete( const MockConnect& data) { NOTIMPLEMENTED(); @@ -1228,7 +1515,7 @@ void DeterministicMockUDPClientSocket::OnConnectComplete( DeterministicMockTCPClientSocket::DeterministicMockTCPClientSocket( net::NetLog* net_log, DeterministicSocketData* data) - : MockClientSocket(BoundNetLog::Make(net_log, net::NetLog::SOURCE_NONE)), + : MockClientSocket(BoundNetLog::Make(net_log, NetLog::SOURCE_NONE)), helper_(net_log, data) { peer_addr_ = data->connect_data().peer_addr; } @@ -1314,9 +1601,22 @@ bool DeterministicMockTCPClientSocket::GetSSLInfo(SSLInfo* ssl_info) { void DeterministicMockTCPClientSocket::OnReadComplete(const MockRead& data) {} +void DeterministicMockTCPClientSocket::OnWriteComplete(int rv) { +} + void DeterministicMockTCPClientSocket::OnConnectComplete( const MockConnect& data) {} +// static +void MockSSLClientSocket::ConnectCallback( + MockSSLClientSocket* ssl_client_socket, + const CompletionCallback& callback, + int rv) { + if (rv == OK) + ssl_client_socket->connected_ = true; + callback.Run(rv); +} + MockSSLClientSocket::MockSSLClientSocket( scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_port_pair, @@ -1327,15 +1627,7 @@ MockSSLClientSocket::MockSSLClientSocket( // tests. transport_socket->socket()->NetLog()), transport_(transport_socket.Pass()), - host_port_pair_(host_port_pair), - data_(data), - is_npn_state_set_(false), - new_npn_value_(false), - is_protocol_negotiated_set_(false), - protocol_negotiated_(kProtoUnknown), - next_connect_state_(STATE_NONE), - reached_connect_(false), - weak_factory_(this) { + data_(data) { DCHECK(data_); peer_addr_ = data->connect.peer_addr; } @@ -1355,23 +1647,28 @@ int MockSSLClientSocket::Write(IOBuffer* buf, int buf_len, } int MockSSLClientSocket::Connect(const CompletionCallback& callback) { - next_connect_state_ = STATE_SSL_CONNECT; - reached_connect_ = true; - int rv = DoConnectLoop(OK); - if (rv == ERR_IO_PENDING) - connect_callback_ = callback; + int rv = transport_->socket()->Connect( + base::Bind(&ConnectCallback, base::Unretained(this), callback)); + if (rv == OK) { + if (data_->connect.result == OK) + connected_ = true; + if (data_->connect.mode == ASYNC) { + RunCallbackAsync(callback, data_->connect.result); + return ERR_IO_PENDING; + } + return data_->connect.result; + } return rv; } void MockSSLClientSocket::Disconnect() { - weak_factory_.InvalidateWeakPtrs(); MockClientSocket::Disconnect(); if (transport_->socket() != NULL) transport_->socket()->Disconnect(); } bool MockSSLClientSocket::IsConnected() const { - return transport_->socket()->IsConnected() && connected_; + return transport_->socket()->IsConnected(); } bool MockSSLClientSocket::WasEverUsed() const { @@ -1395,21 +1692,6 @@ bool MockSSLClientSocket::GetSSLInfo(SSLInfo* ssl_info) { return true; } -std::string MockSSLClientSocket::GetSessionCacheKey() const { - // For the purposes of these tests, |host_and_port| will serve as the - // cache key. - return host_port_pair_.ToString(); -} - -bool MockSSLClientSocket::InSessionCache() const { - return data_->is_in_session_cache; -} - -void MockSSLClientSocket::SetHandshakeCompletionCallback( - const base::Closure& cb) { - handshake_completion_callback_ = cb; -} - void MockSSLClientSocket::GetSSLCertRequestInfo( SSLCertRequestInfo* cert_request_info) { DCHECK(cert_request_info); @@ -1423,42 +1705,11 @@ void MockSSLClientSocket::GetSSLCertRequestInfo( } SSLClientSocket::NextProtoStatus MockSSLClientSocket::GetNextProto( - std::string* proto) { + std::string* proto) const { *proto = data_->next_proto; return data_->next_proto_status; } -bool MockSSLClientSocket::set_was_npn_negotiated(bool negotiated) { - is_npn_state_set_ = true; - return new_npn_value_ = negotiated; -} - -bool MockSSLClientSocket::WasNpnNegotiated() const { - if (is_npn_state_set_) - return new_npn_value_; - return data_->was_npn_negotiated; -} - -NextProto MockSSLClientSocket::GetNegotiatedProtocol() const { - if (is_protocol_negotiated_set_) - return protocol_negotiated_; - return data_->protocol_negotiated; -} - -void MockSSLClientSocket::set_protocol_negotiated( - NextProto protocol_negotiated) { - is_protocol_negotiated_set_ = true; - protocol_negotiated_ = protocol_negotiated; -} - -bool MockSSLClientSocket::WasChannelIDSent() const { - return data_->channel_id_sent; -} - -void MockSSLClientSocket::set_channel_id_sent(bool channel_id_sent) { - data_->channel_id_sent = channel_id_sent; -} - ChannelIDService* MockSSLClientSocket::GetChannelIDService() const { return data_->channel_id_service; } @@ -1467,71 +1718,12 @@ void MockSSLClientSocket::OnReadComplete(const MockRead& data) { NOTIMPLEMENTED(); } -void MockSSLClientSocket::OnConnectComplete(const MockConnect& data) { +void MockSSLClientSocket::OnWriteComplete(int rv) { NOTIMPLEMENTED(); } -void MockSSLClientSocket::RestartPausedConnect() { - DCHECK(data_->should_pause_on_connect); - DCHECK_EQ(next_connect_state_, STATE_SSL_CONNECT_COMPLETE); - OnIOComplete(data_->connect.result); -} - -void MockSSLClientSocket::OnIOComplete(int result) { - int rv = DoConnectLoop(result); - if (rv != ERR_IO_PENDING) - base::ResetAndReturn(&connect_callback_).Run(rv); -} - -int MockSSLClientSocket::DoConnectLoop(int result) { - DCHECK_NE(next_connect_state_, STATE_NONE); - - int rv = result; - do { - ConnectState state = next_connect_state_; - next_connect_state_ = STATE_NONE; - switch (state) { - case STATE_SSL_CONNECT: - rv = DoSSLConnect(); - break; - case STATE_SSL_CONNECT_COMPLETE: - rv = DoSSLConnectComplete(rv); - break; - default: - NOTREACHED() << "bad state"; - rv = ERR_UNEXPECTED; - break; - } - } while (rv != ERR_IO_PENDING && next_connect_state_ != STATE_NONE); - - return rv; -} - -int MockSSLClientSocket::DoSSLConnect() { - next_connect_state_ = STATE_SSL_CONNECT_COMPLETE; - - if (data_->should_pause_on_connect) - return ERR_IO_PENDING; - - if (data_->connect.mode == ASYNC) { - base::MessageLoop::current()->PostTask( - FROM_HERE, - base::Bind(&MockSSLClientSocket::OnIOComplete, - weak_factory_.GetWeakPtr(), - data_->connect.result)); - return ERR_IO_PENDING; - } - - return data_->connect.result; -} - -int MockSSLClientSocket::DoSSLConnectComplete(int result) { - if (result == OK) - connected_ = true; - - if (!handshake_completion_callback_.is_null()) - base::ResetAndReturn(&handshake_completion_callback_).Run(); - return result; +void MockSSLClientSocket::OnConnectComplete(const MockConnect& data) { + NOTIMPLEMENTED(); } MockUDPClientSocket::MockUDPClientSocket(SocketDataProvider* data, @@ -1542,9 +1734,9 @@ MockUDPClientSocket::MockUDPClientSocket(SocketDataProvider* data, read_data_(SYNCHRONOUS, ERR_UNEXPECTED), need_read_data_(true), source_port_(123), - pending_buf_(NULL), - pending_buf_len_(0), - net_log_(BoundNetLog::Make(net_log, net::NetLog::SOURCE_NONE)), + pending_read_buf_(NULL), + pending_read_buf_len_(0), + net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_NONE)), weak_factory_(this) { DCHECK(data_); data_->Reset(); @@ -1560,15 +1752,15 @@ int MockUDPClientSocket::Read(IOBuffer* buf, return ERR_UNEXPECTED; // If the buffer is already in use, a read is already in progress! - DCHECK(pending_buf_.get() == NULL); + DCHECK(pending_read_buf_.get() == NULL); // Store our async IO data. - pending_buf_ = buf; - pending_buf_len_ = buf_len; - pending_callback_ = callback; + pending_read_buf_ = buf; + pending_read_buf_len_ = buf_len; + pending_read_callback_ = callback; if (need_read_data_) { - read_data_ = data_->GetNextRead(); + read_data_ = data_->OnRead(); // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility // to complete the async IO manually later (via OnReadComplete). if (read_data_.result == ERR_IO_PENDING) { @@ -1593,6 +1785,12 @@ int MockUDPClientSocket::Write(IOBuffer* buf, int buf_len, std::string data(buf->data(), buf_len); MockWriteResult write_result = data_->OnWrite(data); + // ERR_IO_PENDING is a signal that the socket data will call back + // asynchronously. + if (write_result.result == ERR_IO_PENDING) { + pending_write_callback_ = callback; + return ERR_IO_PENDING; + } if (write_result.mode == ASYNC) { RunCallbackAsync(callback, write_result.result); return ERR_IO_PENDING; @@ -1637,7 +1835,7 @@ int MockUDPClientSocket::Connect(const IPEndPoint& address) { void MockUDPClientSocket::OnReadComplete(const MockRead& data) { // There must be a read pending. - DCHECK(pending_buf_.get()); + DCHECK(pending_read_buf_.get()); // You can't complete a read with another ERR_IO_PENDING status code. DCHECK_NE(ERR_IO_PENDING, data.result); // Since we've been waiting for data, need_read_data_ should be true. @@ -1650,26 +1848,33 @@ void MockUDPClientSocket::OnReadComplete(const MockRead& data) { // let CompleteRead() schedule a callback. read_data_.mode = SYNCHRONOUS; - net::CompletionCallback callback = pending_callback_; + CompletionCallback callback = pending_read_callback_; int rv = CompleteRead(); RunCallback(callback, rv); } +void MockUDPClientSocket::OnWriteComplete(int rv) { + // There must be a read pending. + DCHECK(!pending_write_callback_.is_null()); + CompletionCallback callback = pending_write_callback_; + RunCallback(callback, rv); +} + void MockUDPClientSocket::OnConnectComplete(const MockConnect& data) { NOTIMPLEMENTED(); } int MockUDPClientSocket::CompleteRead() { - DCHECK(pending_buf_.get()); - DCHECK(pending_buf_len_ > 0); + DCHECK(pending_read_buf_.get()); + DCHECK(pending_read_buf_len_ > 0); // Save the pending async IO data and reset our |pending_| state. - scoped_refptr<IOBuffer> buf = pending_buf_; - int buf_len = pending_buf_len_; - CompletionCallback callback = pending_callback_; - pending_buf_ = NULL; - pending_buf_len_ = 0; - pending_callback_.Reset(); + scoped_refptr<IOBuffer> buf = pending_read_buf_; + int buf_len = pending_read_buf_len_; + CompletionCallback callback = pending_read_callback_; + pending_read_buf_ = NULL; + pending_read_buf_len_ = 0; + pending_read_callback_.Reset(); int result = read_data_.result; DCHECK(result != ERR_IO_PENDING); @@ -1787,9 +1992,9 @@ MockTransportClientSocketPool::MockConnectJob::~MockConnectJob() {} int MockTransportClientSocketPool::MockConnectJob::Connect() { int rv = socket_->Connect(base::Bind(&MockConnectJob::OnConnect, base::Unretained(this))); - if (rv == OK) { + if (rv != ERR_IO_PENDING) { user_callback_.Reset(); - OnConnect(OK); + OnConnect(rv); } return rv; } @@ -1821,6 +2026,11 @@ void MockTransportClientSocketPool::MockConnectJob::OnConnect(int rv) { handle_->set_connect_timing(connect_timing); } else { socket_.reset(); + + // Needed to test copying of ConnectionAttempts in SSL ConnectJob. + ConnectionAttempts attempts; + attempts.push_back(ConnectionAttempt(IPEndPoint(), rv)); + handle_->set_connection_attempts(attempts); } handle_ = NULL; @@ -1835,10 +2045,12 @@ void MockTransportClientSocketPool::MockConnectJob::OnConnect(int rv) { MockTransportClientSocketPool::MockTransportClientSocketPool( int max_sockets, int max_sockets_per_group, - ClientSocketPoolHistograms* histograms, ClientSocketFactory* socket_factory) - : TransportClientSocketPool(max_sockets, max_sockets_per_group, histograms, - NULL, NULL, NULL), + : TransportClientSocketPool(max_sockets, + max_sockets_per_group, + NULL, + NULL, + NULL), client_socket_factory_(socket_factory), last_request_priority_(DEFAULT_PRIORITY), release_count_(0), @@ -1854,7 +2066,7 @@ int MockTransportClientSocketPool::RequestSocket( last_request_priority_ = priority; scoped_ptr<StreamSocket> socket = client_socket_factory_->CreateTransportClientSocket( - AddressList(), net_log.net_log(), net::NetLog::Source()); + AddressList(), net_log.net_log(), NetLog::Source()); MockConnectJob* job = new MockConnectJob(socket.Pass(), handle, callback); job_list_.push_back(job); handle->set_pool_id(1); @@ -1909,7 +2121,7 @@ scoped_ptr<DatagramClientSocket> DeterministicMockClientSocketFactory::CreateDatagramClientSocket( DatagramSocket::BindType bind_type, const RandIntCallback& rand_int_cb, - net::NetLog* net_log, + NetLog* net_log, const NetLog::Source& source) { DeterministicSocketData* data_provider = mock_data().GetNext(); scoped_ptr<DeterministicMockUDPClientSocket> socket( @@ -1917,15 +2129,15 @@ DeterministicMockClientSocketFactory::CreateDatagramClientSocket( data_provider->set_delegate(socket->AsWeakPtr()); udp_client_sockets().push_back(socket.get()); if (bind_type == DatagramSocket::RANDOM_BIND) - socket->set_source_port(rand_int_cb.Run(1025, 65535)); + socket->set_source_port(static_cast<uint16>(rand_int_cb.Run(1025, 65535))); return socket.Pass(); } scoped_ptr<StreamSocket> DeterministicMockClientSocketFactory::CreateTransportClientSocket( const AddressList& addresses, - net::NetLog* net_log, - const net::NetLog::Source& source) { + NetLog* net_log, + const NetLog::Source& source) { DeterministicSocketData* data_provider = mock_data().GetNext(); scoped_ptr<DeterministicMockTCPClientSocket> socket( new DeterministicMockTCPClientSocket(net_log, data_provider)); @@ -1954,10 +2166,12 @@ void DeterministicMockClientSocketFactory::ClearSSLSessionCache() { MockSOCKSClientSocketPool::MockSOCKSClientSocketPool( int max_sockets, int max_sockets_per_group, - ClientSocketPoolHistograms* histograms, TransportClientSocketPool* transport_pool) - : SOCKSClientSocketPool(max_sockets, max_sockets_per_group, histograms, - NULL, transport_pool, NULL), + : SOCKSClientSocketPool(max_sockets, + max_sockets_per_group, + NULL, + transport_pool, + NULL), transport_pool_(transport_pool) { } @@ -1983,6 +2197,21 @@ void MockSOCKSClientSocketPool::ReleaseSocket(const std::string& group_name, return transport_pool_->ReleaseSocket(group_name, socket.Pass(), id); } +ScopedWebSocketEndpointZeroUnlockDelay:: + ScopedWebSocketEndpointZeroUnlockDelay() { + old_delay_ = + WebSocketEndpointLockManager::GetInstance()->SetUnlockDelayForTesting( + base::TimeDelta()); +} + +ScopedWebSocketEndpointZeroUnlockDelay:: + ~ScopedWebSocketEndpointZeroUnlockDelay() { + base::TimeDelta active_delay = + WebSocketEndpointLockManager::GetInstance()->SetUnlockDelayForTesting( + old_delay_); + EXPECT_EQ(active_delay, base::TimeDelta()); +} + const char kSOCKS5GreetRequest[] = { 0x05, 0x01, 0x00 }; const int kSOCKS5GreetRequestLength = arraysize(kSOCKS5GreetRequest); diff --git a/chromium/net/socket/socket_test_util.h b/chromium/net/socket/socket_test_util.h index 7bccdaed727..1048c964751 100644 --- a/chromium/net/socket/socket_test_util.h +++ b/chromium/net/socket/socket_test_util.h @@ -18,13 +18,14 @@ #include "base/memory/scoped_vector.h" #include "base/memory/weak_ptr.h" #include "base/strings/string16.h" +#include "base/time/time.h" #include "net/base/address_list.h" #include "net/base/io_buffer.h" #include "net/base/net_errors.h" -#include "net/base/net_log.h" #include "net/base/test_completion_callback.h" #include "net/http/http_auth_controller.h" #include "net/http/http_proxy_client_socket_pool.h" +#include "net/log/net_log.h" #include "net/socket/client_socket_factory.h" #include "net/socket/client_socket_handle.h" #include "net/socket/socks_client_socket_pool.h" @@ -102,8 +103,7 @@ struct MockReadWrite { result(0), data(NULL), data_len(0), - sequence_number(0), - time_stamp(base::Time::Now()) {} + sequence_number(0) {} // Read/write failure (no data). MockReadWrite(IoMode io_mode, int result) @@ -111,8 +111,7 @@ struct MockReadWrite { result(result), data(NULL), data_len(0), - sequence_number(0), - time_stamp(base::Time::Now()) {} + sequence_number(0) {} // Read/write failure (no data), with sequence information. MockReadWrite(IoMode io_mode, int result, int seq) @@ -120,8 +119,7 @@ struct MockReadWrite { result(result), data(NULL), data_len(0), - sequence_number(seq), - time_stamp(base::Time::Now()) {} + sequence_number(seq) {} // Asynchronous read/write success (inferred data length). explicit MockReadWrite(const char* data) @@ -129,8 +127,7 @@ struct MockReadWrite { result(0), data(data), data_len(strlen(data)), - sequence_number(0), - time_stamp(base::Time::Now()) {} + sequence_number(0) {} // Read/write success (inferred data length). MockReadWrite(IoMode io_mode, const char* data) @@ -138,8 +135,7 @@ struct MockReadWrite { result(0), data(data), data_len(strlen(data)), - sequence_number(0), - time_stamp(base::Time::Now()) {} + sequence_number(0) {} // Read/write success. MockReadWrite(IoMode io_mode, const char* data, int data_len) @@ -147,8 +143,7 @@ struct MockReadWrite { result(0), data(data), data_len(data_len), - sequence_number(0), - time_stamp(base::Time::Now()) {} + sequence_number(0) {} // Read/write success (inferred data length) with sequence information. MockReadWrite(IoMode io_mode, int seq, const char* data) @@ -156,8 +151,7 @@ struct MockReadWrite { result(0), data(data), data_len(strlen(data)), - sequence_number(seq), - time_stamp(base::Time::Now()) {} + sequence_number(seq) {} // Read/write success with sequence information. MockReadWrite(IoMode io_mode, const char* data, int data_len, int seq) @@ -165,8 +159,7 @@ struct MockReadWrite { result(0), data(data), data_len(data_len), - sequence_number(seq), - time_stamp(base::Time::Now()) {} + sequence_number(seq) {} IoMode mode; int result; @@ -178,7 +171,6 @@ struct MockReadWrite { // an ERR_IO_PENDING is returned. int sequence_number; // The sequence number at which a read is allowed // to occur. - base::Time time_stamp; // The time stamp at which the operation occurred. }; typedef MockReadWrite<MOCK_READ> MockRead; @@ -203,9 +195,11 @@ class SocketDataProvider { // If the |MockRead.result| is ERR_IO_PENDING, it informs the caller // that it will be called via the AsyncSocket::OnReadComplete() // function at a later time. - virtual MockRead GetNextRead() = 0; + virtual MockRead OnRead() = 0; virtual MockWriteResult OnWrite(const std::string& data) = 0; virtual void Reset() = 0; + virtual bool AllReadDataConsumed() const = 0; + virtual bool AllWriteDataConsumed() const = 0; // Accessor for the socket which is using the SocketDataProvider. AsyncSocket* socket() { return socket_; } @@ -230,27 +224,42 @@ class AsyncSocket { // is called to complete the asynchronous read operation. // data.async is ignored, and this read is completed synchronously as // part of this call. + // TODO(rch): this should take a StringPiece since most of the fields + // are ignored. virtual void OnReadComplete(const MockRead& data) = 0; + // If an async IO is pending because the SocketDataProvider returned + // ERR_IO_PENDING, then the AsyncSocket waits until this OnReadComplete + // is called to complete the asynchronous read operation. + virtual void OnWriteComplete(int rv) = 0; virtual void OnConnectComplete(const MockConnect& data) = 0; }; -// SocketDataProvider which responds based on static tables of mock reads and -// writes. -class StaticSocketDataProvider : public SocketDataProvider { +// StaticSocketDataHelper manages a list of reads and writes. +class StaticSocketDataHelper { public: - StaticSocketDataProvider(); - StaticSocketDataProvider(MockRead* reads, - size_t reads_count, - MockWrite* writes, - size_t writes_count); - ~StaticSocketDataProvider() override; - - // These functions get access to the next available read and write data. + StaticSocketDataHelper(MockRead* reads, + size_t reads_count, + MockWrite* writes, + size_t writes_count); + ~StaticSocketDataHelper(); + + // These functions get access to the next available read and write data, + // or null if there is no more data available. const MockRead& PeekRead() const; const MockWrite& PeekWrite() const; - // These functions get random access to the read and write data, for timing. - const MockRead& PeekRead(size_t index) const; - const MockWrite& PeekWrite(size_t index) const; + + // Returns the current read or write , and then advances to the next one. + const MockRead& AdvanceRead(); + const MockWrite& AdvanceWrite(); + + // Resets the read and write indexes to 0. + void Reset(); + + // Returns true if |data| is valid data for the next write. In order + // to support short writes, the next write may be longer than |data| + // in which case this method will still return true. + bool VerifyWriteData(const std::string& data); + size_t read_index() const { return read_index_; } size_t write_index() const { return write_index_; } size_t read_count() const { return read_count_; } @@ -259,13 +268,6 @@ class StaticSocketDataProvider : public SocketDataProvider { bool at_read_eof() const { return read_index_ >= read_count_; } bool at_write_eof() const { return write_index_ >= write_count_; } - virtual void CompleteRead() {} - - // SocketDataProvider implementation. - MockRead GetNextRead() override; - MockWriteResult OnWrite(const std::string& data) override; - void Reset() override; - private: MockRead* reads_; size_t read_index_; @@ -274,6 +276,43 @@ class StaticSocketDataProvider : public SocketDataProvider { size_t write_index_; size_t write_count_; + DISALLOW_COPY_AND_ASSIGN(StaticSocketDataHelper); +}; + +// SocketDataProvider which responds based on static tables of mock reads and +// writes. +class StaticSocketDataProvider : public SocketDataProvider { + public: + StaticSocketDataProvider(); + StaticSocketDataProvider(MockRead* reads, + size_t reads_count, + MockWrite* writes, + size_t writes_count); + ~StaticSocketDataProvider() override; + + virtual void CompleteRead() {} + + // SocketDataProvider implementation. + MockRead OnRead() override; + MockWriteResult OnWrite(const std::string& data) override; + void Reset() override; + bool AllReadDataConsumed() const override; + bool AllWriteDataConsumed() const override; + + size_t read_index() const { return helper_.read_index(); } + size_t write_index() const { return helper_.write_index(); } + size_t read_count() const { return helper_.read_count(); } + size_t write_count() const { return helper_.write_count(); } + + bool at_read_eof() const { return helper_.at_read_eof(); } + bool at_write_eof() const { return helper_.at_write_eof(); } + + protected: + StaticSocketDataHelper* helper() { return &helper_; } + + private: + StaticSocketDataHelper helper_; + DISALLOW_COPY_AND_ASSIGN(StaticSocketDataProvider); }; @@ -292,8 +331,8 @@ class DynamicSocketDataProvider : public SocketDataProvider { void allow_unconsumed_reads(bool allow) { allow_unconsumed_reads_ = allow; } // SocketDataProvider implementation. - MockRead GetNextRead() override; - virtual MockWriteResult OnWrite(const std::string& data) = 0; + MockRead OnRead() override; + MockWriteResult OnWrite(const std::string& data) override = 0; void Reset() override; protected: @@ -326,20 +365,13 @@ struct SSLSocketDataProvider { MockConnect connect; SSLClientSocket::NextProtoStatus next_proto_status; std::string next_proto; - bool was_npn_negotiated; - NextProto protocol_negotiated; + NextProtoVector next_protos_expected_in_ssl_config; bool client_cert_sent; SSLCertRequestInfo* cert_request_info; scoped_refptr<X509Certificate> cert; bool channel_id_sent; ChannelIDService* channel_id_service; int connection_status; - // Indicates that the socket should pause in the Connect method. - bool should_pause_on_connect; - // Whether or not the Socket should behave like there is a pre-existing - // session to resume. Whether or not such a session is reported as - // resumed is controlled by |connection_status|. - bool is_in_session_cache; }; // A DataProvider where the client must write a request before the reads (e.g. @@ -376,7 +408,7 @@ class DelayedSocketData : public StaticSocketDataProvider { void ForceNextRead(); // StaticSocketDataProvider: - MockRead GetNextRead() override; + MockRead OnRead() override; MockWriteResult OnWrite(const std::string& data) override; void Reset() override; void CompleteRead() override; @@ -430,7 +462,7 @@ class OrderedSocketData : public StaticSocketDataProvider { void EndLoop(); // StaticSocketDataProvider: - MockRead GetNextRead() override; + MockRead OnRead() override; MockWriteResult OnWrite(const std::string& data) override; void Reset() override; void CompleteRead() override; @@ -445,6 +477,65 @@ class OrderedSocketData : public StaticSocketDataProvider { DISALLOW_COPY_AND_ASSIGN(OrderedSocketData); }; +// Uses the sequence_number field in the mock reads and writes to +// complete the operations in a specified order. +class SequencedSocketData : public SocketDataProvider { + public: + // |reads| is the list of MockRead completions. + // |writes| is the list of MockWrite completions. + SequencedSocketData(MockRead* reads, + size_t reads_count, + MockWrite* writes, + size_t writes_count); + + // |connect| is the result for the connect phase. + // |reads| is the list of MockRead completions. + // |writes| is the list of MockWrite completions. + SequencedSocketData(const MockConnect& connect, + MockRead* reads, + size_t reads_count, + MockWrite* writes, + size_t writes_count); + + ~SequencedSocketData() override; + + // SocketDataProviderBase implementation. + MockRead OnRead() override; + MockWriteResult OnWrite(const std::string& data) override; + void Reset() override; + bool AllReadDataConsumed() const override; + bool AllWriteDataConsumed() const override; + + // Returns true if all data has been read. + bool at_read_eof() const; + + // Returns true if all data has been written. + bool at_write_eof() const; + + private: + // Defines the state for the read or write path. + enum IoState { + IDLE, // No async operation is in progress. + PENDING, // An async operation in waiting for another opteration to + // complete. + COMPLETING, // A task has been posted to complet an async operation. + }; + void OnReadComplete(); + void OnWriteComplete(); + + void MaybePostReadCompleteTask(); + void MaybePostWriteCompleteTask(); + + StaticSocketDataHelper helper_; + int sequence_number_; + IoState read_state_; + IoState write_state_; + + base::WeakPtrFactory<SequencedSocketData> weak_factory_; + + DISALLOW_COPY_AND_ASSIGN(SequencedSocketData); +}; + class DeterministicMockTCPClientSocket; // This class gives the user full control over the network activity, @@ -558,9 +649,9 @@ class DeterministicSocketData : public StaticSocketDataProvider { // StaticSocketDataProvider: - // When the socket calls Read(), that calls GetNextRead(), and expects either + // When the socket calls Read(), that calls OnRead(), and expects either // ERR_IO_PENDING or data. - MockRead GetNextRead() override; + MockRead OnRead() override; // When the socket calls Write(), it always completes synchronously. OnWrite() // checks to make sure the written data matches the expected data. The @@ -643,12 +734,6 @@ class MockClientSocketFactory : public ClientSocketFactory { return mock_data_; } - // Note: this method is unsafe; the elements of the returned vector - // are not necessarily valid. - const std::vector<MockSSLClientSocket*>& ssl_client_sockets() const { - return ssl_client_sockets_; - } - // ClientSocketFactory scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( DatagramSocket::BindType bind_type, @@ -669,7 +754,6 @@ class MockClientSocketFactory : public ClientSocketFactory { private: SocketDataProviderArray<SocketDataProvider> mock_data_; SocketDataProviderArray<SSLSocketDataProvider> mock_ssl_data_; - std::vector<MockSSLClientSocket*> ssl_client_sockets_; }; class MockClientSocket : public SSLClientSocket { @@ -682,17 +766,17 @@ class MockClientSocket : public SSLClientSocket { explicit MockClientSocket(const BoundNetLog& net_log); // Socket implementation. - virtual int Read(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) = 0; - virtual int Write(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) = 0; + int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override = 0; + int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override = 0; int SetReceiveBufferSize(int32 size) override; int SetSendBufferSize(int32 size) override; // StreamSocket implementation. - virtual int Connect(const CompletionCallback& callback) = 0; + int Connect(const CompletionCallback& callback) override = 0; void Disconnect() override; bool IsConnected() const override; bool IsConnectedAndIdle() const override; @@ -701,11 +785,11 @@ class MockClientSocket : public SSLClientSocket { const BoundNetLog& NetLog() const override; void SetSubresourceSpeculation() override {} void SetOmniboxSpeculation() override {} + void GetConnectionAttempts(ConnectionAttempts* out) const override; + void ClearConnectionAttempts() override {} + void AddConnectionAttempts(const ConnectionAttempts& attempts) override {} // SSLClientSocket implementation. - std::string GetSessionCacheKey() const override; - bool InSessionCache() const override; - void SetHandshakeCompletionCallback(const base::Closure& cb) override; void GetSSLCertRequestInfo(SSLCertRequestInfo* cert_request_info) override; int ExportKeyingMaterial(const base::StringPiece& label, bool has_context, @@ -713,8 +797,9 @@ class MockClientSocket : public SSLClientSocket { unsigned char* out, unsigned int outlen) override; int GetTLSUniqueChannelBinding(std::string* out) override; - NextProtoStatus GetNextProto(std::string* proto) override; + NextProtoStatus GetNextProto(std::string* proto) const override; ChannelIDService* GetChannelIDService() const override; + SSLFailureState GetSSLFailureState() const override; protected: ~MockClientSocket() override; @@ -766,9 +851,13 @@ class MockTCPClientSocket : public MockClientSocket, public AsyncSocket { bool UsingTCPFastOpen() const override; bool WasNpnNegotiated() const override; bool GetSSLInfo(SSLInfo* ssl_info) override; + void GetConnectionAttempts(ConnectionAttempts* out) const override; + void ClearConnectionAttempts() override; + void AddConnectionAttempts(const ConnectionAttempts& attempts) override; // AsyncSocket: void OnReadComplete(const MockRead& data) override; + void OnWriteComplete(int rv) override; void OnConnectComplete(const MockConnect& data) override; private: @@ -786,25 +875,25 @@ class MockTCPClientSocket : public MockClientSocket, public AsyncSocket { // TCPClientSocket. bool peer_closed_connection_; - // While an asynchronous IO is pending, we save our user-buffer state. - scoped_refptr<IOBuffer> pending_buf_; - int pending_buf_len_; - CompletionCallback pending_callback_; + // While an asynchronous read is pending, we save our user-buffer state. + scoped_refptr<IOBuffer> pending_read_buf_; + int pending_read_buf_len_; + CompletionCallback pending_read_callback_; + CompletionCallback pending_write_callback_; bool was_used_to_convey_data_; DISALLOW_COPY_AND_ASSIGN(MockTCPClientSocket); }; // DeterministicSocketHelper is a helper class that can be used -// to simulate net::Socket::Read() and net::Socket::Write() +// to simulate Socket::Read() and Socket::Write() // using deterministic |data|. // Note: This is provided as a common helper class because // of the inheritance hierarchy of DeterministicMock[UDP,TCP]ClientSocket and a // desire not to introduce an additional common base class. class DeterministicSocketHelper { public: - DeterministicSocketHelper(net::NetLog* net_log, - DeterministicSocketData* data); + DeterministicSocketHelper(NetLog* net_log, DeterministicSocketData* data); virtual ~DeterministicSocketHelper(); bool write_pending() const { return write_pending_; } @@ -879,15 +968,16 @@ class DeterministicMockUDPClientSocket // AsyncSocket implementation. void OnReadComplete(const MockRead& data) override; + void OnWriteComplete(int rv) override; void OnConnectComplete(const MockConnect& data) override; - void set_source_port(int port) { source_port_ = port; } + void set_source_port(uint16 port) { source_port_ = port; } private: bool connected_; IPEndPoint peer_address_; DeterministicSocketHelper helper_; - int source_port_; // Ephemeral source port. + uint16 source_port_; // Ephemeral source port. DISALLOW_COPY_AND_ASSIGN(DeterministicMockUDPClientSocket); }; @@ -929,6 +1019,7 @@ class DeterministicMockTCPClientSocket // AsyncSocket: void OnReadComplete(const MockRead& data) override; + void OnWriteComplete(int rv) override; void OnConnectComplete(const MockConnect& data) override; private: @@ -960,65 +1051,26 @@ class MockSSLClientSocket : public MockClientSocket, public AsyncSocket { bool WasEverUsed() const override; bool UsingTCPFastOpen() const override; int GetPeerAddress(IPEndPoint* address) const override; - bool WasNpnNegotiated() const override; bool GetSSLInfo(SSLInfo* ssl_info) override; // SSLClientSocket implementation. - std::string GetSessionCacheKey() const override; - bool InSessionCache() const override; - void SetHandshakeCompletionCallback(const base::Closure& cb) override; void GetSSLCertRequestInfo(SSLCertRequestInfo* cert_request_info) override; - NextProtoStatus GetNextProto(std::string* proto) override; - bool set_was_npn_negotiated(bool negotiated) override; - void set_protocol_negotiated(NextProto protocol_negotiated) override; - NextProto GetNegotiatedProtocol() const override; + NextProtoStatus GetNextProto(std::string* proto) const override; // This MockSocket does not implement the manual async IO feature. void OnReadComplete(const MockRead& data) override; + void OnWriteComplete(int rv) override; void OnConnectComplete(const MockConnect& data) override; - bool WasChannelIDSent() const override; - void set_channel_id_sent(bool channel_id_sent) override; ChannelIDService* GetChannelIDService() const override; - bool reached_connect() const { return reached_connect_; } - - // Resumes the connection of a socket that was paused for testing. - // |connect_callback_| should be set before invoking this method. - void RestartPausedConnect(); - private: - enum ConnectState { - STATE_NONE, - STATE_SSL_CONNECT, - STATE_SSL_CONNECT_COMPLETE, - }; - - void OnIOComplete(int result); - - // Runs the state transistion loop. - int DoConnectLoop(int result); - - int DoSSLConnect(); - int DoSSLConnectComplete(int result); + static void ConnectCallback(MockSSLClientSocket* ssl_client_socket, + const CompletionCallback& callback, + int rv); scoped_ptr<ClientSocketHandle> transport_; - HostPortPair host_port_pair_; SSLSocketDataProvider* data_; - bool is_npn_state_set_; - bool new_npn_value_; - bool is_protocol_negotiated_set_; - NextProto protocol_negotiated_; - - CompletionCallback connect_callback_; - // Indicates what state of Connect the socket should enter. - ConnectState next_connect_state_; - // True if the Connect method has been called on the socket. - bool reached_connect_; - - base::Closure handshake_completion_callback_; - - base::WeakPtrFactory<MockSSLClientSocket> weak_factory_; DISALLOW_COPY_AND_ASSIGN(MockSSLClientSocket); }; @@ -1049,9 +1101,10 @@ class MockUDPClientSocket : public DatagramClientSocket, public AsyncSocket { // AsyncSocket implementation. void OnReadComplete(const MockRead& data) override; + void OnWriteComplete(int rv) override; void OnConnectComplete(const MockConnect& data) override; - void set_source_port(int port) { source_port_ = port;} + void set_source_port(uint16 port) { source_port_ = port;} private: int CompleteRead(); @@ -1064,15 +1117,16 @@ class MockUDPClientSocket : public DatagramClientSocket, public AsyncSocket { int read_offset_; MockRead read_data_; bool need_read_data_; - int source_port_; // Ephemeral source port. + uint16 source_port_; // Ephemeral source port. // Address of the "remote" peer we're connected to. IPEndPoint peer_addr_; // While an asynchronous IO is pending, we save our user-buffer state. - scoped_refptr<IOBuffer> pending_buf_; - int pending_buf_len_; - CompletionCallback pending_callback_; + scoped_refptr<IOBuffer> pending_read_buf_; + int pending_read_buf_len_; + CompletionCallback pending_read_callback_; + CompletionCallback pending_write_callback_; BoundNetLog net_log_; @@ -1089,7 +1143,7 @@ class TestSocketRequest : public TestCompletionCallbackBase { ClientSocketHandle* handle() { return &handle_; } - const net::CompletionCallback& callback() const { return callback_; } + const CompletionCallback& callback() const { return callback_; } private: void OnComplete(int result); @@ -1202,7 +1256,6 @@ class MockTransportClientSocketPool : public TransportClientSocketPool { MockTransportClientSocketPool(int max_sockets, int max_sockets_per_group, - ClientSocketPoolHistograms* histograms, ClientSocketFactory* socket_factory); ~MockTransportClientSocketPool() override; @@ -1293,7 +1346,6 @@ class MockSOCKSClientSocketPool : public SOCKSClientSocketPool { public: MockSOCKSClientSocketPool(int max_sockets, int max_sockets_per_group, - ClientSocketPoolHistograms* histograms, TransportClientSocketPool* transport_pool); ~MockSOCKSClientSocketPool() override; @@ -1318,6 +1370,18 @@ class MockSOCKSClientSocketPool : public SOCKSClientSocketPool { DISALLOW_COPY_AND_ASSIGN(MockSOCKSClientSocketPool); }; +// Convenience class to temporarily set the WebSocketEndpointLockManager unlock +// delay to zero for testing purposes. Automatically restores the original value +// when destroyed. +class ScopedWebSocketEndpointZeroUnlockDelay { + public: + ScopedWebSocketEndpointZeroUnlockDelay(); + ~ScopedWebSocketEndpointZeroUnlockDelay(); + + private: + base::TimeDelta old_delay_; +}; + // Constants for a successful SOCKS v5 handshake. extern const char kSOCKS5GreetRequest[]; extern const int kSOCKS5GreetRequestLength; diff --git a/chromium/net/socket/socks5_client_socket.cc b/chromium/net/socket/socks5_client_socket.cc index 681f73f26e9..4ac9ca59656 100644 --- a/chromium/net/socket/socks5_client_socket.cc +++ b/chromium/net/socket/socks5_client_socket.cc @@ -7,13 +7,13 @@ #include "base/basictypes.h" #include "base/callback_helpers.h" #include "base/compiler_specific.h" -#include "base/debug/trace_event.h" #include "base/format_macros.h" #include "base/strings/string_util.h" #include "base/sys_byteorder.h" +#include "base/trace_event/trace_event.h" #include "net/base/io_buffer.h" -#include "net/base/net_log.h" #include "net/base/net_util.h" +#include "net/log/net_log.h" #include "net/socket/client_socket_handle.h" namespace net { @@ -25,8 +25,8 @@ const uint8 SOCKS5ClientSocket::kSOCKS5Version = 0x05; const uint8 SOCKS5ClientSocket::kTunnelCommand = 0x01; const uint8 SOCKS5ClientSocket::kNullByte = 0x00; -COMPILE_ASSERT(sizeof(struct in_addr) == 4, incorrect_system_size_of_IPv4); -COMPILE_ASSERT(sizeof(struct in6_addr) == 16, incorrect_system_size_of_IPv6); +static_assert(sizeof(struct in_addr) == 4, "incorrect system size of IPv4"); +static_assert(sizeof(struct in6_addr) == 16, "incorrect system size of IPv6"); SOCKS5ClientSocket::SOCKS5ClientSocket( scoped_ptr<ClientSocketHandle> transport_socket, @@ -144,7 +144,10 @@ bool SOCKS5ClientSocket::GetSSLInfo(SSLInfo* ssl_info) { } NOTREACHED(); return false; +} +void SOCKS5ClientSocket::GetConnectionAttempts(ConnectionAttempts* out) const { + out->clear(); } // Read is called by the transport layer above to read. This can only be done diff --git a/chromium/net/socket/socks5_client_socket.h b/chromium/net/socket/socks5_client_socket.h index a405212b56b..d54e790b0fd 100644 --- a/chromium/net/socket/socks5_client_socket.h +++ b/chromium/net/socket/socks5_client_socket.h @@ -14,8 +14,8 @@ #include "net/base/address_list.h" #include "net/base/completion_callback.h" #include "net/base/net_errors.h" -#include "net/base/net_log.h" #include "net/dns/host_resolver.h" +#include "net/log/net_log.h" #include "net/socket/stream_socket.h" #include "url/gurl.h" @@ -55,6 +55,9 @@ class NET_EXPORT_PRIVATE SOCKS5ClientSocket : public StreamSocket { bool WasNpnNegotiated() const override; NextProto GetNegotiatedProtocol() const override; bool GetSSLInfo(SSLInfo* ssl_info) override; + void GetConnectionAttempts(ConnectionAttempts* out) const override; + void ClearConnectionAttempts() override {} + void AddConnectionAttempts(const ConnectionAttempts& attempts) override {} // Socket implementation. int Read(IOBuffer* buf, diff --git a/chromium/net/socket/socks5_client_socket_unittest.cc b/chromium/net/socket/socks5_client_socket_unittest.cc index c474a0b4198..76146c7ab83 100644 --- a/chromium/net/socket/socks5_client_socket_unittest.cc +++ b/chromium/net/socket/socks5_client_socket_unittest.cc @@ -10,11 +10,13 @@ #include "base/sys_byteorder.h" #include "net/base/address_list.h" -#include "net/base/net_log.h" -#include "net/base/net_log_unittest.h" #include "net/base/test_completion_callback.h" #include "net/base/winsock_init.h" #include "net/dns/mock_host_resolver.h" +#include "net/log/net_log.h" +#include "net/log/test_net_log.h" +#include "net/log/test_net_log_entry.h" +#include "net/log/test_net_log_util.h" #include "net/socket/client_socket_factory.h" #include "net/socket/socket_test_util.h" #include "net/socket/tcp_client_socket.h" @@ -44,7 +46,7 @@ class SOCKS5ClientSocketTest : public PlatformTest { protected: const uint16 kNwPort; - CapturingNetLog net_log_; + TestNetLog net_log_; scoped_ptr<SOCKS5ClientSocket> user_sock_; AddressList address_list_; // Filled in by BuildMockSocket() and owned by its return value @@ -146,7 +148,7 @@ TEST_F(SOCKS5ClientSocketTest, CompleteHandshake) { EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(user_sock_->IsConnected()); - CapturingNetLog::CapturedEntryList net_log_entries; + TestNetLogEntry::List net_log_entries; net_log_.GetEntries(&net_log_entries); EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0, NetLog::TYPE_SOCKS5_CONNECT)); @@ -258,7 +260,7 @@ TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) { const char partial1[] = { 0x05, 0x01 }; const char partial2[] = { 0x00 }; MockWrite data_writes[] = { - MockWrite(ASYNC, arraysize(partial1)), + MockWrite(ASYNC, partial1, arraysize(partial1)), MockWrite(ASYNC, partial2, arraysize(partial2)), MockWrite(ASYNC, kOkRequest, arraysize(kOkRequest)) }; MockRead data_reads[] = { @@ -270,7 +272,7 @@ TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) { int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); - CapturingNetLog::CapturedEntryList net_log_entries; + TestNetLogEntry::List net_log_entries; net_log_.GetEntries(&net_log_entries); EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0, NetLog::TYPE_SOCKS5_CONNECT)); @@ -301,7 +303,7 @@ TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) { int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); - CapturingNetLog::CapturedEntryList net_log_entries; + TestNetLogEntry::List net_log_entries; net_log_.GetEntries(&net_log_entries); EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0, NetLog::TYPE_SOCKS5_CONNECT)); @@ -330,7 +332,7 @@ TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) { hostname, 80, &net_log_); int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); - CapturingNetLog::CapturedEntryList net_log_entries; + TestNetLogEntry::List net_log_entries; net_log_.GetEntries(&net_log_entries); EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0, NetLog::TYPE_SOCKS5_CONNECT)); @@ -361,7 +363,7 @@ TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) { hostname, 80, &net_log_); int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); - CapturingNetLog::CapturedEntryList net_log_entries; + TestNetLogEntry::List net_log_entries; net_log_.GetEntries(&net_log_entries); EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0, NetLog::TYPE_SOCKS5_CONNECT)); diff --git a/chromium/net/socket/socks_client_socket.cc b/chromium/net/socket/socks_client_socket.cc index f7c69f28fd6..dbdc0251768 100644 --- a/chromium/net/socket/socks_client_socket.cc +++ b/chromium/net/socket/socks_client_socket.cc @@ -10,8 +10,8 @@ #include "base/compiler_specific.h" #include "base/sys_byteorder.h" #include "net/base/io_buffer.h" -#include "net/base/net_log.h" #include "net/base/net_util.h" +#include "net/log/net_log.h" #include "net/socket/client_socket_handle.h" namespace net { @@ -43,8 +43,8 @@ struct SOCKS4ServerRequest { uint16 nw_port; uint8 ip[4]; }; -COMPILE_ASSERT(sizeof(SOCKS4ServerRequest) == kWriteHeaderSize, - socks4_server_request_struct_wrong_size); +static_assert(sizeof(SOCKS4ServerRequest) == kWriteHeaderSize, + "socks4 server request struct has incorrect size"); // A struct holding details of the SOCKS4 Server Response. struct SOCKS4ServerResponse { @@ -53,8 +53,8 @@ struct SOCKS4ServerResponse { uint16 port; uint8 ip[4]; }; -COMPILE_ASSERT(sizeof(SOCKS4ServerResponse) == kReadHeaderSize, - socks4_server_response_struct_wrong_size); +static_assert(sizeof(SOCKS4ServerResponse) == kReadHeaderSize, + "socks4 server response struct has incorrect size"); SOCKSClientSocket::SOCKSClientSocket( scoped_ptr<ClientSocketHandle> transport_socket, @@ -172,7 +172,10 @@ bool SOCKSClientSocket::GetSSLInfo(SSLInfo* ssl_info) { } NOTREACHED(); return false; +} +void SOCKSClientSocket::GetConnectionAttempts(ConnectionAttempts* out) const { + out->clear(); } // Read is called by the transport layer above to read. This can only be done diff --git a/chromium/net/socket/socks_client_socket.h b/chromium/net/socket/socks_client_socket.h index e792881cc7f..ee2918a6094 100644 --- a/chromium/net/socket/socks_client_socket.h +++ b/chromium/net/socket/socks_client_socket.h @@ -14,9 +14,9 @@ #include "net/base/address_list.h" #include "net/base/completion_callback.h" #include "net/base/net_errors.h" -#include "net/base/net_log.h" #include "net/dns/host_resolver.h" #include "net/dns/single_request_host_resolver.h" +#include "net/log/net_log.h" #include "net/socket/stream_socket.h" namespace net { @@ -52,6 +52,9 @@ class NET_EXPORT_PRIVATE SOCKSClientSocket : public StreamSocket { bool WasNpnNegotiated() const override; NextProto GetNegotiatedProtocol() const override; bool GetSSLInfo(SSLInfo* ssl_info) override; + void GetConnectionAttempts(ConnectionAttempts* out) const override; + void ClearConnectionAttempts() override {} + void AddConnectionAttempts(const ConnectionAttempts& attempts) override {} // Socket implementation. int Read(IOBuffer* buf, diff --git a/chromium/net/socket/socks_client_socket_pool.cc b/chromium/net/socket/socks_client_socket_pool.cc index e11b7a48db5..b57a04ba1fe 100644 --- a/chromium/net/socket/socks_client_socket_pool.cc +++ b/chromium/net/socket/socks_client_socket_pool.cc @@ -192,17 +192,17 @@ SOCKSClientSocketPool::SOCKSConnectJobFactory::ConnectionTimeout() const { SOCKSClientSocketPool::SOCKSClientSocketPool( int max_sockets, int max_sockets_per_group, - ClientSocketPoolHistograms* histograms, HostResolver* host_resolver, TransportClientSocketPool* transport_pool, NetLog* net_log) : transport_pool_(transport_pool), - base_(this, max_sockets, max_sockets_per_group, histograms, - ClientSocketPool::unused_idle_socket_timeout(), - ClientSocketPool::used_idle_socket_timeout(), - new SOCKSConnectJobFactory(transport_pool, - host_resolver, - net_log)) { + base_( + this, + max_sockets, + max_sockets_per_group, + ClientSocketPool::unused_idle_socket_timeout(), + ClientSocketPool::used_idle_socket_timeout(), + new SOCKSConnectJobFactory(transport_pool, host_resolver, net_log)) { // We should always have a |transport_pool_| except in unit tests. if (transport_pool_) base_.AddLowerLayeredPool(transport_pool_); @@ -272,11 +272,11 @@ base::DictionaryValue* SOCKSClientSocketPool::GetInfoAsValue( bool include_nested_pools) const { base::DictionaryValue* dict = base_.GetInfoAsValue(name, type); if (include_nested_pools) { - base::ListValue* list = new base::ListValue(); + scoped_ptr<base::ListValue> list(new base::ListValue()); list->Append(transport_pool_->GetInfoAsValue("transport_socket_pool", "transport_socket_pool", false)); - dict->Set("nested_pools", list); + dict->Set("nested_pools", list.Pass()); } return dict; } @@ -285,10 +285,6 @@ base::TimeDelta SOCKSClientSocketPool::ConnectionTimeout() const { return base_.ConnectionTimeout(); } -ClientSocketPoolHistograms* SOCKSClientSocketPool::histograms() const { - return base_.histograms(); -}; - bool SOCKSClientSocketPool::IsStalled() const { return base_.IsStalled(); } diff --git a/chromium/net/socket/socks_client_socket_pool.h b/chromium/net/socket/socks_client_socket_pool.h index 35f7146f967..69bcf00f076 100644 --- a/chromium/net/socket/socks_client_socket_pool.h +++ b/chromium/net/socket/socks_client_socket_pool.h @@ -16,7 +16,6 @@ #include "net/dns/host_resolver.h" #include "net/socket/client_socket_pool.h" #include "net/socket/client_socket_pool_base.h" -#include "net/socket/client_socket_pool_histograms.h" namespace net { @@ -112,7 +111,6 @@ class NET_EXPORT_PRIVATE SOCKSClientSocketPool SOCKSClientSocketPool( int max_sockets, int max_sockets_per_group, - ClientSocketPoolHistograms* histograms, HostResolver* host_resolver, TransportClientSocketPool* transport_pool, NetLog* net_log); @@ -157,8 +155,6 @@ class NET_EXPORT_PRIVATE SOCKSClientSocketPool base::TimeDelta ConnectionTimeout() const override; - ClientSocketPoolHistograms* histograms() const override; - // LowerLayeredPool implementation. bool IsStalled() const override; diff --git a/chromium/net/socket/socks_client_socket_pool_unittest.cc b/chromium/net/socket/socks_client_socket_pool_unittest.cc index 391d31beddb..e841c83635d 100644 --- a/chromium/net/socket/socks_client_socket_pool_unittest.cc +++ b/chromium/net/socket/socks_client_socket_pool_unittest.cc @@ -14,7 +14,6 @@ #include "net/dns/mock_host_resolver.h" #include "net/socket/client_socket_factory.h" #include "net/socket/client_socket_handle.h" -#include "net/socket/client_socket_pool_histograms.h" #include "net/socket/socket_test_util.h" #include "testing/gtest/include/gtest/gtest.h" @@ -90,18 +89,14 @@ class SOCKSClientSocketPoolTest : public testing::Test { }; SOCKSClientSocketPoolTest() - : transport_histograms_("MockTCP"), - transport_socket_pool_( - kMaxSockets, kMaxSocketsPerGroup, - &transport_histograms_, - &transport_client_socket_factory_), - socks_histograms_("SOCKSUnitTest"), - pool_(kMaxSockets, kMaxSocketsPerGroup, - &socks_histograms_, + : transport_socket_pool_(kMaxSockets, + kMaxSocketsPerGroup, + &transport_client_socket_factory_), + pool_(kMaxSockets, + kMaxSocketsPerGroup, &host_resolver_, &transport_socket_pool_, - NULL) { - } + NULL) {} ~SOCKSClientSocketPoolTest() override {} @@ -116,11 +111,9 @@ class SOCKSClientSocketPoolTest : public testing::Test { ScopedVector<TestSocketRequest>* requests() { return test_base_.requests(); } - ClientSocketPoolHistograms transport_histograms_; MockClientSocketFactory transport_client_socket_factory_; MockTransportClientSocketPool transport_socket_pool_; - ClientSocketPoolHistograms socks_histograms_; MockHostResolver host_resolver_; SOCKSClientSocketPool pool_; ClientSocketPoolTest test_base_; diff --git a/chromium/net/socket/socks_client_socket_unittest.cc b/chromium/net/socket/socks_client_socket_unittest.cc index fbb84f8f50a..27b4c70aa95 100644 --- a/chromium/net/socket/socks_client_socket_unittest.cc +++ b/chromium/net/socket/socks_client_socket_unittest.cc @@ -6,12 +6,14 @@ #include "base/memory/scoped_ptr.h" #include "net/base/address_list.h" -#include "net/base/net_log.h" -#include "net/base/net_log_unittest.h" #include "net/base/test_completion_callback.h" #include "net/base/winsock_init.h" #include "net/dns/host_resolver.h" #include "net/dns/mock_host_resolver.h" +#include "net/log/net_log.h" +#include "net/log/test_net_log.h" +#include "net/log/test_net_log_entry.h" +#include "net/log/test_net_log_util.h" #include "net/socket/client_socket_factory.h" #include "net/socket/socket_test_util.h" #include "net/socket/tcp_client_socket.h" @@ -144,7 +146,7 @@ TEST_F(SOCKSClientSocketTest, CompleteHandshake) { MockRead data_reads[] = { MockRead(ASYNC, kSOCKSOkReply, arraysize(kSOCKSOkReply)), MockRead(ASYNC, payload_read.data(), payload_read.size()) }; - CapturingNetLog log; + TestNetLog log; user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), data_writes, arraysize(data_writes), @@ -159,7 +161,7 @@ TEST_F(SOCKSClientSocketTest, CompleteHandshake) { int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); - CapturingNetLog::CapturedEntryList entries; + TestNetLogEntry::List entries; log.GetEntries(&entries); EXPECT_TRUE( LogContainsBeginEvent(entries, 0, NetLog::TYPE_SOCKS_CONNECT)); @@ -220,7 +222,7 @@ TEST_F(SOCKSClientSocketTest, HandshakeFailures) { MockRead data_reads[] = { MockRead(SYNCHRONOUS, tests[i].fail_reply, arraysize(tests[i].fail_reply)) }; - CapturingNetLog log; + TestNetLog log; user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), data_writes, arraysize(data_writes), @@ -231,7 +233,7 @@ TEST_F(SOCKSClientSocketTest, HandshakeFailures) { int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); - CapturingNetLog::CapturedEntryList entries; + TestNetLogEntry::List entries; log.GetEntries(&entries); EXPECT_TRUE(LogContainsBeginEvent( entries, 0, NetLog::TYPE_SOCKS_CONNECT)); @@ -257,7 +259,7 @@ TEST_F(SOCKSClientSocketTest, PartialServerReads) { MockRead data_reads[] = { MockRead(ASYNC, kSOCKSPartialReply1, arraysize(kSOCKSPartialReply1)), MockRead(ASYNC, kSOCKSPartialReply2, arraysize(kSOCKSPartialReply2)) }; - CapturingNetLog log; + TestNetLog log; user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), data_writes, arraysize(data_writes), @@ -267,7 +269,7 @@ TEST_F(SOCKSClientSocketTest, PartialServerReads) { int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); - CapturingNetLog::CapturedEntryList entries; + TestNetLogEntry::List entries; log.GetEntries(&entries); EXPECT_TRUE(LogContainsBeginEvent( entries, 0, NetLog::TYPE_SOCKS_CONNECT)); @@ -287,15 +289,15 @@ TEST_F(SOCKSClientSocketTest, PartialClientWrites) { const char kSOCKSPartialRequest2[] = { 0x00, 0x50, 127, 0, 0, 1, 0 }; MockWrite data_writes[] = { - MockWrite(ASYNC, arraysize(kSOCKSPartialRequest1)), + MockWrite(ASYNC, kSOCKSPartialRequest1, arraysize(kSOCKSPartialRequest1)), // simulate some empty writes MockWrite(ASYNC, 0), MockWrite(ASYNC, 0), - MockWrite(ASYNC, kSOCKSPartialRequest2, - arraysize(kSOCKSPartialRequest2)) }; + MockWrite(ASYNC, kSOCKSPartialRequest2, arraysize(kSOCKSPartialRequest2)), + }; MockRead data_reads[] = { MockRead(ASYNC, kSOCKSOkReply, arraysize(kSOCKSOkReply)) }; - CapturingNetLog log; + TestNetLog log; user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), data_writes, arraysize(data_writes), @@ -305,7 +307,7 @@ TEST_F(SOCKSClientSocketTest, PartialClientWrites) { int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); - CapturingNetLog::CapturedEntryList entries; + TestNetLogEntry::List entries; log.GetEntries(&entries); EXPECT_TRUE(LogContainsBeginEvent( entries, 0, NetLog::TYPE_SOCKS_CONNECT)); @@ -327,7 +329,7 @@ TEST_F(SOCKSClientSocketTest, FailedSocketRead) { MockRead(ASYNC, kSOCKSOkReply, arraysize(kSOCKSOkReply) - 2), // close connection unexpectedly MockRead(SYNCHRONOUS, 0) }; - CapturingNetLog log; + TestNetLog log; user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), data_writes, arraysize(data_writes), @@ -337,7 +339,7 @@ TEST_F(SOCKSClientSocketTest, FailedSocketRead) { int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); - CapturingNetLog::CapturedEntryList entries; + TestNetLogEntry::List entries; log.GetEntries(&entries); EXPECT_TRUE(LogContainsBeginEvent( entries, 0, NetLog::TYPE_SOCKS_CONNECT)); @@ -357,7 +359,7 @@ TEST_F(SOCKSClientSocketTest, FailedDNS) { host_resolver_->rules()->AddSimulatedFailure(hostname); - CapturingNetLog log; + TestNetLog log; user_sock_ = BuildMockSocket(NULL, 0, NULL, 0, @@ -367,7 +369,7 @@ TEST_F(SOCKSClientSocketTest, FailedDNS) { int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); - CapturingNetLog::CapturedEntryList entries; + TestNetLogEntry::List entries; log.GetEntries(&entries); EXPECT_TRUE(LogContainsBeginEvent( entries, 0, NetLog::TYPE_SOCKS_CONNECT)); diff --git a/chromium/net/socket/ssl_client_socket.cc b/chromium/net/socket/ssl_client_socket.cc index 3184e04e3f7..4dd6a4e926d 100644 --- a/chromium/net/socket/ssl_client_socket.cc +++ b/chromium/net/socket/ssl_client_socket.cc @@ -9,19 +9,16 @@ #include "base/strings/string_util.h" #include "crypto/ec_private_key.h" #include "net/base/connection_type_histograms.h" -#include "net/base/host_port_pair.h" +#include "net/base/net_errors.h" #include "net/ssl/channel_id_service.h" +#include "net/ssl/ssl_cipher_suite_names.h" #include "net/ssl/ssl_config_service.h" #include "net/ssl/ssl_connection_status_flags.h" namespace net { SSLClientSocket::SSLClientSocket() - : was_npn_negotiated_(false), - was_spdy_negotiated_(false), - protocol_negotiated_(kProtoUnknown), - channel_id_sent_(false), - signed_cert_timestamps_received_(false), + : signed_cert_timestamps_received_(false), stapled_ocsp_response_received_(false), negotiation_extension_(kExtensionUnknown) { } @@ -38,8 +35,10 @@ NextProto SSLClientSocket::NextProtoFromString( } else if (proto_string == "spdy/3.1") { return kProtoSPDY31; } else if (proto_string == "h2-14") { - // This is the HTTP/2 draft 14 identifier. For internal - // consistency, HTTP/2 is named SPDY4 within Chromium. + // For internal consistency, HTTP/2 is named SPDY4 within Chromium. + // This is the HTTP/2 draft-14 identifier. + return kProtoSPDY4_14; + } else if (proto_string == "h2") { return kProtoSPDY4; } else if (proto_string == "quic/1+spdy/3") { return kProtoQUIC1SPDY3; @@ -59,10 +58,12 @@ const char* SSLClientSocket::NextProtoToString(NextProto next_proto) { return "spdy/3"; case kProtoSPDY31: return "spdy/3.1"; - case kProtoSPDY4: - // This is the HTTP/2 draft 14 identifier. For internal - // consistency, HTTP/2 is named SPDY4 within Chromium. + case kProtoSPDY4_14: + // For internal consistency, HTTP/2 is named SPDY4 within Chromium. + // This is the HTTP/2 draft-14 identifier. return "h2-14"; + case kProtoSPDY4: + return "h2"; case kProtoQUIC1SPDY3: return "quic/1+spdy/3"; case kProtoUnknown: @@ -86,69 +87,48 @@ const char* SSLClientSocket::NextProtoStatusToString( } bool SSLClientSocket::WasNpnNegotiated() const { - return was_npn_negotiated_; + std::string unused_proto; + return GetNextProto(&unused_proto) == kNextProtoNegotiated; } NextProto SSLClientSocket::GetNegotiatedProtocol() const { - return protocol_negotiated_; + std::string proto; + if (GetNextProto(&proto) != kNextProtoNegotiated) + return kProtoUnknown; + return NextProtoFromString(proto); } bool SSLClientSocket::IgnoreCertError(int error, int load_flags) { - if (error == OK || load_flags & LOAD_IGNORE_ALL_CERT_ERRORS) + if (error == OK) return true; - - if (error == ERR_CERT_COMMON_NAME_INVALID && - (load_flags & LOAD_IGNORE_CERT_COMMON_NAME_INVALID)) - return true; - - if (error == ERR_CERT_DATE_INVALID && - (load_flags & LOAD_IGNORE_CERT_DATE_INVALID)) - return true; - - if (error == ERR_CERT_AUTHORITY_INVALID && - (load_flags & LOAD_IGNORE_CERT_AUTHORITY_INVALID)) - return true; - - return false; -} - -bool SSLClientSocket::set_was_npn_negotiated(bool negotiated) { - return was_npn_negotiated_ = negotiated; -} - -bool SSLClientSocket::was_spdy_negotiated() const { - return was_spdy_negotiated_; -} - -bool SSLClientSocket::set_was_spdy_negotiated(bool negotiated) { - return was_spdy_negotiated_ = negotiated; -} - -void SSLClientSocket::set_protocol_negotiated(NextProto protocol_negotiated) { - protocol_negotiated_ = protocol_negotiated; -} - -void SSLClientSocket::set_negotiation_extension( - SSLNegotiationExtension negotiation_extension) { - negotiation_extension_ = negotiation_extension; -} - -bool SSLClientSocket::WasChannelIDSent() const { - return channel_id_sent_; -} - -void SSLClientSocket::set_channel_id_sent(bool channel_id_sent) { - channel_id_sent_ = channel_id_sent; + return (load_flags & LOAD_IGNORE_ALL_CERT_ERRORS) && + IsCertificateError(error); } -void SSLClientSocket::set_signed_cert_timestamps_received( - bool signed_cert_timestamps_received) { - signed_cert_timestamps_received_ = signed_cert_timestamps_received; -} - -void SSLClientSocket::set_stapled_ocsp_response_received( - bool stapled_ocsp_response_received) { - stapled_ocsp_response_received_ = stapled_ocsp_response_received; +void SSLClientSocket::RecordNegotiationExtension() { + if (negotiation_extension_ == kExtensionUnknown) + return; + std::string proto; + SSLClientSocket::NextProtoStatus status = GetNextProto(&proto); + if (status == kNextProtoUnsupported) + return; + // Convert protocol into numerical value for histogram. + NextProto protocol_negotiated = SSLClientSocket::NextProtoFromString(proto); + base::HistogramBase::Sample sample = + static_cast<base::HistogramBase::Sample>(protocol_negotiated); + // In addition to the protocol negotiated, we want to record which TLS + // extension was used, and in case of NPN, whether there was overlap between + // server and client list of supported protocols. + if (negotiation_extension_ == kExtensionNPN) { + if (status == kNextProtoNoOverlap) { + sample += 1000; + } else { + sample += 500; + } + } else { + DCHECK_EQ(kExtensionALPN, negotiation_extension_); + } + UMA_HISTOGRAM_SPARSE_SLOWLY("Net.SSLProtocolNegotiation", sample); } // static @@ -184,28 +164,6 @@ void SSLClientSocket::RecordChannelIDSupport( } // static -void SSLClientSocket::RecordConnectionTypeMetrics(int ssl_version) { - UpdateConnectionTypeHistograms(CONNECTION_SSL); - switch (ssl_version) { - case SSL_CONNECTION_VERSION_SSL2: - UpdateConnectionTypeHistograms(CONNECTION_SSL_SSL2); - break; - case SSL_CONNECTION_VERSION_SSL3: - UpdateConnectionTypeHistograms(CONNECTION_SSL_SSL3); - break; - case SSL_CONNECTION_VERSION_TLS1: - UpdateConnectionTypeHistograms(CONNECTION_SSL_TLS1); - break; - case SSL_CONNECTION_VERSION_TLS1_1: - UpdateConnectionTypeHistograms(CONNECTION_SSL_TLS1_1); - break; - case SSL_CONNECTION_VERSION_TLS1_2: - UpdateConnectionTypeHistograms(CONNECTION_SSL_TLS1_2); - break; - } -} - -// static bool SSLClientSocket::IsChannelIDEnabled( const SSLConfig& ssl_config, ChannelIDService* channel_id_service) { @@ -228,65 +186,47 @@ bool SSLClientSocket::IsChannelIDEnabled( } // static +bool SSLClientSocket::HasCipherAdequateForHTTP2( + const std::vector<uint16>& cipher_suites) { + for (uint16 cipher : cipher_suites) { + if (IsSecureTLSCipherSuite(cipher)) + return true; + } + return false; +} + +// static +bool SSLClientSocket::IsTLSVersionAdequateForHTTP2( + const SSLConfig& ssl_config) { + return ssl_config.version_max >= SSL_PROTOCOL_VERSION_TLS1_2; +} + +// static std::vector<uint8_t> SSLClientSocket::SerializeNextProtos( - const std::vector<std::string>& next_protos) { - // Do a first pass to determine the total length. - size_t wire_length = 0; - for (std::vector<std::string>::const_iterator i = next_protos.begin(); - i != next_protos.end(); ++i) { - if (i->size() > 255) { - LOG(WARNING) << "Ignoring overlong NPN/ALPN protocol: " << *i; + const NextProtoVector& next_protos, + bool can_advertise_http2) { + std::vector<uint8_t> wire_protos; + for (const NextProto next_proto : next_protos) { + if (!can_advertise_http2 && kProtoSPDY4MinimumVersion <= next_proto && + next_proto <= kProtoSPDY4MaximumVersion) { continue; } - if (i->size() == 0) { - LOG(WARNING) << "Ignoring empty NPN/ALPN protocol"; + const std::string proto = NextProtoToString(next_proto); + if (proto.size() > 255) { + LOG(WARNING) << "Ignoring overlong NPN/ALPN protocol: " << proto; continue; } - wire_length += i->size(); - wire_length++; - } - - // Allocate memory for the result and fill it in. - std::vector<uint8_t> wire_protos; - wire_protos.reserve(wire_length); - for (std::vector<std::string>::const_iterator i = next_protos.begin(); - i != next_protos.end(); i++) { - if (i->size() == 0 || i->size() > 255) + if (proto.size() == 0) { + LOG(WARNING) << "Ignoring empty NPN/ALPN protocol"; continue; - wire_protos.push_back(i->size()); - wire_protos.resize(wire_protos.size() + i->size()); - memcpy(&wire_protos[wire_protos.size() - i->size()], - i->data(), i->size()); + } + wire_protos.push_back(proto.size()); + for (const char ch : proto) { + wire_protos.push_back(static_cast<uint8_t>(ch)); + } } - DCHECK_EQ(wire_protos.size(), wire_length); return wire_protos; } -void SSLClientSocket::RecordNegotiationExtension() { - if (negotiation_extension_ == kExtensionUnknown) - return; - std::string proto; - SSLClientSocket::NextProtoStatus status = GetNextProto(&proto); - if (status == kNextProtoUnsupported) - return; - // Convert protocol into numerical value for histogram. - NextProto protocol_negotiated = SSLClientSocket::NextProtoFromString(proto); - base::HistogramBase::Sample sample = - static_cast<base::HistogramBase::Sample>(protocol_negotiated); - // In addition to the protocol negotiated, we want to record which TLS - // extension was used, and in case of NPN, whether there was overlap between - // server and client list of supported protocols. - if (negotiation_extension_ == kExtensionNPN) { - if (status == kNextProtoNoOverlap) { - sample += 1000; - } else { - sample += 500; - } - } else { - DCHECK_EQ(kExtensionALPN, negotiation_extension_); - } - UMA_HISTOGRAM_SPARSE_SLOWLY("Net.SSLProtocolNegotiation", sample); -} - } // namespace net diff --git a/chromium/net/socket/ssl_client_socket.h b/chromium/net/socket/ssl_client_socket.h index 7adfa8c626a..6774f150eb3 100644 --- a/chromium/net/socket/ssl_client_socket.h +++ b/chromium/net/socket/ssl_client_socket.h @@ -13,14 +13,14 @@ #include "net/base/net_errors.h" #include "net/socket/ssl_socket.h" #include "net/socket/stream_socket.h" +#include "net/ssl/ssl_failure_state.h" namespace net { +class CertPolicyEnforcer; class CertVerifier; class ChannelIDService; class CTVerifier; -class HostPortPair; -class ServerBoundCertService; class SSLCertRequestInfo; struct SSLConfig; class SSLInfo; @@ -34,23 +34,27 @@ struct SSLClientSocketContext { : cert_verifier(NULL), channel_id_service(NULL), transport_security_state(NULL), - cert_transparency_verifier(NULL) {} + cert_transparency_verifier(NULL), + cert_policy_enforcer(NULL) {} SSLClientSocketContext(CertVerifier* cert_verifier_arg, ChannelIDService* channel_id_service_arg, TransportSecurityState* transport_security_state_arg, CTVerifier* cert_transparency_verifier_arg, + CertPolicyEnforcer* cert_policy_enforcer_arg, const std::string& ssl_session_cache_shard_arg) : cert_verifier(cert_verifier_arg), channel_id_service(channel_id_service_arg), transport_security_state(transport_security_state_arg), cert_transparency_verifier(cert_transparency_verifier_arg), + cert_policy_enforcer(cert_policy_enforcer_arg), ssl_session_cache_shard(ssl_session_cache_shard_arg) {} CertVerifier* cert_verifier; ChannelIDService* channel_id_service; TransportSecurityState* transport_security_state; CTVerifier* cert_transparency_verifier; + CertPolicyEnforcer* cert_policy_enforcer; // ssl_session_cache_shard is an opaque string that identifies a shard of the // SSL session cache. SSL sockets with the same ssl_session_cache_shard may // resume each other's SSL sessions but we'll never sessions between shards. @@ -90,39 +94,6 @@ class NET_EXPORT SSLClientSocket : public SSLSocket { bool WasNpnNegotiated() const override; NextProto GetNegotiatedProtocol() const override; - // Computes a unique key string for the SSL session cache. - virtual std::string GetSessionCacheKey() const = 0; - - // Returns true if there is a cache entry in the SSL session cache - // for the cache key of the SSL socket. - // - // The cache key consists of a host and port concatenated with a session - // cache shard. These two strings are passed to the constructor of most - // subclasses of SSLClientSocket. - virtual bool InSessionCache() const = 0; - - // Sets |callback| to be run when the handshake has fully completed. - // For example, in the case of False Start, Connect() will return - // early, before the peer's TLS Finished message has been verified, - // in order to allow the caller to call Write() and send application - // data with the client's Finished message. - // In such situations, |callback| will be invoked sometime after - // Connect() - either during a Write() or Read() call, and before - // invoking the Read() or Write() callback. - // Otherwise, during a traditional TLS connection (i.e. no False - // Start), this will be called right before the Connect() callback - // is called. - // - // Note that it's not valid to mutate this socket during such - // callbacks, including deleting the socket. - // - // TODO(mshelley): Provide additional details about whether or not - // the handshake actually succeeded or not. This can be inferred - // from the result to Connect()/Read()/Write(), but may be useful - // to inform here as well. - virtual void SetHandshakeCompletionCallback( - const base::Closure& callback) = 0; - // Gets the SSL CertificateRequest info of the socket after Connect failed // with ERR_SSL_CLIENT_AUTH_CERT_NEEDED. virtual void GetSSLCertRequestInfo( @@ -135,7 +106,7 @@ class NET_EXPORT SSLClientSocket : public SSLSocket { // kNextProtoNegotiated: *proto is set to the negotiated protocol. // kNextProtoNoOverlap: *proto is set to the first protocol in the // supported list. - virtual NextProtoStatus GetNextProto(std::string* proto) = 0; + virtual NextProtoStatus GetNextProto(std::string* proto) const = 0; static NextProto NextProtoFromString(const std::string& proto_string); @@ -143,46 +114,45 @@ class NET_EXPORT SSLClientSocket : public SSLSocket { static const char* NextProtoStatusToString(const NextProtoStatus status); + // Returns true if |error| is OK or |load_flags| ignores certificate errors + // and |error| is a certificate error. static bool IgnoreCertError(int error, int load_flags); // ClearSessionCache clears the SSL session cache, used to resume SSL // sessions. static void ClearSessionCache(); - virtual bool set_was_npn_negotiated(bool negotiated); - - virtual bool was_spdy_negotiated() const; - - virtual bool set_was_spdy_negotiated(bool negotiated); - - virtual void set_protocol_negotiated(NextProto protocol_negotiated); - - void set_negotiation_extension(SSLNegotiationExtension negotiation_extension); + // Get the maximum SSL version supported by the underlying library and + // cryptographic implementation. + static uint16 GetMaxSupportedSSLVersion(); // Returns the ChannelIDService used by this socket, or NULL if // channel ids are not supported. virtual ChannelIDService* GetChannelIDService() const = 0; - // Returns true if a channel ID was sent on this connection. - // This may be useful for protocols, like SPDY, which allow the same - // connection to be shared between multiple domains, each of which need - // a channel ID. - // - // Public for ssl_client_socket_openssl_unittest.cc. - virtual bool WasChannelIDSent() const; - - // Record which TLS extension was used to negotiate protocol and protocol - // chosen in a UMA histogram. - void RecordNegotiationExtension(); + // Returns the state of the handshake when it failed, or |SSL_FAILURE_NONE| if + // the handshake succeeded. This is used to classify causes of the TLS version + // fallback. + virtual SSLFailureState GetSSLFailureState() const = 0; protected: - virtual void set_channel_id_sent(bool channel_id_sent); + void set_negotiation_extension( + SSLNegotiationExtension negotiation_extension) { + negotiation_extension_ = negotiation_extension; + } + + void set_signed_cert_timestamps_received( + bool signed_cert_timestamps_received) { + signed_cert_timestamps_received_ = signed_cert_timestamps_received; + } - virtual void set_signed_cert_timestamps_received( - bool signed_cert_timestamps_received); + void set_stapled_ocsp_response_received(bool stapled_ocsp_response_received) { + stapled_ocsp_response_received_ = stapled_ocsp_response_received; + } - virtual void set_stapled_ocsp_response_received( - bool stapled_ocsp_response_received); + // Record which TLS extension was used to negotiate protocol and protocol + // chosen in a UMA histogram. + void RecordNegotiationExtension(); // Records histograms for channel id support during full handshakes - resumed // handshakes are ignored. @@ -192,18 +162,28 @@ class NET_EXPORT SSLClientSocket : public SSLSocket { bool channel_id_enabled, bool supports_ecc); - // Records ConnectionType histograms for a successful SSL connection. - static void RecordConnectionTypeMetrics(int ssl_version); - // Returns whether TLS channel ID is enabled. static bool IsChannelIDEnabled( const SSLConfig& ssl_config, ChannelIDService* channel_id_service); + // Determine if there is at least one enabled cipher suite that satisfies + // Section 9.2 of the HTTP/2 specification. Note that the server might still + // pick an inadequate cipher suite. + static bool HasCipherAdequateForHTTP2( + const std::vector<uint16>& cipher_suites); + + // Determine if the TLS version required by Section 9.2 of the HTTP/2 + // specification is enabled. Note that the server might still pick an + // inadequate TLS version. + static bool IsTLSVersionAdequateForHTTP2(const SSLConfig& ssl_config); + // Serializes |next_protos| in the wire format for ALPN: protocols are listed - // in order, each prefixed by a one-byte length. + // in order, each prefixed by a one-byte length. Any HTTP/2 protocols in + // |next_protos| are ignored if |can_advertise_http2| is false. static std::vector<uint8_t> SerializeNextProtos( - const std::vector<std::string>& next_protos); + const NextProtoVector& next_protos, + bool can_advertise_http2); // For unit testing only. // Returns the unverified certificate chain as presented by server. @@ -213,6 +193,7 @@ class NET_EXPORT SSLClientSocket : public SSLSocket { const = 0; private: + FRIEND_TEST_ALL_PREFIXES(SSLClientSocket, SerializeNextProtos); // For signed_cert_timestamps_received_ and stapled_ocsp_response_received_. FRIEND_TEST_ALL_PREFIXES(SSLClientSocketTest, ConnectSignedCertTimestampsEnabledTLSExtension); @@ -223,14 +204,6 @@ class NET_EXPORT SSLClientSocket : public SSLSocket { FRIEND_TEST_ALL_PREFIXES(SSLClientSocketTest, VerifyServerChainProperlyOrdered); - // True if NPN was responded to, independent of selecting SPDY or HTTP. - bool was_npn_negotiated_; - // True if NPN successfully negotiated SPDY. - bool was_spdy_negotiated_; - // Protocol that we negotiated with the server. - NextProto protocol_negotiated_; - // True if a channel ID was sent. - bool channel_id_sent_; // True if SCTs were received via a TLS extension. bool signed_cert_timestamps_received_; // True if a stapled OCSP response was received. diff --git a/chromium/net/socket/ssl_client_socket_nss.cc b/chromium/net/socket/ssl_client_socket_nss.cc index 08cf2c55f51..7c5264db81b 100644 --- a/chromium/net/socket/ssl_client_socket_nss.cc +++ b/chromium/net/socket/ssl_client_socket_nss.cc @@ -69,9 +69,7 @@ #include "base/callback_helpers.h" #include "base/compiler_specific.h" #include "base/logging.h" -#include "base/memory/singleton.h" #include "base/metrics/histogram.h" -#include "base/profiler/scoped_tracker.h" #include "base/single_thread_task_runner.h" #include "base/stl_util.h" #include "base/strings/string_number_conversions.h" @@ -89,8 +87,8 @@ #include "net/base/dns_util.h" #include "net/base/io_buffer.h" #include "net/base/net_errors.h" -#include "net/base/net_log.h" #include "net/cert/asn1_util.h" +#include "net/cert/cert_policy_enforcer.h" #include "net/cert/cert_status_flags.h" #include "net/cert/cert_verifier.h" #include "net/cert/ct_ev_whitelist.h" @@ -98,31 +96,20 @@ #include "net/cert/ct_verify_result.h" #include "net/cert/scoped_nss_types.h" #include "net/cert/sct_status_flags.h" -#include "net/cert/single_request_cert_verifier.h" #include "net/cert/x509_certificate_net_log_param.h" #include "net/cert/x509_util.h" +#include "net/cert_net/nss_ocsp.h" #include "net/http/transport_security_state.h" -#include "net/ocsp/nss_ocsp.h" +#include "net/log/net_log.h" #include "net/socket/client_socket_handle.h" #include "net/socket/nss_ssl_util.h" #include "net/ssl/ssl_cert_request_info.h" +#include "net/ssl/ssl_cipher_suite_names.h" #include "net/ssl/ssl_connection_status_flags.h" +#include "net/ssl/ssl_failure_state.h" #include "net/ssl/ssl_info.h" -#if defined(OS_WIN) -#include <windows.h> -#include <wincrypt.h> - -#include "base/win/windows_version.h" -#elif defined(OS_MACOSX) -#include <Security/SecBase.h> -#include <Security/SecCertificate.h> -#include <Security/SecIdentity.h> - -#include "base/mac/mac_logging.h" -#include "base/synchronization/lock.h" -#include "crypto/mac_security_services_lock.h" -#elif defined(USE_NSS) +#if defined(USE_NSS_CERTS) #include <dlfcn.h> #endif @@ -148,6 +135,14 @@ namespace net { } while (0) #endif +#if !defined(CKM_AES_GCM) +#define CKM_AES_GCM 0x00001087 +#endif + +#if !defined(CKM_NSS_CHACHA20_POLY1305) +#define CKM_NSS_CHACHA20_POLY1305 (CKM_NSS + 26) +#endif + namespace { // SSL plaintext fragments are shorter than 16KB. Although the record layer @@ -163,113 +158,6 @@ const int kSendBufferSize = 17 * 1024; // overlap with any value of the net::Error range, including net::OK). const int kNoPendingReadResult = 1; -#if defined(OS_WIN) -// CERT_OCSP_RESPONSE_PROP_ID is only implemented on Vista+, but it can be -// set on Windows XP without error. There is some overhead from the server -// sending the OCSP response if it supports the extension, for the subset of -// XP clients who will request it but be unable to use it, but this is an -// acceptable trade-off for simplicity of implementation. -bool IsOCSPStaplingSupported() { - return true; -} -#elif defined(USE_NSS) -typedef SECStatus -(*CacheOCSPResponseFromSideChannelFunction)( - CERTCertDBHandle *handle, CERTCertificate *cert, PRTime time, - SECItem *encodedResponse, void *pwArg); - -// On Linux, we dynamically link against the system version of libnss3.so. In -// order to continue working on systems without up-to-date versions of NSS we -// lookup CERT_CacheOCSPResponseFromSideChannel with dlsym. - -// RuntimeLibNSSFunctionPointers is a singleton which caches the results of any -// runtime symbol resolution that we need. -class RuntimeLibNSSFunctionPointers { - public: - CacheOCSPResponseFromSideChannelFunction - GetCacheOCSPResponseFromSideChannelFunction() { - return cache_ocsp_response_from_side_channel_; - } - - static RuntimeLibNSSFunctionPointers* GetInstance() { - return Singleton<RuntimeLibNSSFunctionPointers>::get(); - } - - private: - friend struct DefaultSingletonTraits<RuntimeLibNSSFunctionPointers>; - - RuntimeLibNSSFunctionPointers() { - cache_ocsp_response_from_side_channel_ = - (CacheOCSPResponseFromSideChannelFunction) - dlsym(RTLD_DEFAULT, "CERT_CacheOCSPResponseFromSideChannel"); - } - - CacheOCSPResponseFromSideChannelFunction - cache_ocsp_response_from_side_channel_; -}; - -CacheOCSPResponseFromSideChannelFunction -GetCacheOCSPResponseFromSideChannelFunction() { - return RuntimeLibNSSFunctionPointers::GetInstance() - ->GetCacheOCSPResponseFromSideChannelFunction(); -} - -bool IsOCSPStaplingSupported() { - return GetCacheOCSPResponseFromSideChannelFunction() != NULL; -} -#else -// TODO(agl): Figure out if we can plumb the OCSP response into Mac's system -// certificate validation functions. -bool IsOCSPStaplingSupported() { - return false; -} -#endif - -#if defined(OS_WIN) - -// This callback is intended to be used with CertFindChainInStore. In addition -// to filtering by extended/enhanced key usage, we do not show expired -// certificates and require digital signature usage in the key usage -// extension. -// -// This matches our behavior on Mac OS X and that of NSS. It also matches the -// default behavior of IE8. See http://support.microsoft.com/kb/890326 and -// http://blogs.msdn.com/b/askie/archive/2009/06/09/my-expired-client-certificates-no-longer-display-when-connecting-to-my-web-server-using-ie8.aspx -BOOL WINAPI ClientCertFindCallback(PCCERT_CONTEXT cert_context, - void* find_arg) { - VLOG(1) << "Calling ClientCertFindCallback from _nss"; - // Verify the certificate's KU is good. - BYTE key_usage; - if (CertGetIntendedKeyUsage(X509_ASN_ENCODING, cert_context->pCertInfo, - &key_usage, 1)) { - if (!(key_usage & CERT_DIGITAL_SIGNATURE_KEY_USAGE)) - return FALSE; - } else { - DWORD err = GetLastError(); - // If |err| is non-zero, it's an actual error. Otherwise the extension - // just isn't present, and we treat it as if everything was allowed. - if (err) { - DLOG(ERROR) << "CertGetIntendedKeyUsage failed: " << err; - return FALSE; - } - } - - // Verify the current time is within the certificate's validity period. - if (CertVerifyTimeValidity(NULL, cert_context->pCertInfo) != 0) - return FALSE; - - // Verify private key metadata is associated with this certificate. - DWORD size = 0; - if (!CertGetCertificateContextProperty( - cert_context, CERT_KEY_PROV_INFO_PROP_ID, NULL, &size)) { - return FALSE; - } - - return TRUE; -} - -#endif - // Helper functions to make it possible to log events from within the // SSLClientSocketNSS::Core. void AddLogEvent(const base::WeakPtr<BoundNetLog>& net_log, @@ -670,28 +558,11 @@ class SSLClientSocketNSS::Core : public base::RefCountedThreadSafe<Core> { // authentication. // See the documentation in third_party/nss/ssl/ssl.h for the meanings of // the arguments. -#if defined(NSS_PLATFORM_CLIENT_AUTH) - // When NSS has been integrated with awareness of the underlying system - // cryptographic libraries, this callback allows the caller to supply a - // native platform certificate and key for use by NSS. At most, one of - // either (result_certs, result_private_key) or (result_nss_certificate, - // result_nss_private_key) should be set. - // |arg| contains a pointer to the current SSLClientSocketNSS::Core. - static SECStatus PlatformClientAuthHandler( - void* arg, - PRFileDesc* socket, - CERTDistNames* ca_names, - CERTCertList** result_certs, - void** result_private_key, - CERTCertificate** result_nss_certificate, - SECKEYPrivateKey** result_nss_private_key); -#else static SECStatus ClientAuthHandler(void* arg, PRFileDesc* socket, CERTDistNames* ca_names, CERTCertificate** result_certificate, SECKEYPrivateKey** result_private_key); -#endif // Called by NSS to determine if we can False Start. // |arg| contains a pointer to the current SSLClientSocketNSS::Core. @@ -699,12 +570,14 @@ class SSLClientSocketNSS::Core : public base::RefCountedThreadSafe<Core> { void* arg, PRBool* can_false_start); - // Called by NSS once the handshake has completed. + // Called by NSS each time a handshake completely finishes. // |arg| contains a pointer to the current SSLClientSocketNSS::Core. static void HandshakeCallback(PRFileDesc* socket, void* arg); - // Called once the handshake has succeeded. - void HandshakeSucceeded(); + // Called once for each successful handshake. If the initial handshake false + // starts, it is called when it false starts and not when it completely + // finishes. is_initial is true if this is the initial handshake. + void HandshakeSucceeded(bool is_initial); // Handles an NSS error generated while handshaking or performing IO. // Returns a network error code mapped from the original NSS error. @@ -766,6 +639,9 @@ class SSLClientSocketNSS::Core : public base::RefCountedThreadSafe<Core> { // Record TLS extension used for protocol negotiation (NPN or ALPN). void UpdateExtensionUsed(); + // Returns true if renegotiations are allowed. + bool IsRenegotiationAllowed() const; + //////////////////////////////////////////////////////////////////////////// // Methods that are ONLY called on the network task runner: //////////////////////////////////////////////////////////////////////////// @@ -870,7 +746,8 @@ class SSLClientSocketNSS::Core : public base::RefCountedThreadSafe<Core> { bool channel_id_needed_; // True if the handshake state machine was interrupted for client auth. bool client_auth_cert_needed_; - // True if NSS has False Started. + // True if NSS has False Started in the initial handshake, but the initial + // handshake has not yet completely finished.. bool false_started_; // True if NSS has called HandshakeCallback. bool handshake_callback_called_; @@ -974,8 +851,16 @@ bool SSLClientSocketNSS::Core::Init(PRFileDesc* socket, SECStatus rv = SECSuccess; if (!ssl_config_.next_protos.empty()) { + // TODO(bnc): Check ssl_config_.disabled_cipher_suites. + const bool adequate_encryption = + PK11_TokenExists(CKM_AES_GCM) || + PK11_TokenExists(CKM_NSS_CHACHA20_POLY1305); + const bool adequate_key_agreement = PK11_TokenExists(CKM_DH_PKCS_DERIVE) || + PK11_TokenExists(CKM_ECDH1_DERIVE); std::vector<uint8_t> wire_protos = - SerializeNextProtos(ssl_config_.next_protos); + SerializeNextProtos(ssl_config_.next_protos, + adequate_encryption && adequate_key_agreement && + IsTLSVersionAdequateForHTTP2(ssl_config_)); rv = SSL_SetNextProtoNego( nss_fd_, wire_protos.empty() ? NULL : &wire_protos[0], wire_protos.size()); @@ -996,14 +881,8 @@ bool SSLClientSocketNSS::Core::Init(PRFileDesc* socket, return false; } -#if defined(NSS_PLATFORM_CLIENT_AUTH) - rv = SSL_GetPlatformClientAuthDataHook( - nss_fd_, SSLClientSocketNSS::Core::PlatformClientAuthHandler, - this); -#else rv = SSL_GetClientAuthDataHook( nss_fd_, SSLClientSocketNSS::Core::ClientAuthHandler, this); -#endif if (rv != SECSuccess) { LogFailedNSSFunction(*weak_net_log_, "SSL_GetClientAuthDataHook", ""); return false; @@ -1271,222 +1150,9 @@ SECStatus SSLClientSocketNSS::Core::OwnAuthCertHandler( return SECSuccess; } -#if defined(NSS_PLATFORM_CLIENT_AUTH) -// static -SECStatus SSLClientSocketNSS::Core::PlatformClientAuthHandler( - void* arg, - PRFileDesc* socket, - CERTDistNames* ca_names, - CERTCertList** result_certs, - void** result_private_key, - CERTCertificate** result_nss_certificate, - SECKEYPrivateKey** result_nss_private_key) { - Core* core = reinterpret_cast<Core*>(arg); - DCHECK(core->OnNSSTaskRunner()); - - core->PostOrRunCallback( - FROM_HERE, - base::Bind(&AddLogEvent, core->weak_net_log_, - NetLog::TYPE_SSL_CLIENT_CERT_REQUESTED)); - - core->client_auth_cert_needed_ = !core->ssl_config_.send_client_cert; -#if defined(OS_WIN) - if (core->ssl_config_.send_client_cert) { - if (core->ssl_config_.client_cert) { - PCCERT_CONTEXT cert_context = - core->ssl_config_.client_cert->os_cert_handle(); - - HCRYPTPROV_OR_NCRYPT_KEY_HANDLE crypt_prov = 0; - DWORD key_spec = 0; - BOOL must_free = FALSE; - DWORD flags = 0; - if (base::win::GetVersion() >= base::win::VERSION_VISTA) - flags |= CRYPT_ACQUIRE_PREFER_NCRYPT_KEY_FLAG; - - BOOL acquired_key = CryptAcquireCertificatePrivateKey( - cert_context, flags, NULL, &crypt_prov, &key_spec, &must_free); - - if (acquired_key) { - // Should never get a cached handle back - ownership must always be - // transferred. - CHECK_EQ(must_free, TRUE); - - SECItem der_cert; - der_cert.type = siDERCertBuffer; - der_cert.data = cert_context->pbCertEncoded; - der_cert.len = cert_context->cbCertEncoded; - - // TODO(rsleevi): Error checking for NSS allocation errors. - CERTCertDBHandle* db_handle = CERT_GetDefaultCertDB(); - CERTCertificate* user_cert = CERT_NewTempCertificate( - db_handle, &der_cert, NULL, PR_FALSE, PR_TRUE); - if (!user_cert) { - // Importing the certificate can fail for reasons including a serial - // number collision. See crbug.com/97355. - core->AddCertProvidedEvent(0); - return SECFailure; - } - CERTCertList* cert_chain = CERT_NewCertList(); - CERT_AddCertToListTail(cert_chain, user_cert); - - // Add the intermediates. - X509Certificate::OSCertHandles intermediates = - core->ssl_config_.client_cert->GetIntermediateCertificates(); - for (X509Certificate::OSCertHandles::const_iterator it = - intermediates.begin(); it != intermediates.end(); ++it) { - der_cert.data = (*it)->pbCertEncoded; - der_cert.len = (*it)->cbCertEncoded; - - CERTCertificate* intermediate = CERT_NewTempCertificate( - db_handle, &der_cert, NULL, PR_FALSE, PR_TRUE); - if (!intermediate) { - CERT_DestroyCertList(cert_chain); - core->AddCertProvidedEvent(0); - return SECFailure; - } - CERT_AddCertToListTail(cert_chain, intermediate); - } - PCERT_KEY_CONTEXT key_context = reinterpret_cast<PCERT_KEY_CONTEXT>( - PORT_ZAlloc(sizeof(CERT_KEY_CONTEXT))); - key_context->cbSize = sizeof(*key_context); - // NSS will free this context when no longer in use. - key_context->hCryptProv = crypt_prov; - key_context->dwKeySpec = key_spec; - *result_private_key = key_context; - *result_certs = cert_chain; - - int cert_count = 1 + intermediates.size(); - core->AddCertProvidedEvent(cert_count); - return SECSuccess; - } - LOG(WARNING) << "Client cert found without private key"; - } - - // Send no client certificate. - core->AddCertProvidedEvent(0); - return SECFailure; - } - - core->nss_handshake_state_.cert_authorities.clear(); - - std::vector<CERT_NAME_BLOB> issuer_list(ca_names->nnames); - for (int i = 0; i < ca_names->nnames; ++i) { - issuer_list[i].cbData = ca_names->names[i].len; - issuer_list[i].pbData = ca_names->names[i].data; - core->nss_handshake_state_.cert_authorities.push_back(std::string( - reinterpret_cast<const char*>(ca_names->names[i].data), - static_cast<size_t>(ca_names->names[i].len))); - } - - // Update the network task runner's view of the handshake state now that - // server certificate request has been recorded. - core->PostOrRunCallback( - FROM_HERE, base::Bind(&Core::OnHandshakeStateUpdated, core, - core->nss_handshake_state_)); - - // Tell NSS to suspend the client authentication. We will then abort the - // handshake by returning ERR_SSL_CLIENT_AUTH_CERT_NEEDED. - return SECWouldBlock; -#elif defined(OS_MACOSX) - if (core->ssl_config_.send_client_cert) { - if (core->ssl_config_.client_cert.get()) { - OSStatus os_error = noErr; - SecIdentityRef identity = NULL; - SecKeyRef private_key = NULL; - X509Certificate::OSCertHandles chain; - { - base::AutoLock lock(crypto::GetMacSecurityServicesLock()); - os_error = SecIdentityCreateWithCertificate( - NULL, core->ssl_config_.client_cert->os_cert_handle(), &identity); - } - if (os_error == noErr) { - os_error = SecIdentityCopyPrivateKey(identity, &private_key); - CFRelease(identity); - } - - if (os_error == noErr) { - // TODO(rsleevi): Error checking for NSS allocation errors. - *result_certs = CERT_NewCertList(); - *result_private_key = private_key; - - chain.push_back(core->ssl_config_.client_cert->os_cert_handle()); - const X509Certificate::OSCertHandles& intermediates = - core->ssl_config_.client_cert->GetIntermediateCertificates(); - if (!intermediates.empty()) - chain.insert(chain.end(), intermediates.begin(), intermediates.end()); - - for (size_t i = 0, chain_count = chain.size(); i < chain_count; ++i) { - CSSM_DATA cert_data; - SecCertificateRef cert_ref = chain[i]; - os_error = SecCertificateGetData(cert_ref, &cert_data); - if (os_error != noErr) - break; - - SECItem der_cert; - der_cert.type = siDERCertBuffer; - der_cert.data = cert_data.Data; - der_cert.len = cert_data.Length; - CERTCertificate* nss_cert = CERT_NewTempCertificate( - CERT_GetDefaultCertDB(), &der_cert, NULL, PR_FALSE, PR_TRUE); - if (!nss_cert) { - // In the event of an NSS error, make up an OS error and reuse - // the error handling below. - os_error = errSecCreateChainFailed; - break; - } - CERT_AddCertToListTail(*result_certs, nss_cert); - } - } - - if (os_error == noErr) { - core->AddCertProvidedEvent(chain.size()); - return SECSuccess; - } - - OSSTATUS_LOG(WARNING, os_error) - << "Client cert found, but could not be used"; - if (*result_certs) { - CERT_DestroyCertList(*result_certs); - *result_certs = NULL; - } - if (*result_private_key) - *result_private_key = NULL; - if (private_key) - CFRelease(private_key); - } - - // Send no client certificate. - core->AddCertProvidedEvent(0); - return SECFailure; - } - - core->nss_handshake_state_.cert_authorities.clear(); - - // Retrieve the cert issuers accepted by the server. - std::vector<CertPrincipal> valid_issuers; - int n = ca_names->nnames; - for (int i = 0; i < n; i++) { - core->nss_handshake_state_.cert_authorities.push_back(std::string( - reinterpret_cast<const char*>(ca_names->names[i].data), - static_cast<size_t>(ca_names->names[i].len))); - } - - // Update the network task runner's view of the handshake state now that - // server certificate request has been recorded. - core->PostOrRunCallback( - FROM_HERE, base::Bind(&Core::OnHandshakeStateUpdated, core, - core->nss_handshake_state_)); - - // Tell NSS to suspend the client authentication. We will then abort the - // handshake by returning ERR_SSL_CLIENT_AUTH_CERT_NEEDED. - return SECWouldBlock; -#else - return SECFailure; -#endif -} - -#elif defined(OS_IOS) +#if defined(OS_IOS) +// static SECStatus SSLClientSocketNSS::Core::ClientAuthHandler( void* arg, PRFileDesc* socket, @@ -1509,7 +1175,7 @@ SECStatus SSLClientSocketNSS::Core::ClientAuthHandler( return SECFailure; } -#else // NSS_PLATFORM_CLIENT_AUTH +#else // !OS_IOS // static // Based on Mozilla's NSS_GetClientAuthData. @@ -1577,7 +1243,7 @@ SECStatus SSLClientSocketNSS::Core::ClientAuthHandler( // handshake by returning ERR_SSL_CLIENT_AUTH_CERT_NEEDED. return SECWouldBlock; } -#endif // NSS_PLATFORM_CLIENT_AUTH +#endif // OS_IOS // static SECStatus SSLClientSocketNSS::Core::CanFalseStartCallback( @@ -1600,6 +1266,16 @@ SECStatus SSLClientSocketNSS::Core::CanFalseStartCallback( return SECSuccess; } + SSLChannelInfo channel_info; + SECStatus ok = + SSL_GetChannelInfo(socket, &channel_info, sizeof(channel_info)); + if (ok != SECSuccess || channel_info.length != sizeof(channel_info) || + channel_info.protocolVersion < SSL_LIBRARY_VERSION_TLS_1_2 || + !IsFalseStartableTLSCipherSuite(channel_info.cipherSuite)) { + *can_false_start = PR_FALSE; + return SECSuccess; + } + return SSL_RecommendedCanFalseStart(socket, can_false_start); } @@ -1610,6 +1286,7 @@ void SSLClientSocketNSS::Core::HandshakeCallback( Core* core = reinterpret_cast<Core*>(arg); DCHECK(core->OnNSSTaskRunner()); + bool is_initial = !core->handshake_callback_called_; core->handshake_callback_called_ = true; if (core->false_started_) { core->false_started_ = false; @@ -1626,15 +1303,10 @@ void SSLClientSocketNSS::Core::HandshakeCallback( // called HandshakeSucceeded(), so return now. return; } - core->HandshakeSucceeded(); + core->HandshakeSucceeded(is_initial); } -void SSLClientSocketNSS::Core::HandshakeSucceeded() { - // TODO(vadimt): Remove ScopedTracker below once crbug.com/424386 is fixed. - tracked_objects::ScopedTracker tracking_profile( - FROM_HERE_WITH_EXPLICIT_FUNCTION( - "424386 SSLClientSocketNSS::Core::HandshakeSucceeded")); - +void SSLClientSocketNSS::Core::HandshakeSucceeded(bool is_initial) { DCHECK(OnNSSTaskRunner()); PRBool last_handshake_resumed; @@ -1653,6 +1325,22 @@ void SSLClientSocketNSS::Core::HandshakeSucceeded() { UpdateNextProto(); UpdateExtensionUsed(); + if (is_initial && IsRenegotiationAllowed()) { + // For compatibility, do not enforce RFC 5746 support. Per section 4.1, + // enforcement falls largely on the server. + // + // This is done in a callback rather than after SSL_ForceHandshake returns + // because SSL_ForceHandshake will otherwise greedly consume renegotiations + // before returning if Finished and HelloRequest are in the same + // record. + // + // Note that SSL_OptionSet should only be called for an initial + // handshake. See https://crbug.com/125299. + SECStatus rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_RENEGOTIATION, + SSL_RENEGOTIATE_TRANSITIONAL); + DCHECK_EQ(SECSuccess, rv); + } + // Update the network task runners view of the handshake state whenever // a handshake has completed. PostOrRunCallback( @@ -1661,48 +1349,12 @@ void SSLClientSocketNSS::Core::HandshakeSucceeded() { } int SSLClientSocketNSS::Core::HandleNSSError(PRErrorCode nss_error) { - // TODO(vadimt): Remove ScopedTracker below once crbug.com/424386 is fixed. - tracked_objects::ScopedTracker tracking_profile( - FROM_HERE_WITH_EXPLICIT_FUNCTION( - "424386 SSLClientSocketNSS::Core::HandleNSSError")); - DCHECK(OnNSSTaskRunner()); - int net_error = MapNSSClientError(nss_error); - -#if defined(OS_WIN) - // On Windows, a handle to the HCRYPTPROV is cached in the X509Certificate - // os_cert_handle() as an optimization. However, if the certificate - // private key is stored on a smart card, and the smart card is removed, - // the cached HCRYPTPROV will not be able to obtain the HCRYPTKEY again, - // preventing client certificate authentication. Because the - // X509Certificate may outlive the individual SSLClientSocketNSS, due to - // caching in X509Certificate, this failure ends up preventing client - // certificate authentication with the same certificate for all future - // attempts, even after the smart card has been re-inserted. By setting - // the CERT_KEY_PROV_HANDLE_PROP_ID to NULL, the cached HCRYPTPROV will - // typically be freed. This allows a new HCRYPTPROV to be obtained from - // the certificate on the next attempt, which should succeed if the smart - // card has been re-inserted, or will typically prompt the user to - // re-insert the smart card if not. - if ((net_error == ERR_SSL_CLIENT_AUTH_CERT_NO_PRIVATE_KEY || - net_error == ERR_SSL_CLIENT_AUTH_SIGNATURE_FAILED) && - ssl_config_.send_client_cert && ssl_config_.client_cert) { - CertSetCertificateContextProperty( - ssl_config_.client_cert->os_cert_handle(), - CERT_KEY_PROV_HANDLE_PROP_ID, 0, NULL); - } -#endif - - return net_error; + return MapNSSClientError(nss_error); } int SSLClientSocketNSS::Core::DoHandshakeLoop(int last_io_result) { - // TODO(vadimt): Remove ScopedTracker below once crbug.com/424386 is fixed. - tracked_objects::ScopedTracker tracking_profile( - FROM_HERE_WITH_EXPLICIT_FUNCTION( - "424386 SSLClientSocketNSS::Core::DoHandshakeLoop")); - DCHECK(OnNSSTaskRunner()); int rv = last_io_result; @@ -1739,11 +1391,6 @@ int SSLClientSocketNSS::Core::DoHandshakeLoop(int last_io_result) { } int SSLClientSocketNSS::Core::DoReadLoop(int result) { - // TODO(vadimt): Remove ScopedTracker below once crbug.com/424386 is fixed. - tracked_objects::ScopedTracker tracking_profile( - FROM_HERE_WITH_EXPLICIT_FUNCTION( - "424386 SSLClientSocketNSS::Core::DoReadLoop")); - DCHECK(OnNSSTaskRunner()); DCHECK(false_started_ || handshake_callback_called_); DCHECK_EQ(STATE_NONE, next_handshake_state_); @@ -1803,21 +1450,11 @@ int SSLClientSocketNSS::Core::DoWriteLoop(int result) { } int SSLClientSocketNSS::Core::DoHandshake() { - // TODO(vadimt): Remove ScopedTracker below once crbug.com/424386 is fixed. - tracked_objects::ScopedTracker tracking_profile( - FROM_HERE_WITH_EXPLICIT_FUNCTION( - "424386 SSLClientSocketNSS::Core::DoHandshake")); - DCHECK(OnNSSTaskRunner()); int net_error = OK; SECStatus rv = SSL_ForceHandshake(nss_fd_); - // TODO(vadimt): Remove ScopedTracker below once crbug.com/424386 is fixed. - tracked_objects::ScopedTracker tracking_profile1( - FROM_HERE_WITH_EXPLICIT_FUNCTION( - "424386 SSLClientSocketNSS::Core::DoHandshake 1")); - // Note: this function may be called multiple times during the handshake, so // even though channel id and client auth are separate else cases, they can // both be used during a single SSL handshake. @@ -1831,18 +1468,10 @@ int SSLClientSocketNSS::Core::DoHandshake() { base::Bind(&AddLogEventWithCallback, weak_net_log_, NetLog::TYPE_SSL_HANDSHAKE_ERROR, CreateNetLogSSLErrorCallback(net_error, 0))); - - // If the handshake already succeeded (because the server requests but - // doesn't require a client cert), we need to invalidate the SSL session - // so that we won't try to resume the non-client-authenticated session in - // the next handshake. This will cause the server to ask for a client - // cert again. - if (rv == SECSuccess && SSL_InvalidateSession(nss_fd_) != SECSuccess) - LOG(WARNING) << "Couldn't invalidate SSL session: " << PR_GetError(); } else if (rv == SECSuccess) { if (!handshake_callback_called_) { false_started_ = true; - HandshakeSucceeded(); + HandshakeSucceeded(true); } } else { PRErrorCode prerr = PR_GetError(); @@ -1864,11 +1493,6 @@ int SSLClientSocketNSS::Core::DoHandshake() { } int SSLClientSocketNSS::Core::DoGetDBCertComplete(int result) { - // TODO(vadimt): Remove ScopedTracker below once crbug.com/424386 is fixed. - tracked_objects::ScopedTracker tracking_profile( - FROM_HERE_WITH_EXPLICIT_FUNCTION( - "424386 SSLClientSocketNSS::Core::DoGetDBCertComplete")); - SECStatus rv; PostOrRunCallback( FROM_HERE, @@ -2053,11 +1677,6 @@ int SSLClientSocketNSS::Core::DoPayloadWrite() { // transport socket. Return true if some I/O performed, false // otherwise (error or ERR_IO_PENDING). bool SSLClientSocketNSS::Core::DoTransportIO() { - // TODO(vadimt): Remove ScopedTracker below once crbug.com/424386 is fixed. - tracked_objects::ScopedTracker tracking_profile( - FROM_HERE_WITH_EXPLICIT_FUNCTION( - "424386 SSLClientSocketNSS::Core::DoTransportIO")); - DCHECK(OnNSSTaskRunner()); bool network_moved = false; @@ -2234,11 +1853,6 @@ void SSLClientSocketNSS::Core::OnSendComplete(int result) { // callback. For Read() and Write(), that's what we want. But for Connect(), // the caller expects OK (i.e. 0) for success. void SSLClientSocketNSS::Core::DoConnectCallback(int rv) { - // TODO(vadimt): Remove ScopedTracker below once crbug.com/424386 is fixed. - tracked_objects::ScopedTracker tracking_profile( - FROM_HERE_WITH_EXPLICIT_FUNCTION( - "424386 SSLClientSocketNSS::Core::DoConnectCallback")); - DCHECK(OnNSSTaskRunner()); DCHECK_NE(rv, ERR_IO_PENDING); DCHECK(!user_connect_callback_.is_null()); @@ -2250,11 +1864,6 @@ void SSLClientSocketNSS::Core::DoConnectCallback(int rv) { } void SSLClientSocketNSS::Core::DoReadCallback(int rv) { - // TODO(vadimt): Remove ScopedTracker below once crbug.com/424386 is fixed. - tracked_objects::ScopedTracker tracking_profile( - FROM_HERE_WITH_EXPLICIT_FUNCTION( - "424386 SSLClientSocketNSS::Core::DoReadCallback")); - DCHECK(OnNSSTaskRunner()); DCHECK_NE(ERR_IO_PENDING, rv); DCHECK(!user_read_callback_.is_null()); @@ -2270,10 +1879,6 @@ void SSLClientSocketNSS::Core::DoReadCallback(int rv) { PostOrRunCallback( FROM_HERE, base::Bind(&Core::DidNSSRead, this, rv)); - // TODO(vadimt): Remove ScopedTracker below once crbug.com/418183 is fixed. - tracked_objects::ScopedTracker tracking_profile1( - FROM_HERE_WITH_EXPLICIT_FUNCTION( - "SSLClientSocketNSS::Core::DoReadCallback")); PostOrRunCallback( FROM_HERE, base::Bind(base::ResetAndReturn(&user_read_callback_), rv)); @@ -2435,35 +2040,6 @@ void SSLClientSocketNSS::Core::UpdateStapledOCSPResponse() { nss_handshake_state_.stapled_ocsp_response = std::string( reinterpret_cast<char*>(ocsp_responses->items[0].data), ocsp_responses->items[0].len); - - // TODO(agl): figure out how to plumb an OCSP response into the Mac - // system library and update IsOCSPStaplingSupported for Mac. - if (IsOCSPStaplingSupported()) { - #if defined(OS_WIN) - if (nss_handshake_state_.server_cert) { - CRYPT_DATA_BLOB ocsp_response_blob; - ocsp_response_blob.cbData = ocsp_responses->items[0].len; - ocsp_response_blob.pbData = ocsp_responses->items[0].data; - BOOL ok = CertSetCertificateContextProperty( - nss_handshake_state_.server_cert->os_cert_handle(), - CERT_OCSP_RESPONSE_PROP_ID, - CERT_SET_PROPERTY_IGNORE_PERSIST_ERROR_FLAG, - &ocsp_response_blob); - if (!ok) { - VLOG(1) << "Failed to set OCSP response property: " - << GetLastError(); - } - } - #elif defined(USE_NSS) - CacheOCSPResponseFromSideChannelFunction cache_ocsp_response = - GetCacheOCSPResponseFromSideChannelFunction(); - - cache_ocsp_response( - CERT_GetDefaultCertDB(), - nss_handshake_state_.server_cert_chain[0], PR_Now(), - &ocsp_responses->items[0], NULL); - #endif - } // IsOCSPStaplingSupported() } void SSLClientSocketNSS::Core::UpdateConnectionStatus() { @@ -2477,10 +2053,7 @@ void SSLClientSocketNSS::Core::UpdateConnectionStatus() { if (ok == SECSuccess && channel_info.length == sizeof(channel_info) && channel_info.cipherSuite) { - nss_handshake_state_.ssl_connection_status |= - (static_cast<int>(channel_info.cipherSuite) & - SSL_CONNECTION_CIPHERSUITE_MASK) << - SSL_CONNECTION_CIPHERSUITE_SHIFT; + nss_handshake_state_.ssl_connection_status |= channel_info.cipherSuite; nss_handshake_state_.ssl_connection_status |= (static_cast<int>(channel_info.compressionMethod) & @@ -2488,19 +2061,14 @@ void SSLClientSocketNSS::Core::UpdateConnectionStatus() { SSL_CONNECTION_COMPRESSION_SHIFT; int version = SSL_CONNECTION_VERSION_UNKNOWN; - if (channel_info.protocolVersion < SSL_LIBRARY_VERSION_3_0) { - // All versions less than SSL_LIBRARY_VERSION_3_0 are treated as SSL - // version 2. - version = SSL_CONNECTION_VERSION_SSL2; - } else if (channel_info.protocolVersion == SSL_LIBRARY_VERSION_3_0) { - version = SSL_CONNECTION_VERSION_SSL3; - } else if (channel_info.protocolVersion == SSL_LIBRARY_VERSION_TLS_1_0) { + if (channel_info.protocolVersion == SSL_LIBRARY_VERSION_TLS_1_0) { version = SSL_CONNECTION_VERSION_TLS1; } else if (channel_info.protocolVersion == SSL_LIBRARY_VERSION_TLS_1_1) { version = SSL_CONNECTION_VERSION_TLS1_1; } else if (channel_info.protocolVersion == SSL_LIBRARY_VERSION_TLS_1_2) { version = SSL_CONNECTION_VERSION_TLS1_2; } + DCHECK_NE(SSL_CONNECTION_VERSION_UNKNOWN, version); nss_handshake_state_.ssl_connection_status |= (version & SSL_CONNECTION_VERSION_MASK) << SSL_CONNECTION_VERSION_SHIFT; @@ -2571,6 +2139,20 @@ void SSLClientSocketNSS::Core::UpdateExtensionUsed() { } } +bool SSLClientSocketNSS::Core::IsRenegotiationAllowed() const { + DCHECK(OnNSSTaskRunner()); + + if (nss_handshake_state_.next_proto_status == kNextProtoUnsupported) + return ssl_config_.renego_allowed_default; + + NextProto next_proto = NextProtoFromString(nss_handshake_state_.next_proto); + for (NextProto allowed : ssl_config_.renego_allowed_for_protos) { + if (next_proto == allowed) + return true; + } + return false; +} + void SSLClientSocketNSS::Core::RecordChannelIDSupportOnNSSTaskRunner() { DCHECK(OnNSSTaskRunner()); if (nss_handshake_state_.resumed_handshake) @@ -2703,11 +2285,6 @@ void SSLClientSocketNSS::Core::DidNSSWrite(int result) { } void SSLClientSocketNSS::Core::BufferSendComplete(int result) { - // TODO(vadimt): Remove ScopedTracker below once crbug.com/418183 is fixed. - tracked_objects::ScopedTracker tracking_profile( - FROM_HERE_WITH_EXPLICIT_FUNCTION( - "418183 DidCompleteReadWrite => Core::BufferSendComplete")); - if (!OnNSSTaskRunner()) { if (detached_) return; @@ -2751,11 +2328,6 @@ void SSLClientSocketNSS::Core::OnGetChannelIDComplete(int result) { void SSLClientSocketNSS::Core::BufferRecvComplete( IOBuffer* read_buffer, int result) { - // TODO(vadimt): Remove ScopedTracker below once crbug.com/418183 is fixed. - tracked_objects::ScopedTracker tracking_profile( - FROM_HERE_WITH_EXPLICIT_FUNCTION( - "418183 DidCompleteReadWrite => SSLClientSocketNSS::Core::...")); - DCHECK(read_buffer); if (!OnNSSTaskRunner()) { @@ -2838,7 +2410,10 @@ SSLClientSocketNSS::SSLClientSocketNSS( nss_fd_(NULL), net_log_(transport_->socket()->NetLog()), transport_security_state_(context.transport_security_state), + policy_enforcer_(context.cert_policy_enforcer), valid_thread_id_(base::kInvalidThreadId) { + DCHECK(cert_verifier_); + EnterFunction(""); InitCore(); LeaveFunction(""); @@ -2860,6 +2435,20 @@ void SSLClientSocket::ClearSessionCache() { SSL_ClearSessionCache(); } +#if !defined(CKM_NSS_TLS_MASTER_KEY_DERIVE_DH_SHA256) +#define CKM_NSS_TLS_MASTER_KEY_DERIVE_DH_SHA256 (CKM_NSS + 24) +#endif + +// static +uint16 SSLClientSocket::GetMaxSupportedSSLVersion() { + crypto::EnsureNSSInit(); + if (PK11_TokenExists(CKM_NSS_TLS_MASTER_KEY_DERIVE_DH_SHA256)) { + return SSL_PROTOCOL_VERSION_TLS1_2; + } else { + return SSL_PROTOCOL_VERSION_TLS1_1; + } +} + bool SSLClientSocketNSS::GetSSLInfo(SSLInfo* ssl_info) { EnterFunction(""); ssl_info->Reset(); @@ -2880,7 +2469,7 @@ bool SSLClientSocketNSS::GetSSLInfo(SSLInfo* ssl_info) { server_cert_verify_result_.is_issued_by_known_root; ssl_info->client_cert_sent = ssl_config_.send_client_cert && ssl_config_.client_cert.get(); - ssl_info->channel_id_sent = WasChannelIDSent(); + ssl_info->channel_id_sent = core_->state().channel_id_sent; ssl_info->pinning_failure_log = pinning_failure_log_; PRUint16 cipher_suite = SSLConnectionStatusToCipherSuite( @@ -2903,19 +2492,8 @@ bool SSLClientSocketNSS::GetSSLInfo(SSLInfo* ssl_info) { return true; } -std::string SSLClientSocketNSS::GetSessionCacheKey() const { - NOTIMPLEMENTED(); - return std::string(); -} - -bool SSLClientSocketNSS::InSessionCache() const { - // For now, always return true so that SSLConnectJobs are never held back. - return true; -} - -void SSLClientSocketNSS::SetHandshakeCompletionCallback( - const base::Closure& callback) { - NOTIMPLEMENTED(); +void SSLClientSocketNSS::GetConnectionAttempts(ConnectionAttempts* out) const { + out->clear(); } void SSLClientSocketNSS::GetSSLCertRequestInfo( @@ -2963,8 +2541,8 @@ int SSLClientSocketNSS::GetTLSUniqueChannelBinding(std::string* out) { return OK; } -SSLClientSocket::NextProtoStatus -SSLClientSocketNSS::GetNextProto(std::string* proto) { +SSLClientSocket::NextProtoStatus SSLClientSocketNSS::GetNextProto( + std::string* proto) const { *proto = core_->state().next_proto; return core_->state().next_proto_status; } @@ -3021,7 +2599,7 @@ void SSLClientSocketNSS::Disconnect() { // Shut down anything that may call us back. core_->Detach(); - verifier_.reset(); + cert_verifier_request_.reset(); transport_->socket()->Disconnect(); // Reset object state. @@ -3135,7 +2713,7 @@ int SSLClientSocketNSS::Init() { EnsureNSSSSLInit(); if (!NSS_IsInitialized()) return ERR_UNEXPECTED; -#if defined(USE_NSS) || defined(OS_IOS) +#if defined(USE_NSS_CERTS) || defined(OS_IOS) if (ssl_config_.cert_io_enabled) { // We must call EnsureNSSHttpIOInit() here, on the IO thread, to get the IO // loop by MessageLoopForIO::current(). @@ -3224,6 +2802,20 @@ int SSLClientSocketNSS::InitializeSSLOptions() { SSL_CipherPrefSet(nss_fd_, *it, PR_FALSE); } + if (!ssl_config_.enable_deprecated_cipher_suites) { + const PRUint16* const ssl_ciphers = SSL_GetImplementedCiphers(); + const PRUint16 num_ciphers = SSL_GetNumImplementedCiphers(); + for (int i = 0; i < num_ciphers; i++) { + SSLCipherSuiteInfo info; + if (SSL_GetCipherSuiteInfo(ssl_ciphers[i], &info, sizeof(info)) != + SECSuccess) { + continue; + } + if (info.symCipher == ssl_calg_rc4) + SSL_CipherPrefSet(nss_fd_, ssl_ciphers[i], PR_FALSE); + } + } + // Support RFC 5077 rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_SESSION_TICKETS, PR_TRUE); if (rv != SECSuccess) { @@ -3236,13 +2828,9 @@ int SSLClientSocketNSS::InitializeSSLOptions() { if (rv != SECSuccess) LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_ENABLE_FALSE_START"); - // We allow servers to request renegotiation. Since we're a client, - // prohibiting this is rather a waste of time. Only servers are in a - // position to prevent renegotiation attacks. - // http://extendedsubset.com/?p=8 - - rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_RENEGOTIATION, - SSL_RENEGOTIATE_TRANSITIONAL); + // By default, renegotiations are rejected. After the initial handshake + // completes, some application protocols may re-enable it. + rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_RENEGOTIATION, SSL_RENEGOTIATE_NEVER); if (rv != SECSuccess) { LogFailedNSSFunction( net_log_, "SSL_OptionSet", "SSL_ENABLE_RENEGOTIATION"); @@ -3257,8 +2845,8 @@ int SSLClientSocketNSS::InitializeSSLOptions() { // Request OCSP stapling even on platforms that don't support it, in // order to extract Certificate Transparency information. rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_OCSP_STAPLING, - (IsOCSPStaplingSupported() || - ssl_config_.signed_cert_timestamps_enabled)); + cert_verifier_->SupportsOCSPStapling() || + ssl_config_.signed_cert_timestamps_enabled); if (rv != SECSuccess) { LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_ENABLE_OCSP_STAPLING"); @@ -3320,13 +2908,29 @@ int SSLClientSocketNSS::InitializeSSLPeerName() { // SSL tunnel through a proxy -- GetPeerName returns the proxy's address // rather than the destination server's address in that case. std::string peer_id = host_and_port_.ToString(); - // If the ssl_session_cache_shard_ is non-empty, we append it to the peer id. - // This will cause session cache misses between sockets with different values - // of ssl_session_cache_shard_ and this is used to partition the session cache - // for incognito mode. - if (!ssl_session_cache_shard_.empty()) { - peer_id += "/" + ssl_session_cache_shard_; + // Append |ssl_session_cache_shard_| to the peer id. This is used to partition + // the session cache for incognito mode. + peer_id += "/" + ssl_session_cache_shard_; + peer_id += "/"; + // Shard the session cache based on maximum protocol version. This causes + // fallback connections to use a separate session cache. + switch (ssl_config_.version_max) { + case SSL_PROTOCOL_VERSION_TLS1: + peer_id += "tls1"; + break; + case SSL_PROTOCOL_VERSION_TLS1_1: + peer_id += "tls1.1"; + break; + case SSL_PROTOCOL_VERSION_TLS1_2: + peer_id += "tls1.2"; + break; + default: + NOTREACHED(); } + peer_id += "/"; + if (ssl_config_.enable_deprecated_cipher_suites) + peer_id += "deprecated"; + SECStatus rv = SSL_SetSockPeerID(nss_fd_, const_cast<char*>(peer_id.c_str())); if (rv != SECSuccess) LogFailedNSSFunction(net_log_, "SSL_SetSockPeerID", peer_id.c_str()); @@ -3409,11 +3013,12 @@ int SSLClientSocketNSS::DoHandshakeComplete(int result) { return ERR_SSL_FALLBACK_BEYOND_MINIMUM_VERSION; } + RecordNegotiationExtension(); + // SSL handshake is completed. Let's verify the certificate. GotoState(STATE_VERIFY_CERT); // Done! } - set_channel_id_sent(core_->state().channel_id_sent); set_signed_cert_timestamps_received( !core_->state().sct_list_from_tls_extension.empty()); set_stapled_ocsp_response_received( @@ -3465,22 +3070,19 @@ int SSLClientSocketNSS::DoVerifyCert(int result) { flags |= CertVerifier::VERIFY_CERT_IO_ENABLED; if (ssl_config_.rev_checking_required_local_anchors) flags |= CertVerifier::VERIFY_REV_CHECKING_REQUIRED_LOCAL_ANCHORS; - verifier_.reset(new SingleRequestCertVerifier(cert_verifier_)); - return verifier_->Verify( - core_->state().server_cert.get(), - host_and_port_.host(), - flags, - SSLConfigService::GetCRLSet().get(), - &server_cert_verify_result_, + return cert_verifier_->Verify( + core_->state().server_cert.get(), host_and_port_.host(), + core_->state().stapled_ocsp_response, flags, + SSLConfigService::GetCRLSet().get(), &server_cert_verify_result_, base::Bind(&SSLClientSocketNSS::OnHandshakeIOComplete, base::Unretained(this)), - net_log_); + &cert_verifier_request_, net_log_); } // Derived from AuthCertificateCallback() in // mozilla/source/security/manager/ssl/src/nsNSSCallbacks.cpp. int SSLClientSocketNSS::DoVerifyCertComplete(int result) { - verifier_.reset(); + cert_verifier_request_.reset(); if (!start_cert_verification_time_.is_null()) { base::TimeDelta verify_time = @@ -3501,15 +3103,6 @@ int SSLClientSocketNSS::DoVerifyCertComplete(int result) { // purposes. See https://bugzilla.mozilla.org/show_bug.cgi?id=508081 and // http://crbug.com/15630 for more info. - // TODO(hclam): Skip logging if server cert was expected to be bad because - // |server_cert_verify_result_| doesn't contain all the information about - // the cert. - if (result == OK) { - int ssl_version = - SSLConnectionStatusToVersion(core_->state().ssl_connection_status); - RecordConnectionTypeMetrics(ssl_version); - } - const CertStatus cert_status = server_cert_verify_result_.cert_status; if (transport_security_state_ && (result == OK || @@ -3522,21 +3115,6 @@ int SSLClientSocketNSS::DoVerifyCertComplete(int result) { result = ERR_SSL_PINNED_KEY_NOT_IN_CERT_CHAIN; } - scoped_refptr<ct::EVCertsWhitelist> ev_whitelist = - SSLConfigService::GetEVCertsWhitelist(); - if (server_cert_verify_result_.cert_status & CERT_STATUS_IS_EV) { - if (ev_whitelist.get() && ev_whitelist->IsValid()) { - const SHA256HashValue fingerprint( - X509Certificate::CalculateFingerprint256( - server_cert_verify_result_.verified_cert->os_cert_handle())); - - UMA_HISTOGRAM_BOOLEAN( - "Net.SSL_EVCertificateInWhitelist", - ev_whitelist->ContainsCertificateHash( - std::string(reinterpret_cast<const char*>(fingerprint.data), 8))); - } - } - if (result == OK) { // Only check Certificate Transparency if there were no other errors with // the connection. @@ -3560,20 +3138,31 @@ void SSLClientSocketNSS::VerifyCT() { // Note that this is a completely synchronous operation: The CT Log Verifier // gets all the data it needs for SCT verification and does not do any // external communication. - int result = cert_transparency_verifier_->Verify( + cert_transparency_verifier_->Verify( server_cert_verify_result_.verified_cert.get(), core_->state().stapled_ocsp_response, - core_->state().sct_list_from_tls_extension, - &ct_verify_result_, - net_log_); + core_->state().sct_list_from_tls_extension, &ct_verify_result_, net_log_); // TODO(ekasper): wipe stapled_ocsp_response and sct_list_from_tls_extension // from the state after verification is complete, to conserve memory. - VLOG(1) << "CT Verification complete: result " << result - << " Invalid scts: " << ct_verify_result_.invalid_scts.size() - << " Verified scts: " << ct_verify_result_.verified_scts.size() - << " scts from unknown logs: " - << ct_verify_result_.unknown_logs_scts.size(); + if (!policy_enforcer_) { + server_cert_verify_result_.cert_status &= ~CERT_STATUS_IS_EV; + } else { + if (server_cert_verify_result_.cert_status & CERT_STATUS_IS_EV) { + scoped_refptr<ct::EVCertsWhitelist> ev_whitelist = + SSLConfigService::GetEVCertsWhitelist(); + if (!policy_enforcer_->DoesConformToCTEVPolicy( + server_cert_verify_result_.verified_cert.get(), + ev_whitelist.get(), ct_verify_result_, net_log_)) { + // TODO(eranm): Log via the BoundNetLog, see crbug.com/437766 + VLOG(1) << "EV certificate for " + << server_cert_verify_result_.verified_cert->subject() + .GetDisplayName() + << " does not conform to CT policy, removing EV status."; + server_cert_verify_result_.cert_status &= ~CERT_STATUS_IS_EV; + } + } + } } void SSLClientSocketNSS::EnsureThreadIdAssigned() const { @@ -3620,4 +3209,10 @@ ChannelIDService* SSLClientSocketNSS::GetChannelIDService() const { return channel_id_service_; } +SSLFailureState SSLClientSocketNSS::GetSSLFailureState() const { + if (completed_handshake_) + return SSL_FAILURE_NONE; + return SSL_FAILURE_UNKNOWN; +} + } // namespace net diff --git a/chromium/net/socket/ssl_client_socket_nss.h b/chromium/net/socket/ssl_client_socket_nss.h index 71f09c0b82b..75b47af5cec 100644 --- a/chromium/net/socket/ssl_client_socket_nss.h +++ b/chromium/net/socket/ssl_client_socket_nss.h @@ -20,11 +20,12 @@ #include "net/base/completion_callback.h" #include "net/base/host_port_pair.h" #include "net/base/net_export.h" -#include "net/base/net_log.h" #include "net/base/nss_memio.h" +#include "net/cert/cert_verifier.h" #include "net/cert/cert_verify_result.h" #include "net/cert/ct_verify_result.h" #include "net/cert/x509_certificate.h" +#include "net/log/net_log.h" #include "net/socket/ssl_client_socket.h" #include "net/ssl/channel_id_service.h" #include "net/ssl/ssl_config_service.h" @@ -36,11 +37,11 @@ class SequencedTaskRunner; namespace net { class BoundNetLog; +class CertPolicyEnforcer; class CertVerifier; class ChannelIDService; class CTVerifier; class ClientSocketHandle; -class SingleRequestCertVerifier; class TransportSecurityState; class X509Certificate; @@ -67,11 +68,8 @@ class SSLClientSocketNSS : public SSLClientSocket { ~SSLClientSocketNSS() override; // SSLClientSocket implementation. - std::string GetSessionCacheKey() const override; - bool InSessionCache() const override; - void SetHandshakeCompletionCallback(const base::Closure& callback) override; void GetSSLCertRequestInfo(SSLCertRequestInfo* cert_request_info) override; - NextProtoStatus GetNextProto(std::string* proto) override; + NextProtoStatus GetNextProto(std::string* proto) const override; // SSLSocket implementation. int ExportKeyingMaterial(const base::StringPiece& label, @@ -94,6 +92,9 @@ class SSLClientSocketNSS : public SSLClientSocket { bool WasEverUsed() const override; bool UsingTCPFastOpen() const override; bool GetSSLInfo(SSLInfo* ssl_info) override; + void GetConnectionAttempts(ConnectionAttempts* out) const override; + void ClearConnectionAttempts() override {} + void AddConnectionAttempts(const ConnectionAttempts& attempts) override {} // Socket implementation. int Read(IOBuffer* buf, @@ -104,7 +105,10 @@ class SSLClientSocketNSS : public SSLClientSocket { const CompletionCallback& callback) override; int SetReceiveBufferSize(int32 size) override; int SetSendBufferSize(int32 size) override; + + // SSLClientSocket implementation. ChannelIDService* GetChannelIDService() const override; + SSLFailureState GetSSLFailureState() const override; protected: // SSLClientSocket implementation. @@ -169,7 +173,7 @@ class SSLClientSocketNSS : public SSLClientSocket { CertVerifyResult server_cert_verify_result_; CertVerifier* const cert_verifier_; - scoped_ptr<SingleRequestCertVerifier> verifier_; + scoped_ptr<CertVerifier::Request> cert_verifier_request_; // Certificate Transparency: Verifier and result holder. ct::CTVerifyResult ct_verify_result_; @@ -199,6 +203,8 @@ class SSLClientSocketNSS : public SSLClientSocket { TransportSecurityState* transport_security_state_; + CertPolicyEnforcer* const policy_enforcer_; + // pinning_failure_log contains a message produced by // TransportSecurityState::CheckPublicKeyPins in the event of a // pinning failure. It is a (somewhat) human-readable string. diff --git a/chromium/net/socket/ssl_client_socket_openssl.cc b/chromium/net/socket/ssl_client_socket_openssl.cc index 9fdfe38ccdd..89d2952875b 100644 --- a/chromium/net/socket/ssl_client_socket_openssl.cc +++ b/chromium/net/socket/ssl_client_socket_openssl.cc @@ -10,29 +10,35 @@ #include <errno.h> #include <openssl/bio.h> #include <openssl/err.h> +#include <openssl/mem.h> #include <openssl/ssl.h> +#include <string.h> #include "base/bind.h" #include "base/callback_helpers.h" #include "base/environment.h" #include "base/memory/singleton.h" #include "base/metrics/histogram.h" +#include "base/profiler/scoped_tracker.h" #include "base/strings/string_piece.h" #include "base/synchronization/lock.h" +#include "base/threading/thread_local.h" #include "crypto/ec_private_key.h" #include "crypto/openssl_util.h" #include "crypto/scoped_openssl_types.h" #include "net/base/net_errors.h" +#include "net/cert/cert_policy_enforcer.h" #include "net/cert/cert_verifier.h" #include "net/cert/ct_ev_whitelist.h" #include "net/cert/ct_verifier.h" -#include "net/cert/single_request_cert_verifier.h" #include "net/cert/x509_certificate_net_log_param.h" #include "net/cert/x509_util_openssl.h" #include "net/http/transport_security_state.h" -#include "net/socket/ssl_session_cache_openssl.h" +#include "net/ssl/scoped_openssl_types.h" #include "net/ssl/ssl_cert_request_info.h" +#include "net/ssl/ssl_client_session_cache_openssl.h" #include "net/ssl/ssl_connection_status_flags.h" +#include "net/ssl/ssl_failure_state.h" #include "net/ssl/ssl_info.h" #if defined(OS_WIN) @@ -66,13 +72,14 @@ const int kNoPendingReadResult = 1; // the server supports NPN, choosing "http/1.1" is the best answer. const char kDefaultSupportedNPNProtocol[] = "http/1.1"; +// Default size of the internal BoringSSL buffers. +const int KDefaultOpenSSLBufferSize = 17 * 1024; + void FreeX509Stack(STACK_OF(X509)* ptr) { sk_X509_pop_free(ptr, X509_free); } -typedef crypto::ScopedOpenSSL<X509, X509_free>::Type ScopedX509; -typedef crypto::ScopedOpenSSL<STACK_OF(X509), FreeX509Stack>::Type - ScopedX509Stack; +using ScopedX509Stack = crypto::ScopedOpenSSL<STACK_OF(X509), FreeX509Stack>; #if OPENSSL_VERSION_NUMBER < 0x1000103fL // This method doesn't seem to have made it into the OpenSSL headers. @@ -80,11 +87,10 @@ unsigned long SSL_CIPHER_get_id(const SSL_CIPHER* cipher) { return cipher->id; } #endif // Used for encoding the |connection_status| field of an SSLInfo object. -int EncodeSSLConnectionStatus(int cipher_suite, +int EncodeSSLConnectionStatus(uint16 cipher_suite, int compression, int version) { - return ((cipher_suite & SSL_CONNECTION_CIPHERSUITE_MASK) << - SSL_CONNECTION_CIPHERSUITE_SHIFT) | + return cipher_suite | ((compression & SSL_CONNECTION_COMPRESSION_MASK) << SSL_CONNECTION_COMPRESSION_SHIFT) | ((version & SSL_CONNECTION_VERSION_MASK) << @@ -95,10 +101,6 @@ int EncodeSSLConnectionStatus(int cipher_suite, // this SSL connection. int GetNetSSLVersion(SSL* ssl) { switch (SSL_version(ssl)) { - case SSL2_VERSION: - return SSL_CONNECTION_VERSION_SSL2; - case SSL3_VERSION: - return SSL_CONNECTION_VERSION_SSL3; case TLS1_VERSION: return SSL_CONNECTION_VERSION_TLS1; case TLS1_1_VERSION: @@ -106,6 +108,7 @@ int GetNetSSLVersion(SSL* ssl) { case TLS1_2_VERSION: return SSL_CONNECTION_VERSION_TLS1_2; default: + NOTREACHED(); return SSL_CONNECTION_VERSION_UNKNOWN; } } @@ -146,7 +149,7 @@ class SSLClientSocketOpenSSL::SSLContext { public: static SSLContext* GetInstance() { return Singleton<SSLContext>::get(); } SSL_CTX* ssl_ctx() { return ssl_ctx_.get(); } - SSLSessionCacheOpenSSL* session_cache() { return &session_cache_; } + SSLClientSessionCacheOpenSSL* session_cache() { return &session_cache_; } SSLClientSocketOpenSSL* GetClientSocketFromSSL(const SSL* ssl) { DCHECK(ssl); @@ -163,21 +166,30 @@ class SSLClientSocketOpenSSL::SSLContext { private: friend struct DefaultSingletonTraits<SSLContext>; - SSLContext() { + SSLContext() : session_cache_(SSLClientSessionCacheOpenSSL::Config()) { crypto::EnsureOpenSSLInit(); ssl_socket_data_index_ = SSL_get_ex_new_index(0, 0, 0, 0, 0); DCHECK_NE(ssl_socket_data_index_, -1); ssl_ctx_.reset(SSL_CTX_new(SSLv23_client_method())); - session_cache_.Reset(ssl_ctx_.get(), kDefaultSessionCacheConfig); SSL_CTX_set_cert_verify_callback(ssl_ctx_.get(), CertVerifyCallback, NULL); SSL_CTX_set_cert_cb(ssl_ctx_.get(), ClientCertRequestCallback, NULL); SSL_CTX_set_verify(ssl_ctx_.get(), SSL_VERIFY_PEER, NULL); + // This stops |SSL_shutdown| from generating the close_notify message, which + // is currently not sent on the network. + // TODO(haavardm): Remove setting quiet shutdown once 118366 is fixed. + SSL_CTX_set_quiet_shutdown(ssl_ctx_.get(), 1); // TODO(kristianm): Only select this if ssl_config_.next_proto is not empty. // It would be better if the callback were not a global setting, // but that is an OpenSSL issue. SSL_CTX_set_next_proto_select_cb(ssl_ctx_.get(), SelectNextProtoCallback, NULL); ssl_ctx_->tlsext_channel_id_enabled_new = 1; + SSL_CTX_set_info_callback(ssl_ctx_.get(), InfoCallback); + + // Disable the internal session cache. Session caching is handled + // externally (i.e. by SSLClientSessionCacheOpenSSL). + SSL_CTX_set_session_cache_mode( + ssl_ctx_.get(), SSL_SESS_CACHE_CLIENT | SSL_SESS_CACHE_NO_INTERNAL); scoped_ptr<base::Environment> env(base::Environment::Create()); std::string ssl_keylog_file; @@ -194,14 +206,6 @@ class SSLClientSocketOpenSSL::SSLContext { } } - static std::string GetSessionCacheKey(const SSL* ssl) { - SSLClientSocketOpenSSL* socket = GetInstance()->GetClientSocketFromSSL(ssl); - DCHECK(socket); - return socket->GetSessionCacheKey(); - } - - static SSLSessionCacheOpenSSL::Config kDefaultSessionCacheConfig; - static int ClientCertRequestCallback(SSL* ssl, void* arg) { SSLClientSocketOpenSSL* socket = GetInstance()->GetClientSocketFromSSL(ssl); DCHECK(socket); @@ -225,13 +229,23 @@ class SSLClientSocketOpenSSL::SSLContext { return socket->SelectNextProtoCallback(out, outlen, in, inlen); } + static void InfoCallback(const SSL* ssl, int type, int val) { + SSLClientSocketOpenSSL* socket = GetInstance()->GetClientSocketFromSSL(ssl); + socket->InfoCallback(type, val); + } + // This is the index used with SSL_get_ex_data to retrieve the owner // SSLClientSocketOpenSSL object from an SSL instance. int ssl_socket_data_index_; - crypto::ScopedOpenSSL<SSL_CTX, SSL_CTX_free>::Type ssl_ctx_; - // |session_cache_| must be destroyed before |ssl_ctx_|. - SSLSessionCacheOpenSSL session_cache_; + ScopedSSL_CTX ssl_ctx_; + + // TODO(davidben): Use a separate cache per URLRequestContext. + // https://crbug.com/458365 + // + // TODO(davidben): Sessions should be invalidated on fatal + // alerts. https://crbug.com/466352 + SSLClientSessionCacheOpenSSL session_cache_; }; // PeerCertificateChain is a helper object which extracts the certificate @@ -316,21 +330,17 @@ SSLClientSocketOpenSSL::PeerCertificateChain::AsOSChain() const { } // static -SSLSessionCacheOpenSSL::Config - SSLClientSocketOpenSSL::SSLContext::kDefaultSessionCacheConfig = { - &GetSessionCacheKey, // key_func - 1024, // max_entries - 256, // expiration_check_count - 60 * 60, // timeout_seconds -}; - -// static void SSLClientSocket::ClearSessionCache() { SSLClientSocketOpenSSL::SSLContext* context = SSLClientSocketOpenSSL::SSLContext::GetInstance(); context->session_cache()->Flush(); } +// static +uint16 SSLClientSocket::GetMaxSupportedSSLVersion() { + return SSL_PROTOCOL_VERSION_TLS1_2; +} + SSLClientSocketOpenSSL::SSLClientSocketOpenSSL( scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, @@ -345,7 +355,6 @@ SSLClientSocketOpenSSL::SSLClientSocketOpenSSL( server_cert_chain_(new PeerCertificateChain(NULL)), completed_connect_(false), was_ever_used_(false), - client_auth_cert_needed_(false), cert_verifier_(context.cert_verifier), cert_transparency_verifier_(context.cert_transparency_verifier), channel_id_service_(context.channel_id_service), @@ -355,39 +364,23 @@ SSLClientSocketOpenSSL::SSLClientSocketOpenSSL( host_and_port_(host_and_port), ssl_config_(ssl_config), ssl_session_cache_shard_(context.ssl_session_cache_shard), - trying_cached_session_(false), next_handshake_state_(STATE_NONE), npn_status_(kNextProtoUnsupported), - channel_id_xtn_negotiated_(false), - handshake_succeeded_(false), - marked_session_as_good_(false), + channel_id_sent_(false), + handshake_completed_(false), + certificate_verified_(false), + ssl_failure_state_(SSL_FAILURE_NONE), transport_security_state_(context.transport_security_state), + policy_enforcer_(context.cert_policy_enforcer), net_log_(transport_->socket()->NetLog()), weak_factory_(this) { + DCHECK(cert_verifier_); } SSLClientSocketOpenSSL::~SSLClientSocketOpenSSL() { Disconnect(); } -std::string SSLClientSocketOpenSSL::GetSessionCacheKey() const { - std::string result = host_and_port_.ToString(); - result.append("/"); - result.append(ssl_session_cache_shard_); - return result; -} - -bool SSLClientSocketOpenSSL::InSessionCache() const { - SSLContext* context = SSLContext::GetInstance(); - std::string cache_key = GetSessionCacheKey(); - return context->session_cache()->SSLSessionIsInCache(cache_key); -} - -void SSLClientSocketOpenSSL::SetHandshakeCompletionCallback( - const base::Closure& callback) { - handshake_completion_callback_ = callback; -} - void SSLClientSocketOpenSSL::GetSSLCertRequestInfo( SSLCertRequestInfo* cert_request_info) { cert_request_info->host_and_port = host_and_port_; @@ -396,7 +389,7 @@ void SSLClientSocketOpenSSL::GetSSLCertRequestInfo( } SSLClientSocket::NextProtoStatus SSLClientSocketOpenSSL::GetNextProto( - std::string* proto) { + std::string* proto) const { *proto = npn_proto_; return npn_status_; } @@ -406,16 +399,23 @@ SSLClientSocketOpenSSL::GetChannelIDService() const { return channel_id_service_; } +SSLFailureState SSLClientSocketOpenSSL::GetSSLFailureState() const { + return ssl_failure_state_; +} + int SSLClientSocketOpenSSL::ExportKeyingMaterial( const base::StringPiece& label, bool has_context, const base::StringPiece& context, unsigned char* out, unsigned int outlen) { + if (!IsConnected()) + return ERR_SOCKET_NOT_CONNECTED; + crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE); int rv = SSL_export_keying_material( ssl_, out, outlen, label.data(), label.size(), - reinterpret_cast<const unsigned char*>(context.data()), - context.length(), context.length() > 0); + reinterpret_cast<const unsigned char*>(context.data()), context.length(), + has_context ? 1 : 0); if (rv != 1) { int ssl_error = SSL_get_error(ssl_, rv); @@ -455,18 +455,12 @@ int SSLClientSocketOpenSSL::Connect(const CompletionCallback& callback) { user_connect_callback_ = callback; } else { net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SSL_CONNECT, rv); - if (rv < OK) - OnHandshakeCompletion(); } return rv > OK ? OK : rv; } void SSLClientSocketOpenSSL::Disconnect() { - // If a handshake was pending (Connect() had been called), notify interested - // parties that it's been aborted now. If the handshake had already - // completed, this is a no-op. - OnHandshakeCompletion(); if (ssl_) { // Calling SSL_shutdown prevents the session from being marked as // unresumable. @@ -480,7 +474,7 @@ void SSLClientSocketOpenSSL::Disconnect() { } // Shut down anything that may call us back. - verifier_.reset(); + cert_verifier_request_.reset(); transport_->socket()->Disconnect(); // Null all callbacks, delete all buffers. @@ -509,15 +503,17 @@ void SSLClientSocketOpenSSL::Disconnect() { cert_authorities_.clear(); cert_key_types_.clear(); - client_auth_cert_needed_ = false; start_cert_verification_time_ = base::TimeTicks(); npn_status_ = kNextProtoUnsupported; npn_proto_.clear(); - channel_id_xtn_negotiated_ = false; + channel_id_sent_ = false; + handshake_completed_ = false; + certificate_verified_ = false; channel_id_request_handle_.Cancel(); + ssl_failure_state_ = SSL_FAILURE_NONE; } bool SSLClientSocketOpenSSL::IsConnected() const { @@ -538,12 +534,16 @@ bool SSLClientSocketOpenSSL::IsConnectedAndIdle() const { // If an asynchronous operation is still pending. if (user_read_buf_.get() || user_write_buf_.get()) return false; - // If there is data waiting to be sent, or data read from the network that - // has not yet been consumed. - if (BIO_pending(transport_bio_) > 0 || - BIO_wpending(transport_bio_) > 0) { + + // If there is data read from the network that has not yet been consumed, do + // not treat the connection as idle. + // + // Note that this does not check |BIO_pending|, whether there is ciphertext + // that has not yet been flushed to the network. |Write| returns early, so + // this can cause race conditions which cause a socket to not be treated + // reusable when it should be. See https://crbug.com/466147. + if (BIO_wpending(transport_bio_) > 0) return false; - } return transport_->socket()->IsConnectedAndIdle(); } @@ -601,7 +601,7 @@ bool SSLClientSocketOpenSSL::GetSSLInfo(SSLInfo* ssl_info) { server_cert_verify_result_.public_key_hashes; ssl_info->client_cert_sent = ssl_config_.send_client_cert && ssl_config_.client_cert.get(); - ssl_info->channel_id_sent = WasChannelIDSent(); + ssl_info->channel_id_sent = channel_id_sent_; ssl_info->pinning_failure_log = pinning_failure_log_; AddSCTInfoToSSLInfo(ssl_info); @@ -611,7 +611,7 @@ bool SSLClientSocketOpenSSL::GetSSLInfo(SSLInfo* ssl_info) { ssl_info->security_bits = SSL_CIPHER_get_bits(cipher, NULL); ssl_info->connection_status = EncodeSSLConnectionStatus( - SSL_CIPHER_get_id(cipher), 0 /* no compression */, + static_cast<uint16>(SSL_CIPHER_get_id(cipher)), 0 /* no compression */, GetNetSSLVersion(ssl_)); if (!SSL_get_secure_renegotiation_support(ssl_)) @@ -630,6 +630,11 @@ bool SSLClientSocketOpenSSL::GetSSLInfo(SSLInfo* ssl_info) { return true; } +void SSLClientSocketOpenSSL::GetConnectionAttempts( + ConnectionAttempts* out) const { + out->clear(); +} + int SSLClientSocketOpenSSL::Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback) { @@ -645,11 +650,6 @@ int SSLClientSocketOpenSSL::Read(IOBuffer* buf, was_ever_used_ = true; user_read_buf_ = NULL; user_read_buf_len_ = 0; - if (rv <= 0) { - // Failure of a read attempt may indicate a failed false start - // connection. - OnHandshakeCompletion(); - } } return rv; @@ -670,11 +670,6 @@ int SSLClientSocketOpenSSL::Write(IOBuffer* buf, was_ever_used_ = true; user_write_buf_ = NULL; user_write_buf_len_ = 0; - if (rv < 0) { - // Failure of a write attempt may indicate a failed false start - // connection. - OnHandshakeCompletion(); - } } return rv; @@ -702,43 +697,41 @@ int SSLClientSocketOpenSSL::Init() { if (!SSL_set_tlsext_host_name(ssl_, host_and_port_.host().c_str())) return ERR_UNEXPECTED; - // Set an OpenSSL callback to monitor this SSL*'s connection. - SSL_set_info_callback(ssl_, &InfoCallback); + SSL_SESSION* session = context->session_cache()->Lookup(GetSessionCacheKey()); + if (session != nullptr) + SSL_set_session(ssl_, session); - trying_cached_session_ = context->session_cache()->SetSSLSessionWithKey( - ssl_, GetSessionCacheKey()); + send_buffer_ = new GrowableIOBuffer(); + send_buffer_->SetCapacity(KDefaultOpenSSLBufferSize); + recv_buffer_ = new GrowableIOBuffer(); + recv_buffer_->SetCapacity(KDefaultOpenSSLBufferSize); BIO* ssl_bio = NULL; - // 0 => use default buffer sizes. - if (!BIO_new_bio_pair(&ssl_bio, 0, &transport_bio_, 0)) + + // SSLClientSocketOpenSSL retains ownership of the BIO buffers. + if (!BIO_new_bio_pair_external_buf( + &ssl_bio, send_buffer_->capacity(), + reinterpret_cast<uint8_t*>(send_buffer_->data()), &transport_bio_, + recv_buffer_->capacity(), + reinterpret_cast<uint8_t*>(recv_buffer_->data()))) return ERR_UNEXPECTED; DCHECK(ssl_bio); DCHECK(transport_bio_); // Install a callback on OpenSSL's end to plumb transport errors through. - BIO_set_callback(ssl_bio, BIOCallback); + BIO_set_callback(ssl_bio, &SSLClientSocketOpenSSL::BIOCallback); BIO_set_callback_arg(ssl_bio, reinterpret_cast<char*>(this)); SSL_set_bio(ssl_, ssl_bio, ssl_bio); + DCHECK_LT(SSL3_VERSION, ssl_config_.version_min); + DCHECK_LT(SSL3_VERSION, ssl_config_.version_max); + SSL_set_min_version(ssl_, ssl_config_.version_min); + SSL_set_max_version(ssl_, ssl_config_.version_max); + // OpenSSL defaults some options to on, others to off. To avoid ambiguity, // set everything we care about to an absolute value. SslSetClearMask options; - options.ConfigureFlag(SSL_OP_NO_SSLv2, true); - bool ssl3_enabled = (ssl_config_.version_min == SSL_PROTOCOL_VERSION_SSL3); - options.ConfigureFlag(SSL_OP_NO_SSLv3, !ssl3_enabled); - bool tls1_enabled = (ssl_config_.version_min <= SSL_PROTOCOL_VERSION_TLS1 && - ssl_config_.version_max >= SSL_PROTOCOL_VERSION_TLS1); - options.ConfigureFlag(SSL_OP_NO_TLSv1, !tls1_enabled); - bool tls1_1_enabled = - (ssl_config_.version_min <= SSL_PROTOCOL_VERSION_TLS1_1 && - ssl_config_.version_max >= SSL_PROTOCOL_VERSION_TLS1_1); - options.ConfigureFlag(SSL_OP_NO_TLSv1_1, !tls1_1_enabled); - bool tls1_2_enabled = - (ssl_config_.version_min <= SSL_PROTOCOL_VERSION_TLS1_2 && - ssl_config_.version_max >= SSL_PROTOCOL_VERSION_TLS1_2); - options.ConfigureFlag(SSL_OP_NO_TLSv1_2, !tls1_2_enabled); - options.ConfigureFlag(SSL_OP_NO_COMPRESSION, true); // TODO(joth): Set this conditionally, see http://crbug.com/55410 @@ -753,9 +746,11 @@ int SSLClientSocketOpenSSL::Init() { mode.ConfigureFlag(SSL_MODE_RELEASE_BUFFERS, true); mode.ConfigureFlag(SSL_MODE_CBC_RECORD_SPLITTING, true); - mode.ConfigureFlag(SSL_MODE_HANDSHAKE_CUTTHROUGH, + mode.ConfigureFlag(SSL_MODE_ENABLE_FALSE_START, ssl_config_.false_start_enabled); + mode.ConfigureFlag(SSL_MODE_SEND_FALLBACK_SCSV, ssl_config_.version_fallback); + SSL_set_mode(ssl_, mode.set_mask); SSL_clear_mode(ssl_, mode.clear_mask); @@ -768,13 +763,13 @@ int SSLClientSocketOpenSSL::Init() { // disabled by default. Note that !SHA256 and !SHA384 only remove HMAC-SHA256 // and HMAC-SHA384 cipher suites, not GCM cipher suites with SHA256 or SHA384 // as the handshake hash. - std::string command("DEFAULT:!NULL:!aNULL:!IDEA:!FZA:!SRP:!SHA256:!SHA384:" - "!aECDH:!AESGCM+AES256"); + std::string command( + "DEFAULT:!NULL:!aNULL:!SHA256:!SHA384:!aECDH:!AESGCM+AES256:!aPSK"); // Walk through all the installed ciphers, seeing if any need to be // appended to the cipher removal |command|. for (size_t i = 0; i < sk_SSL_CIPHER_num(ciphers); ++i) { const SSL_CIPHER* cipher = sk_SSL_CIPHER_value(ciphers, i); - const uint16 id = SSL_CIPHER_get_id(cipher); + const uint16 id = static_cast<uint16>(SSL_CIPHER_get_id(cipher)); // Remove any ciphers with a strength of less than 80 bits. Note the NSS // implementation uses "effective" bits here but OpenSSL does not provide // this detail. This only impacts Triple DES: reports 112 vs. 168 bits, @@ -794,6 +789,9 @@ int SSLClientSocketOpenSSL::Init() { } } + if (!ssl_config_.enable_deprecated_cipher_suites) + command.append(":!RC4"); + // Disable ECDSA cipher suites on platforms that do not support ECDSA // signed certificates, as servers may use the presence of such // ciphersuites as a hint to send an ECDSA certificate. @@ -809,17 +807,26 @@ int SSLClientSocketOpenSSL::Init() { LOG_IF(WARNING, rv != 1) << "SSL_set_cipher_list('" << command << "') " "returned " << rv; - if (ssl_config_.version_fallback) - SSL_enable_fallback_scsv(ssl_); - // TLS channel ids. if (IsChannelIDEnabled(ssl_config_, channel_id_service_)) { SSL_enable_tls_channel_id(ssl_); } if (!ssl_config_.next_protos.empty()) { + // Get list of ciphers that are enabled. + STACK_OF(SSL_CIPHER)* enabled_ciphers = SSL_get_ciphers(ssl_); + DCHECK(enabled_ciphers); + std::vector<uint16> enabled_ciphers_vector; + for (size_t i = 0; i < sk_SSL_CIPHER_num(enabled_ciphers); ++i) { + const SSL_CIPHER* cipher = sk_SSL_CIPHER_value(enabled_ciphers, i); + const uint16 id = static_cast<uint16>(SSL_CIPHER_get_id(cipher)); + enabled_ciphers_vector.push_back(id); + } + std::vector<uint8_t> wire_protos = - SerializeNextProtos(ssl_config_.next_protos); + SerializeNextProtos(ssl_config_.next_protos, + HasCipherAdequateForHTTP2(enabled_ciphers_vector) && + IsTLSVersionAdequateForHTTP2(ssl_config_)); SSL_set_alpn_protos(ssl_, wire_protos.empty() ? NULL : &wire_protos[0], wire_protos.size()); } @@ -829,8 +836,17 @@ int SSLClientSocketOpenSSL::Init() { SSL_enable_ocsp_stapling(ssl_); } - // TODO(davidben): Enable OCSP stapling on platforms which support it and pass - // into the certificate verifier. https://crbug.com/398677 + if (cert_verifier_->SupportsOCSPStapling()) + SSL_enable_ocsp_stapling(ssl_); + + // Enable fastradio padding. + SSL_enable_fastradio_padding(ssl_, + ssl_config_.fastradio_padding_enabled && + ssl_config_.fastradio_padding_eligible); + + // By default, renegotiations are rejected. After the initial handshake + // completes, some application protocols may re-enable it. + SSL_set_reject_peer_renegotiations(ssl_, 1); return OK; } @@ -842,11 +858,6 @@ void SSLClientSocketOpenSSL::DoReadCallback(int rv) { was_ever_used_ = true; user_read_buf_ = NULL; user_read_buf_len_ = 0; - if (rv <= 0) { - // Failure of a read attempt may indicate a failed false start - // connection. - OnHandshakeCompletion(); - } base::ResetAndReturn(&user_read_callback_).Run(rv); } @@ -857,19 +868,9 @@ void SSLClientSocketOpenSSL::DoWriteCallback(int rv) { was_ever_used_ = true; user_write_buf_ = NULL; user_write_buf_len_ = 0; - if (rv < 0) { - // Failure of a write attempt may indicate a failed false start - // connection. - OnHandshakeCompletion(); - } base::ResetAndReturn(&user_write_callback_).Run(rv); } -void SSLClientSocketOpenSSL::OnHandshakeCompletion() { - if (!handshake_completion_callback_.is_null()) - base::ResetAndReturn(&handshake_completion_callback_).Run(); -} - bool SSLClientSocketOpenSSL::DoTransportIO() { bool network_moved = false; int rv; @@ -885,97 +886,145 @@ bool SSLClientSocketOpenSSL::DoTransportIO() { return network_moved; } +// TODO(cbentzel): Remove including "base/threading/thread_local.h" and +// g_first_run_completed once crbug.com/424386 is fixed. +base::LazyInstance<base::ThreadLocalBoolean>::Leaky g_first_run_completed = + LAZY_INSTANCE_INITIALIZER; + int SSLClientSocketOpenSSL::DoHandshake() { crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE); - int net_error = OK; - int rv = SSL_do_handshake(ssl_); - - if (client_auth_cert_needed_) { - net_error = ERR_SSL_CLIENT_AUTH_CERT_NEEDED; - // If the handshake already succeeded (because the server requests but - // doesn't require a client cert), we need to invalidate the SSL session - // so that we won't try to resume the non-client-authenticated session in - // the next handshake. This will cause the server to ask for a client - // cert again. - if (rv == 1) { - // Remove from session cache but don't clear this connection. - SSL_SESSION* session = SSL_get_session(ssl_); - if (session) { - int rv = SSL_CTX_remove_session(SSL_get_SSL_CTX(ssl_), session); - LOG_IF(WARNING, !rv) << "Couldn't invalidate SSL session: " << session; - } - } - } else if (rv == 1) { - if (trying_cached_session_ && logging::DEBUG_MODE) { - DVLOG(2) << "Result of session reuse for " << host_and_port_.ToString() - << " is: " << (SSL_session_reused(ssl_) ? "Success" : "Fail"); - } - - if (ssl_config_.version_fallback && - ssl_config_.version_max < ssl_config_.version_fallback_min) { - return ERR_SSL_FALLBACK_BEYOND_MINIMUM_VERSION; - } - // SSL handshake is completed. If NPN wasn't negotiated, see if ALPN was. - if (npn_status_ == kNextProtoUnsupported) { - const uint8_t* alpn_proto = NULL; - unsigned alpn_len = 0; - SSL_get0_alpn_selected(ssl_, &alpn_proto, &alpn_len); - if (alpn_len > 0) { - npn_proto_.assign(reinterpret_cast<const char*>(alpn_proto), alpn_len); - npn_status_ = kNextProtoNegotiated; - set_negotiation_extension(kExtensionALPN); - } - } - - RecordChannelIDSupport(channel_id_service_, - channel_id_xtn_negotiated_, - ssl_config_.channel_id_enabled, - crypto::ECPrivateKey::IsSupported()); + int rv; - uint8_t* ocsp_response; - size_t ocsp_response_len; - SSL_get0_ocsp_response(ssl_, &ocsp_response, &ocsp_response_len); - set_stapled_ocsp_response_received(ocsp_response_len != 0); + // TODO(cbentzel): Leave only 1 call to SSL_do_handshake once crbug.com/424386 + // is fixed. + if (ssl_config_.send_client_cert && ssl_config_.client_cert.get()) { + rv = SSL_do_handshake(ssl_); + } else { + if (g_first_run_completed.Get().Get()) { + // TODO(cbentzel): Remove ScopedTracker below once crbug.com/424386 is + // fixed. + tracked_objects::ScopedTracker tracking_profile( + FROM_HERE_WITH_EXPLICIT_FUNCTION("424386 SSL_do_handshake()")); - uint8_t* sct_list; - size_t sct_list_len; - SSL_get0_signed_cert_timestamp_list(ssl_, &sct_list, &sct_list_len); - set_signed_cert_timestamps_received(sct_list_len != 0); + rv = SSL_do_handshake(ssl_); + } else { + g_first_run_completed.Get().Set(true); + rv = SSL_do_handshake(ssl_); + } + } - // Verify the certificate. - UpdateServerCert(); - GotoState(STATE_VERIFY_CERT); - } else { + int net_error = OK; + if (rv <= 0) { int ssl_error = SSL_get_error(ssl_, rv); - if (ssl_error == SSL_ERROR_WANT_CHANNEL_ID_LOOKUP) { // The server supports channel ID. Stop to look one up before returning to // the handshake. - channel_id_xtn_negotiated_ = true; GotoState(STATE_CHANNEL_ID_LOOKUP); return OK; } + if (ssl_error == SSL_ERROR_WANT_X509_LOOKUP && + !ssl_config_.send_client_cert) { + return ERR_SSL_CLIENT_AUTH_CERT_NEEDED; + } OpenSSLErrorInfo error_info; net_error = MapOpenSSLErrorWithDetails(ssl_error, err_tracer, &error_info); - - // If not done, stay in this state if (net_error == ERR_IO_PENDING) { + // If not done, stay in this state GotoState(STATE_HANDSHAKE); + return ERR_IO_PENDING; + } + + LOG(ERROR) << "handshake failed; returned " << rv << ", SSL error code " + << ssl_error << ", net_error " << net_error; + net_log_.AddEvent( + NetLog::TYPE_SSL_HANDSHAKE_ERROR, + CreateNetLogOpenSSLErrorCallback(net_error, ssl_error, error_info)); + + // Classify the handshake failure. This is used to determine causes of the + // TLS version fallback. + + // |cipher| is the current outgoing cipher suite, so it is non-null iff + // ChangeCipherSpec was sent. + const SSL_CIPHER* cipher = SSL_get_current_cipher(ssl_); + if (SSL_get_state(ssl_) == SSL3_ST_CR_SRVR_HELLO_A) { + ssl_failure_state_ = SSL_FAILURE_CLIENT_HELLO; + } else if (cipher && (SSL_CIPHER_get_id(cipher) == + TLS1_CK_DHE_RSA_WITH_AES_128_GCM_SHA256 || + SSL_CIPHER_get_id(cipher) == + TLS1_CK_RSA_WITH_AES_128_GCM_SHA256)) { + ssl_failure_state_ = SSL_FAILURE_BUGGY_GCM; + } else if (cipher && ssl_config_.send_client_cert) { + ssl_failure_state_ = SSL_FAILURE_CLIENT_AUTH; + } else if (ERR_GET_LIB(error_info.error_code) == ERR_LIB_SSL && + ERR_GET_REASON(error_info.error_code) == + SSL_R_OLD_SESSION_VERSION_NOT_RETURNED) { + ssl_failure_state_ = SSL_FAILURE_SESSION_MISMATCH; + } else if (cipher && npn_status_ != kNextProtoUnsupported) { + ssl_failure_state_ = SSL_FAILURE_NEXT_PROTO; } else { - LOG(ERROR) << "handshake failed; returned " << rv - << ", SSL error code " << ssl_error - << ", net_error " << net_error; - net_log_.AddEvent( - NetLog::TYPE_SSL_HANDSHAKE_ERROR, - CreateNetLogOpenSSLErrorCallback(net_error, ssl_error, error_info)); + ssl_failure_state_ = SSL_FAILURE_UNKNOWN; } } + + GotoState(STATE_HANDSHAKE_COMPLETE); return net_error; } +int SSLClientSocketOpenSSL::DoHandshakeComplete(int result) { + if (result < 0) + return result; + + if (ssl_config_.version_fallback && + ssl_config_.version_max < ssl_config_.version_fallback_min) { + return ERR_SSL_FALLBACK_BEYOND_MINIMUM_VERSION; + } + + // SSL handshake is completed. If NPN wasn't negotiated, see if ALPN was. + if (npn_status_ == kNextProtoUnsupported) { + const uint8_t* alpn_proto = NULL; + unsigned alpn_len = 0; + SSL_get0_alpn_selected(ssl_, &alpn_proto, &alpn_len); + if (alpn_len > 0) { + npn_proto_.assign(reinterpret_cast<const char*>(alpn_proto), alpn_len); + npn_status_ = kNextProtoNegotiated; + set_negotiation_extension(kExtensionALPN); + } + } + + RecordNegotiationExtension(); + RecordChannelIDSupport(channel_id_service_, channel_id_sent_, + ssl_config_.channel_id_enabled, + crypto::ECPrivateKey::IsSupported()); + + // Only record OCSP histograms if OCSP was requested. + if (ssl_config_.signed_cert_timestamps_enabled || + cert_verifier_->SupportsOCSPStapling()) { + const uint8_t* ocsp_response; + size_t ocsp_response_len; + SSL_get0_ocsp_response(ssl_, &ocsp_response, &ocsp_response_len); + + set_stapled_ocsp_response_received(ocsp_response_len != 0); + UMA_HISTOGRAM_BOOLEAN("Net.OCSPResponseStapled", ocsp_response_len != 0); + } + + const uint8_t* sct_list; + size_t sct_list_len; + SSL_get0_signed_cert_timestamp_list(ssl_, &sct_list, &sct_list_len); + set_signed_cert_timestamps_received(sct_list_len != 0); + + if (IsRenegotiationAllowed()) + SSL_set_reject_peer_renegotiations(ssl_, 0); + + // Verify the certificate. + UpdateServerCert(); + GotoState(STATE_VERIFY_CERT); + return OK; +} + int SSLClientSocketOpenSSL::DoChannelIDLookup() { + net_log_.AddEvent(NetLog::TYPE_SSL_CHANNEL_ID_REQUESTED); GotoState(STATE_CHANNEL_ID_LOOKUP_COMPLETE); return channel_id_service_->GetOrCreateChannelID( host_and_port_.host(), @@ -1021,7 +1070,8 @@ int SSLClientSocketOpenSSL::DoChannelIDLookupComplete(int result) { } // Return to the handshake. - set_channel_id_sent(true); + channel_id_sent_ = true; + net_log_.AddEvent(NetLog::TYPE_SSL_CHANNEL_ID_PROVIDED); GotoState(STATE_HANDSHAKE); return OK; } @@ -1057,6 +1107,15 @@ int SSLClientSocketOpenSSL::DoVerifyCert(int result) { return ERR_CERT_INVALID; } + std::string ocsp_response; + if (cert_verifier_->SupportsOCSPStapling()) { + const uint8_t* ocsp_response_raw; + size_t ocsp_response_len; + SSL_get0_ocsp_response(ssl_, &ocsp_response_raw, &ocsp_response_len); + ocsp_response.assign(reinterpret_cast<const char*>(ocsp_response_raw), + ocsp_response_len); + } + start_cert_verification_time_ = base::TimeTicks::Now(); int flags = 0; @@ -1068,22 +1127,18 @@ int SSLClientSocketOpenSSL::DoVerifyCert(int result) { flags |= CertVerifier::VERIFY_CERT_IO_ENABLED; if (ssl_config_.rev_checking_required_local_anchors) flags |= CertVerifier::VERIFY_REV_CHECKING_REQUIRED_LOCAL_ANCHORS; - verifier_.reset(new SingleRequestCertVerifier(cert_verifier_)); - return verifier_->Verify( - server_cert_.get(), - host_and_port_.host(), - flags, + return cert_verifier_->Verify( + server_cert_.get(), host_and_port_.host(), ocsp_response, flags, // TODO(davidben): Route the CRLSet through SSLConfig so // SSLClientSocket doesn't depend on SSLConfigService. - SSLConfigService::GetCRLSet().get(), - &server_cert_verify_result_, + SSLConfigService::GetCRLSet().get(), &server_cert_verify_result_, base::Bind(&SSLClientSocketOpenSSL::OnHandshakeIOComplete, base::Unretained(this)), - net_log_); + &cert_verifier_request_, net_log_); } int SSLClientSocketOpenSSL::DoVerifyCertComplete(int result) { - verifier_.reset(); + cert_verifier_request_.reset(); if (!start_cert_verification_time_.is_null()) { base::TimeDelta verify_time = @@ -1095,8 +1150,15 @@ int SSLClientSocketOpenSSL::DoVerifyCertComplete(int result) { } } - if (result == OK) - RecordConnectionTypeMetrics(GetNetSSLVersion(ssl_)); + if (result == OK) { + if (SSL_session_reused(ssl_)) { + // Record whether or not the server tried to resume a session for a + // different version. See https://crbug.com/441456. + UMA_HISTOGRAM_BOOLEAN( + "Net.SSLSessionVersionMatch", + SSL_version(ssl_) == SSL_get_session(ssl_)->ssl_version); + } + } const CertStatus cert_status = server_cert_verify_result_.cert_status; if (transport_security_state_ && @@ -1110,46 +1172,26 @@ int SSLClientSocketOpenSSL::DoVerifyCertComplete(int result) { result = ERR_SSL_PINNED_KEY_NOT_IN_CERT_CHAIN; } - scoped_refptr<ct::EVCertsWhitelist> ev_whitelist = - SSLConfigService::GetEVCertsWhitelist(); - if (server_cert_verify_result_.cert_status & CERT_STATUS_IS_EV) { - if (ev_whitelist.get() && ev_whitelist->IsValid()) { - const SHA256HashValue fingerprint( - X509Certificate::CalculateFingerprint256( - server_cert_verify_result_.verified_cert->os_cert_handle())); - - UMA_HISTOGRAM_BOOLEAN( - "Net.SSL_EVCertificateInWhitelist", - ev_whitelist->ContainsCertificateHash( - std::string(reinterpret_cast<const char*>(fingerprint.data), 8))); - } - } - if (result == OK) { // Only check Certificate Transparency if there were no other errors with // the connection. VerifyCT(); - // TODO(joth): Work out if we need to remember the intermediate CA certs - // when the server sends them to us, and do so here. - SSLContext::GetInstance()->session_cache()->MarkSSLSessionAsGood(ssl_); - marked_session_as_good_ = true; - CheckIfHandshakeFinished(); + DCHECK(!certificate_verified_); + certificate_verified_ = true; + MaybeCacheSession(); } else { DVLOG(1) << "DoVerifyCertComplete error " << ErrorToString(result) << " (" << result << ")"; } completed_connect_ = true; - // Exit DoHandshakeLoop and return the result to the caller to Connect. DCHECK_EQ(STATE_NONE, next_handshake_state_); return result; } void SSLClientSocketOpenSSL::DoConnectCallback(int rv) { - if (rv < OK) - OnHandshakeCompletion(); if (!user_connect_callback_.is_null()) { CompletionCallback c = user_connect_callback_; user_connect_callback_.Reset(); @@ -1160,7 +1202,6 @@ void SSLClientSocketOpenSSL::DoConnectCallback(int rv) { void SSLClientSocketOpenSSL::UpdateServerCert() { server_cert_chain_->Reset(SSL_get_peer_cert_chain(ssl_)); server_cert_ = server_cert_chain_->AsOSChain(); - if (server_cert_.get()) { net_log_.AddEvent( NetLog::TYPE_SSL_CERTIFICATES_RECEIVED, @@ -1173,7 +1214,7 @@ void SSLClientSocketOpenSSL::VerifyCT() { if (!cert_transparency_verifier_) return; - uint8_t* ocsp_response_raw; + const uint8_t* ocsp_response_raw; size_t ocsp_response_len; SSL_get0_ocsp_response(ssl_, &ocsp_response_raw, &ocsp_response_len); std::string ocsp_response; @@ -1182,7 +1223,7 @@ void SSLClientSocketOpenSSL::VerifyCT() { ocsp_response_len); } - uint8_t* sct_list_raw; + const uint8_t* sct_list_raw; size_t sct_list_len; SSL_get0_signed_cert_timestamp_list(ssl_, &sct_list_raw, &sct_list_len); std::string sct_list; @@ -1192,15 +1233,28 @@ void SSLClientSocketOpenSSL::VerifyCT() { // Note that this is a completely synchronous operation: The CT Log Verifier // gets all the data it needs for SCT verification and does not do any // external communication. - int result = cert_transparency_verifier_->Verify( - server_cert_verify_result_.verified_cert.get(), - ocsp_response, sct_list, &ct_verify_result_, net_log_); + cert_transparency_verifier_->Verify( + server_cert_verify_result_.verified_cert.get(), ocsp_response, sct_list, + &ct_verify_result_, net_log_); - VLOG(1) << "CT Verification complete: result " << result - << " Invalid scts: " << ct_verify_result_.invalid_scts.size() - << " Verified scts: " << ct_verify_result_.verified_scts.size() - << " scts from unknown logs: " - << ct_verify_result_.unknown_logs_scts.size(); + if (!policy_enforcer_) { + server_cert_verify_result_.cert_status &= ~CERT_STATUS_IS_EV; + } else { + if (server_cert_verify_result_.cert_status & CERT_STATUS_IS_EV) { + scoped_refptr<ct::EVCertsWhitelist> ev_whitelist = + SSLConfigService::GetEVCertsWhitelist(); + if (!policy_enforcer_->DoesConformToCTEVPolicy( + server_cert_verify_result_.verified_cert.get(), + ev_whitelist.get(), ct_verify_result_, net_log_)) { + // TODO(eranm): Log via the BoundNetLog, see crbug.com/437766 + VLOG(1) << "EV certificate for " + << server_cert_verify_result_.verified_cert->subject() + .GetDisplayName() + << " does not conform to CT policy, removing EV status."; + server_cert_verify_result_.cert_status &= ~CERT_STATUS_IS_EV; + } + } + } } void SSLClientSocketOpenSSL::OnHandshakeIOComplete(int result) { @@ -1277,6 +1331,9 @@ int SSLClientSocketOpenSSL::DoHandshakeLoop(int last_io_result) { case STATE_HANDSHAKE: rv = DoHandshake(); break; + case STATE_HANDSHAKE_COMPLETE: + rv = DoHandshakeComplete(rv); + break; case STATE_CHANNEL_ID_LOOKUP: DCHECK_EQ(OK, rv); rv = DoChannelIDLookup(); @@ -1306,7 +1363,6 @@ int SSLClientSocketOpenSSL::DoHandshakeLoop(int last_io_result) { rv = OK; // This causes us to stay in the loop. } } while (rv != ERR_IO_PENDING && next_handshake_state_ != STATE_NONE); - return rv; } @@ -1335,6 +1391,9 @@ int SSLClientSocketOpenSSL::DoWriteLoop() { int SSLClientSocketOpenSSL::DoPayloadRead() { crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE); + DCHECK_LT(0, user_read_buf_len_); + DCHECK(user_read_buf_.get()); + int rv; if (pending_read_error_ != kNoPendingReadResult) { rv = pending_read_error_; @@ -1354,60 +1413,62 @@ int SSLClientSocketOpenSSL::DoPayloadRead() { } int total_bytes_read = 0; + int ssl_ret; do { - rv = SSL_read(ssl_, user_read_buf_->data() + total_bytes_read, - user_read_buf_len_ - total_bytes_read); - if (rv > 0) - total_bytes_read += rv; - } while (total_bytes_read < user_read_buf_len_ && rv > 0); - - if (total_bytes_read == user_read_buf_len_) { - rv = total_bytes_read; - } else { - // Otherwise, an error occurred (rv <= 0). The error needs to be handled - // immediately, while the OpenSSL errors are still available in - // thread-local storage. However, the handled/remapped error code should - // only be returned if no application data was already read; if it was, the - // error code should be deferred until the next call of DoPayloadRead. + ssl_ret = SSL_read(ssl_, user_read_buf_->data() + total_bytes_read, + user_read_buf_len_ - total_bytes_read); + if (ssl_ret > 0) + total_bytes_read += ssl_ret; + } while (total_bytes_read < user_read_buf_len_ && ssl_ret > 0); + + // Although only the final SSL_read call may have failed, the failure needs to + // processed immediately, while the information still available in OpenSSL's + // error queue. + if (ssl_ret <= 0) { + // A zero return from SSL_read may mean any of: + // - The underlying BIO_read returned 0. + // - The peer sent a close_notify. + // - Any arbitrary error. https://crbug.com/466303 // - // If no data was read, |*next_result| will point to the return value of - // this function. If at least some data was read, |*next_result| will point - // to |pending_read_error_|, to be returned in a future call to - // DoPayloadRead() (e.g.: after the current data is handled). - int *next_result = &rv; - if (total_bytes_read > 0) { - pending_read_error_ = rv; - rv = total_bytes_read; - next_result = &pending_read_error_; + // TransportReadComplete converts the first to an ERR_CONNECTION_CLOSED + // error, so it does not occur. The second and third are distinguished by + // SSL_ERROR_ZERO_RETURN. + pending_read_ssl_error_ = SSL_get_error(ssl_, ssl_ret); + if (pending_read_ssl_error_ == SSL_ERROR_ZERO_RETURN) { + pending_read_error_ = 0; + } else if (pending_read_ssl_error_ == SSL_ERROR_WANT_X509_LOOKUP && + !ssl_config_.send_client_cert) { + pending_read_error_ = ERR_SSL_CLIENT_AUTH_CERT_NEEDED; + } else { + pending_read_error_ = MapOpenSSLErrorWithDetails( + pending_read_ssl_error_, err_tracer, &pending_read_error_info_); } - if (client_auth_cert_needed_) { - *next_result = ERR_SSL_CLIENT_AUTH_CERT_NEEDED; - } else if (*next_result < 0) { - pending_read_ssl_error_ = SSL_get_error(ssl_, *next_result); - *next_result = MapOpenSSLErrorWithDetails(pending_read_ssl_error_, - err_tracer, - &pending_read_error_info_); - - // Many servers do not reliably send a close_notify alert when shutting - // down a connection, and instead terminate the TCP connection. This is - // reported as ERR_CONNECTION_CLOSED. Because of this, map the unclean - // shutdown to a graceful EOF, instead of treating it as an error as it - // should be. - if (*next_result == ERR_CONNECTION_CLOSED) - *next_result = 0; - - if (rv > 0 && *next_result == ERR_IO_PENDING) { - // If at least some data was read from SSL_read(), do not treat - // insufficient data as an error to return in the next call to - // DoPayloadRead() - instead, let the call fall through to check - // SSL_read() again. This is because DoTransportIO() may complete - // in between the next call to DoPayloadRead(), and thus it is - // important to check SSL_read() on subsequent invocations to see - // if a complete record may now be read. - *next_result = kNoPendingReadResult; - } - } + // Many servers do not reliably send a close_notify alert when shutting down + // a connection, and instead terminate the TCP connection. This is reported + // as ERR_CONNECTION_CLOSED. Because of this, map the unclean shutdown to a + // graceful EOF, instead of treating it as an error as it should be. + if (pending_read_error_ == ERR_CONNECTION_CLOSED) + pending_read_error_ = 0; + } + + if (total_bytes_read > 0) { + // Return any bytes read to the caller. The error will be deferred to the + // next call of DoPayloadRead. + rv = total_bytes_read; + + // Do not treat insufficient data as an error to return in the next call to + // DoPayloadRead() - instead, let the call fall through to check SSL_read() + // again. This is because DoTransportIO() may complete in between the next + // call to DoPayloadRead(), and thus it is important to check SSL_read() on + // subsequent invocations to see if a complete record may now be read. + if (pending_read_error_ == ERR_IO_PENDING) + pending_read_error_ = kNoPendingReadResult; + } else { + // No bytes were returned. Return the pending read error immediately. + DCHECK_NE(kNoPendingReadResult, pending_read_error_); + rv = pending_read_error_; + pending_read_error_ = kNoPendingReadResult; } if (rv >= 0) { @@ -1427,6 +1488,7 @@ int SSLClientSocketOpenSSL::DoPayloadRead() { int SSLClientSocketOpenSSL::DoPayloadWrite() { crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE); int rv = SSL_write(ssl_, user_write_buf_->data(), user_write_buf_len_); + if (rv >= 0) { net_log_.AddByteTransferEvent(NetLog::TYPE_SSL_SOCKET_BYTES_SENT, rv, user_write_buf_->data()); @@ -1450,20 +1512,20 @@ int SSLClientSocketOpenSSL::BufferSend(void) { if (transport_send_busy_) return ERR_IO_PENDING; - if (!send_buffer_.get()) { - // Get a fresh send buffer out of the send BIO. - size_t max_read = BIO_pending(transport_bio_); - if (!max_read) - return 0; // Nothing pending in the OpenSSL write BIO. - send_buffer_ = new DrainableIOBuffer(new IOBuffer(max_read), max_read); - int read_bytes = BIO_read(transport_bio_, send_buffer_->data(), max_read); - DCHECK_GT(read_bytes, 0); - CHECK_EQ(static_cast<int>(max_read), read_bytes); - } + size_t buffer_read_offset; + uint8_t* read_buf; + size_t max_read; + int status = BIO_zero_copy_get_read_buf(transport_bio_, &read_buf, + &buffer_read_offset, &max_read); + DCHECK_EQ(status, 1); // Should never fail. + if (!max_read) + return 0; // Nothing pending in the OpenSSL write BIO. + CHECK_EQ(read_buf, reinterpret_cast<uint8_t*>(send_buffer_->StartOfBuffer())); + CHECK_LT(buffer_read_offset, static_cast<size_t>(send_buffer_->capacity())); + send_buffer_->set_offset(buffer_read_offset); int rv = transport_->socket()->Write( - send_buffer_.get(), - send_buffer_->BytesRemaining(), + send_buffer_.get(), max_read, base::Bind(&SSLClientSocketOpenSSL::BufferSendComplete, base::Unretained(this))); if (rv == ERR_IO_PENDING) { @@ -1496,11 +1558,21 @@ int SSLClientSocketOpenSSL::BufferRecv(void) { // fill |transport_bio_| is issued. As long as an SSL client socket cannot // be gracefully shutdown (via SSL close alerts) and re-used for non-SSL // traffic, this over-subscribed Read()ing will not cause issues. - size_t max_write = BIO_ctrl_get_write_guarantee(transport_bio_); + + size_t buffer_write_offset; + uint8_t* write_buf; + size_t max_write; + int status = BIO_zero_copy_get_write_buf(transport_bio_, &write_buf, + &buffer_write_offset, &max_write); + DCHECK_EQ(status, 1); // Should never fail. if (!max_write) return ERR_IO_PENDING; - recv_buffer_ = new IOBuffer(max_write); + CHECK_EQ(write_buf, + reinterpret_cast<uint8_t*>(recv_buffer_->StartOfBuffer())); + CHECK_LT(buffer_write_offset, static_cast<size_t>(recv_buffer_->capacity())); + + recv_buffer_->set_offset(buffer_write_offset); int rv = transport_->socket()->Read( recv_buffer_.get(), max_write, @@ -1515,7 +1587,6 @@ int SSLClientSocketOpenSSL::BufferRecv(void) { } void SSLClientSocketOpenSSL::BufferSendComplete(int result) { - transport_send_busy_ = false; TransportWriteComplete(result); OnSendComplete(result); } @@ -1527,18 +1598,18 @@ void SSLClientSocketOpenSSL::BufferRecvComplete(int result) { void SSLClientSocketOpenSSL::TransportWriteComplete(int result) { DCHECK(ERR_IO_PENDING != result); + int bytes_written = 0; if (result < 0) { // Record the error. Save it to be reported in a future read or write on // transport_bio_'s peer. transport_write_error_ = result; - send_buffer_ = NULL; } else { - DCHECK(send_buffer_.get()); - send_buffer_->DidConsume(result); - DCHECK_GE(send_buffer_->BytesRemaining(), 0); - if (send_buffer_->BytesRemaining() <= 0) - send_buffer_ = NULL; + bytes_written = result; } + DCHECK_GE(send_buffer_->RemainingCapacity(), bytes_written); + int ret = BIO_zero_copy_get_read_buf_done(transport_bio_, bytes_written); + DCHECK_EQ(1, ret); + transport_send_busy_ = false; } int SSLClientSocketOpenSSL::TransportReadComplete(int result) { @@ -1547,18 +1618,18 @@ int SSLClientSocketOpenSSL::TransportReadComplete(int result) { // does not report success. if (result == 0) result = ERR_CONNECTION_CLOSED; + int bytes_read = 0; if (result < 0) { DVLOG(1) << "TransportReadComplete result " << result; // Received an error. Save it to be reported in a future read on // transport_bio_'s peer. transport_read_error_ = result; } else { - DCHECK(recv_buffer_.get()); - int ret = BIO_write(transport_bio_, recv_buffer_->data(), result); - // A write into a memory BIO should always succeed. - DCHECK_EQ(result, ret); + bytes_read = result; } - recv_buffer_ = NULL; + DCHECK_GE(recv_buffer_->RemainingCapacity(), bytes_read); + int ret = BIO_zero_copy_get_write_buf_done(transport_bio_, bytes_read); + DCHECK_EQ(1, ret); transport_recv_busy_ = false; return result; } @@ -1567,6 +1638,8 @@ int SSLClientSocketOpenSSL::ClientCertRequestCallback(SSL* ssl) { DVLOG(3) << "OpenSSL ClientCertRequestCallback called"; DCHECK(ssl == ssl_); + net_log_.AddEvent(NetLog::TYPE_SSL_CLIENT_CERT_REQUESTED); + // Clear any currently configured certificates. SSL_certs_clear(ssl_); @@ -1577,7 +1650,6 @@ int SSLClientSocketOpenSSL::ClientCertRequestCallback(SSL* ssl) { if (!ssl_config_.send_client_cert) { // First pass: we know that a client certificate is needed, but we do not // have one at hand. - client_auth_cert_needed_ = true; STACK_OF(X509_NAME) *authorities = SSL_get_client_CA_list(ssl); for (size_t i = 0; i < sk_X509_NAME_num(authorities); i++) { X509_NAME *ca_name = (X509_NAME *)sk_X509_NAME_value(authorities, i); @@ -1597,7 +1669,8 @@ int SSLClientSocketOpenSSL::ClientCertRequestCallback(SSL* ssl) { static_cast<SSLClientCertType>(client_cert_types[i])); } - return -1; // Suspends handshake. + // Suspends handshake. SSL_get_error will return SSL_ERROR_WANT_X509_LOOKUP. + return -1; } // Second pass: a client certificate should have been selected. @@ -1644,11 +1717,17 @@ int SSLClientSocketOpenSSL::ClientCertRequestCallback(SSL* ssl) { LOG(WARNING) << "Failed to set client certificate"; return -1; } + + int cert_count = 1 + sk_X509_num(chain.get()); + net_log_.AddEvent(NetLog::TYPE_SSL_CLIENT_CERT_PROVIDED, + NetLog::IntegerCallback("cert_count", cert_count)); return 1; } #endif // defined(OS_IOS) // Send no client certificate. + net_log_.AddEvent(NetLog::TYPE_SSL_CLIENT_CERT_PROVIDED, + NetLog::IntegerCallback("cert_count", 0)); return 1; } @@ -1700,11 +1779,10 @@ int SSLClientSocketOpenSSL::SelectNextProtoCallback(unsigned char** out, // For each protocol in server preference order, see if we support it. for (unsigned int i = 0; i < inlen; i += in[i] + 1) { - for (std::vector<std::string>::const_iterator - j = ssl_config_.next_protos.begin(); - j != ssl_config_.next_protos.end(); ++j) { - if (in[i] == j->size() && - memcmp(&in[i + 1], j->data(), in[i]) == 0) { + for (NextProto next_proto : ssl_config_.next_protos) { + const std::string proto = NextProtoToString(next_proto); + if (in[i] == proto.size() && + memcmp(&in[i + 1], proto.data(), in[i]) == 0) { // We found a match. *out = const_cast<unsigned char*>(in) + i + 1; *outlen = in[i]; @@ -1718,9 +1796,10 @@ int SSLClientSocketOpenSSL::SelectNextProtoCallback(unsigned char** out, // If we didn't find a protocol, we select the first one from our list. if (npn_status_ == kNextProtoNoOverlap) { - *out = reinterpret_cast<uint8*>(const_cast<char*>( - ssl_config_.next_protos[0].data())); - *outlen = ssl_config_.next_protos[0].size(); + // NextProtoToString returns a pointer to a static string. + const char* proto = NextProtoToString(ssl_config_.next_protos[0]); + *out = reinterpret_cast<unsigned char*>(const_cast<char*>(proto)); + *outlen = strlen(proto); } npn_proto_.assign(reinterpret_cast<const char*>(*out), *outlen); @@ -1774,30 +1853,26 @@ long SSLClientSocketOpenSSL::BIOCallback( bio, cmd, argp, argi, argl, retvalue); } -// static -void SSLClientSocketOpenSSL::InfoCallback(const SSL* ssl, - int type, - int /*val*/) { - if (type == SSL_CB_HANDSHAKE_DONE) { - SSLClientSocketOpenSSL* ssl_socket = - SSLContext::GetInstance()->GetClientSocketFromSSL(ssl); - ssl_socket->handshake_succeeded_ = true; - ssl_socket->CheckIfHandshakeFinished(); - } -} - -// Determines if both the handshake and certificate verification have completed -// successfully, and calls the handshake completion callback if that is the -// case. -// -// CheckIfHandshakeFinished is called twice per connection: once after -// MarkSSLSessionAsGood, when the certificate has been verified, and -// once via an OpenSSL callback when the handshake has completed. On the -// second call, when the certificate has been verified and the handshake -// has completed, the connection's handshake completion callback is run. -void SSLClientSocketOpenSSL::CheckIfHandshakeFinished() { - if (handshake_succeeded_ && marked_session_as_good_) - OnHandshakeCompletion(); +void SSLClientSocketOpenSSL::MaybeCacheSession() { + // Only cache the session once both the handshake has completed and the + // certificate has been verified. + if (!handshake_completed_ || !certificate_verified_ || + SSL_session_reused(ssl_)) { + return; + } + + SSLContext::GetInstance()->session_cache()->Insert(GetSessionCacheKey(), + SSL_get_session(ssl_)); +} + +void SSLClientSocketOpenSSL::InfoCallback(int type, int val) { + // Note that SSL_CB_HANDSHAKE_DONE may be signaled multiple times if the + // socket renegotiates. + if (type != SSL_CB_HANDSHAKE_DONE || handshake_completed_) + return; + + handshake_completed_ = true; + MaybeCacheSession(); } void SSLClientSocketOpenSSL::AddSCTInfoToSSLInfo(SSLInfo* ssl_info) const { @@ -1822,6 +1897,47 @@ void SSLClientSocketOpenSSL::AddSCTInfoToSSLInfo(SSLInfo* ssl_info) const { } } +std::string SSLClientSocketOpenSSL::GetSessionCacheKey() const { + std::string result = host_and_port_.ToString(); + result.append("/"); + result.append(ssl_session_cache_shard_); + + // Shard the session cache based on maximum protocol version. This causes + // fallback connections to use a separate session cache. + result.append("/"); + switch (ssl_config_.version_max) { + case SSL_PROTOCOL_VERSION_TLS1: + result.append("tls1"); + break; + case SSL_PROTOCOL_VERSION_TLS1_1: + result.append("tls1.1"); + break; + case SSL_PROTOCOL_VERSION_TLS1_2: + result.append("tls1.2"); + break; + default: + NOTREACHED(); + } + + result.append("/"); + if (ssl_config_.enable_deprecated_cipher_suites) + result.append("deprecated"); + + return result; +} + +bool SSLClientSocketOpenSSL::IsRenegotiationAllowed() const { + if (npn_status_ == kNextProtoUnsupported) + return ssl_config_.renego_allowed_default; + + NextProto next_proto = NextProtoFromString(npn_proto_); + for (NextProto allowed : ssl_config_.renego_allowed_for_protos) { + if (next_proto == allowed) + return true; + } + return false; +} + scoped_refptr<X509Certificate> SSLClientSocketOpenSSL::GetUnverifiedServerCertificateChain() const { return server_cert_; diff --git a/chromium/net/socket/ssl_client_socket_openssl.h b/chromium/net/socket/ssl_client_socket_openssl.h index 53d33c4c8c7..452936a4bfb 100644 --- a/chromium/net/socket/ssl_client_socket_openssl.h +++ b/chromium/net/socket/ssl_client_socket_openssl.h @@ -12,6 +12,7 @@ #include "base/memory/weak_ptr.h" #include "net/base/completion_callback.h" #include "net/base/io_buffer.h" +#include "net/cert/cert_verifier.h" #include "net/cert/cert_verify_result.h" #include "net/cert/ct_verify_result.h" #include "net/socket/client_socket_handle.h" @@ -20,6 +21,7 @@ #include "net/ssl/openssl_ssl_util.h" #include "net/ssl/ssl_client_cert_type.h" #include "net/ssl/ssl_config_service.h" +#include "net/ssl/ssl_failure_state.h" // Avoid including misc OpenSSL headers, i.e.: // <openssl/bio.h> @@ -37,7 +39,6 @@ namespace net { class CertVerifier; class CTVerifier; -class SingleRequestCertVerifier; class SSLCertRequestInfo; class SSLInfo; @@ -60,12 +61,10 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { } // SSLClientSocket implementation. - std::string GetSessionCacheKey() const override; - bool InSessionCache() const override; - void SetHandshakeCompletionCallback(const base::Closure& callback) override; void GetSSLCertRequestInfo(SSLCertRequestInfo* cert_request_info) override; - NextProtoStatus GetNextProto(std::string* proto) override; + NextProtoStatus GetNextProto(std::string* proto) const override; ChannelIDService* GetChannelIDService() const override; + SSLFailureState GetSSLFailureState() const override; // SSLSocket implementation. int ExportKeyingMaterial(const base::StringPiece& label, @@ -88,6 +87,9 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { bool WasEverUsed() const override; bool UsingTCPFastOpen() const override; bool GetSSLInfo(SSLInfo* ssl_info) override; + void GetConnectionAttempts(ConnectionAttempts* out) const override; + void ClearConnectionAttempts() override {} + void AddConnectionAttempts(const ConnectionAttempts& attempts) override {} // Socket implementation. int Read(IOBuffer* buf, @@ -114,10 +116,9 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { void DoReadCallback(int result); void DoWriteCallback(int result); - void OnHandshakeCompletion(); - bool DoTransportIO(); int DoHandshake(); + int DoHandshakeComplete(int result); int DoChannelIDLookup(); int DoChannelIDLookupComplete(int result); int DoVerifyCert(int result); @@ -171,11 +172,17 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { const char *argp, int argi, long argl, long retvalue); - // Callback that is used to obtain information about the state of the SSL - // handshake. - static void InfoCallback(const SSL* ssl, int type, int val); + // Called after the initial handshake completes and after the server + // certificate has been verified. The order of handshake completion and + // certificate verification depends on whether the connection was false + // started. After both have happened (thus calling this twice), the session is + // safe to cache and will be cached. + void MaybeCacheSession(); - void CheckIfHandshakeFinished(); + // Callback from the SSL layer when the internal state machine progresses. It + // is used to listen for when the handshake completes entirely; |Connect| may + // return early if false starting. + void InfoCallback(int type, int val); // Adds the SignedCertificateTimestamps from ct_verify_result_ to |ssl_info|. // SCTs are held in three separate vectors in ct_verify_result, each @@ -184,11 +191,20 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { // the |ssl_info|.signed_certificate_timestamps list. void AddSCTInfoToSSLInfo(SSLInfo* ssl_info) const; + // Returns a unique key string for the SSL session cache for + // this socket. + std::string GetSessionCacheKey() const; + + // Returns true if renegotiations are allowed. + bool IsRenegotiationAllowed() const; + bool transport_send_busy_; bool transport_recv_busy_; - scoped_refptr<DrainableIOBuffer> send_buffer_; - scoped_refptr<IOBuffer> recv_buffer_; + // Buffers which are shared by BoringSSL and SSLClientSocketOpenSSL. + // GrowableIOBuffer is used to keep ownership and setting offset. + scoped_refptr<GrowableIOBuffer> send_buffer_; + scoped_refptr<GrowableIOBuffer> recv_buffer_; CompletionCallback user_connect_callback_; CompletionCallback user_read_callback_; @@ -236,9 +252,6 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { // network. bool was_ever_used_; - // Stores client authentication information between ClientAuthHandler and - // GetSSLCertRequestInfo calls. - bool client_auth_cert_needed_; // List of DER-encoded X.509 DistinguishedName of certificate authorities // allowed by the server. std::vector<std::string> cert_authorities_; @@ -247,7 +260,7 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { std::vector<SSLClientCertType> cert_key_types_; CertVerifier* const cert_verifier_; - scoped_ptr<SingleRequestCertVerifier> verifier_; + scoped_ptr<CertVerifier::Request> cert_verifier_request_; base::TimeTicks start_cert_verification_time_; // Certificate Transparency: Verifier and result holder. @@ -257,12 +270,6 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { // The service for retrieving Channel ID keys. May be NULL. ChannelIDService* channel_id_service_; - // Callback that is invoked when the connection finishes. - // - // Note: this callback will be run in Disconnect(). It will not alter - // any member variables of the SSLClientSocketOpenSSL. - base::Closure handshake_completion_callback_; - // OpenSSL stuff SSL* ssl_; BIO* transport_bio_; @@ -275,12 +282,10 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { // resume on the socket with a different value. const std::string ssl_session_cache_shard_; - // Used for session cache diagnostics. - bool trying_cached_session_; - enum State { STATE_NONE, STATE_HANDSHAKE, + STATE_HANDSHAKE_COMPLETE, STATE_CHANNEL_ID_LOOKUP, STATE_CHANNEL_ID_LOOKUP_COMPLETE, STATE_VERIFY_CERT, @@ -292,18 +297,20 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { // Written by the |channel_id_service_|. std::string channel_id_private_key_; std::string channel_id_cert_; - // True if channel ID extension was negotiated. - bool channel_id_xtn_negotiated_; - // True if InfoCallback has been run with result = SSL_CB_HANDSHAKE_DONE. - bool handshake_succeeded_; - // True if MarkSSLSessionAsGood has been called for this socket's - // SSL session. - bool marked_session_as_good_; + // True if a channel ID was sent. + bool channel_id_sent_; + // True if the initial handshake has completed. + bool handshake_completed_; + // True if the initial handshake's certificate has been verified. + bool certificate_verified_; // The request handle for |channel_id_service_|. ChannelIDService::RequestHandle channel_id_request_handle_; + SSLFailureState ssl_failure_state_; TransportSecurityState* transport_security_state_; + CertPolicyEnforcer* const policy_enforcer_; + // pinning_failure_log contains a message produced by // TransportSecurityState::CheckPublicKeyPins in the event of a // pinning failure. It is a (somewhat) human-readable string. diff --git a/chromium/net/socket/ssl_client_socket_openssl_unittest.cc b/chromium/net/socket/ssl_client_socket_openssl_unittest.cc index 8a6a8828810..0a08f8c7562 100644 --- a/chromium/net/socket/ssl_client_socket_openssl_unittest.cc +++ b/chromium/net/socket/ssl_client_socket_openssl_unittest.cc @@ -23,14 +23,13 @@ #include "net/base/address_list.h" #include "net/base/io_buffer.h" #include "net/base/net_errors.h" -#include "net/base/net_log.h" -#include "net/base/net_log_unittest.h" #include "net/base/test_completion_callback.h" #include "net/base/test_data_directory.h" #include "net/cert/mock_cert_verifier.h" #include "net/cert/test_root_certs.h" #include "net/dns/host_resolver.h" #include "net/http/transport_security_state.h" +#include "net/log/net_log.h" #include "net/socket/client_socket_factory.h" #include "net/socket/client_socket_handle.h" #include "net/socket/socket_test_util.h" @@ -50,8 +49,6 @@ namespace { // These client auth tests are currently dependent on OpenSSL's struct X509. #if defined(USE_OPENSSL_CERTS) -const SSLConfig kDefaultSSLConfig; - // Loads a PEM-encoded private key file into a scoped EVP_PKEY object. // |filepath| is the private key file path. // |*pkey| is reset to the new EVP_PKEY on success, untouched otherwise. @@ -94,9 +91,7 @@ class SSLClientSocketOpenSSLClientAuthTest : public PlatformTest { key_store_ = OpenSSLClientKeyStore::GetInstance(); } - virtual ~SSLClientSocketOpenSSLClientAuthTest() { - key_store_->Flush(); - } + ~SSLClientSocketOpenSSLClientAuthTest() override { key_store_->Flush(); } protected: scoped_ptr<SSLClientSocket> CreateSSLClientSocket( @@ -156,7 +151,7 @@ class SSLClientSocketOpenSSLClientAuthTest : public PlatformTest { // Returns true on succes, false otherwise. Success means that the socket // could be created and its Connect() was called, not that the connection // itself was a success. - bool CreateAndConnectSSLClientSocket(SSLConfig& ssl_config, + bool CreateAndConnectSSLClientSocket(const SSLConfig& ssl_config, int* result) { sock_ = CreateSSLClientSocket(transport_.Pass(), test_server_->host_port_pair(), @@ -188,7 +183,7 @@ class SSLClientSocketOpenSSLClientAuthTest : public PlatformTest { scoped_ptr<SpawnedTestServer> test_server_; AddressList addr_; TestCompletionCallback callback_; - CapturingNetLog log_; + NetLog log_; scoped_ptr<StreamSocket> transport_; scoped_ptr<SSLClientSocket> sock_; }; @@ -202,10 +197,9 @@ TEST_F(SSLClientSocketOpenSSLClientAuthTest, NoCert) { ASSERT_TRUE(ConnectToTestServer(ssl_options)); base::FilePath certs_dir = GetTestCertsDirectory(); - SSLConfig ssl_config = kDefaultSSLConfig; int rv; - ASSERT_TRUE(CreateAndConnectSSLClientSocket(ssl_config, &rv)); + ASSERT_TRUE(CreateAndConnectSSLClientSocket(SSLConfig(), &rv)); EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED, rv); EXPECT_FALSE(sock_->IsConnected()); @@ -222,7 +216,7 @@ TEST_F(SSLClientSocketOpenSSLClientAuthTest, SendEmptyCert) { ASSERT_TRUE(ConnectToTestServer(ssl_options)); base::FilePath certs_dir = GetTestCertsDirectory(); - SSLConfig ssl_config = kDefaultSSLConfig; + SSLConfig ssl_config; ssl_config.send_client_cert = true; ssl_config.client_cert = NULL; @@ -244,7 +238,7 @@ TEST_F(SSLClientSocketOpenSSLClientAuthTest, SendGoodCert) { ASSERT_TRUE(ConnectToTestServer(ssl_options)); base::FilePath certs_dir = GetTestCertsDirectory(); - SSLConfig ssl_config = kDefaultSSLConfig; + SSLConfig ssl_config; ssl_config.send_client_cert = true; ssl_config.client_cert = ImportCertFromFile(certs_dir, "client_1.pem"); diff --git a/chromium/net/socket/ssl_client_socket_pool.cc b/chromium/net/socket/ssl_client_socket_pool.cc index 56df1d85d09..f8db1d05041 100644 --- a/chromium/net/socket/ssl_client_socket_pool.cc +++ b/chromium/net/socket/ssl_client_socket_pool.cc @@ -9,7 +9,7 @@ #include "base/metrics/field_trial.h" #include "base/metrics/histogram.h" #include "base/metrics/sparse_histogram.h" -#include "base/stl_util.h" +#include "base/profiler/scoped_tracker.h" #include "base/values.h" #include "net/base/host_port_pair.h" #include "net/base/net_errors.h" @@ -34,7 +34,6 @@ SSLSocketParams::SSLSocketParams( const SSLConfig& ssl_config, PrivacyMode privacy_mode, int load_flags, - bool force_spdy_over_ssl, bool want_spdy_over_npn) : direct_params_(direct_params), socks_proxy_params_(socks_proxy_params), @@ -43,7 +42,6 @@ SSLSocketParams::SSLSocketParams( ssl_config_(ssl_config), privacy_mode_(privacy_mode), load_flags_(load_flags), - force_spdy_over_ssl_(force_spdy_over_ssl), want_spdy_over_npn_(want_spdy_over_npn), ignore_limits_(false) { if (direct_params_.get()) { @@ -95,77 +93,6 @@ SSLSocketParams::GetHttpProxyConnectionParams() const { return http_proxy_params_; } -SSLConnectJobMessenger::SocketAndCallback::SocketAndCallback( - SSLClientSocket* ssl_socket, - const base::Closure& job_resumption_callback) - : socket(ssl_socket), callback(job_resumption_callback) { -} - -SSLConnectJobMessenger::SocketAndCallback::~SocketAndCallback() { -} - -SSLConnectJobMessenger::SSLConnectJobMessenger( - const base::Closure& messenger_finished_callback) - : messenger_finished_callback_(messenger_finished_callback), - weak_factory_(this) { -} - -SSLConnectJobMessenger::~SSLConnectJobMessenger() { -} - -void SSLConnectJobMessenger::RemovePendingSocket(SSLClientSocket* ssl_socket) { - // Sockets do not need to be removed from connecting_sockets_ because - // OnSSLHandshakeCompleted will do this. - for (SSLPendingSocketsAndCallbacks::iterator it = - pending_sockets_and_callbacks_.begin(); - it != pending_sockets_and_callbacks_.end(); - ++it) { - if (it->socket == ssl_socket) { - pending_sockets_and_callbacks_.erase(it); - return; - } - } -} - -bool SSLConnectJobMessenger::CanProceed(SSLClientSocket* ssl_socket) { - // If there are no connecting sockets, allow the connection to proceed. - return connecting_sockets_.empty(); -} - -void SSLConnectJobMessenger::MonitorConnectionResult( - SSLClientSocket* ssl_socket) { - connecting_sockets_.push_back(ssl_socket); - ssl_socket->SetHandshakeCompletionCallback( - base::Bind(&SSLConnectJobMessenger::OnSSLHandshakeCompleted, - weak_factory_.GetWeakPtr())); -} - -void SSLConnectJobMessenger::AddPendingSocket(SSLClientSocket* ssl_socket, - const base::Closure& callback) { - DCHECK(!connecting_sockets_.empty()); - pending_sockets_and_callbacks_.push_back( - SocketAndCallback(ssl_socket, callback)); -} - -void SSLConnectJobMessenger::OnSSLHandshakeCompleted() { - connecting_sockets_.clear(); - SSLPendingSocketsAndCallbacks temp_list; - temp_list.swap(pending_sockets_and_callbacks_); - base::Closure messenger_finished_callback = messenger_finished_callback_; - messenger_finished_callback.Run(); - RunAllCallbacks(temp_list); -} - -void SSLConnectJobMessenger::RunAllCallbacks( - const SSLPendingSocketsAndCallbacks& pending_sockets_and_callbacks) { - for (std::vector<SocketAndCallback>::const_iterator it = - pending_sockets_and_callbacks.begin(); - it != pending_sockets_and_callbacks.end(); - ++it) { - it->callback.Run(); - } -} - // Timeout for the SSL handshake portion of the connect. static const int kSSLHandshakeTimeoutInSeconds = 30; @@ -177,9 +104,7 @@ SSLConnectJob::SSLConnectJob(const std::string& group_name, SOCKSClientSocketPool* socks_pool, HttpProxyClientSocketPool* http_proxy_pool, ClientSocketFactory* client_socket_factory, - HostResolver* host_resolver, const SSLClientSocketContext& context, - const GetMessengerCallback& get_messenger_callback, Delegate* delegate, NetLog* net_log) : ConnectJob(group_name, @@ -192,24 +117,19 @@ SSLConnectJob::SSLConnectJob(const std::string& group_name, socks_pool_(socks_pool), http_proxy_pool_(http_proxy_pool), client_socket_factory_(client_socket_factory), - host_resolver_(host_resolver), context_(context.cert_verifier, context.channel_id_service, context.transport_security_state, context.cert_transparency_verifier, + context.cert_policy_enforcer, (params->privacy_mode() == PRIVACY_MODE_ENABLED ? "pm/" + context.ssl_session_cache_shard : context.ssl_session_cache_shard)), - io_callback_( - base::Bind(&SSLConnectJob::OnIOComplete, base::Unretained(this))), - messenger_(NULL), - get_messenger_callback_(get_messenger_callback), - weak_factory_(this) { + callback_( + base::Bind(&SSLConnectJob::OnIOComplete, base::Unretained(this))) { } SSLConnectJob::~SSLConnectJob() { - if (ssl_socket_.get() && messenger_) - messenger_->RemovePendingSocket(ssl_socket_.get()); } LoadState SSLConnectJob::GetLoadState() const { @@ -224,8 +144,6 @@ LoadState SSLConnectJob::GetLoadState() const { case STATE_SOCKS_CONNECT_COMPLETE: case STATE_TUNNEL_CONNECT: return transport_socket_handle_->GetLoadState(); - case STATE_CREATE_SSL_SOCKET: - case STATE_CHECK_FOR_RESUME: case STATE_SSL_CONNECT: case STATE_SSL_CONNECT_COMPLETE: return LOAD_STATE_SSL_HANDSHAKE; @@ -245,6 +163,10 @@ void SSLConnectJob::GetAdditionalErrorState(ClientSocketHandle* handle) { handle->set_ssl_error_response_info(error_response_info_); if (!connect_timing_.ssl_start.is_null()) handle->set_is_ssl_error(true); + if (ssl_socket_) + handle->set_ssl_failure_state(ssl_socket_->GetSSLFailureState()); + + handle->set_connection_attempts(connection_attempts_); } void SSLConnectJob::OnIOComplete(int result) { @@ -282,12 +204,6 @@ int SSLConnectJob::DoLoop(int result) { case STATE_TUNNEL_CONNECT_COMPLETE: rv = DoTunnelConnectComplete(rv); break; - case STATE_CREATE_SSL_SOCKET: - rv = DoCreateSSLSocket(); - break; - case STATE_CHECK_FOR_RESUME: - rv = DoCheckForResume(); - break; case STATE_SSL_CONNECT: DCHECK_EQ(OK, rv); rv = DoSSLConnect(); @@ -312,17 +228,16 @@ int SSLConnectJob::DoTransportConnect() { transport_socket_handle_.reset(new ClientSocketHandle()); scoped_refptr<TransportSocketParams> direct_params = params_->GetDirectConnectionParams(); - return transport_socket_handle_->Init(group_name(), - direct_params, - priority(), - io_callback_, - transport_pool_, - net_log()); + return transport_socket_handle_->Init(group_name(), direct_params, priority(), + callback_, transport_pool_, net_log()); } int SSLConnectJob::DoTransportConnectComplete(int result) { - if (result == OK) - next_state_ = STATE_CREATE_SSL_SOCKET; + connection_attempts_ = transport_socket_handle_->connection_attempts(); + if (result == OK) { + next_state_ = STATE_SSL_CONNECT; + transport_socket_handle_->socket()->GetPeerAddress(&server_address_); + } return result; } @@ -333,17 +248,14 @@ int SSLConnectJob::DoSOCKSConnect() { transport_socket_handle_.reset(new ClientSocketHandle()); scoped_refptr<SOCKSSocketParams> socks_proxy_params = params_->GetSocksProxyConnectionParams(); - return transport_socket_handle_->Init(group_name(), - socks_proxy_params, - priority(), - io_callback_, - socks_pool_, + return transport_socket_handle_->Init(group_name(), socks_proxy_params, + priority(), callback_, socks_pool_, net_log()); } int SSLConnectJob::DoSOCKSConnectComplete(int result) { if (result == OK) - next_state_ = STATE_CREATE_SSL_SOCKET; + next_state_ = STATE_SSL_CONNECT; return result; } @@ -355,11 +267,8 @@ int SSLConnectJob::DoTunnelConnect() { transport_socket_handle_.reset(new ClientSocketHandle()); scoped_refptr<HttpProxySocketParams> http_proxy_params = params_->GetHttpProxyConnectionParams(); - return transport_socket_handle_->Init(group_name(), - http_proxy_params, - priority(), - io_callback_, - http_proxy_pool_, + return transport_socket_handle_->Init(group_name(), http_proxy_params, + priority(), callback_, http_proxy_pool_, net_log()); } @@ -372,18 +281,22 @@ int SSLConnectJob::DoTunnelConnectComplete(int result) { } else if (result == ERR_PROXY_AUTH_REQUESTED || result == ERR_HTTPS_PROXY_TUNNEL_RESPONSE) { StreamSocket* socket = transport_socket_handle_->socket(); - HttpProxyClientSocket* tunnel_socket = - static_cast<HttpProxyClientSocket*>(socket); + ProxyClientSocket* tunnel_socket = static_cast<ProxyClientSocket*>(socket); error_response_info_ = *tunnel_socket->GetConnectResponseInfo(); } if (result < 0) return result; - next_state_ = STATE_CREATE_SSL_SOCKET; + + next_state_ = STATE_SSL_CONNECT; return result; } -int SSLConnectJob::DoCreateSSLSocket() { - next_state_ = STATE_CHECK_FOR_RESUME; +int SSLConnectJob::DoSSLConnect() { + // TODO(pkasting): Remove ScopedTracker below once crbug.com/462815 is fixed. + tracked_objects::ScopedTracker tracking_profile( + FROM_HERE_WITH_EXPLICIT_FUNCTION("462815 SSLConnectJob::DoSSLConnect")); + + next_state_ = STATE_SSL_CONNECT_COMPLETE; // Reset the timeout to just the time allowed for the SSL handshake. ResetTimer(base::TimeDelta::FromSeconds(kSSLHandshakeTimeoutInSeconds)); @@ -402,88 +315,41 @@ int SSLConnectJob::DoCreateSSLSocket() { connect_timing_.dns_end = socket_connect_timing.dns_end; } + connect_timing_.ssl_start = base::TimeTicks::Now(); + ssl_socket_ = client_socket_factory_->CreateSSLClientSocket( transport_socket_handle_.Pass(), params_->host_and_port(), params_->ssl_config(), context_); - - if (!ssl_socket_->InSessionCache()) - messenger_ = get_messenger_callback_.Run(ssl_socket_->GetSessionCacheKey()); - - return OK; -} - -int SSLConnectJob::DoCheckForResume() { - next_state_ = STATE_SSL_CONNECT; - - if (!messenger_) - return OK; - - if (messenger_->CanProceed(ssl_socket_.get())) { - messenger_->MonitorConnectionResult(ssl_socket_.get()); - // The SSLConnectJob no longer needs access to the messenger after this - // point. - messenger_ = NULL; - return OK; - } - - messenger_->AddPendingSocket(ssl_socket_.get(), - base::Bind(&SSLConnectJob::ResumeSSLConnection, - weak_factory_.GetWeakPtr())); - - return ERR_IO_PENDING; -} - -int SSLConnectJob::DoSSLConnect() { - next_state_ = STATE_SSL_CONNECT_COMPLETE; - - connect_timing_.ssl_start = base::TimeTicks::Now(); - - return ssl_socket_->Connect(io_callback_); + return ssl_socket_->Connect(callback_); } int SSLConnectJob::DoSSLConnectComplete(int result) { + // TODO(rvargas): Remove ScopedTracker below once crbug.com/462784 is fixed. + tracked_objects::ScopedTracker tracking_profile( + FROM_HERE_WITH_EXPLICIT_FUNCTION( + "462784 SSLConnectJob::DoSSLConnectComplete")); + connect_timing_.ssl_end = base::TimeTicks::Now(); - SSLClientSocket::NextProtoStatus status = - SSLClientSocket::kNextProtoUnsupported; - std::string proto; - // GetNextProto will fail and and trigger a NOTREACHED if we pass in a socket - // that hasn't had SSL_ImportFD called on it. If we get a certificate error - // here, then we know that we called SSL_ImportFD. - if (result == OK || IsCertificateError(result)) { - status = ssl_socket_->GetNextProto(&proto); - ssl_socket_->RecordNegotiationExtension(); + if (result != OK && !server_address_.address().empty()) { + connection_attempts_.push_back(ConnectionAttempt(server_address_, result)); + server_address_ = IPEndPoint(); } - // If we want spdy over npn, make sure it succeeded. - if (status == SSLClientSocket::kNextProtoNegotiated) { - ssl_socket_->set_was_npn_negotiated(true); - NextProto protocol_negotiated = - SSLClientSocket::NextProtoFromString(proto); - ssl_socket_->set_protocol_negotiated(protocol_negotiated); - // If we negotiated a SPDY version, it must have been present in - // SSLConfig::next_protos. - // TODO(mbelshe): Verify this. - if (protocol_negotiated >= kProtoSPDYMinimumVersion && - protocol_negotiated <= kProtoSPDYMaximumVersion) { - ssl_socket_->set_was_spdy_negotiated(true); - } - } - if (params_->want_spdy_over_npn() && !ssl_socket_->was_spdy_negotiated()) + // If we want SPDY over ALPN/NPN, make sure it succeeded. + if (params_->want_spdy_over_npn() && + !NextProtoIsSPDY(ssl_socket_->GetNegotiatedProtocol())) { return ERR_NPN_NEGOTIATION_FAILED; - - // Spdy might be turned on by default, or it might be over npn. - bool using_spdy = params_->force_spdy_over_ssl() || - params_->want_spdy_over_npn(); + } if (result == OK || ssl_socket_->IgnoreCertError(result, params_->load_flags())) { DCHECK(!connect_timing_.ssl_start.is_null()); base::TimeDelta connect_duration = connect_timing_.ssl_end - connect_timing_.ssl_start; - if (using_spdy) { + if (params_->want_spdy_over_npn()) { UMA_HISTOGRAM_CUSTOM_TIMES("Net.SpdyConnectionLatency_2", connect_duration, base::TimeDelta::FromMilliseconds(1), @@ -498,17 +364,17 @@ int SSLConnectJob::DoSSLConnectComplete(int result) { 100); SSLInfo ssl_info; - ssl_socket_->GetSSLInfo(&ssl_info); + bool has_ssl_info = ssl_socket_->GetSSLInfo(&ssl_info); + DCHECK(has_ssl_info); + + UMA_HISTOGRAM_ENUMERATION("Net.SSLVersion", SSLConnectionStatusToVersion( + ssl_info.connection_status), + SSL_CONNECTION_VERSION_MAX); UMA_HISTOGRAM_SPARSE_SLOWLY("Net.SSL_CipherSuite", SSLConnectionStatusToCipherSuite( ssl_info.connection_status)); - UMA_HISTOGRAM_BOOLEAN( - "Net.RenegotiationExtensionSupported", - (ssl_info.connection_status & - SSL_CONNECTION_NO_RENEGOTIATION_EXTENSION) == 0); - if (ssl_info.handshake_type == SSLInfo::HANDSHAKE_RESUME) { UMA_HISTOGRAM_CUSTOM_TIMES("Net.SSL_Connection_Latency_Resume_Handshake", connect_duration, @@ -551,6 +417,12 @@ int SSLConnectJob::DoSSLConnectComplete(int result) { } } + UMA_HISTOGRAM_SPARSE_SLOWLY("Net.SSL_Connection_Error", std::abs(result)); + if (params_->ssl_config().fastradio_padding_eligible) { + UMA_HISTOGRAM_SPARSE_SLOWLY("Net.SSL_Connection_Error_FastRadioPadding", + std::abs(result)); + } + if (result == OK || IsCertificateError(result)) { SetSocket(ssl_socket_.Pass()); } else if (result == ERR_SSL_CLIENT_AUTH_CERT_NEEDED) { @@ -562,12 +434,6 @@ int SSLConnectJob::DoSSLConnectComplete(int result) { return result; } -void SSLConnectJob::ResumeSSLConnection() { - DCHECK_EQ(next_state_, STATE_SSL_CONNECT); - messenger_ = NULL; - OnIOComplete(OK); -} - SSLConnectJob::State SSLConnectJob::GetInitialState( SSLSocketParams::ConnectionType connection_type) { switch (connection_type) { @@ -592,17 +458,13 @@ SSLClientSocketPool::SSLConnectJobFactory::SSLConnectJobFactory( SOCKSClientSocketPool* socks_pool, HttpProxyClientSocketPool* http_proxy_pool, ClientSocketFactory* client_socket_factory, - HostResolver* host_resolver, const SSLClientSocketContext& context, - const SSLConnectJob::GetMessengerCallback& get_messenger_callback, NetLog* net_log) : transport_pool_(transport_pool), socks_pool_(socks_pool), http_proxy_pool_(http_proxy_pool), client_socket_factory_(client_socket_factory), - host_resolver_(host_resolver), context_(context), - get_messenger_callback_(get_messenger_callback), net_log_(net_log) { base::TimeDelta max_transport_timeout = base::TimeDelta(); base::TimeDelta pool_timeout; @@ -628,19 +490,17 @@ SSLClientSocketPool::SSLConnectJobFactory::~SSLConnectJobFactory() { SSLClientSocketPool::SSLClientSocketPool( int max_sockets, int max_sockets_per_group, - ClientSocketPoolHistograms* histograms, - HostResolver* host_resolver, CertVerifier* cert_verifier, ChannelIDService* channel_id_service, TransportSecurityState* transport_security_state, CTVerifier* cert_transparency_verifier, + CertPolicyEnforcer* cert_policy_enforcer, const std::string& ssl_session_cache_shard, ClientSocketFactory* client_socket_factory, TransportClientSocketPool* transport_pool, SOCKSClientSocketPool* socks_pool, HttpProxyClientSocketPool* http_proxy_pool, SSLConfigService* ssl_config_service, - bool enable_ssl_connect_job_waiting, NetLog* net_log) : transport_pool_(transport_pool), socks_pool_(socks_pool), @@ -648,7 +508,6 @@ SSLClientSocketPool::SSLClientSocketPool( base_(this, max_sockets, max_sockets_per_group, - histograms, ClientSocketPool::unused_idle_socket_timeout(), ClientSocketPool::used_idle_socket_timeout(), new SSLConnectJobFactory( @@ -656,18 +515,14 @@ SSLClientSocketPool::SSLClientSocketPool( socks_pool, http_proxy_pool, client_socket_factory, - host_resolver, SSLClientSocketContext(cert_verifier, channel_id_service, transport_security_state, cert_transparency_verifier, + cert_policy_enforcer, ssl_session_cache_shard), - base::Bind( - &SSLClientSocketPool::GetOrCreateSSLConnectJobMessenger, - base::Unretained(this)), net_log)), - ssl_config_service_(ssl_config_service), - enable_ssl_connect_job_waiting_(enable_ssl_connect_job_waiting) { + ssl_config_service_(ssl_config_service) { if (ssl_config_service_.get()) ssl_config_service_->AddObserver(this); if (transport_pool_) @@ -679,8 +534,6 @@ SSLClientSocketPool::SSLClientSocketPool( } SSLClientSocketPool::~SSLClientSocketPool() { - STLDeleteContainerPairSecondPointers(messenger_map_.begin(), - messenger_map_.end()); if (ssl_config_service_.get()) ssl_config_service_->RemoveObserver(this); } @@ -697,9 +550,7 @@ scoped_ptr<ConnectJob> SSLClientSocketPool::SSLConnectJobFactory::NewConnectJob( socks_pool_, http_proxy_pool_, client_socket_factory_, - host_resolver_, context_, - get_messenger_callback_, delegate, net_log_)); } @@ -797,10 +648,6 @@ base::TimeDelta SSLClientSocketPool::ConnectionTimeout() const { return base_.ConnectionTimeout(); } -ClientSocketPoolHistograms* SSLClientSocketPool::histograms() const { - return base_.histograms(); -} - bool SSLClientSocketPool::IsStalled() const { return base_.IsStalled(); } @@ -820,32 +667,6 @@ bool SSLClientSocketPool::CloseOneIdleConnection() { return base_.CloseOneIdleConnectionInHigherLayeredPool(); } -SSLConnectJobMessenger* SSLClientSocketPool::GetOrCreateSSLConnectJobMessenger( - const std::string& cache_key) { - if (!enable_ssl_connect_job_waiting_) - return NULL; - MessengerMap::const_iterator it = messenger_map_.find(cache_key); - if (it == messenger_map_.end()) { - std::pair<MessengerMap::iterator, bool> iter = - messenger_map_.insert(MessengerMap::value_type( - cache_key, - new SSLConnectJobMessenger( - base::Bind(&SSLClientSocketPool::DeleteSSLConnectJobMessenger, - base::Unretained(this), - cache_key)))); - it = iter.first; - } - return it->second; -} - -void SSLClientSocketPool::DeleteSSLConnectJobMessenger( - const std::string& cache_key) { - MessengerMap::iterator it = messenger_map_.find(cache_key); - CHECK(it != messenger_map_.end()); - delete it->second; - messenger_map_.erase(it); -} - void SSLClientSocketPool::OnSSLConfigChanged() { FlushWithError(ERR_NETWORK_CHANGED); } diff --git a/chromium/net/socket/ssl_client_socket_pool.h b/chromium/net/socket/ssl_client_socket_pool.h index c7f613e5149..3069c8ddb03 100644 --- a/chromium/net/socket/ssl_client_socket_pool.h +++ b/chromium/net/socket/ssl_client_socket_pool.h @@ -5,24 +5,22 @@ #ifndef NET_SOCKET_SSL_CLIENT_SOCKET_POOL_H_ #define NET_SOCKET_SSL_CLIENT_SOCKET_POOL_H_ -#include <map> #include <string> -#include <vector> #include "base/memory/ref_counted.h" #include "base/memory/scoped_ptr.h" #include "base/time/time.h" #include "net/base/privacy_mode.h" -#include "net/dns/host_resolver.h" #include "net/http/http_response_info.h" #include "net/socket/client_socket_pool.h" #include "net/socket/client_socket_pool_base.h" -#include "net/socket/client_socket_pool_histograms.h" +#include "net/socket/connection_attempts.h" #include "net/socket/ssl_client_socket.h" #include "net/ssl/ssl_config_service.h" namespace net { +class CertPolicyEnforcer; class CertVerifier; class ClientSocketFactory; class ConnectJobFactory; @@ -52,7 +50,6 @@ class NET_EXPORT_PRIVATE SSLSocketParams const SSLConfig& ssl_config, PrivacyMode privacy_mode, int load_flags, - bool force_spdy_over_ssl, bool want_spdy_over_npn); // Returns the type of the underlying connection. @@ -74,7 +71,6 @@ class NET_EXPORT_PRIVATE SSLSocketParams const SSLConfig& ssl_config() const { return ssl_config_; } PrivacyMode privacy_mode() const { return privacy_mode_; } int load_flags() const { return load_flags_; } - bool force_spdy_over_ssl() const { return force_spdy_over_ssl_; } bool want_spdy_over_npn() const { return want_spdy_over_npn_; } bool ignore_limits() const { return ignore_limits_; } @@ -89,96 +85,16 @@ class NET_EXPORT_PRIVATE SSLSocketParams const SSLConfig ssl_config_; const PrivacyMode privacy_mode_; const int load_flags_; - const bool force_spdy_over_ssl_; const bool want_spdy_over_npn_; bool ignore_limits_; DISALLOW_COPY_AND_ASSIGN(SSLSocketParams); }; -// SSLConnectJobMessenger handles communication between concurrent -// SSLConnectJobs that share the same SSL session cache key. -// -// SSLConnectJobMessengers tell the session cache when a certain -// connection should be monitored for success or failure, and -// tell SSLConnectJobs when to pause or resume their connections. -class SSLConnectJobMessenger { - public: - struct SocketAndCallback { - SocketAndCallback(SSLClientSocket* ssl_socket, - const base::Closure& job_resumption_callback); - ~SocketAndCallback(); - - SSLClientSocket* socket; - base::Closure callback; - }; - - typedef std::vector<SocketAndCallback> SSLPendingSocketsAndCallbacks; - - // |messenger_finished_callback| is run when a connection monitored by the - // SSLConnectJobMessenger has completed and we are finished with the - // SSLConnectJobMessenger. - explicit SSLConnectJobMessenger( - const base::Closure& messenger_finished_callback); - ~SSLConnectJobMessenger(); - - // Removes |socket| from the set of sockets being monitored. This - // guarantees that |job_resumption_callback| will not be called for - // the socket. - void RemovePendingSocket(SSLClientSocket* ssl_socket); - - // Returns true if |ssl_socket|'s Connect() method should be called. - bool CanProceed(SSLClientSocket* ssl_socket); - - // Configures the SSLConnectJobMessenger to begin monitoring |ssl_socket|'s - // connection status. After a successful connection, or an error, - // the messenger will determine which sockets that have been added - // via AddPendingSocket() to allow to proceed. - void MonitorConnectionResult(SSLClientSocket* ssl_socket); - - // Adds |socket| to the list of sockets waiting to Connect(). When - // the messenger has determined that it's an appropriate time for |socket| - // to connect, it will invoke |callback|. - // - // Note: It is an error to call AddPendingSocket() without having first - // called MonitorConnectionResult() and configuring a socket that WILL - // have Connect() called on it. - void AddPendingSocket(SSLClientSocket* ssl_socket, - const base::Closure& callback); - - private: - // Processes pending callbacks when a socket completes its SSL handshake -- - // either successfully or unsuccessfully. - void OnSSLHandshakeCompleted(); - - // Runs all callbacks stored in |pending_sockets_and_callbacks_|. - void RunAllCallbacks( - const SSLPendingSocketsAndCallbacks& pending_socket_and_callbacks); - - SSLPendingSocketsAndCallbacks pending_sockets_and_callbacks_; - // Note: this field is a vector to allow for future design changes. Currently, - // this vector should only ever have one entry. - std::vector<SSLClientSocket*> connecting_sockets_; - - base::Closure messenger_finished_callback_; - - base::WeakPtrFactory<SSLConnectJobMessenger> weak_factory_; -}; - // SSLConnectJob handles the SSL handshake after setting up the underlying // connection as specified in the params. class SSLConnectJob : public ConnectJob { public: - // Callback to allow the SSLConnectJob to obtain an SSLConnectJobMessenger to - // coordinate connecting. The SSLConnectJob will supply a unique identifer - // (ex: the SSL session cache key), with the expectation that the same - // Messenger will be returned for all such ConnectJobs. - // - // Note: It will only be called for situations where the SSL session cache - // does not already have a candidate session to resume. - typedef base::Callback<SSLConnectJobMessenger*(const std::string&)> - GetMessengerCallback; - // Note: the SSLConnectJob does not own |messenger| so it must outlive the // job. SSLConnectJob(const std::string& group_name, @@ -189,9 +105,7 @@ class SSLConnectJob : public ConnectJob { SOCKSClientSocketPool* socks_pool, HttpProxyClientSocketPool* http_proxy_pool, ClientSocketFactory* client_socket_factory, - HostResolver* host_resolver, const SSLClientSocketContext& context, - const GetMessengerCallback& get_messenger_callback, Delegate* delegate, NetLog* net_log); ~SSLConnectJob() override; @@ -209,8 +123,6 @@ class SSLConnectJob : public ConnectJob { STATE_SOCKS_CONNECT_COMPLETE, STATE_TUNNEL_CONNECT, STATE_TUNNEL_CONNECT_COMPLETE, - STATE_CREATE_SSL_SOCKET, - STATE_CHECK_FOR_RESUME, STATE_SSL_CONNECT, STATE_SSL_CONNECT_COMPLETE, STATE_NONE, @@ -227,14 +139,9 @@ class SSLConnectJob : public ConnectJob { int DoSOCKSConnectComplete(int result); int DoTunnelConnect(); int DoTunnelConnectComplete(int result); - int DoCreateSSLSocket(); - int DoCheckForResume(); int DoSSLConnect(); int DoSSLConnectComplete(int result); - // Tells a waiting SSLConnectJob to resume its SSL connection. - void ResumeSSLConnection(); - // Returns the initial state for the state machine based on the // |connection_type|. static State GetInitialState(SSLSocketParams::ConnectionType connection_type); @@ -249,21 +156,21 @@ class SSLConnectJob : public ConnectJob { SOCKSClientSocketPool* const socks_pool_; HttpProxyClientSocketPool* const http_proxy_pool_; ClientSocketFactory* const client_socket_factory_; - HostResolver* const host_resolver_; const SSLClientSocketContext context_; State next_state_; - CompletionCallback io_callback_; + CompletionCallback callback_; scoped_ptr<ClientSocketHandle> transport_socket_handle_; scoped_ptr<SSLClientSocket> ssl_socket_; - SSLConnectJobMessenger* messenger_; HttpResponseInfo error_response_info_; - GetMessengerCallback get_messenger_callback_; - - base::WeakPtrFactory<SSLConnectJob> weak_factory_; + ConnectionAttempts connection_attempts_; + // The address of the server the connect job is connected to. Populated if + // and only if the connect job is connected *directly* to the server (not + // through an HTTPS CONNECT request or a SOCKS proxy). + IPEndPoint server_address_; DISALLOW_COPY_AND_ASSIGN(SSLConnectJob); }; @@ -279,19 +186,17 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool // try to create an SSL over SOCKS socket, |socks_pool| may be NULL. SSLClientSocketPool(int max_sockets, int max_sockets_per_group, - ClientSocketPoolHistograms* histograms, - HostResolver* host_resolver, CertVerifier* cert_verifier, ChannelIDService* channel_id_service, TransportSecurityState* transport_security_state, CTVerifier* cert_transparency_verifier, + CertPolicyEnforcer* cert_policy_enforcer, const std::string& ssl_session_cache_shard, ClientSocketFactory* client_socket_factory, TransportClientSocketPool* transport_pool, SOCKSClientSocketPool* socks_pool, HttpProxyClientSocketPool* http_proxy_pool, SSLConfigService* ssl_config_service, - bool enable_ssl_connect_job_waiting, NetLog* net_log); ~SSLClientSocketPool() override; @@ -334,8 +239,6 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool base::TimeDelta ConnectionTimeout() const override; - ClientSocketPoolHistograms* histograms() const override; - // LowerLayeredPool implementation. bool IsStalled() const override; @@ -346,16 +249,8 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool // HigherLayeredPool implementation. bool CloseOneIdleConnection() override; - // Gets the SSLConnectJobMessenger for the given ssl session |cache_key|. If - // none exits, it creates one and stores it in |messenger_map_|. - SSLConnectJobMessenger* GetOrCreateSSLConnectJobMessenger( - const std::string& cache_key); - void DeleteSSLConnectJobMessenger(const std::string& cache_key); - private: typedef ClientSocketPoolBase<SSLSocketParams> PoolBase; - // Maps SSLConnectJob cache keys to SSLConnectJobMessenger objects. - typedef std::map<std::string, SSLConnectJobMessenger*> MessengerMap; // SSLConfigService::Observer implementation. @@ -370,9 +265,7 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool SOCKSClientSocketPool* socks_pool, HttpProxyClientSocketPool* http_proxy_pool, ClientSocketFactory* client_socket_factory, - HostResolver* host_resolver, const SSLClientSocketContext& context, - const SSLConnectJob::GetMessengerCallback& get_messenger_callback, NetLog* net_log); ~SSLConnectJobFactory() override; @@ -390,10 +283,8 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool SOCKSClientSocketPool* const socks_pool_; HttpProxyClientSocketPool* const http_proxy_pool_; ClientSocketFactory* const client_socket_factory_; - HostResolver* const host_resolver_; const SSLClientSocketContext context_; base::TimeDelta timeout_; - SSLConnectJob::GetMessengerCallback get_messenger_callback_; NetLog* net_log_; DISALLOW_COPY_AND_ASSIGN(SSLConnectJobFactory); @@ -404,8 +295,6 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool HttpProxyClientSocketPool* const http_proxy_pool_; PoolBase base_; const scoped_refptr<SSLConfigService> ssl_config_service_; - MessengerMap messenger_map_; - bool enable_ssl_connect_job_waiting_; DISALLOW_COPY_AND_ASSIGN(SSLClientSocketPool); }; diff --git a/chromium/net/socket/ssl_client_socket_pool_unittest.cc b/chromium/net/socket/ssl_client_socket_pool_unittest.cc index 202cd8809e5..a9d34472b69 100644 --- a/chromium/net/socket/ssl_client_socket_pool_unittest.cc +++ b/chromium/net/socket/ssl_client_socket_pool_unittest.cc @@ -6,7 +6,6 @@ #include "base/callback.h" #include "base/compiler_specific.h" -#include "base/run_loop.h" #include "base/strings/string_util.h" #include "base/strings/utf_string_conversions.h" #include "base/time/time.h" @@ -25,7 +24,6 @@ #include "net/http/transport_security_state.h" #include "net/proxy/proxy_service.h" #include "net/socket/client_socket_handle.h" -#include "net/socket/client_socket_pool_histograms.h" #include "net/socket/next_proto.h" #include "net/socket/socket_test_util.h" #include "net/spdy/spdy_session.h" @@ -86,38 +84,31 @@ class SSLClientSocketPoolTest http_auth_handler_factory_( HttpAuthHandlerFactory::CreateDefault(&host_resolver_)), session_(CreateNetworkSession()), - direct_transport_socket_params_( - new TransportSocketParams( - HostPortPair("host", 443), - false, - false, - OnHostResolutionCallback(), - TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)), - transport_histograms_("MockTCP"), + direct_transport_socket_params_(new TransportSocketParams( + HostPortPair("host", 443), + false, + false, + OnHostResolutionCallback(), + TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)), transport_socket_pool_(kMaxSockets, kMaxSocketsPerGroup, - &transport_histograms_, &socket_factory_), - proxy_transport_socket_params_( - new TransportSocketParams( - HostPortPair("proxy", 443), - false, - false, - OnHostResolutionCallback(), - TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)), + proxy_transport_socket_params_(new TransportSocketParams( + HostPortPair("proxy", 443), + false, + false, + OnHostResolutionCallback(), + TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)), socks_socket_params_( new SOCKSSocketParams(proxy_transport_socket_params_, true, HostPortPair("sockshost", 443))), - socks_histograms_("MockSOCKS"), socks_socket_pool_(kMaxSockets, kMaxSocketsPerGroup, - &socks_histograms_, &transport_socket_pool_), http_proxy_socket_params_( new HttpProxySocketParams(proxy_transport_socket_params_, NULL, - GURL("http://host"), std::string(), HostPortPair("host", 80), session_->http_auth_cache(), @@ -125,40 +116,25 @@ class SSLClientSocketPoolTest session_->spdy_session_pool(), true, NULL)), - http_proxy_histograms_("MockHttpProxy"), http_proxy_socket_pool_(kMaxSockets, kMaxSocketsPerGroup, - &http_proxy_histograms_, - &host_resolver_, &transport_socket_pool_, NULL, - NULL, - NULL), - enable_ssl_connect_job_waiting_(false) { + NULL) { scoped_refptr<SSLConfigService> ssl_config_service( new SSLConfigServiceDefaults); ssl_config_service->GetSSLConfig(&ssl_config_); } void CreatePool(bool transport_pool, bool http_proxy_pool, bool socks_pool) { - ssl_histograms_.reset(new ClientSocketPoolHistograms("SSLUnitTest")); pool_.reset(new SSLClientSocketPool( - kMaxSockets, - kMaxSocketsPerGroup, - ssl_histograms_.get(), - NULL /* host_resolver */, - NULL /* cert_verifier */, - NULL /* channel_id_service */, - NULL /* transport_security_state */, - NULL /* cert_transparency_verifier */, - std::string() /* ssl_session_cache_shard */, - &socket_factory_, + kMaxSockets, kMaxSocketsPerGroup, NULL /* cert_verifier */, + NULL /* channel_id_service */, NULL /* transport_security_state */, + NULL /* cert_transparency_verifier */, NULL /* cert_policy_enforcer */, + std::string() /* ssl_session_cache_shard */, &socket_factory_, transport_pool ? &transport_socket_pool_ : NULL, socks_pool ? &socks_socket_pool_ : NULL, - http_proxy_pool ? &http_proxy_socket_pool_ : NULL, - NULL, - enable_ssl_connect_job_waiting_, - NULL)); + http_proxy_pool ? &http_proxy_socket_pool_ : NULL, NULL, NULL)); } scoped_refptr<SSLSocketParams> SSLParams(ProxyServer::Scheme proxy, @@ -172,7 +148,6 @@ class SSLClientSocketPoolTest ssl_config_, PRIVACY_MODE_DISABLED, 0, - false, want_spdy_over_npn)); } @@ -216,487 +191,25 @@ class SSLClientSocketPoolTest const scoped_refptr<HttpNetworkSession> session_; scoped_refptr<TransportSocketParams> direct_transport_socket_params_; - ClientSocketPoolHistograms transport_histograms_; MockTransportClientSocketPool transport_socket_pool_; scoped_refptr<TransportSocketParams> proxy_transport_socket_params_; scoped_refptr<SOCKSSocketParams> socks_socket_params_; - ClientSocketPoolHistograms socks_histograms_; MockSOCKSClientSocketPool socks_socket_pool_; scoped_refptr<HttpProxySocketParams> http_proxy_socket_params_; - ClientSocketPoolHistograms http_proxy_histograms_; HttpProxyClientSocketPool http_proxy_socket_pool_; SSLConfig ssl_config_; - scoped_ptr<ClientSocketPoolHistograms> ssl_histograms_; scoped_ptr<SSLClientSocketPool> pool_; - - bool enable_ssl_connect_job_waiting_; }; -INSTANTIATE_TEST_CASE_P( - NextProto, - SSLClientSocketPoolTest, - testing::Values(kProtoDeprecatedSPDY2, - kProtoSPDY3, kProtoSPDY31, kProtoSPDY4)); - -// Tests that the final socket will connect even if all sockets -// prior to it fail. -// -// All sockets should wait for the first socket to attempt to -// connect. Once it fails to connect, all other sockets should -// attempt to connect. All should fail, except the final socket. -TEST_P(SSLClientSocketPoolTest, AllSocketsFailButLast) { - // Although we request four sockets, the first three socket connect - // failures cause the socket pool to create three more sockets because - // there are pending requests. - StaticSocketDataProvider data1; - StaticSocketDataProvider data2; - StaticSocketDataProvider data3; - StaticSocketDataProvider data4; - StaticSocketDataProvider data5; - StaticSocketDataProvider data6; - StaticSocketDataProvider data7; - socket_factory_.AddSocketDataProvider(&data1); - socket_factory_.AddSocketDataProvider(&data2); - socket_factory_.AddSocketDataProvider(&data3); - socket_factory_.AddSocketDataProvider(&data4); - socket_factory_.AddSocketDataProvider(&data5); - socket_factory_.AddSocketDataProvider(&data6); - socket_factory_.AddSocketDataProvider(&data7); - SSLSocketDataProvider ssl(ASYNC, ERR_SSL_PROTOCOL_ERROR); - ssl.is_in_session_cache = false; - SSLSocketDataProvider ssl2(ASYNC, ERR_SSL_PROTOCOL_ERROR); - ssl2.is_in_session_cache = false; - SSLSocketDataProvider ssl3(ASYNC, ERR_SSL_PROTOCOL_ERROR); - ssl3.is_in_session_cache = false; - SSLSocketDataProvider ssl4(ASYNC, OK); - ssl4.is_in_session_cache = false; - SSLSocketDataProvider ssl5(ASYNC, OK); - ssl5.is_in_session_cache = false; - SSLSocketDataProvider ssl6(ASYNC, OK); - ssl6.is_in_session_cache = false; - SSLSocketDataProvider ssl7(ASYNC, OK); - ssl7.is_in_session_cache = false; - - socket_factory_.AddSSLSocketDataProvider(&ssl); - socket_factory_.AddSSLSocketDataProvider(&ssl2); - socket_factory_.AddSSLSocketDataProvider(&ssl3); - socket_factory_.AddSSLSocketDataProvider(&ssl4); - socket_factory_.AddSSLSocketDataProvider(&ssl5); - socket_factory_.AddSSLSocketDataProvider(&ssl6); - socket_factory_.AddSSLSocketDataProvider(&ssl7); - - enable_ssl_connect_job_waiting_ = true; - CreatePool(true, false, false); - - scoped_refptr<SSLSocketParams> params1 = - SSLParams(ProxyServer::SCHEME_DIRECT, false); - scoped_refptr<SSLSocketParams> params2 = - SSLParams(ProxyServer::SCHEME_DIRECT, false); - scoped_refptr<SSLSocketParams> params3 = - SSLParams(ProxyServer::SCHEME_DIRECT, false); - scoped_refptr<SSLSocketParams> params4 = - SSLParams(ProxyServer::SCHEME_DIRECT, false); - ClientSocketHandle handle1; - ClientSocketHandle handle2; - ClientSocketHandle handle3; - ClientSocketHandle handle4; - TestCompletionCallback callback1; - TestCompletionCallback callback2; - TestCompletionCallback callback3; - TestCompletionCallback callback4; - - handle1.Init( - "b", params1, MEDIUM, callback1.callback(), pool_.get(), BoundNetLog()); - handle2.Init( - "b", params2, MEDIUM, callback2.callback(), pool_.get(), BoundNetLog()); - handle3.Init( - "b", params3, MEDIUM, callback3.callback(), pool_.get(), BoundNetLog()); - handle4.Init( - "b", params4, MEDIUM, callback4.callback(), pool_.get(), BoundNetLog()); - - base::RunLoop().RunUntilIdle(); - - // Only the last socket should have connected. - EXPECT_FALSE(handle1.socket()); - EXPECT_FALSE(handle2.socket()); - EXPECT_FALSE(handle3.socket()); - EXPECT_TRUE(handle4.socket()->IsConnected()); -} - -// Tests that sockets will still connect in parallel if the -// EnableSSLConnectJobWaiting flag is not enabled. -TEST_P(SSLClientSocketPoolTest, SocketsConnectWithoutFlag) { - StaticSocketDataProvider data1; - StaticSocketDataProvider data2; - StaticSocketDataProvider data3; - socket_factory_.AddSocketDataProvider(&data1); - socket_factory_.AddSocketDataProvider(&data2); - socket_factory_.AddSocketDataProvider(&data3); - - SSLSocketDataProvider ssl(ASYNC, OK); - ssl.is_in_session_cache = false; - ssl.should_pause_on_connect = true; - SSLSocketDataProvider ssl2(ASYNC, OK); - ssl2.is_in_session_cache = false; - ssl2.should_pause_on_connect = true; - SSLSocketDataProvider ssl3(ASYNC, OK); - ssl3.is_in_session_cache = false; - ssl3.should_pause_on_connect = true; - socket_factory_.AddSSLSocketDataProvider(&ssl); - socket_factory_.AddSSLSocketDataProvider(&ssl2); - socket_factory_.AddSSLSocketDataProvider(&ssl3); - - CreatePool(true, false, false); - - scoped_refptr<SSLSocketParams> params1 = - SSLParams(ProxyServer::SCHEME_DIRECT, false); - scoped_refptr<SSLSocketParams> params2 = - SSLParams(ProxyServer::SCHEME_DIRECT, false); - scoped_refptr<SSLSocketParams> params3 = - SSLParams(ProxyServer::SCHEME_DIRECT, false); - ClientSocketHandle handle1; - ClientSocketHandle handle2; - ClientSocketHandle handle3; - TestCompletionCallback callback1; - TestCompletionCallback callback2; - TestCompletionCallback callback3; - - handle1.Init( - "b", params1, MEDIUM, callback1.callback(), pool_.get(), BoundNetLog()); - handle2.Init( - "b", params2, MEDIUM, callback2.callback(), pool_.get(), BoundNetLog()); - handle3.Init( - "b", params3, MEDIUM, callback3.callback(), pool_.get(), BoundNetLog()); - - base::RunLoop().RunUntilIdle(); - - std::vector<MockSSLClientSocket*> sockets = - socket_factory_.ssl_client_sockets(); - - // All sockets should have started their connections. - for (std::vector<MockSSLClientSocket*>::iterator it = sockets.begin(); - it != sockets.end(); - ++it) { - EXPECT_TRUE((*it)->reached_connect()); - } - - // Resume connecting all of the sockets. - for (std::vector<MockSSLClientSocket*>::iterator it = sockets.begin(); - it != sockets.end(); - ++it) { - (*it)->RestartPausedConnect(); - } - - callback1.WaitForResult(); - callback2.WaitForResult(); - callback3.WaitForResult(); - - EXPECT_TRUE(handle1.socket()->IsConnected()); - EXPECT_TRUE(handle2.socket()->IsConnected()); - EXPECT_TRUE(handle3.socket()->IsConnected()); -} - -// Tests that the pool deleting an SSLConnectJob will not cause a crash, -// or prevent pending sockets from connecting. -TEST_P(SSLClientSocketPoolTest, DeletedSSLConnectJob) { - StaticSocketDataProvider data1; - StaticSocketDataProvider data2; - StaticSocketDataProvider data3; - socket_factory_.AddSocketDataProvider(&data1); - socket_factory_.AddSocketDataProvider(&data2); - socket_factory_.AddSocketDataProvider(&data3); - - SSLSocketDataProvider ssl(ASYNC, OK); - ssl.is_in_session_cache = false; - ssl.should_pause_on_connect = true; - SSLSocketDataProvider ssl2(ASYNC, OK); - ssl2.is_in_session_cache = false; - SSLSocketDataProvider ssl3(ASYNC, OK); - ssl3.is_in_session_cache = false; - socket_factory_.AddSSLSocketDataProvider(&ssl); - socket_factory_.AddSSLSocketDataProvider(&ssl2); - socket_factory_.AddSSLSocketDataProvider(&ssl3); - - enable_ssl_connect_job_waiting_ = true; - CreatePool(true, false, false); - - scoped_refptr<SSLSocketParams> params1 = - SSLParams(ProxyServer::SCHEME_DIRECT, false); - scoped_refptr<SSLSocketParams> params2 = - SSLParams(ProxyServer::SCHEME_DIRECT, false); - scoped_refptr<SSLSocketParams> params3 = - SSLParams(ProxyServer::SCHEME_DIRECT, false); - ClientSocketHandle handle1; - ClientSocketHandle handle2; - ClientSocketHandle handle3; - TestCompletionCallback callback1; - TestCompletionCallback callback2; - TestCompletionCallback callback3; - - handle1.Init( - "b", params1, MEDIUM, callback1.callback(), pool_.get(), BoundNetLog()); - handle2.Init( - "b", params2, MEDIUM, callback2.callback(), pool_.get(), BoundNetLog()); - handle3.Init( - "b", params3, MEDIUM, callback3.callback(), pool_.get(), BoundNetLog()); - - // Allow the connections to proceed until the first socket has started - // connecting. - base::RunLoop().RunUntilIdle(); - - std::vector<MockSSLClientSocket*> sockets = - socket_factory_.ssl_client_sockets(); - - pool_->CancelRequest("b", &handle2); - - sockets[0]->RestartPausedConnect(); - - callback1.WaitForResult(); - callback3.WaitForResult(); - - EXPECT_TRUE(handle1.socket()->IsConnected()); - EXPECT_FALSE(handle2.socket()); - EXPECT_TRUE(handle3.socket()->IsConnected()); -} - -// Tests that all pending sockets still connect when the pool deletes a pending -// SSLConnectJob which immediately followed a failed leading connection. -TEST_P(SSLClientSocketPoolTest, DeletedSocketAfterFail) { - StaticSocketDataProvider data1; - StaticSocketDataProvider data2; - StaticSocketDataProvider data3; - StaticSocketDataProvider data4; - socket_factory_.AddSocketDataProvider(&data1); - socket_factory_.AddSocketDataProvider(&data2); - socket_factory_.AddSocketDataProvider(&data3); - socket_factory_.AddSocketDataProvider(&data4); - - SSLSocketDataProvider ssl(ASYNC, ERR_SSL_PROTOCOL_ERROR); - ssl.is_in_session_cache = false; - ssl.should_pause_on_connect = true; - SSLSocketDataProvider ssl2(ASYNC, OK); - ssl2.is_in_session_cache = false; - SSLSocketDataProvider ssl3(ASYNC, OK); - ssl3.is_in_session_cache = false; - SSLSocketDataProvider ssl4(ASYNC, OK); - ssl4.is_in_session_cache = false; - socket_factory_.AddSSLSocketDataProvider(&ssl); - socket_factory_.AddSSLSocketDataProvider(&ssl2); - socket_factory_.AddSSLSocketDataProvider(&ssl3); - socket_factory_.AddSSLSocketDataProvider(&ssl4); - - enable_ssl_connect_job_waiting_ = true; - CreatePool(true, false, false); - - scoped_refptr<SSLSocketParams> params1 = - SSLParams(ProxyServer::SCHEME_DIRECT, false); - scoped_refptr<SSLSocketParams> params2 = - SSLParams(ProxyServer::SCHEME_DIRECT, false); - scoped_refptr<SSLSocketParams> params3 = - SSLParams(ProxyServer::SCHEME_DIRECT, false); - ClientSocketHandle handle1; - ClientSocketHandle handle2; - ClientSocketHandle handle3; - TestCompletionCallback callback1; - TestCompletionCallback callback2; - TestCompletionCallback callback3; - - handle1.Init( - "b", params1, MEDIUM, callback1.callback(), pool_.get(), BoundNetLog()); - handle2.Init( - "b", params2, MEDIUM, callback2.callback(), pool_.get(), BoundNetLog()); - handle3.Init( - "b", params3, MEDIUM, callback3.callback(), pool_.get(), BoundNetLog()); - - // Allow the connections to proceed until the first socket has started - // connecting. - base::RunLoop().RunUntilIdle(); - - std::vector<MockSSLClientSocket*> sockets = - socket_factory_.ssl_client_sockets(); - - EXPECT_EQ(3u, sockets.size()); - EXPECT_TRUE(sockets[0]->reached_connect()); - EXPECT_FALSE(handle1.socket()); - - pool_->CancelRequest("b", &handle2); - - sockets[0]->RestartPausedConnect(); - - callback1.WaitForResult(); - callback3.WaitForResult(); - - EXPECT_FALSE(handle1.socket()); - EXPECT_FALSE(handle2.socket()); - EXPECT_TRUE(handle3.socket()->IsConnected()); -} - -// Make sure that sockets still connect after the leader socket's -// connection fails. -TEST_P(SSLClientSocketPoolTest, SimultaneousConnectJobsFail) { - StaticSocketDataProvider data1; - StaticSocketDataProvider data2; - StaticSocketDataProvider data3; - StaticSocketDataProvider data4; - StaticSocketDataProvider data5; - socket_factory_.AddSocketDataProvider(&data1); - socket_factory_.AddSocketDataProvider(&data2); - socket_factory_.AddSocketDataProvider(&data3); - socket_factory_.AddSocketDataProvider(&data4); - socket_factory_.AddSocketDataProvider(&data5); - SSLSocketDataProvider ssl(ASYNC, ERR_SSL_PROTOCOL_ERROR); - ssl.is_in_session_cache = false; - ssl.should_pause_on_connect = true; - SSLSocketDataProvider ssl2(ASYNC, OK); - ssl2.is_in_session_cache = false; - ssl2.should_pause_on_connect = true; - SSLSocketDataProvider ssl3(ASYNC, OK); - ssl3.is_in_session_cache = false; - SSLSocketDataProvider ssl4(ASYNC, OK); - ssl4.is_in_session_cache = false; - SSLSocketDataProvider ssl5(ASYNC, OK); - ssl5.is_in_session_cache = false; - - socket_factory_.AddSSLSocketDataProvider(&ssl); - socket_factory_.AddSSLSocketDataProvider(&ssl2); - socket_factory_.AddSSLSocketDataProvider(&ssl3); - socket_factory_.AddSSLSocketDataProvider(&ssl4); - socket_factory_.AddSSLSocketDataProvider(&ssl5); - - enable_ssl_connect_job_waiting_ = true; - CreatePool(true, false, false); - scoped_refptr<SSLSocketParams> params1 = - SSLParams(ProxyServer::SCHEME_DIRECT, false); - scoped_refptr<SSLSocketParams> params2 = - SSLParams(ProxyServer::SCHEME_DIRECT, false); - scoped_refptr<SSLSocketParams> params3 = - SSLParams(ProxyServer::SCHEME_DIRECT, false); - scoped_refptr<SSLSocketParams> params4 = - SSLParams(ProxyServer::SCHEME_DIRECT, false); - ClientSocketHandle handle1; - ClientSocketHandle handle2; - ClientSocketHandle handle3; - ClientSocketHandle handle4; - TestCompletionCallback callback1; - TestCompletionCallback callback2; - TestCompletionCallback callback3; - TestCompletionCallback callback4; - handle1.Init( - "b", params1, MEDIUM, callback1.callback(), pool_.get(), BoundNetLog()); - handle2.Init( - "b", params2, MEDIUM, callback2.callback(), pool_.get(), BoundNetLog()); - handle3.Init( - "b", params3, MEDIUM, callback3.callback(), pool_.get(), BoundNetLog()); - handle4.Init( - "b", params4, MEDIUM, callback4.callback(), pool_.get(), BoundNetLog()); - - base::RunLoop().RunUntilIdle(); - - std::vector<MockSSLClientSocket*> sockets = - socket_factory_.ssl_client_sockets(); - - std::vector<MockSSLClientSocket*>::const_iterator it = sockets.begin(); - - // The first socket should have had Connect called on it. - EXPECT_TRUE((*it)->reached_connect()); - ++it; - - // No other socket should have reached connect yet. - for (; it != sockets.end(); ++it) - EXPECT_FALSE((*it)->reached_connect()); - - // Allow the first socket to resume it's connection process. - sockets[0]->RestartPausedConnect(); - - base::RunLoop().RunUntilIdle(); - - // The second socket should have reached connect. - EXPECT_TRUE(sockets[1]->reached_connect()); - - // Allow the second socket to continue its connection. - sockets[1]->RestartPausedConnect(); - - base::RunLoop().RunUntilIdle(); - - EXPECT_FALSE(handle1.socket()); - EXPECT_TRUE(handle2.socket()->IsConnected()); - EXPECT_TRUE(handle3.socket()->IsConnected()); - EXPECT_TRUE(handle4.socket()->IsConnected()); -} - -// Make sure that no sockets connect before the "leader" socket, -// given that the leader has a successful connection. -TEST_P(SSLClientSocketPoolTest, SimultaneousConnectJobsSuccess) { - StaticSocketDataProvider data1; - StaticSocketDataProvider data2; - StaticSocketDataProvider data3; - socket_factory_.AddSocketDataProvider(&data1); - socket_factory_.AddSocketDataProvider(&data2); - socket_factory_.AddSocketDataProvider(&data3); - - SSLSocketDataProvider ssl(ASYNC, OK); - ssl.is_in_session_cache = false; - ssl.should_pause_on_connect = true; - SSLSocketDataProvider ssl2(ASYNC, OK); - ssl2.is_in_session_cache = false; - SSLSocketDataProvider ssl3(ASYNC, OK); - ssl3.is_in_session_cache = false; - socket_factory_.AddSSLSocketDataProvider(&ssl); - socket_factory_.AddSSLSocketDataProvider(&ssl2); - socket_factory_.AddSSLSocketDataProvider(&ssl3); - - enable_ssl_connect_job_waiting_ = true; - CreatePool(true, false, false); - - scoped_refptr<SSLSocketParams> params1 = - SSLParams(ProxyServer::SCHEME_DIRECT, false); - scoped_refptr<SSLSocketParams> params2 = - SSLParams(ProxyServer::SCHEME_DIRECT, false); - scoped_refptr<SSLSocketParams> params3 = - SSLParams(ProxyServer::SCHEME_DIRECT, false); - ClientSocketHandle handle1; - ClientSocketHandle handle2; - ClientSocketHandle handle3; - TestCompletionCallback callback1; - TestCompletionCallback callback2; - TestCompletionCallback callback3; - - handle1.Init( - "b", params1, MEDIUM, callback1.callback(), pool_.get(), BoundNetLog()); - handle2.Init( - "b", params2, MEDIUM, callback2.callback(), pool_.get(), BoundNetLog()); - handle3.Init( - "b", params3, MEDIUM, callback3.callback(), pool_.get(), BoundNetLog()); - - // Allow the connections to proceed until the first socket has finished - // connecting. - base::RunLoop().RunUntilIdle(); - - std::vector<MockSSLClientSocket*> sockets = - socket_factory_.ssl_client_sockets(); - - std::vector<MockSSLClientSocket*>::const_iterator it = sockets.begin(); - // The first socket should have reached connect. - EXPECT_TRUE((*it)->reached_connect()); - ++it; - // No other socket should have reached connect yet. - for (; it != sockets.end(); ++it) - EXPECT_FALSE((*it)->reached_connect()); - - sockets[0]->RestartPausedConnect(); - - callback1.WaitForResult(); - callback2.WaitForResult(); - callback3.WaitForResult(); - - EXPECT_TRUE(handle1.socket()->IsConnected()); - EXPECT_TRUE(handle2.socket()->IsConnected()); - EXPECT_TRUE(handle3.socket()->IsConnected()); -} +INSTANTIATE_TEST_CASE_P(NextProto, + SSLClientSocketPoolTest, + testing::Values(kProtoSPDY31, + kProtoSPDY4_14, + kProtoSPDY4)); TEST_P(SSLClientSocketPoolTest, TCPFail) { StaticSocketDataProvider data; @@ -714,6 +227,8 @@ TEST_P(SSLClientSocketPoolTest, TCPFail) { EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); EXPECT_FALSE(handle.is_ssl_error()); + ASSERT_EQ(1u, handle.connection_attempts().size()); + EXPECT_EQ(ERR_CONNECTION_FAILED, handle.connection_attempts()[0].result); } TEST_P(SSLClientSocketPoolTest, TCPFailAsync) { @@ -737,6 +252,8 @@ TEST_P(SSLClientSocketPoolTest, TCPFailAsync) { EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); EXPECT_FALSE(handle.is_ssl_error()); + ASSERT_EQ(1u, handle.connection_attempts().size()); + EXPECT_EQ(ERR_CONNECTION_FAILED, handle.connection_attempts()[0].result); } TEST_P(SSLClientSocketPoolTest, BasicDirect) { @@ -758,6 +275,7 @@ TEST_P(SSLClientSocketPoolTest, BasicDirect) { EXPECT_TRUE(handle.is_initialized()); EXPECT_TRUE(handle.socket()); TestLoadTimingInfo(handle); + EXPECT_EQ(0u, handle.connection_attempts().size()); } // Make sure that SSLConnectJob passes on its priority to its @@ -1127,7 +645,7 @@ TEST_P(SSLClientSocketPoolTest, HttpProxyBasic) { MockWrite writes[] = { MockWrite(SYNCHRONOUS, "CONNECT host:80 HTTP/1.1\r\n" - "Host: host\r\n" + "Host: host:80\r\n" "Proxy-Connection: keep-alive\r\n" "Proxy-Authorization: Basic Zm9vOmJhcg==\r\n\r\n"), }; @@ -1162,7 +680,7 @@ TEST_P(SSLClientSocketPoolTest, SetTransportPriorityOnInitHTTP) { MockWrite writes[] = { MockWrite(SYNCHRONOUS, "CONNECT host:80 HTTP/1.1\r\n" - "Host: host\r\n" + "Host: host:80\r\n" "Proxy-Connection: keep-alive\r\n" "Proxy-Authorization: Basic Zm9vOmJhcg==\r\n\r\n"), }; @@ -1190,10 +708,11 @@ TEST_P(SSLClientSocketPoolTest, SetTransportPriorityOnInitHTTP) { TEST_P(SSLClientSocketPoolTest, HttpProxyBasicAsync) { MockWrite writes[] = { - MockWrite("CONNECT host:80 HTTP/1.1\r\n" - "Host: host\r\n" - "Proxy-Connection: keep-alive\r\n" - "Proxy-Authorization: Basic Zm9vOmJhcg==\r\n\r\n"), + MockWrite( + "CONNECT host:80 HTTP/1.1\r\n" + "Host: host:80\r\n" + "Proxy-Connection: keep-alive\r\n" + "Proxy-Authorization: Basic Zm9vOmJhcg==\r\n\r\n"), }; MockRead reads[] = { MockRead("HTTP/1.1 200 Connection Established\r\n\r\n"), @@ -1225,9 +744,10 @@ TEST_P(SSLClientSocketPoolTest, HttpProxyBasicAsync) { TEST_P(SSLClientSocketPoolTest, NeedProxyAuth) { MockWrite writes[] = { - MockWrite("CONNECT host:80 HTTP/1.1\r\n" - "Host: host\r\n" - "Proxy-Connection: keep-alive\r\n\r\n"), + MockWrite( + "CONNECT host:80 HTTP/1.1\r\n" + "Host: host:80\r\n" + "Proxy-Connection: keep-alive\r\n\r\n"), }; MockRead reads[] = { MockRead("HTTP/1.1 407 Proxy Authentication Required\r\n"), diff --git a/chromium/net/socket/ssl_client_socket_unittest.cc b/chromium/net/socket/ssl_client_socket_unittest.cc index 16e03f7eb8b..0a7b7118e3c 100644 --- a/chromium/net/socket/ssl_client_socket_unittest.cc +++ b/chromium/net/socket/ssl_client_socket_unittest.cc @@ -7,12 +7,11 @@ #include "base/callback_helpers.h" #include "base/memory/ref_counted.h" #include "base/run_loop.h" +#include "base/thread_task_runner_handle.h" #include "base/time/time.h" #include "net/base/address_list.h" #include "net/base/io_buffer.h" #include "net/base/net_errors.h" -#include "net/base/net_log.h" -#include "net/base/net_log_unittest.h" #include "net/base/test_completion_callback.h" #include "net/base/test_data_directory.h" #include "net/cert/asn1_util.h" @@ -21,6 +20,10 @@ #include "net/cert/test_root_certs.h" #include "net/dns/host_resolver.h" #include "net/http/transport_security_state.h" +#include "net/log/net_log.h" +#include "net/log/test_net_log.h" +#include "net/log/test_net_log_entry.h" +#include "net/log/test_net_log_util.h" #include "net/socket/client_socket_factory.h" #include "net/socket/client_socket_handle.h" #include "net/socket/socket_test_util.h" @@ -29,12 +32,27 @@ #include "net/ssl/default_channel_id_store.h" #include "net/ssl/ssl_cert_request_info.h" #include "net/ssl/ssl_config_service.h" +#include "net/ssl/ssl_connection_status_flags.h" +#include "net/ssl/ssl_info.h" #include "net/test/cert_test_util.h" #include "net/test/spawned_test_server/spawned_test_server.h" #include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h" #include "testing/platform_test.h" +#if !defined(USE_OPENSSL) +#include <pk11pub.h> +#include "crypto/nss_util.h" + +#if !defined(CKM_AES_GCM) +#define CKM_AES_GCM 0x00001087 +#endif + +#if !defined(CKM_NSS_TLS_MASTER_KEY_DERIVE_DH_SHA256) +#define CKM_NSS_TLS_MASTER_KEY_DERIVE_DH_SHA256 (CKM_NSS + 24) +#endif +#endif + //----------------------------------------------------------------------------- using testing::_; @@ -45,8 +63,6 @@ namespace net { namespace { -const SSLConfig kDefaultSSLConfig; - // WrappedStreamSocket is a base class that wraps an existing StreamSocket, // forwarding the Socket and StreamSocket interfaces to the underlying // transport. @@ -92,6 +108,15 @@ class WrappedStreamSocket : public StreamSocket { bool GetSSLInfo(SSLInfo* ssl_info) override { return transport_->GetSSLInfo(ssl_info); } + void GetConnectionAttempts(ConnectionAttempts* out) const override { + transport_->GetConnectionAttempts(out); + } + void ClearConnectionAttempts() override { + transport_->ClearConnectionAttempts(); + } + void AddConnectionAttempts(const ConnectionAttempts& attempts) override { + transport_->AddConnectionAttempts(attempts); + } // Socket implementation: int Read(IOBuffer* buf, @@ -335,6 +360,9 @@ class FakeBlockingStreamSocket : public WrappedStreamSocket { int buf_len, const CompletionCallback& callback) override; + int pending_read_result() const { return pending_read_result_; } + IOBuffer* pending_read_buf() const { return pending_read_buf_.get(); } + // Blocks read results on the socket. Reads will not complete until // UnblockReadResult() has been called and a result is ready from the // underlying transport. Note: if BlockReadResult() is called while there is a @@ -357,9 +385,6 @@ class FakeBlockingStreamSocket : public WrappedStreamSocket { // Waits for the blocked Write() call to be scheduled. void WaitForWrite(); - // Returns the wrapped stream socket. - StreamSocket* transport() { return transport_.get(); } - private: // Handles completion from the underlying transport read. void OnReadCompleted(int result); @@ -367,6 +392,9 @@ class FakeBlockingStreamSocket : public WrappedStreamSocket { // True if read callbacks are blocked. bool should_block_read_; + // The buffer for the pending read, or NULL if not consumed. + scoped_refptr<IOBuffer> pending_read_buf_; + // The user callback for the pending read call. CompletionCallback pending_read_callback_; @@ -404,6 +432,7 @@ FakeBlockingStreamSocket::FakeBlockingStreamSocket( int FakeBlockingStreamSocket::Read(IOBuffer* buf, int len, const CompletionCallback& callback) { + DCHECK(!pending_read_buf_); DCHECK(pending_read_callback_.is_null()); DCHECK_EQ(ERR_IO_PENDING, pending_read_result_); DCHECK(!callback.is_null()); @@ -412,9 +441,11 @@ int FakeBlockingStreamSocket::Read(IOBuffer* buf, &FakeBlockingStreamSocket::OnReadCompleted, base::Unretained(this))); if (rv == ERR_IO_PENDING) { // Save the callback to be called later. + pending_read_buf_ = buf; pending_read_callback_ = callback; } else if (should_block_read_) { // Save the callback and read result to be called later. + pending_read_buf_ = buf; pending_read_callback_ = callback; OnReadCompleted(rv); rv = ERR_IO_PENDING; @@ -461,6 +492,7 @@ void FakeBlockingStreamSocket::UnblockReadResult() { if (pending_read_result_ == ERR_IO_PENDING) return; int result = pending_read_result_; + pending_read_buf_ = nullptr; pending_read_result_ = ERR_IO_PENDING; base::ResetAndReturn(&pending_read_callback_).Run(result); } @@ -528,7 +560,8 @@ void FakeBlockingStreamSocket::OnReadCompleted(int result) { read_loop_->Quit(); } else { // Either the Read() was never blocked or UnblockReadResult() was called - // before the Read() completed. Either way, run the callback. + // before the Read() completed. Either way, return the result to the caller. + pending_read_buf_ = nullptr; base::ResetAndReturn(&pending_read_callback_).Run(result); } } @@ -667,15 +700,12 @@ class SSLClientSocketTest : public PlatformTest { SSLClientSocketTest() : socket_factory_(ClientSocketFactory::GetDefaultFactory()), cert_verifier_(new MockCertVerifier), - transport_security_state_(new TransportSecurityState), - ran_handshake_completion_callback_(false) { + transport_security_state_(new TransportSecurityState) { cert_verifier_->set_default_result(OK); context_.cert_verifier = cert_verifier_.get(); context_.transport_security_state = transport_security_state_.get(); } - void RecordCompletedHandshake() { ran_handshake_completion_callback_ = true; } - protected: // The address of the spawned test server, after calling StartTestServer(). const AddressList& addr() const { return addr_; } @@ -755,8 +785,7 @@ class SSLClientSocketTest : public PlatformTest { scoped_ptr<TransportSecurityState> transport_security_state_; SSLClientSocketContext context_; scoped_ptr<SSLClientSocket> sock_; - CapturingNetLog log_; - bool ran_handshake_completion_callback_; + TestNetLog log_; private: scoped_ptr<StreamSocket> transport_; @@ -782,7 +811,7 @@ class SSLClientSocketCertRequestInfoTest : public SSLClientSocketTest { return NULL; TestCompletionCallback callback; - CapturingNetLog log; + TestNetLog log; scoped_ptr<StreamSocket> transport( new TCPClientSocket(addr, &log, NetLog::Source())); int rv = transport->Connect(callback.callback()); @@ -791,7 +820,7 @@ class SSLClientSocketCertRequestInfoTest : public SSLClientSocketTest { EXPECT_EQ(OK, rv); scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); + transport.Pass(), test_server.host_port_pair(), SSLConfig())); EXPECT_FALSE(sock->IsConnected()); rv = sock->Connect(callback.callback()); @@ -809,11 +838,6 @@ class SSLClientSocketCertRequestInfoTest : public SSLClientSocketTest { }; class SSLClientSocketFalseStartTest : public SSLClientSocketTest { - public: - SSLClientSocketFalseStartTest() - : monitor_handshake_callback_(false), - fail_handshake_after_false_start_(false) {} - protected: // Creates an SSLClientSocket with |client_config| attached to a // FakeBlockingStreamSocket, returning both in |*out_raw_transport| and @@ -835,11 +859,8 @@ class SSLClientSocketFalseStartTest : public SSLClientSocketTest { scoped_ptr<SSLClientSocket>* out_sock) { CHECK(test_server()); - scoped_ptr<StreamSocket> real_transport(scoped_ptr<StreamSocket>( - new TCPClientSocket(addr(), NULL, NetLog::Source()))); - real_transport.reset( - new SynchronousErrorStreamSocket(real_transport.Pass())); - + scoped_ptr<StreamSocket> real_transport( + new TCPClientSocket(addr(), NULL, NetLog::Source())); scoped_ptr<FakeBlockingStreamSocket> transport( new FakeBlockingStreamSocket(real_transport.Pass())); int rv = callback->GetResult(transport->Connect(callback->callback())); @@ -849,12 +870,6 @@ class SSLClientSocketFalseStartTest : public SSLClientSocketTest { scoped_ptr<SSLClientSocket> sock = CreateSSLClientSocket( transport.Pass(), test_server()->host_port_pair(), client_config); - if (monitor_handshake_callback_) { - sock->SetHandshakeCompletionCallback( - base::Bind(&SSLClientSocketTest::RecordCompletedHandshake, - base::Unretained(this))); - } - // Connect. Stop before the client processes the first server leg // (ServerHello, etc.) raw_transport->BlockReadResult(); @@ -870,12 +885,6 @@ class SSLClientSocketFalseStartTest : public SSLClientSocketTest { raw_transport->UnblockReadResult(); raw_transport->WaitForWrite(); - if (fail_handshake_after_false_start_) { - SynchronousErrorStreamSocket* error_socket = - static_cast<SynchronousErrorStreamSocket*>( - raw_transport->transport()); - error_socket->SetNextReadError(ERR_CONNECTION_RESET); - } // And, finally, release that and block the next server leg // (ChangeCipherSpec, Finished). raw_transport->BlockReadResult(); @@ -893,7 +902,6 @@ class SSLClientSocketFalseStartTest : public SSLClientSocketTest { TestCompletionCallback callback; FakeBlockingStreamSocket* raw_transport = NULL; scoped_ptr<SSLClientSocket> sock; - ASSERT_NO_FATAL_FAILURE(CreateAndConnectUntilServerFinishedReceived( client_config, &callback, &raw_transport, &sock)); @@ -928,10 +936,7 @@ class SSLClientSocketFalseStartTest : public SSLClientSocketTest { // After releasing reads, the connection proceeds. raw_transport->UnblockReadResult(); rv = callback.GetResult(rv); - if (fail_handshake_after_false_start_) - EXPECT_EQ(ERR_CONNECTION_RESET, rv); - else - EXPECT_LT(0, rv); + EXPECT_LT(0, rv); } else { // False Start is not enabled, so the handshake will not complete because // the server second leg is blocked. @@ -939,34 +944,25 @@ class SSLClientSocketFalseStartTest : public SSLClientSocketTest { EXPECT_FALSE(callback.have_result()); } } - - // Indicates that the socket's handshake completion callback should - // be monitored. - bool monitor_handshake_callback_; - // Indicates that this test's handshake should fail after the client - // "finished" message is sent. - bool fail_handshake_after_false_start_; }; class SSLClientSocketChannelIDTest : public SSLClientSocketTest { protected: void EnableChannelID() { - channel_id_service_.reset( - new ChannelIDService(new DefaultChannelIDStore(NULL), - base::MessageLoopProxy::current())); + channel_id_service_.reset(new ChannelIDService( + new DefaultChannelIDStore(NULL), base::ThreadTaskRunnerHandle::Get())); context_.channel_id_service = channel_id_service_.get(); } void EnableFailingChannelID() { channel_id_service_.reset(new ChannelIDService( - new FailingChannelIDStore(), base::MessageLoopProxy::current())); + new FailingChannelIDStore(), base::ThreadTaskRunnerHandle::Get())); context_.channel_id_service = channel_id_service_.get(); } void EnableAsyncFailingChannelID() { channel_id_service_.reset(new ChannelIDService( - new AsyncFailingChannelIDStore(), - base::MessageLoopProxy::current())); + new AsyncFailingChannelIDStore(), base::ThreadTaskRunnerHandle::Get())); context_.channel_id_service = channel_id_service_.get(); } @@ -983,14 +979,23 @@ class SSLClientSocketChannelIDTest : public SSLClientSocketTest { // they'll give up waiting for application data and send the Finished after a // timeout. This means that an SSL connect end event may appear as a socket // write. -static bool LogContainsSSLConnectEndEvent( - const CapturingNetLog::CapturedEntryList& log, - int i) { +static bool LogContainsSSLConnectEndEvent(const TestNetLogEntry::List& log, + int i) { return LogContainsEndEvent(log, i, NetLog::TYPE_SSL_CONNECT) || LogContainsEvent( log, i, NetLog::TYPE_SOCKET_BYTES_SENT, NetLog::PHASE_NONE); } +bool SupportsAESGCM() { +#if defined(USE_OPENSSL) + return true; +#else + crypto::EnsureNSSInit(); + return PK11_TokenExists(CKM_AES_GCM) && + PK11_TokenExists(CKM_NSS_TLS_MASTER_KEY_DERIVE_DH_SHA256); +#endif +} + } // namespace TEST_F(SSLClientSocketTest, Connect) { @@ -1003,7 +1008,7 @@ TEST_F(SSLClientSocketTest, Connect) { ASSERT_TRUE(test_server.GetAddressList(&addr)); TestCompletionCallback callback; - CapturingNetLog log; + TestNetLog log; scoped_ptr<StreamSocket> transport( new TCPClientSocket(addr, &log, NetLog::Source())); int rv = transport->Connect(callback.callback()); @@ -1012,13 +1017,13 @@ TEST_F(SSLClientSocketTest, Connect) { EXPECT_EQ(OK, rv); scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); + transport.Pass(), test_server.host_port_pair(), SSLConfig())); EXPECT_FALSE(sock->IsConnected()); rv = sock->Connect(callback.callback()); - CapturingNetLog::CapturedEntryList entries; + TestNetLogEntry::List entries; log.GetEntries(&entries); EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT)); if (rv == ERR_IO_PENDING) @@ -1045,7 +1050,7 @@ TEST_F(SSLClientSocketTest, ConnectExpired) { ASSERT_TRUE(test_server.GetAddressList(&addr)); TestCompletionCallback callback; - CapturingNetLog log; + TestNetLog log; scoped_ptr<StreamSocket> transport( new TCPClientSocket(addr, &log, NetLog::Source())); int rv = transport->Connect(callback.callback()); @@ -1054,13 +1059,13 @@ TEST_F(SSLClientSocketTest, ConnectExpired) { EXPECT_EQ(OK, rv); scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); + transport.Pass(), test_server.host_port_pair(), SSLConfig())); EXPECT_FALSE(sock->IsConnected()); rv = sock->Connect(callback.callback()); - CapturingNetLog::CapturedEntryList entries; + TestNetLogEntry::List entries; log.GetEntries(&entries); EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT)); if (rv == ERR_IO_PENDING) @@ -1089,7 +1094,7 @@ TEST_F(SSLClientSocketTest, ConnectMismatched) { ASSERT_TRUE(test_server.GetAddressList(&addr)); TestCompletionCallback callback; - CapturingNetLog log; + TestNetLog log; scoped_ptr<StreamSocket> transport( new TCPClientSocket(addr, &log, NetLog::Source())); int rv = transport->Connect(callback.callback()); @@ -1098,13 +1103,13 @@ TEST_F(SSLClientSocketTest, ConnectMismatched) { EXPECT_EQ(OK, rv); scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); + transport.Pass(), test_server.host_port_pair(), SSLConfig())); EXPECT_FALSE(sock->IsConnected()); rv = sock->Connect(callback.callback()); - CapturingNetLog::CapturedEntryList entries; + TestNetLogEntry::List entries; log.GetEntries(&entries); EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT)); if (rv == ERR_IO_PENDING) @@ -1133,7 +1138,7 @@ TEST_F(SSLClientSocketTest, ConnectClientAuthCertRequested) { ASSERT_TRUE(test_server.GetAddressList(&addr)); TestCompletionCallback callback; - CapturingNetLog log; + TestNetLog log; scoped_ptr<StreamSocket> transport( new TCPClientSocket(addr, &log, NetLog::Source())); int rv = transport->Connect(callback.callback()); @@ -1142,13 +1147,13 @@ TEST_F(SSLClientSocketTest, ConnectClientAuthCertRequested) { EXPECT_EQ(OK, rv); scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); + transport.Pass(), test_server.host_port_pair(), SSLConfig())); EXPECT_FALSE(sock->IsConnected()); rv = sock->Connect(callback.callback()); - CapturingNetLog::CapturedEntryList entries; + TestNetLogEntry::List entries; log.GetEntries(&entries); EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT)); if (rv == ERR_IO_PENDING) @@ -1192,7 +1197,7 @@ TEST_F(SSLClientSocketTest, ConnectClientAuthSendNullCert) { ASSERT_TRUE(test_server.GetAddressList(&addr)); TestCompletionCallback callback; - CapturingNetLog log; + TestNetLog log; scoped_ptr<StreamSocket> transport( new TCPClientSocket(addr, &log, NetLog::Source())); int rv = transport->Connect(callback.callback()); @@ -1200,7 +1205,7 @@ TEST_F(SSLClientSocketTest, ConnectClientAuthSendNullCert) { rv = callback.WaitForResult(); EXPECT_EQ(OK, rv); - SSLConfig ssl_config = kDefaultSSLConfig; + SSLConfig ssl_config; ssl_config.send_client_cert = true; ssl_config.client_cert = NULL; @@ -1213,7 +1218,7 @@ TEST_F(SSLClientSocketTest, ConnectClientAuthSendNullCert) { // TODO(davidben): Add a test which requires them and verify the error. rv = sock->Connect(callback.callback()); - CapturingNetLog::CapturedEntryList entries; + TestNetLogEntry::List entries; log.GetEntries(&entries); EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT)); if (rv == ERR_IO_PENDING) @@ -1240,6 +1245,8 @@ TEST_F(SSLClientSocketTest, ConnectClientAuthSendNullCert) { // - Server closes the underlying TCP connection directly. // - Server sends data unexpectedly. +// Tests that the socket can be read from successfully. Also test that a peer's +// close_notify alert is successfully processed without error. TEST_F(SSLClientSocketTest, Read) { SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, SpawnedTestServer::kLocalhost, @@ -1258,7 +1265,7 @@ TEST_F(SSLClientSocketTest, Read) { EXPECT_EQ(OK, rv); scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); + transport.Pass(), test_server.host_port_pair(), SSLConfig())); rv = sock->Connect(callback.callback()); if (rv == ERR_IO_PENDING) @@ -1291,6 +1298,9 @@ TEST_F(SSLClientSocketTest, Read) { if (rv <= 0) break; } + + // The peer should have cleanly closed the connection with a close_notify. + EXPECT_EQ(0, rv); } // Tests that SSLClientSocket properly handles when the underlying transport @@ -1543,7 +1553,7 @@ TEST_F(SSLClientSocketTest, Read_FullDuplex) { EXPECT_EQ(OK, rv); scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); + transport.Pass(), test_server.host_port_pair(), SSLConfig())); rv = sock->Connect(callback.callback()); if (rv == ERR_IO_PENDING) @@ -1821,10 +1831,8 @@ TEST_F(SSLClientSocketTest, Connect_WithZeroReturn) { EXPECT_EQ(OK, rv); SynchronousErrorStreamSocket* raw_transport = transport.get(); - scoped_ptr<SSLClientSocket> sock( - CreateSSLClientSocket(transport.Pass(), - test_server.host_port_pair(), - kDefaultSSLConfig)); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), SSLConfig())); raw_transport->SetNextReadError(0); @@ -1833,8 +1841,8 @@ TEST_F(SSLClientSocketTest, Connect_WithZeroReturn) { EXPECT_FALSE(sock->IsConnected()); } -// Tests that SSLClientSocket cleanly returns a Read of size 0 if the -// underlying socket is cleanly closed. +// Tests that SSLClientSocket returns a Read of size 0 if the underlying socket +// is cleanly closed, but the peer does not send close_notify. // This is a regression test for https://crbug.com/422246 TEST_F(SSLClientSocketTest, Read_WithZeroReturn) { SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, @@ -1921,6 +1929,28 @@ TEST_F(SSLClientSocketTest, Read_WithAsyncZeroReturn) { EXPECT_EQ(0, rv); } +// Tests that fatal alerts from the peer are processed. This is a regression +// test for https://crbug.com/466303. +TEST_F(SSLClientSocketTest, Read_WithFatalAlert) { + SpawnedTestServer::SSLOptions ssl_options; + ssl_options.alert_after_handshake = true; + ASSERT_TRUE(StartTestServer(ssl_options)); + + SSLConfig ssl_config; + TestCompletionCallback callback; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr(), &log_, NetLog::Source())); + EXPECT_EQ(OK, callback.GetResult(transport->Connect(callback.callback()))); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server()->host_port_pair(), ssl_config)); + EXPECT_EQ(OK, callback.GetResult(sock->Connect(callback.callback()))); + + // Receive the fatal alert. + scoped_refptr<IOBuffer> buf(new IOBuffer(4096)); + EXPECT_EQ(ERR_SSL_PROTOCOL_ERROR, callback.GetResult(sock->Read( + buf.get(), 4096, callback.callback()))); +} + TEST_F(SSLClientSocketTest, Read_SmallChunks) { SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, SpawnedTestServer::kLocalhost, @@ -1939,7 +1969,7 @@ TEST_F(SSLClientSocketTest, Read_SmallChunks) { EXPECT_EQ(OK, rv); scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); + transport.Pass(), test_server.host_port_pair(), SSLConfig())); rv = sock->Connect(callback.callback()); if (rv == ERR_IO_PENDING) @@ -1993,7 +2023,7 @@ TEST_F(SSLClientSocketTest, Read_ManySmallRecords) { ASSERT_EQ(OK, rv); scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); + transport.Pass(), test_server.host_port_pair(), SSLConfig())); rv = callback.GetResult(sock->Connect(callback.callback())); ASSERT_EQ(OK, rv); @@ -2043,7 +2073,7 @@ TEST_F(SSLClientSocketTest, Read_Interrupted) { EXPECT_EQ(OK, rv); scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); + transport.Pass(), test_server.host_port_pair(), SSLConfig())); rv = sock->Connect(callback.callback()); if (rv == ERR_IO_PENDING) @@ -2084,8 +2114,8 @@ TEST_F(SSLClientSocketTest, Read_FullLogging) { ASSERT_TRUE(test_server.GetAddressList(&addr)); TestCompletionCallback callback; - CapturingNetLog log; - log.SetLogLevel(NetLog::LOG_ALL); + TestNetLog log; + log.SetCaptureMode(NetLogCaptureMode::IncludeSocketBytes()); scoped_ptr<StreamSocket> transport( new TCPClientSocket(addr, &log, NetLog::Source())); int rv = transport->Connect(callback.callback()); @@ -2094,7 +2124,7 @@ TEST_F(SSLClientSocketTest, Read_FullLogging) { EXPECT_EQ(OK, rv); scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); + transport.Pass(), test_server.host_port_pair(), SSLConfig())); rv = sock->Connect(callback.callback()); if (rv == ERR_IO_PENDING) @@ -2115,7 +2145,7 @@ TEST_F(SSLClientSocketTest, Read_FullLogging) { rv = callback.WaitForResult(); EXPECT_EQ(static_cast<int>(arraysize(request_text) - 1), rv); - CapturingNetLog::CapturedEntryList entries; + TestNetLogEntry::List entries; log.GetEntries(&entries); size_t last_index = ExpectLogContainsSomewhereAfter( entries, 5, NetLog::TYPE_SSL_SOCKET_BYTES_SENT, NetLog::PHASE_NONE); @@ -2180,7 +2210,7 @@ TEST_F(SSLClientSocketTest, PrematureApplicationData) { EXPECT_EQ(OK, rv); scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); + transport.Pass(), test_server.host_port_pair(), SSLConfig())); rv = sock->Connect(callback.callback()); if (rv == ERR_IO_PENDING) @@ -2189,16 +2219,18 @@ TEST_F(SSLClientSocketTest, PrematureApplicationData) { } TEST_F(SSLClientSocketTest, CipherSuiteDisables) { - // Rather than exhaustively disabling every RC4 ciphersuite defined at - // http://www.iana.org/assignments/tls-parameters/tls-parameters.xml, - // only disabling those cipher suites that the test server actually - // implements. - const uint16 kCiphersToDisable[] = {0x0005, // TLS_RSA_WITH_RC4_128_SHA + // Rather than exhaustively disabling every AES_128_CBC ciphersuite defined at + // http://www.iana.org/assignments/tls-parameters/tls-parameters.xml, only + // disabling those cipher suites that the test server actually implements. + const uint16 kCiphersToDisable[] = { + 0x002f, // TLS_RSA_WITH_AES_128_CBC_SHA + 0x0033, // TLS_DHE_RSA_WITH_AES_128_CBC_SHA + 0xc013, // TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA }; SpawnedTestServer::SSLOptions ssl_options; - // Enable only RC4 on the test server. - ssl_options.bulk_ciphers = SpawnedTestServer::SSLOptions::BULK_CIPHER_RC4; + // Enable only AES_128_CBC on the test server. + ssl_options.bulk_ciphers = SpawnedTestServer::SSLOptions::BULK_CIPHER_AES128; SpawnedTestServer test_server( SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath()); ASSERT_TRUE(test_server.Start()); @@ -2207,7 +2239,7 @@ TEST_F(SSLClientSocketTest, CipherSuiteDisables) { ASSERT_TRUE(test_server.GetAddressList(&addr)); TestCompletionCallback callback; - CapturingNetLog log; + TestNetLog log; scoped_ptr<StreamSocket> transport( new TCPClientSocket(addr, &log, NetLog::Source())); int rv = transport->Connect(callback.callback()); @@ -2225,23 +2257,15 @@ TEST_F(SSLClientSocketTest, CipherSuiteDisables) { EXPECT_FALSE(sock->IsConnected()); rv = sock->Connect(callback.callback()); - CapturingNetLog::CapturedEntryList entries; + TestNetLogEntry::List entries; log.GetEntries(&entries); EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT)); - // NSS has special handling that maps a handshake_failure alert received - // immediately after a client_hello to be a mismatched cipher suite error, - // leading to ERR_SSL_VERSION_OR_CIPHER_MISMATCH. When using OpenSSL or - // Secure Transport (OS X), the handshake_failure is bubbled up without any - // interpretation, leading to ERR_SSL_PROTOCOL_ERROR. Either way, a failure - // indicates that no cipher suite was negotiated with the test server. if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_TRUE(rv == ERR_SSL_VERSION_OR_CIPHER_MISMATCH || - rv == ERR_SSL_PROTOCOL_ERROR); - // The exact ordering differs between SSLClientSocketNSS (which issues an - // extra read) and SSLClientSocketMac (which does not). Just make sure the - // error appears somewhere in the log. + EXPECT_EQ(ERR_SSL_VERSION_OR_CIPHER_MISMATCH, rv); + // The exact ordering depends no whether an extra read is issued. Just check + // the error is somewhere in the log. log.GetEntries(&entries); ExpectLogContainsSomewhere( entries, 0, NetLog::TYPE_SSL_HANDSHAKE_ERROR, NetLog::PHASE_NONE); @@ -2285,11 +2309,9 @@ TEST_F(SSLClientSocketTest, ClientSocketHandleNotFromPool) { scoped_ptr<ClientSocketHandle> socket_handle(new ClientSocketHandle()); socket_handle->SetSocket(transport.Pass()); - scoped_ptr<SSLClientSocket> sock( - socket_factory_->CreateSSLClientSocket(socket_handle.Pass(), - test_server.host_port_pair(), - kDefaultSSLConfig, - context_)); + scoped_ptr<SSLClientSocket> sock(socket_factory_->CreateSSLClientSocket( + socket_handle.Pass(), test_server.host_port_pair(), SSLConfig(), + context_)); EXPECT_FALSE(sock->IsConnected()); rv = sock->Connect(callback.callback()); @@ -2319,7 +2341,7 @@ TEST_F(SSLClientSocketTest, ExportKeyingMaterial) { EXPECT_EQ(OK, rv); scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); + transport.Pass(), test_server.host_port_pair(), SSLConfig())); rv = sock->Connect(callback.callback()); if (rv == ERR_IO_PENDING) @@ -2328,19 +2350,33 @@ TEST_F(SSLClientSocketTest, ExportKeyingMaterial) { EXPECT_TRUE(sock->IsConnected()); const int kKeyingMaterialSize = 32; - const char* kKeyingLabel1 = "client-socket-test-1"; - const char* kKeyingContext = ""; + const char kKeyingLabel1[] = "client-socket-test-1"; + const char kKeyingContext1[] = ""; unsigned char client_out1[kKeyingMaterialSize]; memset(client_out1, 0, sizeof(client_out1)); - rv = sock->ExportKeyingMaterial( - kKeyingLabel1, false, kKeyingContext, client_out1, sizeof(client_out1)); + rv = sock->ExportKeyingMaterial(kKeyingLabel1, false, kKeyingContext1, + client_out1, sizeof(client_out1)); EXPECT_EQ(rv, OK); - const char* kKeyingLabel2 = "client-socket-test-2"; + const char kKeyingLabel2[] = "client-socket-test-2"; unsigned char client_out2[kKeyingMaterialSize]; memset(client_out2, 0, sizeof(client_out2)); - rv = sock->ExportKeyingMaterial( - kKeyingLabel2, false, kKeyingContext, client_out2, sizeof(client_out2)); + rv = sock->ExportKeyingMaterial(kKeyingLabel2, false, kKeyingContext1, + client_out2, sizeof(client_out2)); + EXPECT_EQ(rv, OK); + EXPECT_NE(memcmp(client_out1, client_out2, kKeyingMaterialSize), 0); + + const char kKeyingContext2[] = "context"; + rv = sock->ExportKeyingMaterial(kKeyingLabel1, true, kKeyingContext2, + client_out2, sizeof(client_out2)); + EXPECT_EQ(rv, OK); + EXPECT_NE(memcmp(client_out1, client_out2, kKeyingMaterialSize), 0); + + // Using an empty context should give different key material from not using a + // context at all. + memset(client_out2, 0, sizeof(client_out2)); + rv = sock->ExportKeyingMaterial(kKeyingLabel1, true, kKeyingContext1, + client_out2, sizeof(client_out2)); EXPECT_EQ(rv, OK); EXPECT_NE(memcmp(client_out1, client_out2, kKeyingMaterialSize), 0); } @@ -2351,6 +2387,33 @@ TEST(SSLClientSocket, ClearSessionCache) { SSLClientSocket::ClearSessionCache(); } +TEST(SSLClientSocket, SerializeNextProtos) { + NextProtoVector next_protos; + next_protos.push_back(kProtoHTTP11); + next_protos.push_back(kProtoSPDY31); + static std::vector<uint8_t> serialized = + SSLClientSocket::SerializeNextProtos(next_protos, true); + ASSERT_EQ(18u, serialized.size()); + EXPECT_EQ(8, serialized[0]); // length("http/1.1") + EXPECT_EQ('h', serialized[1]); + EXPECT_EQ('t', serialized[2]); + EXPECT_EQ('t', serialized[3]); + EXPECT_EQ('p', serialized[4]); + EXPECT_EQ('/', serialized[5]); + EXPECT_EQ('1', serialized[6]); + EXPECT_EQ('.', serialized[7]); + EXPECT_EQ('1', serialized[8]); + EXPECT_EQ(8, serialized[9]); // length("spdy/3.1") + EXPECT_EQ('s', serialized[10]); + EXPECT_EQ('p', serialized[11]); + EXPECT_EQ('d', serialized[12]); + EXPECT_EQ('y', serialized[13]); + EXPECT_EQ('/', serialized[14]); + EXPECT_EQ('3', serialized[15]); + EXPECT_EQ('.', serialized[16]); + EXPECT_EQ('1', serialized[17]); +} + // Test that the server certificates are properly retrieved from the underlying // SSL stack. TEST_F(SSLClientSocketTest, VerifyServerChainProperlyOrdered) { @@ -2377,7 +2440,7 @@ TEST_F(SSLClientSocketTest, VerifyServerChainProperlyOrdered) { EXPECT_EQ(OK, rv); scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); + transport.Pass(), test_server.host_port_pair(), SSLConfig())); EXPECT_FALSE(sock->IsConnected()); rv = sock->Connect(callback.callback()); rv = callback.GetResult(rv); @@ -2475,7 +2538,7 @@ TEST_F(SSLClientSocketTest, VerifyReturnChainProperlyOrdered) { ASSERT_TRUE(test_server.GetAddressList(&addr)); TestCompletionCallback callback; - CapturingNetLog log; + TestNetLog log; scoped_ptr<StreamSocket> transport( new TCPClientSocket(addr, &log, NetLog::Source())); int rv = transport->Connect(callback.callback()); @@ -2484,11 +2547,11 @@ TEST_F(SSLClientSocketTest, VerifyReturnChainProperlyOrdered) { EXPECT_EQ(OK, rv); scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); + transport.Pass(), test_server.host_port_pair(), SSLConfig())); EXPECT_FALSE(sock->IsConnected()); rv = sock->Connect(callback.callback()); - CapturingNetLog::CapturedEntryList entries; + TestNetLogEntry::List entries; log.GetEntries(&entries); EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT)); if (rv == ERR_IO_PENDING) @@ -2732,7 +2795,7 @@ TEST_F(SSLClientSocketTest, ReuseStates) { EXPECT_EQ(OK, rv); scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); + transport.Pass(), test_server.host_port_pair(), SSLConfig())); rv = sock->Connect(callback.callback()); if (rv == ERR_IO_PENDING) @@ -2769,9 +2832,10 @@ TEST_F(SSLClientSocketTest, ReuseStates) { // attempt to read one byte extra. } -#if defined(USE_OPENSSL) - -TEST_F(SSLClientSocketTest, HandshakeCallbackIsRun_WithFailure) { +// Tests that IsConnectedAndIdle treats a socket as idle even if a Write hasn't +// been flushed completely out of SSLClientSocket's internal buffers. This is a +// regression test for https://crbug.com/466147. +TEST_F(SSLClientSocketTest, ReusableAfterWrite) { SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, SpawnedTestServer::kLocalhost, base::FilePath()); @@ -2783,178 +2847,388 @@ TEST_F(SSLClientSocketTest, HandshakeCallbackIsRun_WithFailure) { TestCompletionCallback callback; scoped_ptr<StreamSocket> real_transport( new TCPClientSocket(addr, NULL, NetLog::Source())); - scoped_ptr<SynchronousErrorStreamSocket> transport( - new SynchronousErrorStreamSocket(real_transport.Pass())); - int rv = callback.GetResult(transport->Connect(callback.callback())); - EXPECT_EQ(OK, rv); + scoped_ptr<FakeBlockingStreamSocket> transport( + new FakeBlockingStreamSocket(real_transport.Pass())); + FakeBlockingStreamSocket* raw_transport = transport.get(); + ASSERT_EQ(OK, callback.GetResult(transport->Connect(callback.callback()))); - // Disable TLS False Start to avoid handshake non-determinism. - SSLConfig ssl_config; - ssl_config.false_start_enabled = false; + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), SSLConfig())); + ASSERT_EQ(OK, callback.GetResult(sock->Connect(callback.callback()))); - SynchronousErrorStreamSocket* raw_transport = transport.get(); + // Block any application data from reaching the network. + raw_transport->BlockWrite(); + + // Write a partial HTTP request. + const char kRequestText[] = "GET / HTTP/1.0"; + const size_t kRequestLen = arraysize(kRequestText) - 1; + scoped_refptr<IOBuffer> request_buffer(new IOBuffer(kRequestLen)); + memcpy(request_buffer->data(), kRequestText, kRequestLen); + + // Although transport writes are blocked, both SSLClientSocketOpenSSL and + // SSLClientSocketNSS complete the outer Write operation. + EXPECT_EQ(static_cast<int>(kRequestLen), + callback.GetResult(sock->Write(request_buffer.get(), kRequestLen, + callback.callback()))); + + // The Write operation is complete, so the socket should be treated as + // reusable, in case the server returns an HTTP response before completely + // consuming the request body. In this case, we assume the server will + // properly drain the request body before trying to read the next request. + EXPECT_TRUE(sock->IsConnectedAndIdle()); +} + +// Tests that basic session resumption works. +TEST_F(SSLClientSocketTest, SessionResumption) { + SpawnedTestServer::SSLOptions ssl_options; + ASSERT_TRUE(StartTestServer(ssl_options)); + + // First, perform a full handshake. + SSLConfig ssl_config; + TestCompletionCallback callback; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr(), &log_, NetLog::Source())); + ASSERT_EQ(OK, callback.GetResult(transport->Connect(callback.callback()))); scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport.Pass(), test_server.host_port_pair(), ssl_config)); + transport.Pass(), test_server()->host_port_pair(), ssl_config)); + ASSERT_EQ(OK, callback.GetResult(sock->Connect(callback.callback()))); + SSLInfo ssl_info; + ASSERT_TRUE(sock->GetSSLInfo(&ssl_info)); + EXPECT_EQ(SSLInfo::HANDSHAKE_FULL, ssl_info.handshake_type); - sock->SetHandshakeCompletionCallback(base::Bind( - &SSLClientSocketTest::RecordCompletedHandshake, base::Unretained(this))); + // The next connection should resume. + transport.reset(new TCPClientSocket(addr(), &log_, NetLog::Source())); + ASSERT_EQ(OK, callback.GetResult(transport->Connect(callback.callback()))); + sock = CreateSSLClientSocket(transport.Pass(), + test_server()->host_port_pair(), ssl_config); + ASSERT_EQ(OK, callback.GetResult(sock->Connect(callback.callback()))); + ASSERT_TRUE(sock->GetSSLInfo(&ssl_info)); + EXPECT_EQ(SSLInfo::HANDSHAKE_RESUME, ssl_info.handshake_type); - raw_transport->SetNextWriteError(ERR_CONNECTION_RESET); + // Using a different HostPortPair uses a different session cache key. + transport.reset(new TCPClientSocket(addr(), &log_, NetLog::Source())); + ASSERT_EQ(OK, callback.GetResult(transport->Connect(callback.callback()))); + sock = CreateSSLClientSocket(transport.Pass(), + HostPortPair("example.com", 443), ssl_config); + ASSERT_EQ(OK, callback.GetResult(sock->Connect(callback.callback()))); + ASSERT_TRUE(sock->GetSSLInfo(&ssl_info)); + EXPECT_EQ(SSLInfo::HANDSHAKE_FULL, ssl_info.handshake_type); - rv = callback.GetResult(sock->Connect(callback.callback())); - EXPECT_EQ(ERR_CONNECTION_RESET, rv); - EXPECT_FALSE(sock->IsConnected()); + SSLClientSocket::ClearSessionCache(); - EXPECT_TRUE(ran_handshake_completion_callback_); + // After clearing the session cache, the next handshake doesn't resume. + transport.reset(new TCPClientSocket(addr(), &log_, NetLog::Source())); + ASSERT_EQ(OK, callback.GetResult(transport->Connect(callback.callback()))); + sock = CreateSSLClientSocket(transport.Pass(), + test_server()->host_port_pair(), ssl_config); + ASSERT_EQ(OK, callback.GetResult(sock->Connect(callback.callback()))); + ASSERT_TRUE(sock->GetSSLInfo(&ssl_info)); + EXPECT_EQ(SSLInfo::HANDSHAKE_FULL, ssl_info.handshake_type); } -// Tests that the completion callback is run when an SSL connection -// completes successfully. -TEST_F(SSLClientSocketTest, HandshakeCallbackIsRun_WithSuccess) { - SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, - SpawnedTestServer::kLocalhost, - base::FilePath()); - ASSERT_TRUE(test_server.Start()); +// Tests that connections with certificate errors do not add entries to the +// session cache. +TEST_F(SSLClientSocketTest, CertificateErrorNoResume) { + SpawnedTestServer::SSLOptions ssl_options; + ASSERT_TRUE(StartTestServer(ssl_options)); - AddressList addr; - ASSERT_TRUE(test_server.GetAddressList(&addr)); + cert_verifier_->set_default_result(ERR_CERT_COMMON_NAME_INVALID); + SSLConfig ssl_config; + TestCompletionCallback callback; scoped_ptr<StreamSocket> transport( - new TCPClientSocket(addr, NULL, NetLog::Source())); + new TCPClientSocket(addr(), &log_, NetLog::Source())); + ASSERT_EQ(OK, callback.GetResult(transport->Connect(callback.callback()))); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server()->host_port_pair(), ssl_config)); + EXPECT_EQ(ERR_CERT_COMMON_NAME_INVALID, + callback.GetResult(sock->Connect(callback.callback()))); + + cert_verifier_->set_default_result(OK); + + // The next connection should perform a full handshake. + transport.reset(new TCPClientSocket(addr(), &log_, NetLog::Source())); + ASSERT_EQ(OK, callback.GetResult(transport->Connect(callback.callback()))); + sock = CreateSSLClientSocket(transport.Pass(), + test_server()->host_port_pair(), ssl_config); + ASSERT_EQ(OK, callback.GetResult(sock->Connect(callback.callback()))); + SSLInfo ssl_info; + ASSERT_TRUE(sock->GetSSLInfo(&ssl_info)); + EXPECT_EQ(SSLInfo::HANDSHAKE_FULL, ssl_info.handshake_type); +} - TestCompletionCallback callback; - int rv = transport->Connect(callback.callback()); - if (rv == ERR_IO_PENDING) - rv = callback.WaitForResult(); - EXPECT_EQ(OK, rv); +// Tests that session caches are sharded by max_version. +TEST_F(SSLClientSocketTest, FallbackShardSessionCache) { + SpawnedTestServer::SSLOptions ssl_options; + ASSERT_TRUE(StartTestServer(ssl_options)); - SSLConfig ssl_config = kDefaultSSLConfig; - ssl_config.false_start_enabled = false; + // Prepare a normal and fallback SSL config. + SSLConfig ssl_config; + SSLConfig fallback_ssl_config; + fallback_ssl_config.version_max = SSL_PROTOCOL_VERSION_TLS1; + fallback_ssl_config.version_fallback = true; + // Connect with a fallback config from the test server to add an entry to the + // session cache. + TestCompletionCallback callback; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr(), &log_, NetLog::Source())); + EXPECT_EQ(OK, callback.GetResult(transport->Connect(callback.callback()))); scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport.Pass(), test_server.host_port_pair(), ssl_config)); + transport.Pass(), test_server()->host_port_pair(), fallback_ssl_config)); + EXPECT_EQ(OK, callback.GetResult(sock->Connect(callback.callback()))); + SSLInfo ssl_info; + EXPECT_TRUE(sock->GetSSLInfo(&ssl_info)); + EXPECT_EQ(SSLInfo::HANDSHAKE_FULL, ssl_info.handshake_type); + EXPECT_EQ(SSL_CONNECTION_VERSION_TLS1, + SSLConnectionStatusToVersion(ssl_info.connection_status)); + + // A non-fallback connection needs a full handshake. + transport.reset(new TCPClientSocket(addr(), &log_, NetLog::Source())); + EXPECT_EQ(OK, callback.GetResult(transport->Connect(callback.callback()))); + sock = CreateSSLClientSocket(transport.Pass(), + test_server()->host_port_pair(), ssl_config); + EXPECT_EQ(OK, callback.GetResult(sock->Connect(callback.callback()))); + EXPECT_TRUE(sock->GetSSLInfo(&ssl_info)); + EXPECT_EQ(SSLInfo::HANDSHAKE_FULL, ssl_info.handshake_type); + // This does not check for equality because TLS 1.2 support is conditional on + // system NSS features. + EXPECT_LT(SSL_CONNECTION_VERSION_TLS1, + SSLConnectionStatusToVersion(ssl_info.connection_status)); + + // Note: if the server (correctly) declines to resume a TLS 1.0 session at TLS + // 1.2, the above test would not be sufficient to prove the session caches are + // sharded. Implementations vary here, so, to avoid being sensitive to this, + // attempt to resume with two more connections. + + // The non-fallback connection added a > TLS 1.0 entry to the session cache. + transport.reset(new TCPClientSocket(addr(), &log_, NetLog::Source())); + EXPECT_EQ(OK, callback.GetResult(transport->Connect(callback.callback()))); + sock = CreateSSLClientSocket(transport.Pass(), + test_server()->host_port_pair(), ssl_config); + EXPECT_EQ(OK, callback.GetResult(sock->Connect(callback.callback()))); + EXPECT_TRUE(sock->GetSSLInfo(&ssl_info)); + EXPECT_EQ(SSLInfo::HANDSHAKE_RESUME, ssl_info.handshake_type); + // This does not check for equality because TLS 1.2 support is conditional on + // system NSS features. + EXPECT_LT(SSL_CONNECTION_VERSION_TLS1, + SSLConnectionStatusToVersion(ssl_info.connection_status)); + + // The fallback connection still resumes from its session cache. It cannot + // offer the > TLS 1.0 session, so this must have been the session from the + // first fallback connection. + transport.reset(new TCPClientSocket(addr(), &log_, NetLog::Source())); + EXPECT_EQ(OK, callback.GetResult(transport->Connect(callback.callback()))); + sock = CreateSSLClientSocket( + transport.Pass(), test_server()->host_port_pair(), fallback_ssl_config); + EXPECT_EQ(OK, callback.GetResult(sock->Connect(callback.callback()))); + EXPECT_TRUE(sock->GetSSLInfo(&ssl_info)); + EXPECT_EQ(SSLInfo::HANDSHAKE_RESUME, ssl_info.handshake_type); + EXPECT_EQ(SSL_CONNECTION_VERSION_TLS1, + SSLConnectionStatusToVersion(ssl_info.connection_status)); +} - sock->SetHandshakeCompletionCallback(base::Bind( - &SSLClientSocketTest::RecordCompletedHandshake, base::Unretained(this))); +// Test that RC4 is only enabled if enable_deprecated_cipher_suites is set. +TEST_F(SSLClientSocketTest, DeprecatedRC4) { + SpawnedTestServer::SSLOptions ssl_options; + ssl_options.bulk_ciphers = SpawnedTestServer::SSLOptions::BULK_CIPHER_RC4; + ASSERT_TRUE(StartTestServer(ssl_options)); - rv = callback.GetResult(sock->Connect(callback.callback())); + // Normal handshakes with RC4 do not work. + SSLConfig ssl_config; + TestCompletionCallback callback; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr(), &log_, NetLog::Source())); + ASSERT_EQ(OK, callback.GetResult(transport->Connect(callback.callback()))); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server()->host_port_pair(), ssl_config)); + ASSERT_EQ(ERR_SSL_VERSION_OR_CIPHER_MISMATCH, + callback.GetResult(sock->Connect(callback.callback()))); - EXPECT_EQ(OK, rv); - EXPECT_TRUE(sock->IsConnected()); - EXPECT_TRUE(ran_handshake_completion_callback_); + // Enabling deprecated ciphers works fine. + ssl_config.enable_deprecated_cipher_suites = true; + transport.reset(new TCPClientSocket(addr(), &log_, NetLog::Source())); + ASSERT_EQ(OK, callback.GetResult(transport->Connect(callback.callback()))); + sock = CreateSSLClientSocket(transport.Pass(), + test_server()->host_port_pair(), ssl_config); + ASSERT_EQ(OK, callback.GetResult(sock->Connect(callback.callback()))); } -// Tests that the completion callback is run with a server that doesn't cache -// sessions. -TEST_F(SSLClientSocketTest, HandshakeCallbackIsRun_WithDisabledSessionCache) { +// Tests that enabling deprecated ciphers shards the session cache. +TEST_F(SSLClientSocketTest, DeprecatedShardSessionCache) { SpawnedTestServer::SSLOptions ssl_options; - ssl_options.disable_session_cache = true; - SpawnedTestServer test_server( - SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath()); - ASSERT_TRUE(test_server.Start()); + ASSERT_TRUE(StartTestServer(ssl_options)); - AddressList addr; - ASSERT_TRUE(test_server.GetAddressList(&addr)); + // Prepare a normal and deprecated SSL config. + SSLConfig ssl_config; + SSLConfig deprecated_ssl_config; + deprecated_ssl_config.enable_deprecated_cipher_suites = true; + // Connect with deprecated ciphers enabled to warm the session cache cache. + TestCompletionCallback callback; scoped_ptr<StreamSocket> transport( - new TCPClientSocket(addr, NULL, NetLog::Source())); + new TCPClientSocket(addr(), &log_, NetLog::Source())); + EXPECT_EQ(OK, callback.GetResult(transport->Connect(callback.callback()))); + scoped_ptr<SSLClientSocket> sock( + CreateSSLClientSocket(transport.Pass(), test_server()->host_port_pair(), + deprecated_ssl_config)); + EXPECT_EQ(OK, callback.GetResult(sock->Connect(callback.callback()))); + SSLInfo ssl_info; + EXPECT_TRUE(sock->GetSSLInfo(&ssl_info)); + EXPECT_EQ(SSLInfo::HANDSHAKE_FULL, ssl_info.handshake_type); - TestCompletionCallback callback; - int rv = transport->Connect(callback.callback()); - if (rv == ERR_IO_PENDING) - rv = callback.WaitForResult(); - EXPECT_EQ(OK, rv); + // Test that re-connecting with deprecated ciphers enabled still resumes. + transport.reset(new TCPClientSocket(addr(), &log_, NetLog::Source())); + EXPECT_EQ(OK, callback.GetResult(transport->Connect(callback.callback()))); + sock = CreateSSLClientSocket( + transport.Pass(), test_server()->host_port_pair(), deprecated_ssl_config); + EXPECT_EQ(OK, callback.GetResult(sock->Connect(callback.callback()))); + EXPECT_TRUE(sock->GetSSLInfo(&ssl_info)); + EXPECT_EQ(SSLInfo::HANDSHAKE_RESUME, ssl_info.handshake_type); - SSLConfig ssl_config = kDefaultSSLConfig; - ssl_config.false_start_enabled = false; + // However, a normal connection needs a full handshake. + transport.reset(new TCPClientSocket(addr(), &log_, NetLog::Source())); + EXPECT_EQ(OK, callback.GetResult(transport->Connect(callback.callback()))); + sock = CreateSSLClientSocket(transport.Pass(), + test_server()->host_port_pair(), ssl_config); + EXPECT_EQ(OK, callback.GetResult(sock->Connect(callback.callback()))); + EXPECT_TRUE(sock->GetSSLInfo(&ssl_info)); + EXPECT_EQ(SSLInfo::HANDSHAKE_FULL, ssl_info.handshake_type); - scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( - transport.Pass(), test_server.host_port_pair(), ssl_config)); + // Clear the session cache for the inverse test. + SSLClientSocket::ClearSessionCache(); - sock->SetHandshakeCompletionCallback(base::Bind( - &SSLClientSocketTest::RecordCompletedHandshake, base::Unretained(this))); + // Now make a normal connection to prime the session cache. + transport.reset(new TCPClientSocket(addr(), &log_, NetLog::Source())); + EXPECT_EQ(OK, callback.GetResult(transport->Connect(callback.callback()))); + sock = CreateSSLClientSocket(transport.Pass(), + test_server()->host_port_pair(), ssl_config); + EXPECT_EQ(OK, callback.GetResult(sock->Connect(callback.callback()))); + EXPECT_TRUE(sock->GetSSLInfo(&ssl_info)); + EXPECT_EQ(SSLInfo::HANDSHAKE_FULL, ssl_info.handshake_type); - rv = callback.GetResult(sock->Connect(callback.callback())); + // A normal connection should be able to resume. + transport.reset(new TCPClientSocket(addr(), &log_, NetLog::Source())); + EXPECT_EQ(OK, callback.GetResult(transport->Connect(callback.callback()))); + sock = CreateSSLClientSocket(transport.Pass(), + test_server()->host_port_pair(), ssl_config); + EXPECT_EQ(OK, callback.GetResult(sock->Connect(callback.callback()))); + EXPECT_TRUE(sock->GetSSLInfo(&ssl_info)); + EXPECT_EQ(SSLInfo::HANDSHAKE_RESUME, ssl_info.handshake_type); - EXPECT_EQ(OK, rv); - EXPECT_TRUE(sock->IsConnected()); - EXPECT_TRUE(ran_handshake_completion_callback_); + // However, enabling deprecated ciphers connects fresh. + transport.reset(new TCPClientSocket(addr(), &log_, NetLog::Source())); + EXPECT_EQ(OK, callback.GetResult(transport->Connect(callback.callback()))); + sock = CreateSSLClientSocket( + transport.Pass(), test_server()->host_port_pair(), deprecated_ssl_config); + EXPECT_EQ(OK, callback.GetResult(sock->Connect(callback.callback()))); + EXPECT_TRUE(sock->GetSSLInfo(&ssl_info)); + EXPECT_EQ(SSLInfo::HANDSHAKE_FULL, ssl_info.handshake_type); } -TEST_F(SSLClientSocketFalseStartTest, - HandshakeCallbackIsRun_WithFalseStartFailure) { - // False Start requires NPN and a forward-secret cipher suite. +TEST_F(SSLClientSocketFalseStartTest, FalseStartEnabled) { + if (!SupportsAESGCM()) { + LOG(WARNING) << "Skipping test because AES-GCM is not supported."; + return; + } + + // False Start requires NPN/ALPN, ECDHE, and an AEAD. SpawnedTestServer::SSLOptions server_options; server_options.key_exchanges = - SpawnedTestServer::SSLOptions::KEY_EXCHANGE_DHE_RSA; + SpawnedTestServer::SSLOptions::KEY_EXCHANGE_ECDHE_RSA; + server_options.bulk_ciphers = + SpawnedTestServer::SSLOptions::BULK_CIPHER_AES128GCM; server_options.enable_npn = true; SSLConfig client_config; - client_config.next_protos.push_back("http/1.1"); - monitor_handshake_callback_ = true; - fail_handshake_after_false_start_ = true; - ASSERT_NO_FATAL_FAILURE(TestFalseStart(server_options, client_config, true)); - ASSERT_TRUE(ran_handshake_completion_callback_); + client_config.next_protos.push_back(kProtoHTTP11); + ASSERT_NO_FATAL_FAILURE( + TestFalseStart(server_options, client_config, true)); } -TEST_F(SSLClientSocketFalseStartTest, - HandshakeCallbackIsRun_WithFalseStartSuccess) { - // False Start requires NPN and a forward-secret cipher suite. +// Test that False Start is disabled without NPN. +TEST_F(SSLClientSocketFalseStartTest, NoNPN) { + if (!SupportsAESGCM()) { + LOG(WARNING) << "Skipping test because AES-GCM is not supported."; + return; + } + SpawnedTestServer::SSLOptions server_options; server_options.key_exchanges = - SpawnedTestServer::SSLOptions::KEY_EXCHANGE_DHE_RSA; - server_options.enable_npn = true; + SpawnedTestServer::SSLOptions::KEY_EXCHANGE_ECDHE_RSA; + server_options.bulk_ciphers = + SpawnedTestServer::SSLOptions::BULK_CIPHER_AES128GCM; SSLConfig client_config; - client_config.next_protos.push_back("http/1.1"); - monitor_handshake_callback_ = true; - ASSERT_NO_FATAL_FAILURE(TestFalseStart(server_options, client_config, true)); - ASSERT_TRUE(ran_handshake_completion_callback_); + client_config.next_protos.clear(); + ASSERT_NO_FATAL_FAILURE( + TestFalseStart(server_options, client_config, false)); } -#endif // defined(USE_OPENSSL) -TEST_F(SSLClientSocketFalseStartTest, FalseStartEnabled) { - // False Start requires NPN and a forward-secret cipher suite. +// Test that False Start is disabled with plain RSA ciphers. +TEST_F(SSLClientSocketFalseStartTest, RSA) { + if (!SupportsAESGCM()) { + LOG(WARNING) << "Skipping test because AES-GCM is not supported."; + return; + } + SpawnedTestServer::SSLOptions server_options; server_options.key_exchanges = - SpawnedTestServer::SSLOptions::KEY_EXCHANGE_DHE_RSA; + SpawnedTestServer::SSLOptions::KEY_EXCHANGE_RSA; + server_options.bulk_ciphers = + SpawnedTestServer::SSLOptions::BULK_CIPHER_AES128GCM; server_options.enable_npn = true; SSLConfig client_config; - client_config.next_protos.push_back("http/1.1"); + client_config.next_protos.push_back(kProtoHTTP11); ASSERT_NO_FATAL_FAILURE( - TestFalseStart(server_options, client_config, true)); + TestFalseStart(server_options, client_config, false)); } -// Test that False Start is disabled without NPN. -TEST_F(SSLClientSocketFalseStartTest, NoNPN) { +// Test that False Start is disabled with DHE_RSA ciphers. +TEST_F(SSLClientSocketFalseStartTest, DHE_RSA) { + if (!SupportsAESGCM()) { + LOG(WARNING) << "Skipping test because AES-GCM is not supported."; + return; + } + SpawnedTestServer::SSLOptions server_options; server_options.key_exchanges = SpawnedTestServer::SSLOptions::KEY_EXCHANGE_DHE_RSA; + server_options.bulk_ciphers = + SpawnedTestServer::SSLOptions::BULK_CIPHER_AES128GCM; + server_options.enable_npn = true; SSLConfig client_config; - client_config.next_protos.clear(); - ASSERT_NO_FATAL_FAILURE( - TestFalseStart(server_options, client_config, false)); + client_config.next_protos.push_back(kProtoHTTP11); + ASSERT_NO_FATAL_FAILURE(TestFalseStart(server_options, client_config, false)); } -// Test that False Start is disabled without a forward-secret cipher suite. -TEST_F(SSLClientSocketFalseStartTest, NoForwardSecrecy) { +// Test that False Start is disabled without an AEAD. +TEST_F(SSLClientSocketFalseStartTest, NoAEAD) { SpawnedTestServer::SSLOptions server_options; server_options.key_exchanges = - SpawnedTestServer::SSLOptions::KEY_EXCHANGE_RSA; + SpawnedTestServer::SSLOptions::KEY_EXCHANGE_ECDHE_RSA; + server_options.bulk_ciphers = + SpawnedTestServer::SSLOptions::BULK_CIPHER_AES128; server_options.enable_npn = true; SSLConfig client_config; - client_config.next_protos.push_back("http/1.1"); - ASSERT_NO_FATAL_FAILURE( - TestFalseStart(server_options, client_config, false)); + client_config.next_protos.push_back(kProtoHTTP11); + ASSERT_NO_FATAL_FAILURE(TestFalseStart(server_options, client_config, false)); } // Test that sessions are resumable after receiving the server Finished message. TEST_F(SSLClientSocketFalseStartTest, SessionResumption) { + if (!SupportsAESGCM()) { + LOG(WARNING) << "Skipping test because AES-GCM is not supported."; + return; + } + // Start a server. SpawnedTestServer::SSLOptions server_options; server_options.key_exchanges = - SpawnedTestServer::SSLOptions::KEY_EXCHANGE_DHE_RSA; + SpawnedTestServer::SSLOptions::KEY_EXCHANGE_ECDHE_RSA; + server_options.bulk_ciphers = + SpawnedTestServer::SSLOptions::BULK_CIPHER_AES128GCM; server_options.enable_npn = true; SSLConfig client_config; - client_config.next_protos.push_back("http/1.1"); + client_config.next_protos.push_back(kProtoHTTP11); // Let a full handshake complete with False Start. ASSERT_NO_FATAL_FAILURE( @@ -2976,22 +3250,29 @@ TEST_F(SSLClientSocketFalseStartTest, SessionResumption) { EXPECT_EQ(SSLInfo::HANDSHAKE_RESUME, ssl_info.handshake_type); } -// Test that sessions are not resumable before receiving the server Finished -// message. -TEST_F(SSLClientSocketFalseStartTest, NoSessionResumptionBeforeFinish) { +// Test that False Started sessions are not resumable before receiving the +// server Finished message. +TEST_F(SSLClientSocketFalseStartTest, NoSessionResumptionBeforeFinished) { + if (!SupportsAESGCM()) { + LOG(WARNING) << "Skipping test because AES-GCM is not supported."; + return; + } + // Start a server. SpawnedTestServer::SSLOptions server_options; server_options.key_exchanges = - SpawnedTestServer::SSLOptions::KEY_EXCHANGE_DHE_RSA; + SpawnedTestServer::SSLOptions::KEY_EXCHANGE_ECDHE_RSA; + server_options.bulk_ciphers = + SpawnedTestServer::SSLOptions::BULK_CIPHER_AES128GCM; server_options.enable_npn = true; ASSERT_TRUE(StartTestServer(server_options)); SSLConfig client_config; - client_config.next_protos.push_back("http/1.1"); + client_config.next_protos.push_back(kProtoHTTP11); // Start a handshake up to the server Finished message. TestCompletionCallback callback; - FakeBlockingStreamSocket* raw_transport1; + FakeBlockingStreamSocket* raw_transport1 = NULL; scoped_ptr<SSLClientSocket> sock1; ASSERT_NO_FATAL_FAILURE(CreateAndConnectUntilServerFinishedReceived( client_config, &callback, &raw_transport1, &sock1)); @@ -2999,6 +3280,92 @@ TEST_F(SSLClientSocketFalseStartTest, NoSessionResumptionBeforeFinish) { // still completes. EXPECT_EQ(OK, callback.WaitForResult()); + // Continue to block the client (|sock1|) from processing the Finished + // message, but allow it to arrive on the socket. This ensures that, from the + // server's point of view, it has completed the handshake and added the + // session to its session cache. + // + // The actual read on |sock1| will not complete until the Finished message is + // processed; however, pump the underlying transport so that it is read from + // the socket. NOTE: This may flakily pass if the server's final flight + // doesn't come in one Read. + scoped_refptr<IOBuffer> buf(new IOBuffer(4096)); + int rv = sock1->Read(buf.get(), 4096, callback.callback()); + EXPECT_EQ(ERR_IO_PENDING, rv); + raw_transport1->WaitForReadResult(); + + // Drop the old socket. This is needed because the Python test server can't + // service two sockets in parallel. + sock1.reset(); + + // Start a second connection. + scoped_ptr<StreamSocket> transport2( + new TCPClientSocket(addr(), &log_, NetLog::Source())); + EXPECT_EQ(OK, callback.GetResult(transport2->Connect(callback.callback()))); + scoped_ptr<SSLClientSocket> sock2 = CreateSSLClientSocket( + transport2.Pass(), test_server()->host_port_pair(), client_config); + EXPECT_EQ(OK, callback.GetResult(sock2->Connect(callback.callback()))); + + // No session resumption because the first connection never received a server + // Finished message. + SSLInfo ssl_info; + EXPECT_TRUE(sock2->GetSSLInfo(&ssl_info)); + EXPECT_EQ(SSLInfo::HANDSHAKE_FULL, ssl_info.handshake_type); +} + +// Test that False Started sessions are not resumable if the server Finished +// message was bad. +TEST_F(SSLClientSocketFalseStartTest, NoSessionResumptionBadFinished) { + if (!SupportsAESGCM()) { + LOG(WARNING) << "Skipping test because AES-GCM is not supported."; + return; + } + + // Start a server. + SpawnedTestServer::SSLOptions server_options; + server_options.key_exchanges = + SpawnedTestServer::SSLOptions::KEY_EXCHANGE_ECDHE_RSA; + server_options.bulk_ciphers = + SpawnedTestServer::SSLOptions::BULK_CIPHER_AES128GCM; + server_options.enable_npn = true; + ASSERT_TRUE(StartTestServer(server_options)); + + SSLConfig client_config; + client_config.next_protos.push_back(kProtoHTTP11); + + // Start a handshake up to the server Finished message. + TestCompletionCallback callback; + FakeBlockingStreamSocket* raw_transport1 = NULL; + scoped_ptr<SSLClientSocket> sock1; + ASSERT_NO_FATAL_FAILURE(CreateAndConnectUntilServerFinishedReceived( + client_config, &callback, &raw_transport1, &sock1)); + // Although raw_transport1 has the server Finished blocked, the handshake + // still completes. + EXPECT_EQ(OK, callback.WaitForResult()); + + // Continue to block the client (|sock1|) from processing the Finished + // message, but allow it to arrive on the socket. This ensures that, from the + // server's point of view, it has completed the handshake and added the + // session to its session cache. + // + // The actual read on |sock1| will not complete until the Finished message is + // processed; however, pump the underlying transport so that it is read from + // the socket. + scoped_refptr<IOBuffer> buf(new IOBuffer(4096)); + int rv = sock1->Read(buf.get(), 4096, callback.callback()); + EXPECT_EQ(ERR_IO_PENDING, rv); + raw_transport1->WaitForReadResult(); + + // The server's second leg, or part of it, is now received but not yet sent to + // |sock1|. Before doing so, break the server's second leg. + int bytes_read = raw_transport1->pending_read_result(); + ASSERT_LT(0, bytes_read); + raw_transport1->pending_read_buf()->data()[bytes_read - 1]++; + + // Unblock the Finished message. |sock1->Read| should now fail. + raw_transport1->UnblockReadResult(); + EXPECT_EQ(ERR_SSL_PROTOCOL_ERROR, callback.GetResult(rv)); + // Drop the old socket. This is needed because the Python test server can't // service two sockets in parallel. sock1.reset(); @@ -3025,7 +3392,7 @@ TEST_F(SSLClientSocketChannelIDTest, SendChannelID) { ASSERT_TRUE(ConnectToTestServer(ssl_options)); EnableChannelID(); - SSLConfig ssl_config = kDefaultSSLConfig; + SSLConfig ssl_config; ssl_config.channel_id_enabled = true; int rv; @@ -3033,7 +3400,9 @@ TEST_F(SSLClientSocketChannelIDTest, SendChannelID) { EXPECT_EQ(OK, rv); EXPECT_TRUE(sock_->IsConnected()); - EXPECT_TRUE(sock_->WasChannelIDSent()); + SSLInfo ssl_info; + ASSERT_TRUE(sock_->GetSSLInfo(&ssl_info)); + EXPECT_TRUE(ssl_info.channel_id_sent); sock_->Disconnect(); EXPECT_FALSE(sock_->IsConnected()); @@ -3047,7 +3416,7 @@ TEST_F(SSLClientSocketChannelIDTest, FailingChannelID) { ASSERT_TRUE(ConnectToTestServer(ssl_options)); EnableFailingChannelID(); - SSLConfig ssl_config = kDefaultSSLConfig; + SSLConfig ssl_config; ssl_config.channel_id_enabled = true; int rv; @@ -3069,7 +3438,7 @@ TEST_F(SSLClientSocketChannelIDTest, FailingChannelIDAsync) { ASSERT_TRUE(ConnectToTestServer(ssl_options)); EnableAsyncFailingChannelID(); - SSLConfig ssl_config = kDefaultSSLConfig; + SSLConfig ssl_config; ssl_config.channel_id_enabled = true; int rv; diff --git a/chromium/net/socket/ssl_server_socket_nss.cc b/chromium/net/socket/ssl_server_socket_nss.cc index 7fa5835b430..6f505affdd7 100644 --- a/chromium/net/socket/ssl_server_socket_nss.cc +++ b/chromium/net/socket/ssl_server_socket_nss.cc @@ -32,11 +32,11 @@ #include "base/callback_helpers.h" #include "base/lazy_instance.h" #include "base/memory/ref_counted.h" -#include "crypto/rsa_private_key.h" #include "crypto/nss_util_internal.h" +#include "crypto/rsa_private_key.h" #include "net/base/io_buffer.h" #include "net/base/net_errors.h" -#include "net/base/net_log.h" +#include "net/log/net_log.h" #include "net/socket/nss_ssl_util.h" // SSL plaintext fragments are shorter than 16KB. Although the record layer @@ -307,6 +307,10 @@ bool SSLServerSocketNSS::GetSSLInfo(SSLInfo* ssl_info) { return false; } +void SSLServerSocketNSS::GetConnectionAttempts(ConnectionAttempts* out) const { + out->clear(); +} + int SSLServerSocketNSS::InitializeSSLOptions() { // Transport connected, now hook it up to nss nss_fd_ = memio_CreateIOLayer(kRecvBufferSize, kSendBufferSize); @@ -349,7 +353,7 @@ int SSLServerSocketNSS::InitializeSSLOptions() { return ERR_NO_SSL_VERSIONS_ENABLED; } - if (ssl_config_.require_forward_secrecy) { + if (ssl_config_.require_ecdhe) { const PRUint16* const ssl_ciphers = SSL_GetImplementedCiphers(); const PRUint16 num_ciphers = SSL_GetNumImplementedCiphers(); diff --git a/chromium/net/socket/ssl_server_socket_nss.h b/chromium/net/socket/ssl_server_socket_nss.h index d40b096577c..d1bcec69067 100644 --- a/chromium/net/socket/ssl_server_socket_nss.h +++ b/chromium/net/socket/ssl_server_socket_nss.h @@ -13,8 +13,8 @@ #include "base/memory/scoped_ptr.h" #include "net/base/completion_callback.h" #include "net/base/host_port_pair.h" -#include "net/base/net_log.h" #include "net/base/nss_memio.h" +#include "net/log/net_log.h" #include "net/socket/ssl_server_socket.h" #include "net/ssl/ssl_config_service.h" @@ -66,6 +66,9 @@ class SSLServerSocketNSS : public SSLServerSocket { bool WasNpnNegotiated() const override; NextProto GetNegotiatedProtocol() const override; bool GetSSLInfo(SSLInfo* ssl_info) override; + void GetConnectionAttempts(ConnectionAttempts* out) const override; + void ClearConnectionAttempts() override {} + void AddConnectionAttempts(const ConnectionAttempts& attempts) override {} private: enum State { diff --git a/chromium/net/socket/ssl_server_socket_openssl.cc b/chromium/net/socket/ssl_server_socket_openssl.cc index 29d1ffa0508..5a61eeb3063 100644 --- a/chromium/net/socket/ssl_server_socket_openssl.cc +++ b/chromium/net/socket/ssl_server_socket_openssl.cc @@ -9,11 +9,13 @@ #include "base/callback_helpers.h" #include "base/logging.h" +#include "base/strings/string_util.h" #include "crypto/openssl_util.h" #include "crypto/rsa_private_key.h" #include "crypto/scoped_openssl_types.h" #include "net/base/net_errors.h" #include "net/ssl/openssl_ssl_util.h" +#include "net/ssl/scoped_openssl_types.h" #define GotoState(s) next_handshake_state_ = s @@ -246,6 +248,11 @@ bool SSLServerSocketOpenSSL::GetSSLInfo(SSLInfo* ssl_info) { return false; } +void SSLServerSocketOpenSSL::GetConnectionAttempts( + ConnectionAttempts* out) const { + out->clear(); +} + void SSLServerSocketOpenSSL::OnSendComplete(int result) { if (next_handshake_state_ == STATE_HANDSHAKE) { // In handshake phase. @@ -606,9 +613,7 @@ int SSLServerSocketOpenSSL::Init() { crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE); - crypto::ScopedOpenSSL<SSL_CTX, SSL_CTX_free>::Type ssl_ctx( - // It support SSLv2, SSLv3, and TLSv1. - SSL_CTX_new(SSLv23_server_method())); + ScopedSSL_CTX ssl_ctx(SSL_CTX_new(SSLv23_server_method())); ssl_ = SSL_new(ssl_ctx.get()); if (!ssl_) return ERR_UNEXPECTED; @@ -638,8 +643,7 @@ int SSLServerSocketOpenSSL::Init() { const unsigned char* der_string_array = reinterpret_cast<const unsigned char*>(der_string.data()); - crypto::ScopedOpenSSL<X509, X509_free>::Type x509( - d2i_X509(NULL, &der_string_array, der_string.length())); + ScopedX509 x509(d2i_X509(NULL, &der_string_array, der_string.length())); if (!x509.get()) return ERR_UNEXPECTED; @@ -656,24 +660,14 @@ int SSLServerSocketOpenSSL::Init() { return ERR_UNEXPECTED; } + DCHECK_LT(SSL3_VERSION, ssl_config_.version_min); + DCHECK_LT(SSL3_VERSION, ssl_config_.version_max); + SSL_set_min_version(ssl_, ssl_config_.version_min); + SSL_set_max_version(ssl_, ssl_config_.version_max); + // OpenSSL defaults some options to on, others to off. To avoid ambiguity, // set everything we care about to an absolute value. SslSetClearMask options; - options.ConfigureFlag(SSL_OP_NO_SSLv2, true); - bool ssl3_enabled = (ssl_config_.version_min == SSL_PROTOCOL_VERSION_SSL3); - options.ConfigureFlag(SSL_OP_NO_SSLv3, !ssl3_enabled); - bool tls1_enabled = (ssl_config_.version_min <= SSL_PROTOCOL_VERSION_TLS1 && - ssl_config_.version_max >= SSL_PROTOCOL_VERSION_TLS1); - options.ConfigureFlag(SSL_OP_NO_TLSv1, !tls1_enabled); - bool tls1_1_enabled = - (ssl_config_.version_min <= SSL_PROTOCOL_VERSION_TLS1_1 && - ssl_config_.version_max >= SSL_PROTOCOL_VERSION_TLS1_1); - options.ConfigureFlag(SSL_OP_NO_TLSv1_1, !tls1_1_enabled); - bool tls1_2_enabled = - (ssl_config_.version_min <= SSL_PROTOCOL_VERSION_TLS1_2 && - ssl_config_.version_max >= SSL_PROTOCOL_VERSION_TLS1_2); - options.ConfigureFlag(SSL_OP_NO_TLSv1_2, !tls1_2_enabled); - options.ConfigureFlag(SSL_OP_NO_COMPRESSION, true); SSL_set_options(ssl_, options.set_mask); @@ -687,6 +681,48 @@ int SSLServerSocketOpenSSL::Init() { SSL_set_mode(ssl_, mode.set_mask); SSL_clear_mode(ssl_, mode.clear_mask); + // Removing ciphers by ID from OpenSSL is a bit involved as we must use the + // textual name with SSL_set_cipher_list because there is no public API to + // directly remove a cipher by ID. + STACK_OF(SSL_CIPHER)* ciphers = SSL_get_ciphers(ssl_); + DCHECK(ciphers); + // See SSLConfig::disabled_cipher_suites for description of the suites + // disabled by default. Note that !SHA256 and !SHA384 only remove HMAC-SHA256 + // and HMAC-SHA384 cipher suites, not GCM cipher suites with SHA256 or SHA384 + // as the handshake hash. + std::string command("DEFAULT:!SHA256:!SHA384:!AESGCM+AES256:!aPSK"); + // Walk through all the installed ciphers, seeing if any need to be + // appended to the cipher removal |command|. + for (size_t i = 0; i < sk_SSL_CIPHER_num(ciphers); ++i) { + const SSL_CIPHER* cipher = sk_SSL_CIPHER_value(ciphers, i); + const uint16_t id = static_cast<uint16_t>(SSL_CIPHER_get_id(cipher)); + + bool disable = false; + if (ssl_config_.require_ecdhe) { + base::StringPiece kx_name(SSL_CIPHER_get_kx_name(cipher)); + disable = kx_name != "ECDHE_RSA" && kx_name != "ECDHE_ECDSA"; + } + if (!disable) { + disable = std::find(ssl_config_.disabled_cipher_suites.begin(), + ssl_config_.disabled_cipher_suites.end(), + id) != ssl_config_.disabled_cipher_suites.end(); + } + if (disable) { + const char* name = SSL_CIPHER_get_name(cipher); + DVLOG(3) << "Found cipher to remove: '" << name << "', ID: " << id + << " strength: " << SSL_CIPHER_get_bits(cipher, NULL); + command.append(":!"); + command.append(name); + } + } + + int rv = SSL_set_cipher_list(ssl_, command.c_str()); + // If this fails (rv = 0) it means there are no ciphers enabled on this SSL. + // This will almost certainly result in the socket failing to complete the + // handshake at which point the appropriate error is bubbled up to the client. + LOG_IF(WARNING, rv != 1) << "SSL_set_cipher_list('" << command + << "') returned " << rv; + return OK; } diff --git a/chromium/net/socket/ssl_server_socket_openssl.h b/chromium/net/socket/ssl_server_socket_openssl.h index c58bd569352..34e8bb001fc 100644 --- a/chromium/net/socket/ssl_server_socket_openssl.h +++ b/chromium/net/socket/ssl_server_socket_openssl.h @@ -8,7 +8,7 @@ #include "base/memory/scoped_ptr.h" #include "net/base/completion_callback.h" #include "net/base/io_buffer.h" -#include "net/base/net_log.h" +#include "net/log/net_log.h" #include "net/socket/ssl_server_socket.h" #include "net/ssl/ssl_config_service.h" @@ -68,6 +68,9 @@ class SSLServerSocketOpenSSL : public SSLServerSocket { bool WasNpnNegotiated() const override; NextProto GetNegotiatedProtocol() const override; bool GetSSLInfo(SSLInfo* ssl_info) override; + void GetConnectionAttempts(ConnectionAttempts* out) const override; + void ClearConnectionAttempts() override {} + void AddConnectionAttempts(const ConnectionAttempts& attempts) override {} private: enum State { diff --git a/chromium/net/socket/ssl_server_socket_unittest.cc b/chromium/net/socket/ssl_server_socket_unittest.cc index 5fabdd2e90e..84663eb790f 100644 --- a/chromium/net/socket/ssl_server_socket_unittest.cc +++ b/chromium/net/socket/ssl_server_socket_unittest.cc @@ -31,22 +31,35 @@ #include "net/base/io_buffer.h" #include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" -#include "net/base/net_log.h" #include "net/base/test_data_directory.h" #include "net/cert/cert_status_flags.h" #include "net/cert/mock_cert_verifier.h" #include "net/cert/x509_certificate.h" #include "net/http/transport_security_state.h" +#include "net/log/net_log.h" #include "net/socket/client_socket_factory.h" #include "net/socket/socket_test_util.h" #include "net/socket/ssl_client_socket.h" #include "net/socket/stream_socket.h" +#include "net/ssl/ssl_cipher_suite_names.h" #include "net/ssl/ssl_config_service.h" +#include "net/ssl/ssl_connection_status_flags.h" #include "net/ssl/ssl_info.h" #include "net/test/cert_test_util.h" #include "testing/gtest/include/gtest/gtest.h" #include "testing/platform_test.h" +#if !defined(USE_OPENSSL) +#include <pk11pub.h> + +#if !defined(CKM_AES_GCM) +#define CKM_AES_GCM 0x00001087 +#endif +#if !defined(CKM_NSS_TLS_MASTER_KEY_DERIVE_DH_SHA256) +#define CKM_NSS_TLS_MASTER_KEY_DERIVE_DH_SHA256 (CKM_NSS + 24) +#endif +#endif + namespace net { namespace { @@ -226,6 +239,14 @@ class FakeSocket : public StreamSocket { bool GetSSLInfo(SSLInfo* ssl_info) override { return false; } + void GetConnectionAttempts(ConnectionAttempts* out) const override { + out->clear(); + } + + void ClearConnectionAttempts() override {} + + void AddConnectionAttempts(const ConnectionAttempts& attempts) override {} + private: BoundNetLog net_log_; FakeDataChannel* incoming_; @@ -313,30 +334,30 @@ class SSLServerSocketTest : public PlatformTest { scoped_ptr<crypto::RSAPrivateKey> private_key( crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector)); - SSLConfig ssl_config; - ssl_config.false_start_enabled = false; - ssl_config.channel_id_enabled = false; + client_ssl_config_.false_start_enabled = false; + client_ssl_config_.channel_id_enabled = false; // Certificate provided by the host doesn't need authority. SSLConfig::CertAndStatus cert_and_status; cert_and_status.cert_status = CERT_STATUS_AUTHORITY_INVALID; cert_and_status.der_cert = cert_der; - ssl_config.allowed_bad_certs.push_back(cert_and_status); + client_ssl_config_.allowed_bad_certs.push_back(cert_and_status); HostPortPair host_and_pair("unittest", 0); SSLClientSocketContext context; context.cert_verifier = cert_verifier_.get(); context.transport_security_state = transport_security_state_.get(); - client_socket_ = - socket_factory_->CreateSSLClientSocket( - client_connection.Pass(), host_and_pair, ssl_config, context); - server_socket_ = CreateSSLServerSocket( - server_socket.Pass(), - cert.get(), private_key.get(), SSLConfig()); + client_socket_ = socket_factory_->CreateSSLClientSocket( + client_connection.Pass(), host_and_pair, client_ssl_config_, context); + server_socket_ = + CreateSSLServerSocket(server_socket.Pass(), cert.get(), + private_key.get(), server_ssl_config_); } FakeDataChannel channel_1_; FakeDataChannel channel_2_; + SSLConfig client_ssl_config_; + SSLConfig server_ssl_config_; scoped_ptr<SSLClientSocket> client_socket_; scoped_ptr<SSLServerSocket> server_socket_; ClientSocketFactory* socket_factory_; @@ -375,8 +396,27 @@ TEST_F(SSLServerSocketTest, Handshake) { // Make sure the cert status is expected. SSLInfo ssl_info; - client_socket_->GetSSLInfo(&ssl_info); + ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info)); EXPECT_EQ(CERT_STATUS_AUTHORITY_INVALID, ssl_info.cert_status); + + // The default cipher suite should be ECDHE and, unless on NSS and the + // platform doesn't support it, an AEAD. + uint16_t cipher_suite = + SSLConnectionStatusToCipherSuite(ssl_info.connection_status); + const char* key_exchange; + const char* cipher; + const char* mac; + bool is_aead; + SSLCipherSuiteToStrings(&key_exchange, &cipher, &mac, &is_aead, cipher_suite); + EXPECT_STREQ("ECDHE_RSA", key_exchange); +#if defined(USE_OPENSSL) + bool supports_aead = true; +#else + bool supports_aead = + PK11_TokenExists(CKM_AES_GCM) && + PK11_TokenExists(CKM_NSS_TLS_MASTER_KEY_DERIVE_DH_SHA256); +#endif + EXPECT_TRUE(!supports_aead || is_aead); } TEST_F(SSLServerSocketTest, DataTransfer) { @@ -535,8 +575,8 @@ TEST_F(SSLServerSocketTest, ExportKeyingMaterial) { } const int kKeyingMaterialSize = 32; - const char* kKeyingLabel = "EXPERIMENTAL-server-socket-test"; - const char* kKeyingContext = ""; + const char kKeyingLabel[] = "EXPERIMENTAL-server-socket-test"; + const char kKeyingContext[] = ""; unsigned char server_out[kKeyingMaterialSize]; int rv = server_socket_->ExportKeyingMaterial(kKeyingLabel, false, kKeyingContext, @@ -550,7 +590,7 @@ TEST_F(SSLServerSocketTest, ExportKeyingMaterial) { ASSERT_EQ(OK, rv); EXPECT_EQ(0, memcmp(server_out, client_out, sizeof(server_out))); - const char* kKeyingLabelBad = "EXPERIMENTAL-server-socket-test-bad"; + const char kKeyingLabelBad[] = "EXPERIMENTAL-server-socket-test-bad"; unsigned char client_bad[kKeyingMaterialSize]; rv = client_socket_->ExportKeyingMaterial(kKeyingLabelBad, false, kKeyingContext, @@ -559,4 +599,40 @@ TEST_F(SSLServerSocketTest, ExportKeyingMaterial) { EXPECT_NE(0, memcmp(server_out, client_bad, sizeof(server_out))); } +// Verifies that SSLConfig::require_ecdhe flags works properly. +TEST_F(SSLServerSocketTest, RequireEcdheFlag) { + // Disable all ECDHE suites on the client side. + uint16_t kEcdheCiphers[] = { + 0xc007, // ECDHE_ECDSA_WITH_RC4_128_SHA + 0xc009, // ECDHE_ECDSA_WITH_AES_128_CBC_SHA + 0xc00a, // ECDHE_ECDSA_WITH_AES_256_CBC_SHA + 0xc011, // ECDHE_RSA_WITH_RC4_128_SHA + 0xc013, // ECDHE_RSA_WITH_AES_128_CBC_SHA + 0xc014, // ECDHE_RSA_WITH_AES_256_CBC_SHA + 0xc02b, // ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 + 0xc02f, // ECDHE_RSA_WITH_AES_128_GCM_SHA256 + 0xcc13, // ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 + 0xcc14, // ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 + }; + client_ssl_config_.disabled_cipher_suites.assign( + kEcdheCiphers, kEcdheCiphers + arraysize(kEcdheCiphers)); + + // Require ECDHE on the server. + server_ssl_config_.require_ecdhe = true; + + Initialize(); + + TestCompletionCallback connect_callback; + TestCompletionCallback handshake_callback; + + int client_ret = client_socket_->Connect(connect_callback.callback()); + int server_ret = server_socket_->Handshake(handshake_callback.callback()); + + client_ret = connect_callback.GetResult(client_ret); + server_ret = handshake_callback.GetResult(server_ret); + + ASSERT_EQ(ERR_SSL_VERSION_OR_CIPHER_MISMATCH, client_ret); + ASSERT_EQ(ERR_SSL_VERSION_OR_CIPHER_MISMATCH, server_ret); +} + } // namespace net diff --git a/chromium/net/socket/ssl_session_cache_openssl.cc b/chromium/net/socket/ssl_session_cache_openssl.cc deleted file mode 100644 index 92ae44b9ac5..00000000000 --- a/chromium/net/socket/ssl_session_cache_openssl.cc +++ /dev/null @@ -1,530 +0,0 @@ -// Copyright 2013 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "net/socket/ssl_session_cache_openssl.h" - -#include <list> -#include <map> - -#include <openssl/rand.h> -#include <openssl/ssl.h> - -#include "base/containers/hash_tables.h" -#include "base/lazy_instance.h" -#include "base/logging.h" -#include "base/synchronization/lock.h" - -namespace net { - -namespace { - -// A helper class to lazily create a new EX_DATA index to map SSL_CTX handles -// to their corresponding SSLSessionCacheOpenSSLImpl object. -class SSLContextExIndex { -public: - SSLContextExIndex() { - context_index_ = SSL_CTX_get_ex_new_index(0, NULL, NULL, NULL, NULL); - DCHECK_NE(-1, context_index_); - session_index_ = SSL_SESSION_get_ex_new_index(0, NULL, NULL, NULL, NULL); - DCHECK_NE(-1, session_index_); - } - - int context_index() const { return context_index_; } - int session_index() const { return session_index_; } - - private: - int context_index_; - int session_index_; -}; - -// static -base::LazyInstance<SSLContextExIndex>::Leaky s_ssl_context_ex_instance = - LAZY_INSTANCE_INITIALIZER; - -// Retrieve the global EX_DATA index, created lazily on first call, to -// be used with SSL_CTX_set_ex_data() and SSL_CTX_get_ex_data(). -static int GetSSLContextExIndex() { - return s_ssl_context_ex_instance.Get().context_index(); -} - -// Retrieve the global EX_DATA index, created lazily on first call, to -// be used with SSL_SESSION_set_ex_data() and SSL_SESSION_get_ex_data(). -static int GetSSLSessionExIndex() { - return s_ssl_context_ex_instance.Get().session_index(); -} - -// Helper struct used to store session IDs in a SessionIdIndex container -// (see definition below). To save memory each entry only holds a pointer -// to the session ID buffer, which must outlive the entry itself. On the -// other hand, a hash is included to minimize the number of hashing -// computations during cache operations. -struct SessionId { - SessionId(const unsigned char* a_id, unsigned a_id_len) - : id(a_id), id_len(a_id_len), hash(ComputeHash(a_id, a_id_len)) {} - - explicit SessionId(const SessionId& other) - : id(other.id), id_len(other.id_len), hash(other.hash) {} - - explicit SessionId(SSL_SESSION* session) - : id(session->session_id), - id_len(session->session_id_length), - hash(ComputeHash(session->session_id, session->session_id_length)) {} - - bool operator==(const SessionId& other) const { - return hash == other.hash && id_len == other.id_len && - !memcmp(id, other.id, id_len); - } - - const unsigned char* id; - unsigned id_len; - size_t hash; - - private: - // Session ID are random strings of bytes. This happens to compute the same - // value as std::hash<std::string> without the extra string copy. See - // base/containers/hash_tables.h. Other hashing computations are possible, - // this one is just simple enough to do the job. - size_t ComputeHash(const unsigned char* id, unsigned id_len) { - size_t result = 0; - for (unsigned n = 0; n < id_len; ++n) { - result = (result * 131) + id[n]; - } - return result; - } -}; - -} // namespace - -} // namespace net - -namespace BASE_HASH_NAMESPACE { - -template <> -struct hash<net::SessionId> { - std::size_t operator()(const net::SessionId& entry) const { - return entry.hash; - } -}; - -} // namespace BASE_HASH_NAMESPACE - -namespace net { - -// Implementation of the real SSLSessionCache. -// -// The implementation is inspired by base::MRUCache, except that the deletor -// also needs to remove the entry from other containers. In a nutshell, this -// uses several basic containers: -// -// |ordering_| is a doubly-linked list of SSL_SESSION handles, ordered in -// MRU order. -// -// |key_index_| is a hash table mapping unique cache keys (e.g. host/port -// values) to a single iterator of |ordering_|. It is used to efficiently -// find the cached session associated with a given key. -// -// |id_index_| is a hash table mapping SessionId values to iterators -// of |key_index_|. If is used to efficiently remove sessions from the cache, -// as well as check for the existence of a session ID value in the cache. -// -// SSL_SESSION objects are reference-counted, and owned by the cache. This -// means that their reference count is incremented when they are added, and -// decremented when they are removed. -// -// Assuming an average key size of 100 characters, each node requires the -// following memory usage on 32-bit Android, when linked against STLport: -// -// 12 (ordering_ node, including SSL_SESSION handle) -// 100 (key characters) -// + 24 (std::string header/minimum size) -// + 8 (key_index_ node, excluding the 2 lines above for the key). -// + 20 (id_index_ node) -// -------- -// 164 bytes/node -// -// Hence, 41 KiB for a full cache with a maximum of 1024 entries, excluding -// the size of SSL_SESSION objects and heap fragmentation. -// - -class SSLSessionCacheOpenSSLImpl { - public: - // Construct new instance. This registers various hooks into the SSL_CTX - // context |ctx|. OpenSSL will call back during SSL connection - // operations. |key_func| is used to map a SSL handle to a unique cache - // string, according to the client's preferences. - SSLSessionCacheOpenSSLImpl(SSL_CTX* ctx, - const SSLSessionCacheOpenSSL::Config& config) - : ctx_(ctx), config_(config), expiration_check_(0) { - DCHECK(ctx); - - // NO_INTERNAL_STORE disables OpenSSL's builtin cache, and - // NO_AUTO_CLEAR disables the call to SSL_CTX_flush_sessions - // every 256 connections (this number is hard-coded in the library - // and can't be changed). - SSL_CTX_set_session_cache_mode(ctx_, - SSL_SESS_CACHE_CLIENT | - SSL_SESS_CACHE_NO_INTERNAL_STORE | - SSL_SESS_CACHE_NO_AUTO_CLEAR); - - SSL_CTX_sess_set_new_cb(ctx_, NewSessionCallbackStatic); - SSL_CTX_sess_set_remove_cb(ctx_, RemoveSessionCallbackStatic); - SSL_CTX_set_generate_session_id(ctx_, GenerateSessionIdStatic); - SSL_CTX_set_timeout(ctx_, config_.timeout_seconds); - - SSL_CTX_set_ex_data(ctx_, GetSSLContextExIndex(), this); - } - - // Destroy this instance. Must happen before |ctx_| is destroyed. - ~SSLSessionCacheOpenSSLImpl() { - Flush(); - SSL_CTX_set_ex_data(ctx_, GetSSLContextExIndex(), NULL); - SSL_CTX_sess_set_new_cb(ctx_, NULL); - SSL_CTX_sess_set_remove_cb(ctx_, NULL); - SSL_CTX_set_generate_session_id(ctx_, NULL); - } - - // Return the number of items in this cache. - size_t size() const { return key_index_.size(); } - - // Retrieve the cache key from |ssl| and look for a corresponding - // cached session ID. If one is found, call SSL_set_session() to associate - // it with the |ssl| connection. - // - // Will also check for expired sessions every |expiration_check_count| - // calls. - // - // Return true if a cached session ID was found, false otherwise. - bool SetSSLSession(SSL* ssl) { - std::string cache_key = config_.key_func(ssl); - if (cache_key.empty()) - return false; - - return SetSSLSessionWithKey(ssl, cache_key); - } - - // Variant of SetSSLSession to be used when the client already has computed - // the cache key. Avoid a call to the configuration's |key_func| function. - bool SetSSLSessionWithKey(SSL* ssl, const std::string& cache_key) { - base::AutoLock locked(lock_); - - DCHECK_EQ(config_.key_func(ssl), cache_key); - - if (++expiration_check_ >= config_.expiration_check_count) { - expiration_check_ = 0; - FlushExpiredSessionsLocked(); - } - - KeyIndex::iterator it = key_index_.find(cache_key); - if (it == key_index_.end()) - return false; - - SSL_SESSION* session = *it->second; - DCHECK(session); - - DVLOG(2) << "Lookup session: " << session << " for " << cache_key; - - void* session_is_good = - SSL_SESSION_get_ex_data(session, GetSSLSessionExIndex()); - if (!session_is_good) - return false; // Session has not yet been marked good. Treat as a miss. - - // Move to front of MRU list. - ordering_.push_front(session); - ordering_.erase(it->second); - it->second = ordering_.begin(); - - return SSL_set_session(ssl, session) == 1; - } - - // Return true iff a cached session was associated with the given |cache_key|. - bool SSLSessionIsInCache(const std::string& cache_key) const { - base::AutoLock locked(lock_); - KeyIndex::const_iterator it = key_index_.find(cache_key); - if (it == key_index_.end()) - return false; - - SSL_SESSION* session = *it->second; - DCHECK(session); - - void* session_is_good = - SSL_SESSION_get_ex_data(session, GetSSLSessionExIndex()); - - return session_is_good != NULL; - } - - void MarkSSLSessionAsGood(SSL* ssl) { - SSL_SESSION* session = SSL_get_session(ssl); - CHECK(session); - - // Mark the session as good, allowing it to be used for future connections. - SSL_SESSION_set_ex_data( - session, GetSSLSessionExIndex(), reinterpret_cast<void*>(1)); - } - - // Flush all entries from the cache. - void Flush() { - base::AutoLock lock(lock_); - id_index_.clear(); - key_index_.clear(); - while (!ordering_.empty()) { - SSL_SESSION* session = ordering_.front(); - ordering_.pop_front(); - SSL_SESSION_free(session); - } - } - - private: - // Type for list of SSL_SESSION handles, ordered in MRU order. - typedef std::list<SSL_SESSION*> MRUSessionList; - // Type for a dictionary from unique cache keys to session list nodes. - typedef base::hash_map<std::string, MRUSessionList::iterator> KeyIndex; - // Type for a dictionary from SessionId values to key index nodes. - typedef base::hash_map<SessionId, KeyIndex::iterator> SessionIdIndex; - - // Return the key associated with a given session, or the empty string if - // none exist. This shall only be used for debugging. - std::string SessionKey(SSL_SESSION* session) { - if (!session) - return std::string("<null-session>"); - - if (session->session_id_length == 0) - return std::string("<empty-session-id>"); - - SessionIdIndex::iterator it = id_index_.find(SessionId(session)); - if (it == id_index_.end()) - return std::string("<unknown-session>"); - - return it->second->first; - } - - // Remove a given |session| from the cache. Lock must be held. - void RemoveSessionLocked(SSL_SESSION* session) { - lock_.AssertAcquired(); - DCHECK(session); - DCHECK_GT(session->session_id_length, 0U); - SessionId session_id(session); - SessionIdIndex::iterator id_it = id_index_.find(session_id); - if (id_it == id_index_.end()) { - LOG(ERROR) << "Trying to remove unknown session from cache: " << session; - return; - } - KeyIndex::iterator key_it = id_it->second; - DCHECK(key_it != key_index_.end()); - DCHECK_EQ(session, *key_it->second); - - id_index_.erase(session_id); - ordering_.erase(key_it->second); - key_index_.erase(key_it); - - SSL_SESSION_free(session); - - DCHECK_EQ(key_index_.size(), id_index_.size()); - } - - // Used internally to flush expired sessions. Lock must be held. - void FlushExpiredSessionsLocked() { - lock_.AssertAcquired(); - - // Unfortunately, OpenSSL initializes |session->time| with a time() - // timestamps, which makes mocking / unit testing difficult. - long timeout_secs = static_cast<long>(::time(NULL)); - MRUSessionList::iterator it = ordering_.begin(); - while (it != ordering_.end()) { - SSL_SESSION* session = *it++; - - // Important, use <= instead of < here to allow unit testing to - // work properly. That's because unit tests that check the expiration - // behaviour will use a session timeout of 0 seconds. - if (session->time + session->timeout <= timeout_secs) { - DVLOG(2) << "Expiring session " << session << " for " - << SessionKey(session); - RemoveSessionLocked(session); - } - } - } - - // Retrieve the cache associated with a given SSL context |ctx|. - static SSLSessionCacheOpenSSLImpl* GetCache(SSL_CTX* ctx) { - DCHECK(ctx); - void* result = SSL_CTX_get_ex_data(ctx, GetSSLContextExIndex()); - DCHECK(result); - return reinterpret_cast<SSLSessionCacheOpenSSLImpl*>(result); - } - - // Called by OpenSSL when a new |session| was created and added to a given - // |ssl| connection. Note that the session's reference count was already - // incremented before the function is entered. The function must return 1 - // to indicate that it took ownership of the session, i.e. that the caller - // should not decrement its reference count after completion. - static int NewSessionCallbackStatic(SSL* ssl, SSL_SESSION* session) { - SSLSessionCacheOpenSSLImpl* cache = GetCache(ssl->ctx); - cache->OnSessionAdded(ssl, session); - return 1; - } - - // Called by OpenSSL to indicate that a session must be removed from the - // cache. This happens when SSL_CTX is destroyed. - static void RemoveSessionCallbackStatic(SSL_CTX* ctx, SSL_SESSION* session) { - GetCache(ctx)->OnSessionRemoved(session); - } - - // Called by OpenSSL to generate a new session ID. This happens during a - // SSL connection operation, when the SSL object doesn't have a session yet. - // - // A session ID is a random string of bytes used to uniquely identify the - // session between a client and a server. - // - // |ssl| is a SSL connection handle. Ignored here. - // |id| is the target buffer where the ID must be generated. - // |*id_len| is, on input, the size of the desired ID. It will be 16 for - // SSLv2, and 32 for anything else. OpenSSL allows an implementation - // to change it on output, but this will not happen here. - // - // The function must ensure the generated ID is really unique, i.e. that - // another session in the cache doesn't already use the same value. It must - // return 1 to indicate success, or 0 for failure. - static int GenerateSessionIdStatic(const SSL* ssl, - unsigned char* id, - unsigned* id_len) { - if (!GetCache(ssl->ctx)->OnGenerateSessionId(id, *id_len)) - return 0; - - return 1; - } - - // Add |session| to the cache in association with |cache_key|. If a session - // already exists, it is replaced with the new one. This assumes that the - // caller already incremented the session's reference count. - void OnSessionAdded(SSL* ssl, SSL_SESSION* session) { - base::AutoLock locked(lock_); - DCHECK(ssl); - DCHECK_GT(session->session_id_length, 0U); - std::string cache_key = config_.key_func(ssl); - KeyIndex::iterator it = key_index_.find(cache_key); - if (it == key_index_.end()) { - DVLOG(2) << "Add session " << session << " for " << cache_key; - // This is a new session. Add it to the cache. - ordering_.push_front(session); - std::pair<KeyIndex::iterator, bool> ret = - key_index_.insert(std::make_pair(cache_key, ordering_.begin())); - DCHECK(ret.second); - it = ret.first; - DCHECK(it != key_index_.end()); - } else { - // An existing session exists for this key, so replace it if needed. - DVLOG(2) << "Replace session " << *it->second << " with " << session - << " for " << cache_key; - SSL_SESSION* old_session = *it->second; - if (old_session != session) { - id_index_.erase(SessionId(old_session)); - SSL_SESSION_free(old_session); - } - ordering_.erase(it->second); - ordering_.push_front(session); - it->second = ordering_.begin(); - } - - id_index_[SessionId(session)] = it; - - if (key_index_.size() > config_.max_entries) - ShrinkCacheLocked(); - - DCHECK_EQ(key_index_.size(), id_index_.size()); - DCHECK_LE(key_index_.size(), config_.max_entries); - } - - // Shrink the cache to ensure no more than config_.max_entries entries, - // starting with older entries first. Lock must be acquired. - void ShrinkCacheLocked() { - lock_.AssertAcquired(); - DCHECK_EQ(key_index_.size(), ordering_.size()); - DCHECK_EQ(key_index_.size(), id_index_.size()); - - while (key_index_.size() > config_.max_entries) { - MRUSessionList::reverse_iterator it = ordering_.rbegin(); - DCHECK(it != ordering_.rend()); - - SSL_SESSION* session = *it; - DCHECK(session); - DVLOG(2) << "Evicting session " << session << " for " - << SessionKey(session); - RemoveSessionLocked(session); - } - } - - // Remove |session| from the cache. - void OnSessionRemoved(SSL_SESSION* session) { - base::AutoLock locked(lock_); - DVLOG(2) << "Remove session " << session << " for " << SessionKey(session); - RemoveSessionLocked(session); - } - - // See GenerateSessionIdStatic for a description of what this function does. - bool OnGenerateSessionId(unsigned char* id, unsigned id_len) { - base::AutoLock locked(lock_); - // This mimics def_generate_session_id() in openssl/ssl/ssl_sess.cc, - // I.e. try to generate a pseudo-random bit string, and check that no - // other entry in the cache has the same value. - const size_t kMaxTries = 10; - for (size_t tries = 0; tries < kMaxTries; ++tries) { - if (RAND_pseudo_bytes(id, id_len) <= 0) { - DLOG(ERROR) << "Couldn't generate " << id_len - << " pseudo random bytes?"; - return false; - } - if (id_index_.find(SessionId(id, id_len)) == id_index_.end()) - return true; - } - DLOG(ERROR) << "Couldn't generate unique session ID of " << id_len - << "bytes after " << kMaxTries << " tries."; - return false; - } - - SSL_CTX* ctx_; - SSLSessionCacheOpenSSL::Config config_; - - // method to get the index which can later be used with SSL_CTX_get_ex_data() - // or SSL_CTX_set_ex_data(). - mutable base::Lock lock_; // Protects access to containers below. - - MRUSessionList ordering_; - KeyIndex key_index_; - SessionIdIndex id_index_; - - size_t expiration_check_; -}; - -SSLSessionCacheOpenSSL::~SSLSessionCacheOpenSSL() { delete impl_; } - -size_t SSLSessionCacheOpenSSL::size() const { return impl_->size(); } - -void SSLSessionCacheOpenSSL::Reset(SSL_CTX* ctx, const Config& config) { - if (impl_) - delete impl_; - - impl_ = new SSLSessionCacheOpenSSLImpl(ctx, config); -} - -bool SSLSessionCacheOpenSSL::SetSSLSession(SSL* ssl) { - return impl_->SetSSLSession(ssl); -} - -bool SSLSessionCacheOpenSSL::SetSSLSessionWithKey( - SSL* ssl, - const std::string& cache_key) { - return impl_->SetSSLSessionWithKey(ssl, cache_key); -} - -bool SSLSessionCacheOpenSSL::SSLSessionIsInCache( - const std::string& cache_key) const { - return impl_->SSLSessionIsInCache(cache_key); -} - -void SSLSessionCacheOpenSSL::MarkSSLSessionAsGood(SSL* ssl) { - return impl_->MarkSSLSessionAsGood(ssl); -} - -void SSLSessionCacheOpenSSL::Flush() { impl_->Flush(); } - -} // namespace net diff --git a/chromium/net/socket/ssl_session_cache_openssl.h b/chromium/net/socket/ssl_session_cache_openssl.h deleted file mode 100644 index abf5eab78cb..00000000000 --- a/chromium/net/socket/ssl_session_cache_openssl.h +++ /dev/null @@ -1,144 +0,0 @@ -// Copyright 2013 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef NET_SOCKET_SSL_SESSION_CACHE_OPENSSL_H -#define NET_SOCKET_SSL_SESSION_CACHE_OPENSSL_H - -#include <string> - -#include "base/basictypes.h" -#include "net/base/net_export.h" - -// Avoid including OpenSSL headers here. -typedef struct ssl_ctx_st SSL_CTX; -typedef struct ssl_st SSL; - -namespace net { - -class SSLSessionCacheOpenSSLImpl; - -// A class used to implement a custom cache of SSL_SESSION objects. -// Usage is as follows: -// -// - Client creates a new cache instance with appropriate configuration, -// associating it with a given SSL_CTX object. -// -// The configuration must include a pointer to a client-provided function -// that can retrieve a unique cache key from an existing SSL handle. -// -// - When creating a new SSL connection, call SetSSLSession() with the newly -// created SSL handle, and a cache key for the current host/port. If a -// session is already in the cache, it will be added to the connection -// through SSL_set_session(). -// -// - Otherwise, OpenSSL will create a new SSL_SESSION object during the -// connection, and will pass it to the cache's internal functions, -// transparently to the client. -// -// - Each session has a timeout in seconds, which are checked every N-th call -// to SetSSLSession(), where N is the current configuration's -// |check_expiration_count|. Expired sessions are removed automatically -// from the cache. -// -// - Clients can call Flush() to remove all sessions from the cache, this is -// useful when the system's certificate store has changed. -// -// This class is thread-safe. There shouldn't be any issue with multiple -// SSL connections being performed in parallel in multiple threads. -class NET_EXPORT SSLSessionCacheOpenSSL { - public: - // Type of a function that takes a SSL handle and returns a unique cache - // key string to identify it. - typedef std::string GetSessionKeyFunction(const SSL* ssl); - - // A small structure used to configure a cache on creation. - // |key_func| is a function used at runtime to retrieve the unique cache key - // from a given SSL connection handle. - // |max_entries| is the maximum number of entries in the cache. - // |expiration_check_count| is the number of calls to SetSSLSession() that - // will trigger a check for expired sessions. - // |timeout_seconds| is the timeout of new cached sessions in seconds. - struct Config { - GetSessionKeyFunction* key_func; - size_t max_entries; - size_t expiration_check_count; - int timeout_seconds; - }; - - SSLSessionCacheOpenSSL() : impl_(NULL) {} - - // Construct a new cache instance. - // |ctx| is a SSL_CTX context handle that will be associated with this cache. - // |key_func| is a function that will be used at runtime to retrieve the - // unique cache key from a SSL connection handle. - // |max_entries| is the maximum number of entries in the cache. - // |timeout_seconds| is the timeout of new cached sessions in seconds. - // |expiration_check_count| is the number of calls to SetSSLSession() that - // will trigger a check for expired sessions. - SSLSessionCacheOpenSSL(SSL_CTX* ctx, const Config& config) : impl_(NULL) { - Reset(ctx, config); - } - - // Destroy this instance. This must be called before the SSL_CTX handle - // is destroyed. - ~SSLSessionCacheOpenSSL(); - - // Reset the cache configuration. This flushes any existing entries. - void Reset(SSL_CTX* ctx, const Config& config); - - size_t size() const; - - // Lookup the unique cache key associated with |ssl| connection handle, - // and find a cached session for it in the cache. If one is found, associate - // it with the |ssl| connection through SSL_set_session(). Consider using - // SetSSLSessionWithKey() if you already have the key. - // - // Every |check_expiration_count| call to either SetSSLSession() or - // SetSSLSessionWithKey() triggers a check for, and removal of, expired - // sessions. - // - // Return true iff a cached session was associated with the |ssl| connection. - bool SetSSLSession(SSL* ssl); - - // A more efficient variant of SetSSLSession() that can be used if the caller - // already has the cache key for the session of interest. The caller must - // ensure that the value of |cache_key| matches the result of calling the - // configuration's |key_func| function with the |ssl| as parameter. - // - // Every |check_expiration_count| call to either SetSSLSession() or - // SetSSLSessionWithKey() triggers a check for, and removal of, expired - // sessions. - // - // Return true iff a cached session was associated with the |ssl| connection. - bool SetSSLSessionWithKey(SSL* ssl, const std::string& cache_key); - - // Return true iff a cached session was associated with the given |cache_key|. - bool SSLSessionIsInCache(const std::string& cache_key) const; - - // Indicates that the SSL session associated with |ssl| is "good" - that is, - // that all associated cryptographic parameters that were negotiated, - // including the peer's certificate, were successfully validated. Because - // OpenSSL does not provide an asynchronous certificate verification - // callback, it's necessary to manually manage the sessions to ensure that - // only validated sessions are resumed. - void MarkSSLSessionAsGood(SSL* ssl); - - // Flush removes all entries from the cache. This is typically called when - // the system's certificate store has changed. - void Flush(); - - // TODO(digit): Move to client code. - static const int kDefaultTimeoutSeconds = 60 * 60; - static const size_t kMaxEntries = 1024; - static const size_t kMaxExpirationChecks = 256; - - private: - DISALLOW_COPY_AND_ASSIGN(SSLSessionCacheOpenSSL); - - SSLSessionCacheOpenSSLImpl* impl_; -}; - -} // namespace net - -#endif // NET_SOCKET_SSL_SESSION_CACHE_OPENSSL_H diff --git a/chromium/net/socket/ssl_session_cache_openssl_unittest.cc b/chromium/net/socket/ssl_session_cache_openssl_unittest.cc deleted file mode 100644 index 78bac63ccb2..00000000000 --- a/chromium/net/socket/ssl_session_cache_openssl_unittest.cc +++ /dev/null @@ -1,380 +0,0 @@ -// Copyright 2013 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "net/socket/ssl_session_cache_openssl.h" - -#include <openssl/ssl.h> - -#include "base/lazy_instance.h" -#include "base/logging.h" -#include "base/strings/stringprintf.h" -#include "crypto/openssl_util.h" -#include "crypto/scoped_openssl_types.h" - -#include "testing/gtest/include/gtest/gtest.h" - -// This is an internal OpenSSL function that can be used to create a new -// session for an existing SSL object. This shall force a call to the -// 'generate_session_id' callback from the SSL's session context. -// |s| is the target SSL connection handle. -// |session| is non-0 to ask for the creation of a new session. If 0, -// this will set an empty session with no ID instead. -extern "C" OPENSSL_EXPORT int ssl_get_new_session(SSL* s, int session); - -// This is an internal OpenSSL function which is used internally to add -// a new session to the cache. It is normally triggered by a succesful -// connection. However, this unit test does not use the network at all. -extern "C" OPENSSL_EXPORT void ssl_update_cache(SSL* s, int mode); - -namespace net { - -namespace { - -typedef crypto::ScopedOpenSSL<SSL, SSL_free>::Type ScopedSSL; -typedef crypto::ScopedOpenSSL<SSL_CTX, SSL_CTX_free>::Type ScopedSSL_CTX; - -// Helper class used to associate arbitrary std::string keys with SSL objects. -class SSLKeyHelper { - public: - // Return the string associated with a given SSL handle |ssl|, or the - // empty string if none exists. - static std::string Get(const SSL* ssl) { - return GetInstance()->GetValue(ssl); - } - - // Associate a string with a given SSL handle |ssl|. - static void Set(SSL* ssl, const std::string& value) { - GetInstance()->SetValue(ssl, value); - } - - static SSLKeyHelper* GetInstance() { - static base::LazyInstance<SSLKeyHelper>::Leaky s_instance = - LAZY_INSTANCE_INITIALIZER; - return s_instance.Pointer(); - } - - SSLKeyHelper() { - ex_index_ = SSL_get_ex_new_index(0, NULL, NULL, KeyDup, KeyFree); - CHECK_NE(-1, ex_index_); - } - - std::string GetValue(const SSL* ssl) { - std::string* value = - reinterpret_cast<std::string*>(SSL_get_ex_data(ssl, ex_index_)); - if (!value) - return std::string(); - return *value; - } - - void SetValue(SSL* ssl, const std::string& value) { - int ret = SSL_set_ex_data(ssl, ex_index_, new std::string(value)); - CHECK_EQ(1, ret); - } - - // Called when an SSL object is copied through SSL_dup(). This needs to copy - // the value as well. - static int KeyDup(CRYPTO_EX_DATA* to, - const CRYPTO_EX_DATA* from, - void** from_fd, - int idx, - long argl, - void* argp) { - // |from_fd| is really the address of a temporary pointer. On input, it - // points to the value from the original SSL object. The function must - // update it to the address of a copy. - std::string** ptr = reinterpret_cast<std::string**>(from_fd); - std::string* old_string = *ptr; - std::string* new_string = new std::string(*old_string); - *ptr = new_string; - return 0; // Ignored by the implementation. - } - - // Called to destroy the value associated with an SSL object. - static void KeyFree(void* parent, - void* ptr, - CRYPTO_EX_DATA* ad, - int index, - long argl, - void* argp) { - std::string* value = reinterpret_cast<std::string*>(ptr); - delete value; - } - - int ex_index_; -}; - -} // namespace - -class SSLSessionCacheOpenSSLTest : public testing::Test { - public: - SSLSessionCacheOpenSSLTest() { - crypto::EnsureOpenSSLInit(); - ctx_.reset(SSL_CTX_new(SSLv23_client_method())); - cache_.Reset(ctx_.get(), kDefaultConfig); - } - - // Reset cache configuration. - void ResetConfig(const SSLSessionCacheOpenSSL::Config& config) { - cache_.Reset(ctx_.get(), config); - } - - // Helper function to create a new SSL connection object associated with - // a given unique |cache_key|. This does _not_ add the session to the cache. - // Caller must free the object with SSL_free(). - SSL* NewSSL(const std::string& cache_key) { - SSL* ssl = SSL_new(ctx_.get()); - if (!ssl) - return NULL; - - SSLKeyHelper::Set(ssl, cache_key); // associate cache key. - ResetSessionID(ssl); // create new unique session ID. - return ssl; - } - - // Reset the session ID of a given SSL object. This creates a new session - // with a new unique random ID. Does not add it to the cache. - static void ResetSessionID(SSL* ssl) { ssl_get_new_session(ssl, 1); } - - // Add a given SSL object and its session to the cache. - void AddToCache(SSL* ssl) { - ssl_update_cache(ssl, ctx_.get()->session_cache_mode); - } - - static const SSLSessionCacheOpenSSL::Config kDefaultConfig; - - protected: - ScopedSSL_CTX ctx_; - // |cache_| must be destroyed before |ctx_| and thus appears after it. - SSLSessionCacheOpenSSL cache_; -}; - -// static -const SSLSessionCacheOpenSSL::Config - SSLSessionCacheOpenSSLTest::kDefaultConfig = { - &SSLKeyHelper::Get, // key_func - 1024, // max_entries - 256, // expiration_check_count - 60 * 60, // timeout_seconds -}; - -TEST_F(SSLSessionCacheOpenSSLTest, EmptyCacheCreation) { - EXPECT_EQ(0U, cache_.size()); -} - -TEST_F(SSLSessionCacheOpenSSLTest, CacheOneSession) { - ScopedSSL ssl(NewSSL("hello")); - - EXPECT_EQ(0U, cache_.size()); - AddToCache(ssl.get()); - EXPECT_EQ(1U, cache_.size()); - ssl.reset(NULL); - EXPECT_EQ(1U, cache_.size()); -} - -TEST_F(SSLSessionCacheOpenSSLTest, CacheMultipleSessions) { - const size_t kNumItems = 100; - int local_id = 1; - - // Add kNumItems to the cache. - for (size_t n = 0; n < kNumItems; ++n) { - std::string local_id_string = base::StringPrintf("%d", local_id++); - ScopedSSL ssl(NewSSL(local_id_string)); - AddToCache(ssl.get()); - EXPECT_EQ(n + 1, cache_.size()); - } -} - -TEST_F(SSLSessionCacheOpenSSLTest, Flush) { - const size_t kNumItems = 100; - int local_id = 1; - - // Add kNumItems to the cache. - for (size_t n = 0; n < kNumItems; ++n) { - std::string local_id_string = base::StringPrintf("%d", local_id++); - ScopedSSL ssl(NewSSL(local_id_string)); - AddToCache(ssl.get()); - } - EXPECT_EQ(kNumItems, cache_.size()); - - cache_.Flush(); - EXPECT_EQ(0U, cache_.size()); -} - -TEST_F(SSLSessionCacheOpenSSLTest, SetSSLSession) { - const std::string key("hello"); - ScopedSSL ssl(NewSSL(key)); - - // First call should fail because the session is not in the cache. - EXPECT_FALSE(cache_.SetSSLSession(ssl.get())); - SSL_SESSION* session = ssl.get()->session; - EXPECT_TRUE(session); - EXPECT_EQ(1, session->references); - - AddToCache(ssl.get()); - EXPECT_EQ(2, session->references); - - // Mark the session as good, so that it is re-used for the second connection. - cache_.MarkSSLSessionAsGood(ssl.get()); - - ssl.reset(NULL); - EXPECT_EQ(1, session->references); - - // Second call should find the session ID and associate it with |ssl2|. - ScopedSSL ssl2(NewSSL(key)); - EXPECT_TRUE(cache_.SetSSLSession(ssl2.get())); - - EXPECT_EQ(session, ssl2.get()->session); - EXPECT_EQ(2, session->references); -} - -TEST_F(SSLSessionCacheOpenSSLTest, SetSSLSessionWithKey) { - const std::string key("hello"); - ScopedSSL ssl(NewSSL(key)); - AddToCache(ssl.get()); - cache_.MarkSSLSessionAsGood(ssl.get()); - ssl.reset(NULL); - - ScopedSSL ssl2(NewSSL(key)); - EXPECT_TRUE(cache_.SetSSLSessionWithKey(ssl2.get(), key)); -} - -TEST_F(SSLSessionCacheOpenSSLTest, CheckSessionReplacement) { - // Check that if two SSL connections have the same key, only one - // corresponding session can be stored in the cache. - const std::string common_key("common-key"); - ScopedSSL ssl1(NewSSL(common_key)); - ScopedSSL ssl2(NewSSL(common_key)); - - AddToCache(ssl1.get()); - EXPECT_EQ(1U, cache_.size()); - EXPECT_EQ(2, ssl1.get()->session->references); - - // This ends up calling OnSessionAdded which will discover that there is - // already one session ID associated with the key, and will replace it. - AddToCache(ssl2.get()); - EXPECT_EQ(1U, cache_.size()); - EXPECT_EQ(1, ssl1.get()->session->references); - EXPECT_EQ(2, ssl2.get()->session->references); -} - -// Check that when two connections have the same key, a new session is created -// if the existing session has not yet been marked "good". Further, after the -// first session completes, if the second session has replaced it in the cache, -// new sessions should continue to fail until the currently cached session -// succeeds. -TEST_F(SSLSessionCacheOpenSSLTest, CheckSessionReplacementWhenNotGood) { - const std::string key("hello"); - ScopedSSL ssl(NewSSL(key)); - - // First call should fail because the session is not in the cache. - EXPECT_FALSE(cache_.SetSSLSession(ssl.get())); - SSL_SESSION* session = ssl.get()->session; - ASSERT_TRUE(session); - EXPECT_EQ(1, session->references); - - AddToCache(ssl.get()); - EXPECT_EQ(2, session->references); - - // Second call should find the session ID, but because it is not yet good, - // fail to associate it with |ssl2|. - ScopedSSL ssl2(NewSSL(key)); - EXPECT_FALSE(cache_.SetSSLSession(ssl2.get())); - SSL_SESSION* session2 = ssl2.get()->session; - ASSERT_TRUE(session2); - EXPECT_EQ(1, session2->references); - - EXPECT_NE(session, session2); - - // Add the second connection to the cache. It should replace the first - // session, and the cache should hold on to the second session. - AddToCache(ssl2.get()); - EXPECT_EQ(1, session->references); - EXPECT_EQ(2, session2->references); - - // Mark the first session as good, simulating it completing. - cache_.MarkSSLSessionAsGood(ssl.get()); - - // Third call should find the session ID, but because the second session (the - // current cache entry) is not yet good, fail to associate it with |ssl3|. - ScopedSSL ssl3(NewSSL(key)); - EXPECT_FALSE(cache_.SetSSLSession(ssl3.get())); - EXPECT_NE(session, ssl3.get()->session); - EXPECT_NE(session2, ssl3.get()->session); - EXPECT_EQ(1, ssl3.get()->session->references); -} - -TEST_F(SSLSessionCacheOpenSSLTest, CheckEviction) { - const size_t kMaxItems = 20; - int local_id = 1; - - SSLSessionCacheOpenSSL::Config config = kDefaultConfig; - config.max_entries = kMaxItems; - ResetConfig(config); - - // Add kMaxItems to the cache. - for (size_t n = 0; n < kMaxItems; ++n) { - std::string local_id_string = base::StringPrintf("%d", local_id++); - ScopedSSL ssl(NewSSL(local_id_string)); - - AddToCache(ssl.get()); - EXPECT_EQ(n + 1, cache_.size()); - } - - // Continue adding new items to the cache, check that old ones are - // evicted. - for (size_t n = 0; n < kMaxItems; ++n) { - std::string local_id_string = base::StringPrintf("%d", local_id++); - ScopedSSL ssl(NewSSL(local_id_string)); - - AddToCache(ssl.get()); - EXPECT_EQ(kMaxItems, cache_.size()); - } -} - -// Check that session expiration works properly. -TEST_F(SSLSessionCacheOpenSSLTest, CheckExpiration) { - const size_t kMaxCheckCount = 10; - const size_t kNumEntries = 20; - - SSLSessionCacheOpenSSL::Config config = kDefaultConfig; - config.expiration_check_count = kMaxCheckCount; - config.timeout_seconds = 1000; - ResetConfig(config); - - // Add |kNumItems - 1| session entries with crafted time values. - for (size_t n = 0; n < kNumEntries - 1U; ++n) { - std::string key = base::StringPrintf("%d", static_cast<int>(n)); - ScopedSSL ssl(NewSSL(key)); - // Cheat a little: Force the session |time| value, this guarantees that they - // are expired, given that ::time() will always return a value that is - // past the first 100 seconds after the Unix epoch. - ssl.get()->session->time = static_cast<long>(n); - AddToCache(ssl.get()); - } - EXPECT_EQ(kNumEntries - 1U, cache_.size()); - - // Add nother session which will get the current time, and thus not be - // expirable until 1000 seconds have passed. - ScopedSSL good_ssl(NewSSL("good-key")); - AddToCache(good_ssl.get()); - good_ssl.reset(NULL); - EXPECT_EQ(kNumEntries, cache_.size()); - - // Call SetSSLSession() |kMaxCheckCount - 1| times, this shall not expire - // any session - for (size_t n = 0; n < kMaxCheckCount - 1U; ++n) { - ScopedSSL ssl(NewSSL("unknown-key")); - cache_.SetSSLSession(ssl.get()); - EXPECT_EQ(kNumEntries, cache_.size()); - } - - // Call SetSSLSession another time, this shall expire all sessions except - // the last one. - ScopedSSL bad_ssl(NewSSL("unknown-key")); - cache_.SetSSLSession(bad_ssl.get()); - bad_ssl.reset(NULL); - EXPECT_EQ(1U, cache_.size()); -} - -} // namespace net diff --git a/chromium/net/socket/stream_listen_socket.cc b/chromium/net/socket/stream_listen_socket.cc index abb5fbc6b52..fd164a556d5 100644 --- a/chromium/net/socket/stream_listen_socket.cc +++ b/chromium/net/socket/stream_listen_socket.cc @@ -21,7 +21,6 @@ #include "base/memory/ref_counted.h" #include "base/memory/scoped_ptr.h" #include "base/posix/eintr_wrapper.h" -#include "base/profiler/scoped_tracker.h" #include "base/sys_byteorder.h" #include "base/threading/platform_thread.h" #include "build/build_config.h" @@ -247,10 +246,6 @@ void StreamListenSocket::UnwatchSocket() { #if defined(OS_WIN) // MessageLoop watcher callback. void StreamListenSocket::OnObjectSignaled(HANDLE object) { - // TODO(vadimt): Remove ScopedTracker below once crbug.com/418183 is fixed. - tracked_objects::ScopedTracker tracking_profile( - FROM_HERE_WITH_EXPLICIT_FUNCTION("StreamListenSocket_OnObjectSignaled")); - WSANETWORKEVENTS ev; if (kSocketError == WSAEnumNetworkEvents(socket_, socket_event_, &ev)) { // TODO diff --git a/chromium/net/socket/stream_listen_socket.h b/chromium/net/socket/stream_listen_socket.h index f8f9419484d..c83657bccbc 100644 --- a/chromium/net/socket/stream_listen_socket.h +++ b/chromium/net/socket/stream_listen_socket.h @@ -111,7 +111,7 @@ class NET_EXPORT StreamListenSocket #if defined(OS_WIN) // ObjectWatcher delegate. - virtual void OnObjectSignaled(HANDLE object); + void OnObjectSignaled(HANDLE object) override; base::win::ObjectWatcher watcher_; HANDLE socket_event_; #elif defined(OS_POSIX) diff --git a/chromium/net/socket/stream_socket.h b/chromium/net/socket/stream_socket.h index b41fed8b51b..5669ea3e9a0 100644 --- a/chromium/net/socket/stream_socket.h +++ b/chromium/net/socket/stream_socket.h @@ -5,7 +5,8 @@ #ifndef NET_SOCKET_STREAM_SOCKET_H_ #define NET_SOCKET_STREAM_SOCKET_H_ -#include "net/base/net_log.h" +#include "net/log/net_log.h" +#include "net/socket/connection_attempts.h" #include "net/socket/next_proto.h" #include "net/socket/socket.h" @@ -95,6 +96,16 @@ class NET_EXPORT_PRIVATE StreamSocket : public Socket { // SSL was not used by this socket. virtual bool GetSSLInfo(SSLInfo* ssl_info) = 0; + // Overwrites |out| with the connection attempts made in the process of + // connecting this socket. + virtual void GetConnectionAttempts(ConnectionAttempts* out) const = 0; + + // Clears the socket's list of connection attempts. + virtual void ClearConnectionAttempts() = 0; + + // Adds |attempts| to the socket's list of connection attempts. + virtual void AddConnectionAttempts(const ConnectionAttempts& attempts) = 0; + protected: // The following class is only used to gather statistics about the history of // a socket. It is only instantiated and used in basic sockets, such as diff --git a/chromium/net/socket/tcp_client_socket.cc b/chromium/net/socket/tcp_client_socket.cc index dcf124bfd6b..a1c1fadb85e 100644 --- a/chromium/net/socket/tcp_client_socket.cc +++ b/chromium/net/socket/tcp_client_socket.cc @@ -117,25 +117,32 @@ int TCPClientSocket::DoConnect() { const IPEndPoint& endpoint = addresses_[current_address_index_]; - if (previously_disconnected_) { - use_history_.Reset(); - previously_disconnected_ = false; - } - - next_connect_state_ = CONNECT_STATE_CONNECT_COMPLETE; + { + // TODO(ricea): Remove ScopedTracker below once crbug.com/436634 is fixed. + tracked_objects::ScopedTracker tracking_profile( + FROM_HERE_WITH_EXPLICIT_FUNCTION("436634 TCPClientSocket::DoConnect")); + + if (previously_disconnected_) { + use_history_.Reset(); + connection_attempts_.clear(); + previously_disconnected_ = false; + } - if (socket_->IsValid()) { - DCHECK(bind_address_); - } else { - int result = OpenSocket(endpoint.GetFamily()); - if (result != OK) - return result; + next_connect_state_ = CONNECT_STATE_CONNECT_COMPLETE; - if (bind_address_) { - result = socket_->Bind(*bind_address_); - if (result != OK) { - socket_->Close(); + if (socket_->IsValid()) { + DCHECK(bind_address_); + } else { + int result = OpenSocket(endpoint.GetFamily()); + if (result != OK) return result; + + if (bind_address_) { + result = socket_->Bind(*bind_address_); + if (result != OK) { + socket_->Close(); + return result; + } } } } @@ -153,6 +160,9 @@ int TCPClientSocket::DoConnectComplete(int result) { return OK; // Done! } + connection_attempts_.push_back( + ConnectionAttempt(addresses_[current_address_index_], result)); + // Close whatever partially connected socket we currently have. DoDisconnect(); @@ -290,6 +300,20 @@ bool TCPClientSocket::SetNoDelay(bool no_delay) { return socket_->SetNoDelay(no_delay); } +void TCPClientSocket::GetConnectionAttempts(ConnectionAttempts* out) const { + *out = connection_attempts_; +} + +void TCPClientSocket::ClearConnectionAttempts() { + connection_attempts_.clear(); +} + +void TCPClientSocket::AddConnectionAttempts( + const ConnectionAttempts& attempts) { + connection_attempts_.insert(connection_attempts_.begin(), attempts.begin(), + attempts.end()); +} + void TCPClientSocket::DidCompleteConnect(int result) { DCHECK_EQ(next_connect_state_, CONNECT_STATE_CONNECT_COMPLETE); DCHECK_NE(result, ERR_IO_PENDING); @@ -307,10 +331,10 @@ void TCPClientSocket::DidCompleteReadWrite(const CompletionCallback& callback, if (result > 0) use_history_.set_was_used_to_convey_data(); - // TODO(vadimt): Remove ScopedTracker below once crbug.com/418183 is fixed. + // TODO(pkasting): Remove ScopedTracker below once crbug.com/462780 is fixed. tracked_objects::ScopedTracker tracking_profile( FROM_HERE_WITH_EXPLICIT_FUNCTION( - "TCPClientSocket::DidCompleteReadWrite")); + "462780 TCPClientSocket::DidCompleteReadWrite")); callback.Run(result); } diff --git a/chromium/net/socket/tcp_client_socket.h b/chromium/net/socket/tcp_client_socket.h index 0deec2a0c9f..0b8062bfd09 100644 --- a/chromium/net/socket/tcp_client_socket.h +++ b/chromium/net/socket/tcp_client_socket.h @@ -11,7 +11,8 @@ #include "net/base/address_list.h" #include "net/base/completion_callback.h" #include "net/base/net_export.h" -#include "net/base/net_log.h" +#include "net/log/net_log.h" +#include "net/socket/connection_attempts.h" #include "net/socket/stream_socket.h" #include "net/socket/tcp_socket.h" @@ -69,6 +70,10 @@ class NET_EXPORT TCPClientSocket : public StreamSocket { virtual bool SetKeepAlive(bool enable, int delay); virtual bool SetNoDelay(bool no_delay); + void GetConnectionAttempts(ConnectionAttempts* out) const override; + void ClearConnectionAttempts() override; + void AddConnectionAttempts(const ConnectionAttempts& attempts) override; + private: // State machine for connecting the socket. enum ConnectState { @@ -116,6 +121,9 @@ class NET_EXPORT TCPClientSocket : public StreamSocket { // histograms. UseHistory use_history_; + // Failed connection attempts made while trying to connect this socket. + ConnectionAttempts connection_attempts_; + DISALLOW_COPY_AND_ASSIGN(TCPClientSocket); }; diff --git a/chromium/net/socket/tcp_listen_socket.cc b/chromium/net/socket/tcp_listen_socket.cc index 585c41292de..11b23908ab1 100644 --- a/chromium/net/socket/tcp_listen_socket.cc +++ b/chromium/net/socket/tcp_listen_socket.cc @@ -31,7 +31,9 @@ namespace net { // static scoped_ptr<TCPListenSocket> TCPListenSocket::CreateAndListen( - const string& ip, int port, StreamListenSocket::Delegate* del) { + const string& ip, + uint16 port, + StreamListenSocket::Delegate* del) { SocketDescriptor s = CreateAndBind(ip, port); if (s == kInvalidSocket) return scoped_ptr<TCPListenSocket>(); @@ -47,7 +49,7 @@ TCPListenSocket::TCPListenSocket(SocketDescriptor s, TCPListenSocket::~TCPListenSocket() {} -SocketDescriptor TCPListenSocket::CreateAndBind(const string& ip, int port) { +SocketDescriptor TCPListenSocket::CreateAndBind(const string& ip, uint16 port) { SocketDescriptor s = CreatePlatformSocket(AF_INET, SOCK_STREAM, IPPROTO_TCP); if (s != kInvalidSocket) { #if defined(OS_POSIX) @@ -74,7 +76,7 @@ SocketDescriptor TCPListenSocket::CreateAndBind(const string& ip, int port) { } SocketDescriptor TCPListenSocket::CreateAndBindAnyPort(const string& ip, - int* port) { + uint16* port) { SocketDescriptor s = CreateAndBind(ip, 0); if (s == kInvalidSocket) return kInvalidSocket; @@ -110,16 +112,4 @@ void TCPListenSocket::Accept() { socket_delegate_->DidAccept(this, sock.Pass()); } -TCPListenSocketFactory::TCPListenSocketFactory(const string& ip, int port) - : ip_(ip), - port_(port) { -} - -TCPListenSocketFactory::~TCPListenSocketFactory() {} - -scoped_ptr<StreamListenSocket> TCPListenSocketFactory::CreateAndListen( - StreamListenSocket::Delegate* delegate) const { - return TCPListenSocket::CreateAndListen(ip_, port_, delegate); -} - } // namespace net diff --git a/chromium/net/socket/tcp_listen_socket.h b/chromium/net/socket/tcp_listen_socket.h index 1702e50e8ed..d87fd49bbd0 100644 --- a/chromium/net/socket/tcp_listen_socket.h +++ b/chromium/net/socket/tcp_listen_socket.h @@ -12,23 +12,22 @@ #include "net/socket/socket_descriptor.h" #include "net/socket/stream_listen_socket.h" +namespace nacl { +class NaClProcessHost; +} + namespace net { -// Implements a TCP socket. +namespace test_server { +class EmbeddedTestServer; +} + +// Implements a TCP socket. This class is deprecated and will be removed +// once crbug.com/472766 is fixed. There should not be any new consumer of this +// class. class NET_EXPORT TCPListenSocket : public StreamListenSocket { public: ~TCPListenSocket() override; - // Listen on port for the specified IP address. Use 127.0.0.1 to only - // accept local connections. - static scoped_ptr<TCPListenSocket> CreateAndListen( - const std::string& ip, int port, StreamListenSocket::Delegate* del); - - // Get raw TCP socket descriptor bound to ip:port. - static SocketDescriptor CreateAndBind(const std::string& ip, int port); - - // Get raw TCP socket descriptor bound to ip and return port it is bound to. - static SocketDescriptor CreateAndBindAnyPort(const std::string& ip, - int* port); protected: TCPListenSocket(SocketDescriptor s, StreamListenSocket::Delegate* del); @@ -37,24 +36,27 @@ class NET_EXPORT TCPListenSocket : public StreamListenSocket { void Accept() override; private: - DISALLOW_COPY_AND_ASSIGN(TCPListenSocket); -}; + // Note that friend classes are temporary until crbug.com/472766 is fixed. + friend class test_server::EmbeddedTestServer; + friend class TCPListenSocketTester; + friend class TransportClientSocketTest; + friend class nacl::NaClProcessHost; -// Factory that can be used to instantiate TCPListenSocket. -class NET_EXPORT TCPListenSocketFactory : public StreamListenSocketFactory { - public: - TCPListenSocketFactory(const std::string& ip, int port); - ~TCPListenSocketFactory() override; + // Listen on port for the specified IP address. Use 127.0.0.1 to only + // accept local connections. + static scoped_ptr<TCPListenSocket> CreateAndListen( + const std::string& ip, + uint16 port, + StreamListenSocket::Delegate* del); - // StreamListenSocketFactory overrides. - scoped_ptr<StreamListenSocket> CreateAndListen( - StreamListenSocket::Delegate* delegate) const override; + // Get raw TCP socket descriptor bound to ip:port. + static SocketDescriptor CreateAndBind(const std::string& ip, uint16 port); - private: - const std::string ip_; - const int port_; + // Get raw TCP socket descriptor bound to ip and return port it is bound to. + static SocketDescriptor CreateAndBindAnyPort(const std::string& ip, + uint16* port); - DISALLOW_COPY_AND_ASSIGN(TCPListenSocketFactory); + DISALLOW_COPY_AND_ASSIGN(TCPListenSocket); }; } // namespace net diff --git a/chromium/net/socket/tcp_listen_socket_unittest.h b/chromium/net/socket/tcp_listen_socket_unittest.h index 984442afdc0..cd19a1caefc 100644 --- a/chromium/net/socket/tcp_listen_socket_unittest.h +++ b/chromium/net/socket/tcp_listen_socket_unittest.h @@ -48,7 +48,7 @@ class TCPListenSocketTestAction { : action_(action), data_(data) {} - const std::string data() const { return data_; } + const std::string& data() const { return data_; } ActionType type() const { return action_; } private: diff --git a/chromium/net/socket/tcp_server_socket.h b/chromium/net/socket/tcp_server_socket.h index a3919e6845a..b161be65723 100644 --- a/chromium/net/socket/tcp_server_socket.h +++ b/chromium/net/socket/tcp_server_socket.h @@ -10,7 +10,7 @@ #include "base/memory/scoped_ptr.h" #include "net/base/ip_endpoint.h" #include "net/base/net_export.h" -#include "net/base/net_log.h" +#include "net/log/net_log.h" #include "net/socket/server_socket.h" #include "net/socket/tcp_socket.h" diff --git a/chromium/net/socket/tcp_server_socket_unittest.cc b/chromium/net/socket/tcp_server_socket_unittest.cc index 01bae9ff188..2f6491bf8cf 100644 --- a/chromium/net/socket/tcp_server_socket_unittest.cc +++ b/chromium/net/socket/tcp_server_socket_unittest.cc @@ -50,7 +50,7 @@ class TCPServerSocketTest : public PlatformTest { *success = true; } - void ParseAddress(std::string ip_str, int port, IPEndPoint* address) { + void ParseAddress(std::string ip_str, uint16 port, IPEndPoint* address) { IPAddressNumber ip_number; bool rv = ParseIPLiteralToNumber(ip_str, &ip_number); if (!rv) diff --git a/chromium/net/socket/tcp_socket_libevent.cc b/chromium/net/socket/tcp_socket_libevent.cc index cc2376590f5..56c19b203a4 100644 --- a/chromium/net/socket/tcp_socket_libevent.cc +++ b/chromium/net/socket/tcp_socket_libevent.cc @@ -9,9 +9,10 @@ #include <sys/socket.h> #include "base/bind.h" +#include "base/files/file_path.h" +#include "base/files/file_util.h" #include "base/logging.h" #include "base/metrics/histogram.h" -#include "base/metrics/stats_counters.h" #include "base/posix/eintr_wrapper.h" #include "base/task_runner_util.h" #include "base/threading/worker_pool.h" @@ -21,6 +22,7 @@ #include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" #include "net/base/net_util.h" +#include "net/base/network_activity_monitor.h" #include "net/base/network_change_notifier.h" #include "net/socket/socket_libevent.h" #include "net/socket/socket_net_log_params.h" @@ -54,6 +56,7 @@ bool SetTCPNoDelay(int fd, bool no_delay) { // SetTCPKeepAlive sets SO_KEEPALIVE. bool SetTCPKeepAlive(int fd, bool enable, int delay) { + // Enabling TCP keepalives is the same on all platforms. int on = enable ? 1 : 0; if (setsockopt(fd, SOL_SOCKET, SO_KEEPALIVE, &on, sizeof(on))) { PLOG(ERROR) << "Failed to set SO_KEEPALIVE on fd: " << fd; @@ -65,6 +68,8 @@ bool SetTCPKeepAlive(int fd, bool enable, int delay) { return true; #if defined(OS_LINUX) || defined(OS_ANDROID) + // Setting the keepalive interval varies by platform. + // Set seconds until first TCP keep alive. if (setsockopt(fd, SOL_TCP, TCP_KEEPIDLE, &delay, sizeof(delay))) { PLOG(ERROR) << "Failed to set TCP_KEEPIDLE on fd: " << fd; @@ -75,6 +80,11 @@ bool SetTCPKeepAlive(int fd, bool enable, int delay) { PLOG(ERROR) << "Failed to set TCP_KEEPINTVL on fd: " << fd; return false; } +#elif defined(OS_MACOSX) || defined(OS_IOS) + if (setsockopt(fd, IPPROTO_TCP, TCP_KEEPALIVE, &delay, sizeof(delay))) { + PLOG(ERROR) << "Failed to set TCP_KEEPALIVE on fd: " << fd; + return false; + } #endif return true; } @@ -537,9 +547,6 @@ int TCPSocketLibevent::HandleConnectCompleted(int rv) const { } void TCPSocketLibevent::LogConnectBegin(const AddressList& addresses) const { - base::StatsCounter connects("tcp.connect"); - connects.Increment(); - net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT, addresses.CreateNetLogCallback()); } @@ -596,11 +603,10 @@ int TCPSocketLibevent::HandleReadCompleted(IOBuffer* buf, int rv) { CreateNetLogSocketErrorCallback(rv, errno)); return rv; } - - base::StatsCounter read_bytes("tcp.read_bytes"); - read_bytes.Add(rv); net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED, rv, buf->data()); + NetworkActivityMonitor::GetInstance()->IncrementBytesReceived(rv); + return rv; } @@ -628,11 +634,9 @@ int TCPSocketLibevent::HandleWriteCompleted(IOBuffer* buf, int rv) { CreateNetLogSocketErrorCallback(rv, errno)); return rv; } - - base::StatsCounter write_bytes("tcp.write_bytes"); - write_bytes.Add(rv); net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_SENT, rv, buf->data()); + NetworkActivityMonitor::GetInstance()->IncrementBytesSent(rv); return rv; } diff --git a/chromium/net/socket/tcp_socket_libevent.h b/chromium/net/socket/tcp_socket_libevent.h index 0958b6d25d0..78b035d2d8d 100644 --- a/chromium/net/socket/tcp_socket_libevent.h +++ b/chromium/net/socket/tcp_socket_libevent.h @@ -12,7 +12,7 @@ #include "net/base/address_family.h" #include "net/base/completion_callback.h" #include "net/base/net_export.h" -#include "net/base/net_log.h" +#include "net/log/net_log.h" namespace net { diff --git a/chromium/net/socket/tcp_socket_unittest.cc b/chromium/net/socket/tcp_socket_unittest.cc index 198138860fc..4bfc1384fbd 100644 --- a/chromium/net/socket/tcp_socket_unittest.cc +++ b/chromium/net/socket/tcp_socket_unittest.cc @@ -9,6 +9,7 @@ #include <string> #include <vector> +#include "base/basictypes.h" #include "base/memory/ref_counted.h" #include "base/memory/scoped_ptr.h" #include "net/base/address_list.h" @@ -56,7 +57,9 @@ class TCPSocketTest : public PlatformTest { *success = true; } - void ParseAddress(const std::string& ip_str, int port, IPEndPoint* address) { + void ParseAddress(const std::string& ip_str, + uint16 port, + IPEndPoint* address) { IPAddressNumber ip_number; bool rv = ParseIPLiteralToNumber(ip_str, &ip_number); if (!rv) diff --git a/chromium/net/socket/tcp_socket_win.cc b/chromium/net/socket/tcp_socket_win.cc index d5565ad669c..2620ebaf701 100644 --- a/chromium/net/socket/tcp_socket_win.cc +++ b/chromium/net/socket/tcp_socket_win.cc @@ -9,7 +9,6 @@ #include "base/callback_helpers.h" #include "base/logging.h" -#include "base/metrics/stats_counters.h" #include "base/profiler/scoped_tracker.h" #include "base/win/windows_version.h" #include "net/base/address_list.h" @@ -18,6 +17,7 @@ #include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" #include "net/base/net_util.h" +#include "net/base/network_activity_monitor.h" #include "net/base/network_change_notifier.h" #include "net/base/winsock_init.h" #include "net/base/winsock_util.h" @@ -80,11 +80,11 @@ bool DisableNagle(SOCKET socket, bool disable) { // Enable TCP Keep-Alive to prevent NAT routers from timing out TCP // connections. See http://crbug.com/27400 for details. bool SetTCPKeepAlive(SOCKET socket, BOOL enable, int delay_secs) { - int delay = delay_secs * 1000; + unsigned delay = delay_secs * 1000; struct tcp_keepalive keepalive_vals = { - enable ? 1 : 0, // TCP keep-alive on. - delay, // Delay seconds before sending first TCP keep-alive packet. - delay, // Delay seconds between sending TCP keep-alive packets. + enable ? 1u : 0u, // TCP keep-alive on. + delay, // Delay seconds before sending first TCP keep-alive packet. + delay, // Delay seconds between sending TCP keep-alive packets. }; DWORD bytes_returned = 0xABAB; int rv = WSAIoctl(socket, SIO_KEEPALIVE_VALS, &keepalive_vals, @@ -167,10 +167,10 @@ class TCPSocketWin::Core : public base::RefCounted<Core> { class ReadDelegate : public base::win::ObjectWatcher::Delegate { public: explicit ReadDelegate(Core* core) : core_(core) {} - virtual ~ReadDelegate() {} + ~ReadDelegate() override {} // base::ObjectWatcher::Delegate methods: - virtual void OnObjectSignaled(HANDLE object); + void OnObjectSignaled(HANDLE object) override; private: Core* const core_; @@ -179,10 +179,10 @@ class TCPSocketWin::Core : public base::RefCounted<Core> { class WriteDelegate : public base::win::ObjectWatcher::Delegate { public: explicit WriteDelegate(Core* core) : core_(core) {} - virtual ~WriteDelegate() {} + ~WriteDelegate() override {} // base::ObjectWatcher::Delegate methods: - virtual void OnObjectSignaled(HANDLE object); + void OnObjectSignaled(HANDLE object) override; private: Core* const core_; @@ -246,11 +246,6 @@ void TCPSocketWin::Core::WatchForWrite() { } void TCPSocketWin::Core::ReadDelegate::OnObjectSignaled(HANDLE object) { - // TODO(vadimt): Remove ScopedTracker below once crbug.com/418183 is fixed. - tracked_objects::ScopedTracker tracking_profile( - FROM_HERE_WITH_EXPLICIT_FUNCTION( - "TCPSocketWin_Core_ReadDelegate_OnObjectSignaled")); - DCHECK_EQ(object, core_->read_overlapped_.hEvent); if (core_->socket_) { if (core_->socket_->waiting_connect_) @@ -264,11 +259,6 @@ void TCPSocketWin::Core::ReadDelegate::OnObjectSignaled(HANDLE object) { void TCPSocketWin::Core::WriteDelegate::OnObjectSignaled( HANDLE object) { - // TODO(vadimt): Remove ScopedTracker below once crbug.com/418183 is fixed. - tracked_objects::ScopedTracker tracking_profile( - FROM_HERE_WITH_EXPLICIT_FUNCTION( - "TCPSocketWin_Core_WriteDelegate_OnObjectSignaled")); - DCHECK_EQ(object, core_->write_overlapped_.hEvent); if (core_->socket_) core_->socket_->DidCompleteWrite(); @@ -324,7 +314,7 @@ int TCPSocketWin::AdoptConnectedSocket(SOCKET socket, const IPEndPoint& peer_address) { DCHECK(CalledOnValidThread()); DCHECK_EQ(socket_, INVALID_SOCKET); - DCHECK(!core_); + DCHECK(!core_.get()); socket_ = socket; @@ -436,7 +426,7 @@ int TCPSocketWin::Connect(const IPEndPoint& address, // again after a connection attempt failed on Windows, it results in // unspecified behavior according to POSIX. Therefore, we make it behave in // the same way as TCPSocketLibevent. - DCHECK(!peer_address_ && !core_); + DCHECK(!peer_address_ && !core_.get()); if (!logging_multiple_connect_attempts_) LogConnectBegin(AddressList(address)); @@ -504,7 +494,7 @@ int TCPSocketWin::Read(IOBuffer* buf, DCHECK_NE(socket_, INVALID_SOCKET); DCHECK(!waiting_read_); CHECK(read_callback_.is_null()); - DCHECK(!core_->read_iobuffer_); + DCHECK(!core_->read_iobuffer_.get()); return DoRead(buf, buf_len, callback); } @@ -517,10 +507,7 @@ int TCPSocketWin::Write(IOBuffer* buf, DCHECK(!waiting_write_); CHECK(write_callback_.is_null()); DCHECK_GT(buf_len, 0); - DCHECK(!core_->write_iobuffer_); - - base::StatsCounter writes("tcp.writes"); - writes.Increment(); + DCHECK(!core_->write_iobuffer_.get()); WSABUF write_buffer; write_buffer.len = buf_len; @@ -541,10 +528,9 @@ int TCPSocketWin::Write(IOBuffer* buf, << " bytes, but " << rv << " bytes reported."; return ERR_WINSOCK_UNEXPECTED_WRITTEN_BYTES; } - base::StatsCounter write_bytes("tcp.write_bytes"); - write_bytes.Add(rv); net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_SENT, rv, buf->data()); + NetworkActivityMonitor::GetInstance()->IncrementBytesSent(rv); return rv; } } else { @@ -689,7 +675,7 @@ void TCPSocketWin::Close() { accept_event_ = WSA_INVALID_EVENT; } - if (core_) { + if (core_.get()) { if (waiting_connect_) { // We closed the socket, so this notification will never come. // From MSDN' WSAEventSelect documentation: @@ -766,10 +752,6 @@ int TCPSocketWin::AcceptInternal(scoped_ptr<TCPSocketWin>* socket, } void TCPSocketWin::OnObjectSignaled(HANDLE object) { - // TODO(vadimt): Remove ScopedTracker below once crbug.com/418183 is fixed. - tracked_objects::ScopedTracker tracking_profile( - FROM_HERE_WITH_EXPLICIT_FUNCTION("TCPSocketWin_OnObjectSignaled")); - WSANETWORKEVENTS ev; if (WSAEnumNetworkEvents(socket_, accept_event_, &ev) == SOCKET_ERROR) { PLOG(ERROR) << "WSAEnumNetworkEvents()"; @@ -796,12 +778,13 @@ void TCPSocketWin::OnObjectSignaled(HANDLE object) { int TCPSocketWin::DoConnect() { DCHECK_EQ(connect_os_error_, 0); - DCHECK(!core_); + DCHECK(!core_.get()); net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT, CreateNetLogIPEndPointCallback(peer_address_.get())); core_ = new Core(this); + // WSAEventSelect sets the socket to non-blocking mode as a side effect. // Our connect() and recv() calls require that the socket be non-blocking. WSAEventSelect(socket_, core_->read_overlapped_.hEvent, FD_CONNECT); @@ -809,7 +792,16 @@ int TCPSocketWin::DoConnect() { SockaddrStorage storage; if (!peer_address_->ToSockAddr(storage.addr, &storage.addr_len)) return ERR_ADDRESS_INVALID; - if (!connect(socket_, storage.addr, storage.addr_len)) { + + int result; + { + // TODO(ricea): Remove ScopedTracker below once crbug.com/436634 is fixed. + tracked_objects::ScopedTracker tracking_profile( + FROM_HERE_WITH_EXPLICIT_FUNCTION("436634 connect()")); + result = connect(socket_, storage.addr, storage.addr_len); + } + + if (!result) { // Connected without waiting! // // The MSDN page for connect says: @@ -835,6 +827,10 @@ int TCPSocketWin::DoConnect() { } } + // TODO(ricea): Remove ScopedTracker below once crbug.com/436634 is fixed. + tracked_objects::ScopedTracker tracking_profile( + FROM_HERE_WITH_EXPLICIT_FUNCTION("436634 WatchForRead()")); + core_->WatchForRead(); return ERR_IO_PENDING; } @@ -855,9 +851,6 @@ void TCPSocketWin::DoConnectComplete(int result) { } void TCPSocketWin::LogConnectBegin(const AddressList& addresses) { - base::StatsCounter connects("tcp.connect"); - connects.Increment(); - net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT, addresses.CreateNetLogCallback()); } @@ -908,11 +901,9 @@ int TCPSocketWin::DoRead(IOBuffer* buf, int buf_len, return net_error; } } else { - base::StatsCounter read_bytes("tcp.read_bytes"); - if (rv > 0) - read_bytes.Add(rv); net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED, rv, buf->data()); + NetworkActivityMonitor::GetInstance()->IncrementBytesReceived(rv); return rv; } @@ -930,8 +921,15 @@ void TCPSocketWin::DidCompleteConnect() { int result; WSANETWORKEVENTS events; - int rv = WSAEnumNetworkEvents(socket_, core_->read_overlapped_.hEvent, - &events); + int rv; + { + // TODO(pkasting): Remove ScopedTracker below once crbug.com/462784 is + // fixed. + tracked_objects::ScopedTracker tracking_profile1( + FROM_HERE_WITH_EXPLICIT_FUNCTION( + "462784 TCPSocketWin::DidCompleteConnect -> WSAEnumNetworkEvents")); + rv = WSAEnumNetworkEvents(socket_, core_->read_overlapped_.hEvent, &events); + } int os_error = 0; if (rv == SOCKET_ERROR) { NOTREACHED(); @@ -949,6 +947,10 @@ void TCPSocketWin::DidCompleteConnect() { DoConnectComplete(result); waiting_connect_ = false; + // TODO(pkasting): Remove ScopedTracker below once crbug.com/462784 is fixed. + tracked_objects::ScopedTracker tracking_profile4( + FROM_HERE_WITH_EXPLICIT_FUNCTION( + "462784 TCPSocketWin::DidCompleteConnect -> read_callback_")); DCHECK_NE(result, ERR_IO_PENDING); base::ResetAndReturn(&read_callback_).Run(result); } @@ -978,10 +980,9 @@ void TCPSocketWin::DidCompleteWrite() { << " bytes reported."; rv = ERR_WINSOCK_UNEXPECTED_WRITTEN_BYTES; } else { - base::StatsCounter write_bytes("tcp.write_bytes"); - write_bytes.Add(num_bytes); net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_SENT, num_bytes, core_->write_iobuffer_->data()); + NetworkActivityMonitor::GetInstance()->IncrementBytesSent(num_bytes); } } @@ -1003,6 +1004,11 @@ void TCPSocketWin::DidSignalRead() { os_error = WSAGetLastError(); rv = MapSystemError(os_error); } else if (network_events.lNetworkEvents) { + // TODO(pkasting): Remove ScopedTracker below once crbug.com/462778 is + // fixed. + tracked_objects::ScopedTracker tracking_profile2( + FROM_HERE_WITH_EXPLICIT_FUNCTION( + "462778 TCPSocketWin::DidSignalRead -> DoRead")); DCHECK_EQ(network_events.lNetworkEvents & ~(FD_READ | FD_CLOSE), 0); // If network_events.lNetworkEvents is FD_CLOSE and // network_events.iErrorCode[FD_CLOSE_BIT] is 0, it is a graceful @@ -1018,7 +1024,7 @@ void TCPSocketWin::DidSignalRead() { // DoRead() because recv() reports a more accurate error code // (WSAECONNRESET vs. WSAECONNABORTED) when the connection was // reset. - rv = DoRead(core_->read_iobuffer_, core_->read_buffer_length_, + rv = DoRead(core_->read_iobuffer_.get(), core_->read_buffer_length_, read_callback_); if (rv == ERR_IO_PENDING) return; @@ -1034,9 +1040,6 @@ void TCPSocketWin::DidSignalRead() { core_->read_buffer_length_ = 0; DCHECK_NE(rv, ERR_IO_PENDING); - // TODO(vadimt): Remove ScopedTracker below once crbug.com/418183 is fixed. - tracked_objects::ScopedTracker tracking_profile( - FROM_HERE_WITH_EXPLICIT_FUNCTION("TCPSocketWin::DidSignalRead")); base::ResetAndReturn(&read_callback_).Run(rv); } diff --git a/chromium/net/socket/tcp_socket_win.h b/chromium/net/socket/tcp_socket_win.h index 80174adea9c..1012259806c 100644 --- a/chromium/net/socket/tcp_socket_win.h +++ b/chromium/net/socket/tcp_socket_win.h @@ -16,7 +16,7 @@ #include "net/base/address_family.h" #include "net/base/completion_callback.h" #include "net/base/net_export.h" -#include "net/base/net_log.h" +#include "net/log/net_log.h" namespace net { @@ -28,7 +28,7 @@ class NET_EXPORT TCPSocketWin : NON_EXPORTED_BASE(public base::NonThreadSafe), public base::win::ObjectWatcher::Delegate { public: TCPSocketWin(NetLog* net_log, const NetLog::Source& source); - virtual ~TCPSocketWin(); + ~TCPSocketWin() override; int Open(AddressFamily family); @@ -101,7 +101,7 @@ class NET_EXPORT TCPSocketWin : NON_EXPORTED_BASE(public base::NonThreadSafe), class Core; // base::ObjectWatcher::Delegate implementation. - virtual void OnObjectSignaled(HANDLE object) override; + void OnObjectSignaled(HANDLE object) override; int AcceptInternal(scoped_ptr<TCPSocketWin>* socket, IPEndPoint* address); diff --git a/chromium/net/socket/transport_client_socket_pool.cc b/chromium/net/socket/transport_client_socket_pool.cc index 06202f13d81..a5f8adbed24 100644 --- a/chromium/net/socket/transport_client_socket_pool.cc +++ b/chromium/net/socket/transport_client_socket_pool.cc @@ -11,13 +11,14 @@ #include "base/logging.h" #include "base/message_loop/message_loop.h" #include "base/metrics/histogram.h" +#include "base/profiler/scoped_tracker.h" #include "base/strings/string_util.h" #include "base/synchronization/lock.h" #include "base/time/time.h" #include "base/values.h" #include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" -#include "net/base/net_log.h" +#include "net/log/net_log.h" #include "net/socket/client_socket_factory.h" #include "net/socket/client_socket_handle.h" #include "net/socket/client_socket_pool_base.h" @@ -200,10 +201,14 @@ TransportConnectJob::TransportConnectJob( HostResolver* host_resolver, Delegate* delegate, NetLog* net_log) - : ConnectJob(group_name, timeout_duration, priority, delegate, + : ConnectJob(group_name, + timeout_duration, + priority, + delegate, BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)), helper_(params, client_socket_factory, host_resolver, &connect_timing_), - interval_between_connects_(CONNECT_INTERVAL_GT_20MS) { + interval_between_connects_(CONNECT_INTERVAL_GT_20MS), + resolve_result_(OK) { helper_.SetOnIOComplete(this); } @@ -227,6 +232,21 @@ LoadState TransportConnectJob::GetLoadState() const { return LOAD_STATE_IDLE; } +void TransportConnectJob::GetAdditionalErrorState(ClientSocketHandle* handle) { + // If hostname resolution failed, record an empty endpoint and the result. + // Also record any attempts made on either of the sockets. + ConnectionAttempts attempts; + if (resolve_result_ != OK) { + DCHECK_EQ(0u, helper_.addresses().size()); + attempts.push_back(ConnectionAttempt(IPEndPoint(), resolve_result_)); + } + attempts.insert(attempts.begin(), connection_attempts_.begin(), + connection_attempts_.end()); + attempts.insert(attempts.begin(), fallback_connection_attempts_.begin(), + fallback_connection_attempts_.end()); + handle->set_connection_attempts(attempts); +} + // static void TransportConnectJob::MakeAddressListStartWithIPv4(AddressList* list) { for (AddressList::iterator i = list->begin(); i != list->end(); ++i) { @@ -238,10 +258,16 @@ void TransportConnectJob::MakeAddressListStartWithIPv4(AddressList* list) { } int TransportConnectJob::DoResolveHost() { + // TODO(ricea): Remove ScopedTracker below once crbug.com/436634 is fixed. + tracked_objects::ScopedTracker tracking_profile( + FROM_HERE_WITH_EXPLICIT_FUNCTION( + "436634 TransportConnectJob::DoResolveHost")); + return helper_.DoResolveHost(priority(), net_log()); } int TransportConnectJob::DoResolveHostComplete(int result) { + resolve_result_ = result; return helper_.DoResolveHostComplete(result, net_log()); } @@ -301,6 +327,16 @@ int TransportConnectJob::DoTransportConnect() { int TransportConnectJob::DoTransportConnectComplete(int result) { if (result == OK) { + // Success will be returned via the main socket, so also include connection + // attempts made on the fallback socket up to this point. (Unfortunately, + // the only simple way to return information in the success case is through + // the successfully-connected socket.) + if (fallback_transport_socket_) { + ConnectionAttempts fallback_attempts; + fallback_transport_socket_->GetConnectionAttempts(&fallback_attempts); + transport_socket_->AddConnectionAttempts(fallback_attempts); + } + bool is_ipv4 = helper_.addresses().front().GetFamily() == ADDRESS_FAMILY_IPV4; TransportConnectJobHelper::ConnectionLatencyHistogram race_result = @@ -349,11 +385,18 @@ int TransportConnectJob::DoTransportConnectComplete(int result) { SetSocket(transport_socket_.Pass()); fallback_timer_.Stop(); } else { + // Failure will be returned via |GetAdditionalErrorState|, so save + // connection attempts from both sockets for use there. + CopyConnectionAttemptsFromSockets(); + // Be a bit paranoid and kill off the fallback members to prevent reuse. fallback_transport_socket_.reset(); fallback_addresses_.reset(); } + // N.B.: The owner of the ConnectJob will delete it after the callback is + // called, so the fallback socket, if any, won't stick around for long. + return result; } @@ -397,6 +440,17 @@ void TransportConnectJob::DoIPv6FallbackTransportConnectComplete(int result) { if (result == OK) { DCHECK(!fallback_connect_start_time_.is_null()); + + // Success will be returned via the fallback socket, so also include + // connection attempts made on the main socket up to this point. + // (Unfortunately, the only simple way to return information in the success + // case is through the successfully-connected socket.) + if (transport_socket_) { + ConnectionAttempts attempts; + transport_socket_->GetConnectionAttempts(&attempts); + fallback_transport_socket_->AddConnectionAttempts(attempts); + } + connect_timing_.connect_start = fallback_connect_start_time_; helper_.HistogramDuration( TransportConnectJobHelper::CONNECTION_LATENCY_IPV4_WINS_RACE); @@ -404,10 +458,18 @@ void TransportConnectJob::DoIPv6FallbackTransportConnectComplete(int result) { helper_.set_next_state(TransportConnectJobHelper::STATE_NONE); transport_socket_.reset(); } else { + // Failure will be returned via |GetAdditionalErrorState|, so save + // connection attempts from both sockets for use there. + CopyConnectionAttemptsFromSockets(); + // Be a bit paranoid and kill off the fallback members to prevent reuse. fallback_transport_socket_.reset(); fallback_addresses_.reset(); } + + // N.B.: The owner of the ConnectJob will delete it after the callback is + // called, so the main socket, if any, won't stick around for long. + NotifyDelegateOfCompletion(result); // Deletes |this| } @@ -415,6 +477,15 @@ int TransportConnectJob::ConnectInternal() { return helper_.DoConnectInternal(this); } +void TransportConnectJob::CopyConnectionAttemptsFromSockets() { + if (transport_socket_) + transport_socket_->GetConnectionAttempts(&connection_attempts_); + if (fallback_transport_socket_) { + fallback_transport_socket_->GetConnectionAttempts( + &fallback_connection_attempts_); + } +} + scoped_ptr<ConnectJob> TransportClientSocketPool::TransportConnectJobFactory::NewConnectJob( const std::string& group_name, @@ -440,15 +511,17 @@ base::TimeDelta TransportClientSocketPool::TransportClientSocketPool( int max_sockets, int max_sockets_per_group, - ClientSocketPoolHistograms* histograms, HostResolver* host_resolver, ClientSocketFactory* client_socket_factory, NetLog* net_log) - : base_(NULL, max_sockets, max_sockets_per_group, histograms, + : base_(NULL, + max_sockets, + max_sockets_per_group, ClientSocketPool::unused_idle_socket_timeout(), ClientSocketPool::used_idle_socket_timeout(), new TransportConnectJobFactory(client_socket_factory, - host_resolver, net_log)) { + host_resolver, + net_log)) { base_.EnableConnectBackupJobs(); } @@ -473,7 +546,7 @@ int TransportClientSocketPool::RequestSocket( void TransportClientSocketPool::NetLogTcpClientSocketPoolRequestedSocket( const BoundNetLog& net_log, const scoped_refptr<TransportSocketParams>* casted_params) { - if (net_log.IsLogging()) { + if (net_log.IsCapturing()) { // TODO(eroman): Split out the host and port parameters. net_log.AddEvent( NetLog::TYPE_TCP_CLIENT_SOCKET_POOL_REQUESTED_SOCKET, @@ -490,7 +563,7 @@ void TransportClientSocketPool::RequestSockets( const scoped_refptr<TransportSocketParams>* casted_params = static_cast<const scoped_refptr<TransportSocketParams>*>(params); - if (net_log.IsLogging()) { + if (net_log.IsCapturing()) { // TODO(eroman): Split out the host and port parameters. net_log.AddEvent( NetLog::TYPE_TCP_CLIENT_SOCKET_POOL_REQUESTED_SOCKETS, @@ -547,10 +620,6 @@ base::TimeDelta TransportClientSocketPool::ConnectionTimeout() const { return base_.ConnectionTimeout(); } -ClientSocketPoolHistograms* TransportClientSocketPool::histograms() const { - return base_.histograms(); -} - bool TransportClientSocketPool::IsStalled() const { return base_.IsStalled(); } diff --git a/chromium/net/socket/transport_client_socket_pool.h b/chromium/net/socket/transport_client_socket_pool.h index 15cef5c02ac..c2c8ab2a8a6 100644 --- a/chromium/net/socket/transport_client_socket_pool.h +++ b/chromium/net/socket/transport_client_socket_pool.h @@ -17,7 +17,7 @@ #include "net/dns/single_request_host_resolver.h" #include "net/socket/client_socket_pool.h" #include "net/socket/client_socket_pool_base.h" -#include "net/socket/client_socket_pool_histograms.h" +#include "net/socket/connection_attempts.h" namespace net { @@ -168,6 +168,7 @@ class NET_EXPORT_PRIVATE TransportConnectJob : public ConnectJob { // ConnectJob methods. LoadState GetLoadState() const override; + void GetAdditionalErrorState(ClientSocketHandle* handle) override; // Rolls |addrlist| forward until the first IPv4 address, if any. // WARNING: this method should only be used to implement the prefer-IPv4 hack. @@ -196,6 +197,8 @@ class NET_EXPORT_PRIVATE TransportConnectJob : public ConnectJob { // Otherwise, it returns a net error code. int ConnectInternal() override; + void CopyConnectionAttemptsFromSockets(); + TransportConnectJobHelper helper_; scoped_ptr<StreamSocket> transport_socket_; @@ -208,6 +211,16 @@ class NET_EXPORT_PRIVATE TransportConnectJob : public ConnectJob { // Track the interval between this connect and previous connect. ConnectInterval interval_between_connects_; + int resolve_result_; + + // Used in the failure case to save connection attempts made on the main and + // fallback sockets and pass them on in |GetAdditionalErrorState|. (In the + // success case, connection attempts are passed through the returned socket; + // attempts are copied from the other socket, if one exists, into it before + // it is returned.) + ConnectionAttempts connection_attempts_; + ConnectionAttempts fallback_connection_attempts_; + DISALLOW_COPY_AND_ASSIGN(TransportConnectJob); }; @@ -218,7 +231,6 @@ class NET_EXPORT_PRIVATE TransportClientSocketPool : public ClientSocketPool { TransportClientSocketPool( int max_sockets, int max_sockets_per_group, - ClientSocketPoolHistograms* histograms, HostResolver* host_resolver, ClientSocketFactory* client_socket_factory, NetLog* net_log); @@ -252,7 +264,6 @@ class NET_EXPORT_PRIVATE TransportClientSocketPool : public ClientSocketPool { const std::string& type, bool include_nested_pools) const override; base::TimeDelta ConnectionTimeout() const override; - ClientSocketPoolHistograms* histograms() const override; // HigherLayeredPool implementation. bool IsStalled() const override; diff --git a/chromium/net/socket/transport_client_socket_pool_test_util.cc b/chromium/net/socket/transport_client_socket_pool_test_util.cc index 82ed8e6a78e..352fd878fe5 100644 --- a/chromium/net/socket/transport_client_socket_pool_test_util.cc +++ b/chromium/net/socket/transport_client_socket_pool_test_util.cc @@ -69,6 +69,11 @@ class MockConnectClientSocket : public StreamSocket { bool WasNpnNegotiated() const override { return false; } NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; } bool GetSSLInfo(SSLInfo* ssl_info) override { return false; } + void GetConnectionAttempts(ConnectionAttempts* out) const override { + out->clear(); + } + void ClearConnectionAttempts() override {} + void AddConnectionAttempts(const ConnectionAttempts& attempts) override {} // Socket implementation. int Read(IOBuffer* buf, @@ -125,6 +130,13 @@ class MockFailingClientSocket : public StreamSocket { bool WasNpnNegotiated() const override { return false; } NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; } bool GetSSLInfo(SSLInfo* ssl_info) override { return false; } + void GetConnectionAttempts(ConnectionAttempts* out) const override { + out->clear(); + for (const auto& addr : addrlist_) + out->push_back(ConnectionAttempt(addr, ERR_CONNECTION_FAILED)); + } + void ClearConnectionAttempts() override {} + void AddConnectionAttempts(const ConnectionAttempts& attempts) override {} // Socket implementation. int Read(IOBuffer* buf, @@ -196,9 +208,16 @@ class MockTriggerableClientSocket : public StreamSocket { static scoped_ptr<StreamSocket> MakeMockStalledClientSocket( const AddressList& addrlist, - net::NetLog* net_log) { + net::NetLog* net_log, + bool failing) { scoped_ptr<MockTriggerableClientSocket> socket( new MockTriggerableClientSocket(addrlist, true, net_log)); + if (failing) { + DCHECK_LE(1u, addrlist.size()); + ConnectionAttempts attempts; + attempts.push_back(ConnectionAttempt(addrlist[0], ERR_CONNECTION_FAILED)); + socket->AddConnectionAttempts(attempts); + } return socket.Pass(); } @@ -236,6 +255,14 @@ class MockTriggerableClientSocket : public StreamSocket { bool WasNpnNegotiated() const override { return false; } NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; } bool GetSSLInfo(SSLInfo* ssl_info) override { return false; } + void GetConnectionAttempts(ConnectionAttempts* out) const override { + *out = connection_attempts_; + } + void ClearConnectionAttempts() override { connection_attempts_.clear(); } + void AddConnectionAttempts(const ConnectionAttempts& attempts) override { + connection_attempts_.insert(connection_attempts_.begin(), attempts.begin(), + attempts.end()); + } // Socket implementation. int Read(IOBuffer* buf, @@ -264,6 +291,7 @@ class MockTriggerableClientSocket : public StreamSocket { BoundNetLog net_log_; CompletionCallback callback_; bool use_tcp_fastopen_; + ConnectionAttempts connection_attempts_; base::WeakPtrFactory<MockTriggerableClientSocket> weak_factory_; @@ -364,8 +392,11 @@ MockTransportClientSocketFactory::CreateTransportClientSocket( return MockTriggerableClientSocket::MakeMockDelayedClientSocket( addresses, false, delay_, net_log_); case MOCK_STALLED_CLIENT_SOCKET: - return MockTriggerableClientSocket::MakeMockStalledClientSocket(addresses, - net_log_); + return MockTriggerableClientSocket::MakeMockStalledClientSocket( + addresses, net_log_, false); + case MOCK_STALLED_FAILING_CLIENT_SOCKET: + return MockTriggerableClientSocket::MakeMockStalledClientSocket( + addresses, net_log_, true); case MOCK_TRIGGERABLE_CLIENT_SOCKET: { scoped_ptr<MockTriggerableClientSocket> rv( new MockTriggerableClientSocket(addresses, true, net_log_)); diff --git a/chromium/net/socket/transport_client_socket_pool_test_util.h b/chromium/net/socket/transport_client_socket_pool_test_util.h index b375353f06f..6e38af66613 100644 --- a/chromium/net/socket/transport_client_socket_pool_test_util.h +++ b/chromium/net/socket/transport_client_socket_pool_test_util.h @@ -17,7 +17,7 @@ #include "base/memory/scoped_ptr.h" #include "base/time/time.h" #include "net/base/address_list.h" -#include "net/base/net_log.h" +#include "net/log/net_log.h" #include "net/socket/client_socket_factory.h" #include "net/socket/client_socket_handle.h" #include "net/socket/stream_socket.h" @@ -62,6 +62,9 @@ class MockTransportClientSocketFactory : public ClientSocketFactory { MOCK_DELAYED_FAILING_CLIENT_SOCKET, // A stalled socket that never connects at all. MOCK_STALLED_CLIENT_SOCKET, + // A stalled socket that never connects at all, but returns a failing + // ConnectionAttempt in |GetConnectionAttempts|. + MOCK_STALLED_FAILING_CLIENT_SOCKET, // A socket that can be triggered to connect explicitly, asynchronously. MOCK_TRIGGERABLE_CLIENT_SOCKET, }; diff --git a/chromium/net/socket/transport_client_socket_pool_unittest.cc b/chromium/net/socket/transport_client_socket_pool_unittest.cc index c0687ef5a43..1a00d46a47a 100644 --- a/chromium/net/socket/transport_client_socket_pool_unittest.cc +++ b/chromium/net/socket/transport_client_socket_pool_unittest.cc @@ -9,7 +9,6 @@ #include "base/callback.h" #include "base/message_loop/message_loop.h" #include "base/threading/platform_thread.h" -#include "net/base/capturing_net_log.h" #include "net/base/ip_endpoint.h" #include "net/base/load_timing_info.h" #include "net/base/load_timing_info_test_util.h" @@ -17,8 +16,8 @@ #include "net/base/net_util.h" #include "net/base/test_completion_callback.h" #include "net/dns/mock_host_resolver.h" +#include "net/log/test_net_log.h" #include "net/socket/client_socket_handle.h" -#include "net/socket/client_socket_pool_histograms.h" #include "net/socket/socket_test_util.h" #include "net/socket/stream_socket.h" #include "net/socket/transport_client_socket_pool_test_util.h" @@ -46,12 +45,10 @@ class TransportClientSocketPoolTest : public testing::Test { false, OnHostResolutionCallback(), TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)), - histograms_(new ClientSocketPoolHistograms("TCPUnitTest")), host_resolver_(new MockHostResolver), client_socket_factory_(&net_log_), pool_(kMaxSockets, kMaxSocketsPerGroup, - histograms_.get(), host_resolver_.get(), &client_socket_factory_, NULL) { @@ -93,9 +90,8 @@ class TransportClientSocketPoolTest : public testing::Test { size_t completion_count() const { return test_base_.completion_count(); } bool connect_backup_jobs_enabled_; - CapturingNetLog net_log_; + TestNetLog net_log_; scoped_refptr<TransportSocketParams> params_; - scoped_ptr<ClientSocketPoolHistograms> histograms_; scoped_ptr<MockHostResolver> host_resolver_; MockTransportClientSocketFactory client_socket_factory_; TransportClientSocketPool pool_; @@ -189,6 +185,7 @@ TEST_F(TransportClientSocketPoolTest, Basic) { EXPECT_TRUE(handle.is_initialized()); EXPECT_TRUE(handle.socket()); TestLoadTimingInfoConnectedNotReused(handle); + EXPECT_EQ(0u, handle.connection_attempts().size()); } // Make sure that TransportConnectJob passes on its priority to its @@ -217,6 +214,9 @@ TEST_F(TransportClientSocketPoolTest, InitHostResolutionFailure) { handle.Init("a", dest, kDefaultPriority, callback.callback(), &pool_, BoundNetLog())); EXPECT_EQ(ERR_NAME_NOT_RESOLVED, callback.WaitForResult()); + ASSERT_EQ(1u, handle.connection_attempts().size()); + EXPECT_TRUE(handle.connection_attempts()[0].endpoint.address().empty()); + EXPECT_EQ(ERR_NAME_NOT_RESOLVED, handle.connection_attempts()[0].result); } TEST_F(TransportClientSocketPoolTest, InitConnectionFailure) { @@ -228,12 +228,20 @@ TEST_F(TransportClientSocketPoolTest, InitConnectionFailure) { handle.Init("a", params_, kDefaultPriority, callback.callback(), &pool_, BoundNetLog())); EXPECT_EQ(ERR_CONNECTION_FAILED, callback.WaitForResult()); + ASSERT_EQ(1u, handle.connection_attempts().size()); + EXPECT_EQ("127.0.0.1:80", + handle.connection_attempts()[0].endpoint.ToString()); + EXPECT_EQ(ERR_CONNECTION_FAILED, handle.connection_attempts()[0].result); // Make the host resolutions complete synchronously this time. host_resolver_->set_synchronous_mode(true); EXPECT_EQ(ERR_CONNECTION_FAILED, handle.Init("a", params_, kDefaultPriority, callback.callback(), &pool_, BoundNetLog())); + ASSERT_EQ(1u, handle.connection_attempts().size()); + EXPECT_EQ("127.0.0.1:80", + handle.connection_attempts()[0].endpoint.ToString()); + EXPECT_EQ(ERR_CONNECTION_FAILED, handle.connection_attempts()[0].result); } TEST_F(TransportClientSocketPoolTest, PendingRequests) { @@ -768,6 +776,8 @@ TEST_F(TransportClientSocketPoolTest, BackupSocketFailAfterStall) { EXPECT_EQ(ERR_CONNECTION_FAILED, callback.WaitForResult()); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); + ASSERT_EQ(1u, handle.connection_attempts().size()); + EXPECT_EQ(ERR_CONNECTION_FAILED, handle.connection_attempts()[0].result); EXPECT_EQ(0, pool_.IdleSocketCount()); handle.Reset(); @@ -816,6 +826,8 @@ TEST_F(TransportClientSocketPoolTest, BackupSocketFailAfterDelay) { EXPECT_EQ(ERR_CONNECTION_FAILED, callback.WaitForResult()); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); + ASSERT_EQ(1u, handle.connection_attempts().size()); + EXPECT_EQ(ERR_CONNECTION_FAILED, handle.connection_attempts()[0].result); handle.Reset(); // Reset for the next case. @@ -829,17 +841,16 @@ TEST_F(TransportClientSocketPoolTest, IPv6FallbackSocketIPv4FinishesFirst) { ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(false); TransportClientSocketPool pool(kMaxSockets, kMaxSocketsPerGroup, - histograms_.get(), host_resolver_.get(), &client_socket_factory_, NULL); MockTransportClientSocketFactory::ClientSocketType case_types[] = { - // This is the IPv6 socket. - MockTransportClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET, - // This is the IPv4 socket. - MockTransportClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET - }; + // This is the IPv6 socket. It stalls, but presents one failed connection + // attempt on GetConnectionAttempts. + MockTransportClientSocketFactory::MOCK_STALLED_FAILING_CLIENT_SOCKET, + // This is the IPv4 socket. + MockTransportClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET}; client_socket_factory_.set_client_socket_types(case_types, 2); @@ -861,6 +872,14 @@ TEST_F(TransportClientSocketPoolTest, IPv6FallbackSocketIPv4FinishesFirst) { IPEndPoint endpoint; handle.socket()->GetLocalAddress(&endpoint); EXPECT_EQ(kIPv4AddressSize, endpoint.address().size()); + + // Check that the failed connection attempt on the main socket is collected. + ConnectionAttempts attempts; + handle.socket()->GetConnectionAttempts(&attempts); + ASSERT_EQ(1u, attempts.size()); + EXPECT_EQ(ERR_CONNECTION_FAILED, attempts[0].result); + EXPECT_EQ(kIPv6AddressSize, attempts[0].endpoint.address().size()); + EXPECT_EQ(2, client_socket_factory_.allocation_count()); } @@ -872,17 +891,16 @@ TEST_F(TransportClientSocketPoolTest, IPv6FallbackSocketIPv6FinishesFirst) { ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(false); TransportClientSocketPool pool(kMaxSockets, kMaxSocketsPerGroup, - histograms_.get(), host_resolver_.get(), &client_socket_factory_, NULL); MockTransportClientSocketFactory::ClientSocketType case_types[] = { - // This is the IPv6 socket. - MockTransportClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET, - // This is the IPv4 socket. - MockTransportClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET - }; + // This is the IPv6 socket. + MockTransportClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET, + // This is the IPv4 socket. It stalls, but presents one failed connection + // attempt on GetConnectionATtempts. + MockTransportClientSocketFactory::MOCK_STALLED_FAILING_CLIENT_SOCKET}; client_socket_factory_.set_client_socket_types(case_types, 2); client_socket_factory_.set_delay(base::TimeDelta::FromMilliseconds( @@ -906,6 +924,15 @@ TEST_F(TransportClientSocketPoolTest, IPv6FallbackSocketIPv6FinishesFirst) { IPEndPoint endpoint; handle.socket()->GetLocalAddress(&endpoint); EXPECT_EQ(kIPv6AddressSize, endpoint.address().size()); + + // Check that the failed connection attempt on the fallback socket is + // collected. + ConnectionAttempts attempts; + handle.socket()->GetConnectionAttempts(&attempts); + ASSERT_EQ(1u, attempts.size()); + EXPECT_EQ(ERR_CONNECTION_FAILED, attempts[0].result); + EXPECT_EQ(kIPv4AddressSize, attempts[0].endpoint.address().size()); + EXPECT_EQ(2, client_socket_factory_.allocation_count()); } @@ -914,7 +941,6 @@ TEST_F(TransportClientSocketPoolTest, IPv6NoIPv4AddressesToFallbackTo) { ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(false); TransportClientSocketPool pool(kMaxSockets, kMaxSocketsPerGroup, - histograms_.get(), host_resolver_.get(), &client_socket_factory_, NULL); @@ -940,6 +966,7 @@ TEST_F(TransportClientSocketPoolTest, IPv6NoIPv4AddressesToFallbackTo) { IPEndPoint endpoint; handle.socket()->GetLocalAddress(&endpoint); EXPECT_EQ(kIPv6AddressSize, endpoint.address().size()); + EXPECT_EQ(0u, handle.connection_attempts().size()); EXPECT_EQ(1, client_socket_factory_.allocation_count()); } @@ -948,7 +975,6 @@ TEST_F(TransportClientSocketPoolTest, IPv4HasNoFallback) { ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(false); TransportClientSocketPool pool(kMaxSockets, kMaxSocketsPerGroup, - histograms_.get(), host_resolver_.get(), &client_socket_factory_, NULL); @@ -973,6 +999,7 @@ TEST_F(TransportClientSocketPoolTest, IPv4HasNoFallback) { IPEndPoint endpoint; handle.socket()->GetLocalAddress(&endpoint); EXPECT_EQ(kIPv4AddressSize, endpoint.address().size()); + EXPECT_EQ(0u, handle.connection_attempts().size()); EXPECT_EQ(1, client_socket_factory_.allocation_count()); } @@ -983,7 +1010,6 @@ TEST_F(TransportClientSocketPoolTest, TCPFastOpenOnIPv4WithNoFallback) { ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(false); TransportClientSocketPool pool(kMaxSockets, kMaxSocketsPerGroup, - histograms_.get(), host_resolver_.get(), &client_socket_factory_, NULL); @@ -1008,7 +1034,6 @@ TEST_F(TransportClientSocketPoolTest, TCPFastOpenOnIPv6WithNoFallback) { ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(false); TransportClientSocketPool pool(kMaxSockets, kMaxSocketsPerGroup, - histograms_.get(), host_resolver_.get(), &client_socket_factory_, NULL); @@ -1036,7 +1061,6 @@ TEST_F(TransportClientSocketPoolTest, ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(false); TransportClientSocketPool pool(kMaxSockets, kMaxSocketsPerGroup, - histograms_.get(), host_resolver_.get(), &client_socket_factory_, NULL); @@ -1076,7 +1100,6 @@ TEST_F(TransportClientSocketPoolTest, ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(false); TransportClientSocketPool pool(kMaxSockets, kMaxSocketsPerGroup, - histograms_.get(), host_resolver_.get(), &client_socket_factory_, NULL); diff --git a/chromium/net/socket/transport_client_socket_unittest.cc b/chromium/net/socket/transport_client_socket_unittest.cc index d01cbad6dc8..9ab358cb891 100644 --- a/chromium/net/socket/transport_client_socket_unittest.cc +++ b/chromium/net/socket/transport_client_socket_unittest.cc @@ -2,21 +2,25 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "net/socket/tcp_client_socket.h" +#include <string> #include "base/basictypes.h" +#include "base/bind.h" #include "base/memory/ref_counted.h" #include "base/memory/scoped_ptr.h" +#include "base/run_loop.h" #include "net/base/address_list.h" #include "net/base/io_buffer.h" #include "net/base/net_errors.h" -#include "net/base/net_log.h" -#include "net/base/net_log_unittest.h" #include "net/base/test_completion_callback.h" -#include "net/base/winsock_init.h" #include "net/dns/mock_host_resolver.h" +#include "net/log/net_log.h" +#include "net/log/test_net_log.h" +#include "net/log/test_net_log_entry.h" +#include "net/log/test_net_log_util.h" #include "net/socket/client_socket_factory.h" -#include "net/socket/tcp_listen_socket.h" +#include "net/socket/tcp_client_socket.h" +#include "net/socket/tcp_server_socket.h" #include "testing/gtest/include/gtest/gtest.h" #include "testing/platform_test.h" @@ -26,40 +30,19 @@ namespace { const char kServerReply[] = "HTTP/1.1 404 Not Found"; -enum ClientSocketTestTypes { - TCP, - SCTP -}; +enum ClientSocketTestTypes { TCP, SCTP }; } // namespace class TransportClientSocketTest - : public StreamListenSocket::Delegate, - public ::testing::TestWithParam<ClientSocketTestTypes> { + : public ::testing::TestWithParam<ClientSocketTestTypes> { public: TransportClientSocketTest() : listen_port_(0), socket_factory_(ClientSocketFactory::GetDefaultFactory()), - close_server_socket_on_next_send_(false) { - } - - virtual ~TransportClientSocketTest() { - } + close_server_socket_on_next_send_(false) {} - // Implement StreamListenSocket::Delegate methods - void DidAccept(StreamListenSocket* server, - scoped_ptr<StreamListenSocket> connection) override { - connected_sock_.reset( - static_cast<TCPListenSocket*>(connection.release())); - } - void DidRead(StreamListenSocket*, const char* str, int len) override { - // TODO(dkegel): this might not be long enough to tickle some bugs. - connected_sock_->Send(kServerReply, arraysize(kServerReply) - 1, - false /* Don't append line feed */); - if (close_server_socket_on_next_send_) - CloseServerSocket(); - } - void DidClose(StreamListenSocket* sock) override {} + virtual ~TransportClientSocketTest() {} // Testcase hooks void SetUp() override; @@ -69,12 +52,9 @@ class TransportClientSocketTest connected_sock_.reset(); } - void PauseServerReads() { - connected_sock_->PauseReads(); - } - - void ResumeServerReads() { - connected_sock_->ResumeReads(); + void AcceptCallback(int res) { + ASSERT_EQ(OK, res); + connect_loop_.Quit(); } int DrainClientSocket(IOBuffer* buf, @@ -82,96 +62,164 @@ class TransportClientSocketTest uint32 bytes_to_read, TestCompletionCallback* callback); - void SendClientRequest(); + // Establishes a connection to the server. + void EstablishConnection(TestCompletionCallback* callback); + + // Sends a request from the client to the server socket. Makes the server read + // the request and send a response. + void SendRequestAndResponse(); + + // Makes |connected_sock_| to read |expected_bytes_read| bytes. Returns the + // the data read as a string. + std::string ReadServerData(int expected_bytes_read); + + // Sends server response. + void SendServerResponse(); void set_close_server_socket_on_next_send(bool close) { close_server_socket_on_next_send_ = close; } protected: - int listen_port_; - CapturingNetLog net_log_; + base::RunLoop connect_loop_; + uint16 listen_port_; + TestNetLog net_log_; ClientSocketFactory* const socket_factory_; scoped_ptr<StreamSocket> sock_; + scoped_ptr<StreamSocket> connected_sock_; private: - scoped_ptr<TCPListenSocket> listen_sock_; - scoped_ptr<TCPListenSocket> connected_sock_; + scoped_ptr<TCPServerSocket> listen_sock_; bool close_server_socket_on_next_send_; }; void TransportClientSocketTest::SetUp() { ::testing::TestWithParam<ClientSocketTestTypes>::SetUp(); - // Find a free port to listen on - scoped_ptr<TCPListenSocket> sock; - int port; - // Range of ports to listen on. Shouldn't need to try many. - const int kMinPort = 10100; - const int kMaxPort = 10200; -#if defined(OS_WIN) - EnsureWinsockInit(); -#endif - for (port = kMinPort; port < kMaxPort; port++) { - sock = TCPListenSocket::CreateAndListen("127.0.0.1", port, this); - if (sock.get()) - break; - } - ASSERT_TRUE(sock.get() != NULL); - listen_sock_ = sock.Pass(); - listen_port_ = port; + // Open a server socket on an ephemeral port. + listen_sock_.reset(new TCPServerSocket(NULL, NetLog::Source())); + IPAddressNumber address; + ParseIPLiteralToNumber("127.0.0.1", &address); + IPEndPoint local_address(address, 0); + ASSERT_EQ(OK, listen_sock_->Listen(local_address, 1)); + // Get the server's address (including the actual port number). + ASSERT_EQ(OK, listen_sock_->GetLocalAddress(&local_address)); + listen_port_ = local_address.port(); + listen_sock_->Accept(&connected_sock_, + base::Bind(&TransportClientSocketTest::AcceptCallback, + base::Unretained(this))); AddressList addr; // MockHostResolver resolves everything to 127.0.0.1. scoped_ptr<HostResolver> resolver(new MockHostResolver()); HostResolver::RequestInfo info(HostPortPair("localhost", listen_port_)); TestCompletionCallback callback; - int rv = resolver->Resolve( - info, DEFAULT_PRIORITY, &addr, callback.callback(), NULL, BoundNetLog()); + int rv = resolver->Resolve(info, DEFAULT_PRIORITY, &addr, callback.callback(), + NULL, BoundNetLog()); CHECK_EQ(ERR_IO_PENDING, rv); rv = callback.WaitForResult(); CHECK_EQ(rv, OK); - sock_ = - socket_factory_->CreateTransportClientSocket(addr, - &net_log_, - NetLog::Source()); + sock_ = socket_factory_->CreateTransportClientSocket(addr, &net_log_, + NetLog::Source()); } int TransportClientSocketTest::DrainClientSocket( - IOBuffer* buf, uint32 buf_len, - uint32 bytes_to_read, TestCompletionCallback* callback) { + IOBuffer* buf, + uint32 buf_len, + uint32 bytes_to_read, + TestCompletionCallback* callback) { int rv = OK; uint32 bytes_read = 0; while (bytes_read < bytes_to_read) { rv = sock_->Read(buf, buf_len, callback->callback()); EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); - - if (rv == ERR_IO_PENDING) - rv = callback->WaitForResult(); - - EXPECT_GE(rv, 0); + rv = callback->GetResult(rv); + EXPECT_GT(rv, 0); bytes_read += rv; } return static_cast<int>(bytes_read); } -void TransportClientSocketTest::SendClientRequest() { +void TransportClientSocketTest::EstablishConnection( + TestCompletionCallback* callback) { + int rv = sock_->Connect(callback->callback()); + // Wait for |listen_sock_| to accept a connection. + connect_loop_.Run(); + // Now wait for the client socket to accept the connection. + EXPECT_EQ(OK, callback->GetResult(rv)); +} + +void TransportClientSocketTest::SendRequestAndResponse() { + // Send client request. const char request_text[] = "GET / HTTP/1.0\r\n\r\n"; - scoped_refptr<IOBuffer> request_buffer( - new IOBuffer(arraysize(request_text) - 1)); - TestCompletionCallback callback; - int rv; + int request_len = strlen(request_text); + scoped_refptr<DrainableIOBuffer> request_buffer( + new DrainableIOBuffer(new IOBuffer(request_len), request_len)); + memcpy(request_buffer->data(), request_text, request_len); + + int bytes_written = 0; + while (request_buffer->BytesRemaining() > 0) { + TestCompletionCallback write_callback; + int write_result = + sock_->Write(request_buffer.get(), request_buffer->BytesRemaining(), + write_callback.callback()); + write_result = write_callback.GetResult(write_result); + ASSERT_GT(write_result, 0); + ASSERT_LE(bytes_written + write_result, request_len); + request_buffer->DidConsume(write_result); + bytes_written += write_result; + } + ASSERT_EQ(request_len, bytes_written); - memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1); - rv = sock_->Write( - request_buffer.get(), arraysize(request_text) - 1, callback.callback()); - EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); + // Confirm that the server receives what client sent. + std::string data_received = ReadServerData(bytes_written); + ASSERT_TRUE(connected_sock_->IsConnectedAndIdle()); + ASSERT_EQ(request_text, data_received); - if (rv == ERR_IO_PENDING) - rv = callback.WaitForResult(); - EXPECT_EQ(rv, static_cast<int>(arraysize(request_text) - 1)); + // Write server response. + SendServerResponse(); +} + +void TransportClientSocketTest::SendServerResponse() { + // TODO(dkegel): this might not be long enough to tickle some bugs. + int reply_len = strlen(kServerReply); + scoped_refptr<DrainableIOBuffer> write_buffer( + new DrainableIOBuffer(new IOBuffer(reply_len), reply_len)); + memcpy(write_buffer->data(), kServerReply, reply_len); + int bytes_written = 0; + while (write_buffer->BytesRemaining() > 0) { + TestCompletionCallback write_callback; + int write_result = connected_sock_->Write(write_buffer.get(), + write_buffer->BytesRemaining(), + write_callback.callback()); + write_result = write_callback.GetResult(write_result); + ASSERT_GE(write_result, 0); + ASSERT_LE(bytes_written + write_result, reply_len); + write_buffer->DidConsume(write_result); + bytes_written += write_result; + } + if (close_server_socket_on_next_send_) + CloseServerSocket(); +} + +std::string TransportClientSocketTest::ReadServerData(int expected_bytes_read) { + int bytes_read = 0; + scoped_refptr<IOBufferWithSize> read_buffer( + new IOBufferWithSize(expected_bytes_read)); + while (bytes_read < expected_bytes_read) { + TestCompletionCallback read_callback; + int rv = connected_sock_->Read(read_buffer.get(), + expected_bytes_read - bytes_read, + read_callback.callback()); + EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); + rv = read_callback.GetResult(rv); + EXPECT_GE(rv, 0); + bytes_read += rv; + } + EXPECT_EQ(expected_bytes_read, bytes_read); + return std::string(read_buffer->data(), bytes_read); } // TODO(leighton): Add SCTP to this list when it is ready. @@ -184,13 +232,16 @@ TEST_P(TransportClientSocketTest, Connect) { EXPECT_FALSE(sock_->IsConnected()); int rv = sock_->Connect(callback.callback()); + // Wait for |listen_sock_| to accept a connection. + connect_loop_.Run(); - net::CapturingNetLog::CapturedEntryList net_log_entries; + TestNetLogEntry::List net_log_entries; net_log_.GetEntries(&net_log_entries); - EXPECT_TRUE(net::LogContainsBeginEvent( - net_log_entries, 0, net::NetLog::TYPE_SOCKET_ALIVE)); - EXPECT_TRUE(net::LogContainsBeginEvent( - net_log_entries, 1, net::NetLog::TYPE_TCP_CONNECT)); + EXPECT_TRUE( + LogContainsBeginEvent(net_log_entries, 0, NetLog::TYPE_SOCKET_ALIVE)); + EXPECT_TRUE( + LogContainsBeginEvent(net_log_entries, 1, NetLog::TYPE_TCP_CONNECT)); + // Now wait for the client socket to accept the connection. if (rv != OK) { ASSERT_EQ(rv, ERR_IO_PENDING); rv = callback.WaitForResult(); @@ -199,8 +250,8 @@ TEST_P(TransportClientSocketTest, Connect) { EXPECT_TRUE(sock_->IsConnected()); net_log_.GetEntries(&net_log_entries); - EXPECT_TRUE(net::LogContainsEndEvent( - net_log_entries, -1, net::NetLog::TYPE_TCP_CONNECT)); + EXPECT_TRUE( + LogContainsEndEvent(net_log_entries, -1, NetLog::TYPE_TCP_CONNECT)); sock_->Disconnect(); EXPECT_FALSE(sock_->IsConnected()); @@ -213,17 +264,14 @@ TEST_P(TransportClientSocketTest, IsConnected) { EXPECT_FALSE(sock_->IsConnected()); EXPECT_FALSE(sock_->IsConnectedAndIdle()); - int rv = sock_->Connect(callback.callback()); - if (rv != OK) { - ASSERT_EQ(rv, ERR_IO_PENDING); - rv = callback.WaitForResult(); - EXPECT_EQ(rv, OK); - } + + EstablishConnection(&callback); + EXPECT_TRUE(sock_->IsConnected()); EXPECT_TRUE(sock_->IsConnectedAndIdle()); // Send the request and wait for the server to respond. - SendClientRequest(); + SendRequestAndResponse(); // Drain a single byte so we know we've received some data. bytes_read = DrainClientSocket(buf.get(), 1, 1, &callback); @@ -234,9 +282,9 @@ TEST_P(TransportClientSocketTest, IsConnected) { EXPECT_TRUE(sock_->IsConnected()); EXPECT_FALSE(sock_->IsConnectedAndIdle()); - bytes_read = DrainClientSocket( - buf.get(), 4096, arraysize(kServerReply) - 2, &callback); - ASSERT_EQ(bytes_read, arraysize(kServerReply) - 2); + bytes_read = + DrainClientSocket(buf.get(), 4096, strlen(kServerReply) - 1, &callback); + ASSERT_EQ(bytes_read, strlen(kServerReply) - 1); // After draining the data, the socket should be back to connected // and idle. @@ -245,7 +293,7 @@ TEST_P(TransportClientSocketTest, IsConnected) { // This time close the server socket immediately after the server response. set_close_server_socket_on_next_send(true); - SendClientRequest(); + SendRequestAndResponse(); bytes_read = DrainClientSocket(buf.get(), 1, 1, &callback); ASSERT_EQ(bytes_read, 1u); @@ -254,16 +302,16 @@ TEST_P(TransportClientSocketTest, IsConnected) { EXPECT_TRUE(sock_->IsConnected()); EXPECT_FALSE(sock_->IsConnectedAndIdle()); - bytes_read = DrainClientSocket( - buf.get(), 4096, arraysize(kServerReply) - 2, &callback); - ASSERT_EQ(bytes_read, arraysize(kServerReply) - 2); + bytes_read = + DrainClientSocket(buf.get(), 4096, strlen(kServerReply) - 1, &callback); + ASSERT_EQ(bytes_read, strlen(kServerReply) - 1); // Once the data is drained, the socket should now be seen as not // connected. if (sock_->IsConnected()) { // In the unlikely event that the server's connection closure is not // processed in time, wait for the connection to be closed. - rv = sock_->Read(buf.get(), 4096, callback.callback()); + int rv = sock_->Read(buf.get(), 4096, callback.callback()); EXPECT_EQ(0, callback.GetResult(rv)); EXPECT_FALSE(sock_->IsConnected()); } @@ -272,24 +320,20 @@ TEST_P(TransportClientSocketTest, IsConnected) { TEST_P(TransportClientSocketTest, Read) { TestCompletionCallback callback; - int rv = sock_->Connect(callback.callback()); - if (rv != OK) { - ASSERT_EQ(rv, ERR_IO_PENDING); + EstablishConnection(&callback); - rv = callback.WaitForResult(); - EXPECT_EQ(rv, OK); - } - SendClientRequest(); + SendRequestAndResponse(); scoped_refptr<IOBuffer> buf(new IOBuffer(4096)); - uint32 bytes_read = DrainClientSocket( - buf.get(), 4096, arraysize(kServerReply) - 1, &callback); - ASSERT_EQ(bytes_read, arraysize(kServerReply) - 1); + uint32 bytes_read = + DrainClientSocket(buf.get(), 4096, strlen(kServerReply), &callback); + ASSERT_EQ(bytes_read, strlen(kServerReply)); + ASSERT_EQ(std::string(kServerReply), std::string(buf->data(), bytes_read)); // All data has been read now. Read once more to force an ERR_IO_PENDING, and // then close the server socket, and note the close. - rv = sock_->Read(buf.get(), 4096, callback.callback()); + int rv = sock_->Read(buf.get(), 4096, callback.callback()); ASSERT_EQ(ERR_IO_PENDING, rv); CloseServerSocket(); EXPECT_EQ(0, callback.WaitForResult()); @@ -297,23 +341,17 @@ TEST_P(TransportClientSocketTest, Read) { TEST_P(TransportClientSocketTest, Read_SmallChunks) { TestCompletionCallback callback; - int rv = sock_->Connect(callback.callback()); - if (rv != OK) { - ASSERT_EQ(rv, ERR_IO_PENDING); + EstablishConnection(&callback); - rv = callback.WaitForResult(); - EXPECT_EQ(rv, OK); - } - SendClientRequest(); + SendRequestAndResponse(); scoped_refptr<IOBuffer> buf(new IOBuffer(1)); uint32 bytes_read = 0; - while (bytes_read < arraysize(kServerReply) - 1) { - rv = sock_->Read(buf.get(), 1, callback.callback()); + while (bytes_read < strlen(kServerReply)) { + int rv = sock_->Read(buf.get(), 1, callback.callback()); EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); - if (rv == ERR_IO_PENDING) - rv = callback.WaitForResult(); + rv = callback.GetResult(rv); ASSERT_EQ(1, rv); bytes_read += rv; @@ -322,7 +360,7 @@ TEST_P(TransportClientSocketTest, Read_SmallChunks) { // All data has been read now. Read once more to force an ERR_IO_PENDING, and // then close the server socket, and note the close. - rv = sock_->Read(buf.get(), 1, callback.callback()); + int rv = sock_->Read(buf.get(), 1, callback.callback()); ASSERT_EQ(ERR_IO_PENDING, rv); CloseServerSocket(); EXPECT_EQ(0, callback.WaitForResult()); @@ -330,59 +368,48 @@ TEST_P(TransportClientSocketTest, Read_SmallChunks) { TEST_P(TransportClientSocketTest, Read_Interrupted) { TestCompletionCallback callback; - int rv = sock_->Connect(callback.callback()); - if (rv != OK) { - ASSERT_EQ(ERR_IO_PENDING, rv); + EstablishConnection(&callback); - rv = callback.WaitForResult(); - EXPECT_EQ(rv, OK); - } - SendClientRequest(); + SendRequestAndResponse(); // Do a partial read and then exit. This test should not crash! scoped_refptr<IOBuffer> buf(new IOBuffer(16)); - rv = sock_->Read(buf.get(), 16, callback.callback()); + int rv = sock_->Read(buf.get(), 16, callback.callback()); EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); - if (rv == ERR_IO_PENDING) - rv = callback.WaitForResult(); + rv = callback.GetResult(rv); EXPECT_NE(0, rv); } -TEST_P(TransportClientSocketTest, DISABLED_FullDuplex_ReadFirst) { +TEST_P(TransportClientSocketTest, FullDuplex_ReadFirst) { TestCompletionCallback callback; - int rv = sock_->Connect(callback.callback()); - if (rv != OK) { - ASSERT_EQ(rv, ERR_IO_PENDING); - - rv = callback.WaitForResult(); - EXPECT_EQ(rv, OK); - } + EstablishConnection(&callback); // Read first. There's no data, so it should return ERR_IO_PENDING. const int kBufLen = 4096; scoped_refptr<IOBuffer> buf(new IOBuffer(kBufLen)); - rv = sock_->Read(buf.get(), kBufLen, callback.callback()); + int rv = sock_->Read(buf.get(), kBufLen, callback.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); - PauseServerReads(); const int kWriteBufLen = 64 * 1024; scoped_refptr<IOBuffer> request_buffer(new IOBuffer(kWriteBufLen)); char* request_data = request_buffer->data(); memset(request_data, 'A', kWriteBufLen); TestCompletionCallback write_callback; + int bytes_written = 0; while (true) { - rv = sock_->Write( - request_buffer.get(), kWriteBufLen, write_callback.callback()); + rv = sock_->Write(request_buffer.get(), kWriteBufLen, + write_callback.callback()); ASSERT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); - if (rv == ERR_IO_PENDING) { - ResumeServerReads(); + ReadServerData(bytes_written); + SendServerResponse(); rv = write_callback.WaitForResult(); break; } + bytes_written += rv; } // At this point, both read and write have returned ERR_IO_PENDING, and the @@ -393,30 +420,25 @@ TEST_P(TransportClientSocketTest, DISABLED_FullDuplex_ReadFirst) { EXPECT_GE(rv, 0); } -TEST_P(TransportClientSocketTest, DISABLED_FullDuplex_WriteFirst) { +TEST_P(TransportClientSocketTest, FullDuplex_WriteFirst) { TestCompletionCallback callback; - int rv = sock_->Connect(callback.callback()); - if (rv != OK) { - ASSERT_EQ(ERR_IO_PENDING, rv); - - rv = callback.WaitForResult(); - EXPECT_EQ(OK, rv); - } + EstablishConnection(&callback); - PauseServerReads(); const int kWriteBufLen = 64 * 1024; scoped_refptr<IOBuffer> request_buffer(new IOBuffer(kWriteBufLen)); char* request_data = request_buffer->data(); memset(request_data, 'A', kWriteBufLen); TestCompletionCallback write_callback; + int bytes_written = 0; while (true) { - rv = sock_->Write( - request_buffer.get(), kWriteBufLen, write_callback.callback()); + int rv = sock_->Write(request_buffer.get(), kWriteBufLen, + write_callback.callback()); ASSERT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); if (rv == ERR_IO_PENDING) break; + bytes_written += rv; } // Now we have the Write() blocked on ERR_IO_PENDING. It's time to force the @@ -425,7 +447,7 @@ TEST_P(TransportClientSocketTest, DISABLED_FullDuplex_WriteFirst) { const int kBufLen = 4096; scoped_refptr<IOBuffer> buf(new IOBuffer(kBufLen)); while (true) { - rv = sock_->Read(buf.get(), kBufLen, callback.callback()); + int rv = sock_->Read(buf.get(), kBufLen, callback.callback()); ASSERT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); if (rv == ERR_IO_PENDING) break; @@ -435,8 +457,9 @@ TEST_P(TransportClientSocketTest, DISABLED_FullDuplex_WriteFirst) { // run the write and read callbacks to make sure they can handle full duplex // communications. - ResumeServerReads(); - rv = write_callback.WaitForResult(); + ReadServerData(bytes_written); + SendServerResponse(); + int rv = write_callback.WaitForResult(); EXPECT_GE(rv, 0); // It's possible the read is blocked because it's already read all the data. diff --git a/chromium/net/socket/unix_domain_client_socket_posix.cc b/chromium/net/socket/unix_domain_client_socket_posix.cc index 5adbca9979e..79aa275bf4b 100644 --- a/chromium/net/socket/unix_domain_client_socket_posix.cc +++ b/chromium/net/socket/unix_domain_client_socket_posix.cc @@ -9,6 +9,7 @@ #include "base/logging.h" #include "base/posix/eintr_wrapper.h" +#include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" #include "net/base/net_util.h" #include "net/socket/socket_libevent.h" @@ -98,13 +99,25 @@ bool UnixDomainClientSocket::IsConnectedAndIdle() const { } int UnixDomainClientSocket::GetPeerAddress(IPEndPoint* address) const { - NOTIMPLEMENTED(); - return ERR_NOT_IMPLEMENTED; + // Unix domain sockets have no valid associated addr/port; + // return either not connected or address invalid. + DCHECK(address); + + if (!IsConnected()) + return ERR_SOCKET_NOT_CONNECTED; + + return ERR_ADDRESS_INVALID; } int UnixDomainClientSocket::GetLocalAddress(IPEndPoint* address) const { - NOTIMPLEMENTED(); - return ERR_NOT_IMPLEMENTED; + // Unix domain sockets have no valid associated addr/port; + // return either not connected or address invalid. + DCHECK(address); + + if (!socket_) + return ERR_SOCKET_NOT_CONNECTED; + + return ERR_ADDRESS_INVALID; } const BoundNetLog& UnixDomainClientSocket::NetLog() const { @@ -137,6 +150,11 @@ bool UnixDomainClientSocket::GetSSLInfo(SSLInfo* ssl_info) { return false; } +void UnixDomainClientSocket::GetConnectionAttempts( + ConnectionAttempts* out) const { + out->clear(); +} + int UnixDomainClientSocket::Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback) { DCHECK(socket_); diff --git a/chromium/net/socket/unix_domain_client_socket_posix.h b/chromium/net/socket/unix_domain_client_socket_posix.h index 2a8bdb625c9..77ef9d57944 100644 --- a/chromium/net/socket/unix_domain_client_socket_posix.h +++ b/chromium/net/socket/unix_domain_client_socket_posix.h @@ -12,7 +12,7 @@ #include "base/memory/scoped_ptr.h" #include "net/base/completion_callback.h" #include "net/base/net_export.h" -#include "net/base/net_log.h" +#include "net/log/net_log.h" #include "net/socket/socket_descriptor.h" #include "net/socket/stream_socket.h" @@ -55,6 +55,9 @@ class NET_EXPORT UnixDomainClientSocket : public StreamSocket { bool WasNpnNegotiated() const override; NextProto GetNegotiatedProtocol() const override; bool GetSSLInfo(SSLInfo* ssl_info) override; + void GetConnectionAttempts(ConnectionAttempts* out) const override; + void ClearConnectionAttempts() override {} + void AddConnectionAttempts(const ConnectionAttempts& attempts) override {} // Socket implementation. int Read(IOBuffer* buf, diff --git a/chromium/net/socket/unix_domain_listen_socket_posix.cc b/chromium/net/socket/unix_domain_listen_socket_posix.cc index 3e46439c8b5..333ef0b4d0f 100644 --- a/chromium/net/socket/unix_domain_listen_socket_posix.cc +++ b/chromium/net/socket/unix_domain_listen_socket_posix.cc @@ -127,41 +127,5 @@ void UnixDomainListenSocket::Accept() { socket_delegate_->DidAccept(this, sock.Pass()); } -UnixDomainListenSocketFactory::UnixDomainListenSocketFactory( - const std::string& path, - const UnixDomainListenSocket::AuthCallback& auth_callback) - : path_(path), - auth_callback_(auth_callback) {} - -UnixDomainListenSocketFactory::~UnixDomainListenSocketFactory() {} - -scoped_ptr<StreamListenSocket> UnixDomainListenSocketFactory::CreateAndListen( - StreamListenSocket::Delegate* delegate) const { - return UnixDomainListenSocket::CreateAndListen( - path_, delegate, auth_callback_).Pass(); -} - -#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED) - -UnixDomainListenSocketWithAbstractNamespaceFactory:: -UnixDomainListenSocketWithAbstractNamespaceFactory( - const std::string& path, - const std::string& fallback_path, - const UnixDomainListenSocket::AuthCallback& auth_callback) - : UnixDomainListenSocketFactory(path, auth_callback), - fallback_path_(fallback_path) {} - -UnixDomainListenSocketWithAbstractNamespaceFactory:: -~UnixDomainListenSocketWithAbstractNamespaceFactory() {} - -scoped_ptr<StreamListenSocket> -UnixDomainListenSocketWithAbstractNamespaceFactory::CreateAndListen( - StreamListenSocket::Delegate* delegate) const { - return UnixDomainListenSocket::CreateAndListenWithAbstractNamespace( - path_, fallback_path_, delegate, auth_callback_); -} - -#endif - } // namespace deprecated } // namespace net diff --git a/chromium/net/socket/unix_domain_listen_socket_posix.h b/chromium/net/socket/unix_domain_listen_socket_posix.h index 82ec342edaa..da578ce5c11 100644 --- a/chromium/net/socket/unix_domain_listen_socket_posix.h +++ b/chromium/net/socket/unix_domain_listen_socket_posix.h @@ -22,36 +22,27 @@ #define SOCKET_ABSTRACT_NAMESPACE_SUPPORTED #endif +namespace remoting { +class GnubbyAuthHandlerPosix; +} + namespace net { namespace deprecated { // Unix Domain Socket Implementation. Supports abstract namespaces on Linux. +// This class is deprecated and will be removed once crbug.com/472766 is fixed. +// There should not be any new consumer of this class. class NET_EXPORT UnixDomainListenSocket : public StreamListenSocket { public: typedef UnixDomainServerSocket::AuthCallback AuthCallback; ~UnixDomainListenSocket() override; - // Note that the returned UnixDomainListenSocket instance does not take - // ownership of |del|. - static scoped_ptr<UnixDomainListenSocket> CreateAndListen( - const std::string& path, - StreamListenSocket::Delegate* del, - const AuthCallback& auth_callback); - -#if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED) - // Same as above except that the created socket uses the abstract namespace - // which is a Linux-only feature. If |fallback_path| is not empty, - // make the second attempt with the provided fallback name. - static scoped_ptr<UnixDomainListenSocket> - CreateAndListenWithAbstractNamespace( - const std::string& path, - const std::string& fallback_path, - StreamListenSocket::Delegate* del, - const AuthCallback& auth_callback); -#endif - private: + // Note that friend classes are temporary until crbug.com/472766 is fixed. + friend class UnixDomainListenSocketTestHelper; + friend class remoting::GnubbyAuthHandlerPosix; + UnixDomainListenSocket(SocketDescriptor s, StreamListenSocket::Delegate* del, const AuthCallback& auth_callback); @@ -66,55 +57,28 @@ class NET_EXPORT UnixDomainListenSocket : public StreamListenSocket { // StreamListenSocket: void Accept() override; - AuthCallback auth_callback_; - - DISALLOW_COPY_AND_ASSIGN(UnixDomainListenSocket); -}; - -// Factory that can be used to instantiate UnixDomainListenSocket. -class NET_EXPORT UnixDomainListenSocketFactory - : public StreamListenSocketFactory { - public: - // Note that this class does not take ownership of the provided delegate. - UnixDomainListenSocketFactory( + // Note that the returned UnixDomainListenSocket instance does not take + // ownership of |del|. + static scoped_ptr<UnixDomainListenSocket> CreateAndListen( const std::string& path, - const UnixDomainListenSocket::AuthCallback& auth_callback); - ~UnixDomainListenSocketFactory() override; - - // StreamListenSocketFactory: - scoped_ptr<StreamListenSocket> CreateAndListen( - StreamListenSocket::Delegate* delegate) const override; - - protected: - const std::string path_; - const UnixDomainListenSocket::AuthCallback auth_callback_; - - private: - DISALLOW_COPY_AND_ASSIGN(UnixDomainListenSocketFactory); -}; + StreamListenSocket::Delegate* del, + const AuthCallback& auth_callback); #if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED) -// Use this factory to instantiate UnixDomainListenSocket using the abstract -// namespace feature (only supported on Linux). -class NET_EXPORT UnixDomainListenSocketWithAbstractNamespaceFactory - : public UnixDomainListenSocketFactory { - public: - UnixDomainListenSocketWithAbstractNamespaceFactory( - const std::string& path, - const std::string& fallback_path, - const UnixDomainListenSocket::AuthCallback& auth_callback); - ~UnixDomainListenSocketWithAbstractNamespaceFactory() override; - - // UnixDomainListenSocketFactory: - scoped_ptr<StreamListenSocket> CreateAndListen( - StreamListenSocket::Delegate* delegate) const override; + // Same as above except that the created socket uses the abstract namespace + // which is a Linux-only feature. If |fallback_path| is not empty, + // make the second attempt with the provided fallback name. + static scoped_ptr<UnixDomainListenSocket> + CreateAndListenWithAbstractNamespace(const std::string& path, + const std::string& fallback_path, + StreamListenSocket::Delegate* del, + const AuthCallback& auth_callback); +#endif - private: - std::string fallback_path_; + AuthCallback auth_callback_; - DISALLOW_COPY_AND_ASSIGN(UnixDomainListenSocketWithAbstractNamespaceFactory); + DISALLOW_COPY_AND_ASSIGN(UnixDomainListenSocket); }; -#endif } // namespace deprecated } // namespace net diff --git a/chromium/net/socket/unix_domain_listen_socket_posix_unittest.cc b/chromium/net/socket/unix_domain_listen_socket_posix_unittest.cc index aaf362310c3..1513ca940e9 100644 --- a/chromium/net/socket/unix_domain_listen_socket_posix_unittest.cc +++ b/chromium/net/socket/unix_domain_listen_socket_posix_unittest.cc @@ -142,6 +142,8 @@ bool UserCanConnectCallback( return allow_user; } +} // namespace + class UnixDomainListenSocketTestHelper : public testing::Test { public: void CreateAndListen() { @@ -150,6 +152,15 @@ class UnixDomainListenSocketTestHelper : public testing::Test { socket_delegate_->OnListenCompleted(); } + scoped_ptr<UnixDomainListenSocket> CreateAndListenWithAbstractNamespace( + const std::string& path, + const std::string& fallback_path, + StreamListenSocket::Delegate* del, + const UnixDomainListenSocket::AuthCallback& auth_callback) { + return UnixDomainListenSocket::CreateAndListenInternal( + path, fallback_path, del, auth_callback, true); + } + protected: UnixDomainListenSocketTestHelper(const string& path_str, bool allow_user) : allow_user_(allow_user) { @@ -222,6 +233,8 @@ class UnixDomainListenSocketTestHelper : public testing::Test { scoped_ptr<UnixDomainListenSocket> socket_; }; +namespace { + class UnixDomainListenSocketTest : public UnixDomainListenSocketTestHelper { protected: UnixDomainListenSocketTest() @@ -260,28 +273,25 @@ TEST_F(UnixDomainListenSocketTestWithInvalidPath, // file. TEST_F(UnixDomainListenSocketTestWithInvalidPath, CreateAndListenWithAbstractNamespace) { - socket_ = UnixDomainListenSocket::CreateAndListenWithAbstractNamespace( + socket_ = CreateAndListenWithAbstractNamespace( file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback()); EXPECT_FALSE(socket_.get() == NULL); } TEST_F(UnixDomainListenSocketTest, TestFallbackName) { scoped_ptr<UnixDomainListenSocket> existing_socket = - UnixDomainListenSocket::CreateAndListenWithAbstractNamespace( + CreateAndListenWithAbstractNamespace( file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback()); EXPECT_FALSE(existing_socket.get() == NULL); // First, try to bind socket with the same name with no fallback name. - socket_ = - UnixDomainListenSocket::CreateAndListenWithAbstractNamespace( - file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback()); + socket_ = CreateAndListenWithAbstractNamespace( + file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback()); EXPECT_TRUE(socket_.get() == NULL); // Now with a fallback name. const char kFallbackSocketName[] = "socket_for_testing_2"; - socket_ = UnixDomainListenSocket::CreateAndListenWithAbstractNamespace( - file_path_.value(), - GetTempSocketPath(kFallbackSocketName).value(), - socket_delegate_.get(), - MakeAuthCallback()); + socket_ = CreateAndListenWithAbstractNamespace( + file_path_.value(), GetTempSocketPath(kFallbackSocketName).value(), + socket_delegate_.get(), MakeAuthCallback()); EXPECT_FALSE(socket_.get() == NULL); } #endif diff --git a/chromium/net/socket/unix_domain_server_socket_posix.cc b/chromium/net/socket/unix_domain_server_socket_posix.cc index 4d6328310ff..6866d3632c7 100644 --- a/chromium/net/socket/unix_domain_server_socket_posix.cc +++ b/chromium/net/socket/unix_domain_server_socket_posix.cc @@ -67,7 +67,7 @@ int UnixDomainServerSocket::Listen(const IPEndPoint& address, int backlog) { int UnixDomainServerSocket::ListenWithAddressAndPort( const std::string& unix_domain_path, - int port_unused, + uint16 port_unused, int backlog) { DCHECK(!listen_socket_); diff --git a/chromium/net/socket/unix_domain_server_socket_posix.h b/chromium/net/socket/unix_domain_server_socket_posix.h index 0a26eb3d375..1097548c513 100644 --- a/chromium/net/socket/unix_domain_server_socket_posix.h +++ b/chromium/net/socket/unix_domain_server_socket_posix.h @@ -53,7 +53,7 @@ class NET_EXPORT UnixDomainServerSocket : public ServerSocket { // ServerSocket implementation. int Listen(const IPEndPoint& address, int backlog) override; int ListenWithAddressAndPort(const std::string& unix_domain_path, - int port_unused, + uint16 port_unused, int backlog) override; int GetLocalAddress(IPEndPoint* address) const override; int Accept(scoped_ptr<StreamSocket>* socket, diff --git a/chromium/net/socket/websocket_endpoint_lock_manager.cc b/chromium/net/socket/websocket_endpoint_lock_manager.cc index e578bb2435b..caddd8d8786 100644 --- a/chromium/net/socket/websocket_endpoint_lock_manager.cc +++ b/chromium/net/socket/websocket_endpoint_lock_manager.cc @@ -6,12 +6,23 @@ #include <utility> +#include "base/bind.h" #include "base/logging.h" +#include "base/message_loop/message_loop.h" #include "net/base/net_errors.h" -#include "net/base/net_log.h" +#include "net/log/net_log.h" namespace net { +namespace { + +// This delay prevents DoS attacks. +// TODO(ricea): Replace this with randomised truncated exponential backoff. +// See crbug.com/377613. +const int kUnlockDelayInMs = 10; + +} // namespace + WebSocketEndpointLockManager::Waiter::~Waiter() { if (next()) { DCHECK(previous()); @@ -65,23 +76,31 @@ void WebSocketEndpointLockManager::UnlockSocket(StreamSocket* socket) { << lock_info_it->first.ToString() << " (" << socket_lock_info_map_.size() << " socket(s) left)"; socket_lock_info_map_.erase(socket_it); - DCHECK(socket == lock_info_it->second.socket); + DCHECK_EQ(socket, lock_info_it->second.socket); lock_info_it->second.socket = NULL; - UnlockEndpointByIterator(lock_info_it); + UnlockEndpointAfterDelay(lock_info_it->first); } void WebSocketEndpointLockManager::UnlockEndpoint(const IPEndPoint& endpoint) { LockInfoMap::iterator lock_info_it = lock_info_map_.find(endpoint); if (lock_info_it == lock_info_map_.end()) return; - - UnlockEndpointByIterator(lock_info_it); + if (lock_info_it->second.socket) + EraseSocket(lock_info_it); + UnlockEndpointAfterDelay(endpoint); } bool WebSocketEndpointLockManager::IsEmpty() const { return lock_info_map_.empty() && socket_lock_info_map_.empty(); } +base::TimeDelta WebSocketEndpointLockManager::SetUnlockDelayForTesting( + base::TimeDelta new_delay) { + base::TimeDelta old_delay = unlock_delay_; + unlock_delay_ = new_delay; + return old_delay; +} + WebSocketEndpointLockManager::LockInfo::LockInfo() : socket(NULL) {} WebSocketEndpointLockManager::LockInfo::~LockInfo() { DCHECK(!socket); @@ -92,17 +111,37 @@ WebSocketEndpointLockManager::LockInfo::LockInfo(const LockInfo& rhs) DCHECK(!rhs.queue); } -WebSocketEndpointLockManager::WebSocketEndpointLockManager() {} +WebSocketEndpointLockManager::WebSocketEndpointLockManager() + : unlock_delay_(base::TimeDelta::FromMilliseconds(kUnlockDelayInMs)), + pending_unlock_count_(0), + weak_factory_(this) { +} WebSocketEndpointLockManager::~WebSocketEndpointLockManager() { - DCHECK(lock_info_map_.empty()); + DCHECK_EQ(lock_info_map_.size(), pending_unlock_count_); DCHECK(socket_lock_info_map_.empty()); } -void WebSocketEndpointLockManager::UnlockEndpointByIterator( - LockInfoMap::iterator lock_info_it) { - if (lock_info_it->second.socket) - EraseSocket(lock_info_it); +void WebSocketEndpointLockManager::UnlockEndpointAfterDelay( + const IPEndPoint& endpoint) { + DVLOG(3) << "Delaying " << unlock_delay_.InMilliseconds() + << "ms before unlocking endpoint " << endpoint.ToString(); + ++pending_unlock_count_; + base::MessageLoop::current()->PostDelayedTask( + FROM_HERE, + base::Bind(&WebSocketEndpointLockManager::DelayedUnlockEndpoint, + weak_factory_.GetWeakPtr(), endpoint), + unlock_delay_); +} + +void WebSocketEndpointLockManager::DelayedUnlockEndpoint( + const IPEndPoint& endpoint) { + LockInfoMap::iterator lock_info_it = lock_info_map_.find(endpoint); + DCHECK_GT(pending_unlock_count_, 0U); + --pending_unlock_count_; + if (lock_info_it == lock_info_map_.end()) + return; + DCHECK(!lock_info_it->second.socket); LockInfo::WaiterQueue* queue = lock_info_it->second.queue.get(); DCHECK(queue); if (queue->empty()) { @@ -115,7 +154,6 @@ void WebSocketEndpointLockManager::UnlockEndpointByIterator( << " and activating next waiter"; Waiter* next_job = queue->head()->value(); next_job->RemoveFromList(); - // This must be last to minimise the excitement caused by re-entrancy. next_job->GotEndpointLock(); } diff --git a/chromium/net/socket/websocket_endpoint_lock_manager.h b/chromium/net/socket/websocket_endpoint_lock_manager.h index d5cad508d6a..bddd5455ccf 100644 --- a/chromium/net/socket/websocket_endpoint_lock_manager.h +++ b/chromium/net/socket/websocket_endpoint_lock_manager.h @@ -11,6 +11,7 @@ #include "base/logging.h" #include "base/macros.h" #include "base/memory/singleton.h" +#include "base/time/time.h" #include "net/base/ip_endpoint.h" #include "net/base/net_export.h" #include "net/socket/websocket_transport_client_socket_pool.h" @@ -19,8 +20,25 @@ namespace net { class StreamSocket; +// Keep track of ongoing WebSocket connections in order to satisfy the WebSocket +// connection throttling requirements described in RFC6455 4.1.2: +// +// 2. If the client already has a WebSocket connection to the remote +// host (IP address) identified by /host/ and port /port/ pair, even +// if the remote host is known by another name, the client MUST wait +// until that connection has been established or for that connection +// to have failed. There MUST be no more than one connection in a +// CONNECTING state. If multiple connections to the same IP address +// are attempted simultaneously, the client MUST serialize them so +// that there is no more than one connection at a time running +// through the following steps. +// +// This class is neither thread-safe nor thread-compatible. +// TODO(ricea): Make this class thread-compatible by making it not be a +// singleton. class NET_EXPORT_PRIVATE WebSocketEndpointLockManager { public: + // Implement this interface to wait for an endpoint to be available. class NET_EXPORT_PRIVATE Waiter : public base::LinkNode<Waiter> { public: // If the node is in a list, removes it. @@ -45,22 +63,28 @@ class NET_EXPORT_PRIVATE WebSocketEndpointLockManager { // UnlockSocket(). void RememberSocket(StreamSocket* socket, const IPEndPoint& endpoint); - // Releases the lock on the endpoint that was associated with |socket| by - // RememberSocket(). If appropriate, triggers the next socket connection. - // Should be called exactly once for each |socket| that was passed to - // RememberSocket(). Does nothing if UnlockEndpoint() has been called since + // Removes the socket association that was recorded by RememberSocket(), then + // asynchronously releases the lock on the endpoint after a delay. If + // appropriate, calls |waiter->GetEndpointLock()| when the lock is + // released. Should be called exactly once for each |socket| that was passed + // to RememberSocket(). Does nothing if UnlockEndpoint() has been called since // the call to RememberSocket(). void UnlockSocket(StreamSocket* socket); - // Releases the lock on |endpoint|. Does nothing if |endpoint| is not locked. - // Removes any socket association that was recorded with RememberSocket(). If - // appropriate, calls |waiter->GotEndpointLock()|. + // Asynchronously releases the lock on |endpoint| after a delay. Does nothing + // if |endpoint| is not locked. Removes any socket association that was + // recorded with RememberSocket(). If appropriate, calls + // |waiter->GotEndpointLock()| when the lock is released. void UnlockEndpoint(const IPEndPoint& endpoint); // Checks that |lock_info_map_| and |socket_lock_info_map_| are empty. For // tests. bool IsEmpty() const; + // Changes the value of the unlock delay. Returns the previous value of the + // delay. + base::TimeDelta SetUnlockDelayForTesting(base::TimeDelta new_delay); + private: struct LockInfo { typedef base::LinkedList<Waiter> WaiterQueue; @@ -97,7 +121,8 @@ class NET_EXPORT_PRIVATE WebSocketEndpointLockManager { WebSocketEndpointLockManager(); ~WebSocketEndpointLockManager(); - void UnlockEndpointByIterator(LockInfoMap::iterator lock_info_it); + void UnlockEndpointAfterDelay(const IPEndPoint& endpoint); + void DelayedUnlockEndpoint(const IPEndPoint& endpoint); void EraseSocket(LockInfoMap::iterator lock_info_it); // If an entry is present in the map for a particular endpoint, then that @@ -111,6 +136,16 @@ class NET_EXPORT_PRIVATE WebSocketEndpointLockManager { // is non-NULL if and only if there is an entry in this map for the socket. SocketLockInfoMap socket_lock_info_map_; + // Time to wait between a call to Unlock* and actually unlocking the socket. + base::TimeDelta unlock_delay_; + + // Number of sockets currently pending unlock. + size_t pending_unlock_count_; + + // The messsage loop holding the unlock delay callback may outlive this + // object. + base::WeakPtrFactory<WebSocketEndpointLockManager> weak_factory_; + friend struct DefaultSingletonTraits<WebSocketEndpointLockManager>; DISALLOW_COPY_AND_ASSIGN(WebSocketEndpointLockManager); diff --git a/chromium/net/socket/websocket_endpoint_lock_manager_unittest.cc b/chromium/net/socket/websocket_endpoint_lock_manager_unittest.cc index 1626aa90201..29fc067b733 100644 --- a/chromium/net/socket/websocket_endpoint_lock_manager_unittest.cc +++ b/chromium/net/socket/websocket_endpoint_lock_manager_unittest.cc @@ -4,6 +4,9 @@ #include "net/socket/websocket_endpoint_lock_manager.h" +#include "base/message_loop/message_loop.h" +#include "base/run_loop.h" +#include "base/time/time.h" #include "net/base/net_errors.h" #include "net/socket/next_proto.h" #include "net/socket/socket_test_util.h" @@ -51,6 +54,14 @@ class FakeStreamSocket : public StreamSocket { bool GetSSLInfo(SSLInfo* ssl_info) override { return false; } + void GetConnectionAttempts(ConnectionAttempts* out) const override { + out->clear(); + } + + void ClearConnectionAttempts() override {} + + void AddConnectionAttempts(const ConnectionAttempts& attempts) override {} + // Socket implementation int Read(IOBuffer* buf, int buf_len, @@ -89,11 +100,30 @@ class FakeWaiter : public WebSocketEndpointLockManager::Waiter { bool called_; }; +class BlockingWaiter : public FakeWaiter { + public: + void WaitForLock() { + while (!called()) { + run_loop_.Run(); + } + } + + void GotEndpointLock() override { + FakeWaiter::GotEndpointLock(); + run_loop_.Quit(); + } + + private: + base::RunLoop run_loop_; +}; + class WebSocketEndpointLockManagerTest : public ::testing::Test { protected: WebSocketEndpointLockManagerTest() : instance_(WebSocketEndpointLockManager::GetInstance()) {} ~WebSocketEndpointLockManagerTest() override { + // Permit any pending asynchronous unlock operations to complete. + RunUntilIdle(); // If this check fails then subsequent tests may fail. CHECK(instance_->IsEmpty()); } @@ -109,10 +139,14 @@ class WebSocketEndpointLockManagerTest : public ::testing::Test { void UnlockDummyEndpoint(int times) { for (int i = 0; i < times; ++i) { instance()->UnlockEndpoint(DummyEndpoint()); + RunUntilIdle(); } } + static void RunUntilIdle() { base::RunLoop().RunUntilIdle(); } + WebSocketEndpointLockManager* const instance_; + ScopedWebSocketEndpointZeroUnlockDelay zero_unlock_delay_; }; TEST_F(WebSocketEndpointLockManagerTest, GetInstanceWorks) { @@ -131,6 +165,7 @@ TEST_F(WebSocketEndpointLockManagerTest, LockEndpointReturnsOkOnce) { TEST_F(WebSocketEndpointLockManagerTest, GotEndpointLockNotCalledOnOk) { FakeWaiter waiter; EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &waiter)); + RunUntilIdle(); EXPECT_FALSE(waiter.called()); UnlockDummyEndpoint(1); @@ -141,6 +176,7 @@ TEST_F(WebSocketEndpointLockManagerTest, GotEndpointLockNotCalledImmediately) { EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &waiters[0])); EXPECT_EQ(ERR_IO_PENDING, instance()->LockEndpoint(DummyEndpoint(), &waiters[1])); + RunUntilIdle(); EXPECT_FALSE(waiters[1].called()); UnlockDummyEndpoint(2); @@ -152,6 +188,7 @@ TEST_F(WebSocketEndpointLockManagerTest, GotEndpointLockCalledWhenUnlocked) { EXPECT_EQ(ERR_IO_PENDING, instance()->LockEndpoint(DummyEndpoint(), &waiters[1])); instance()->UnlockEndpoint(DummyEndpoint()); + RunUntilIdle(); EXPECT_TRUE(waiters[1].called()); UnlockDummyEndpoint(1); @@ -169,6 +206,7 @@ TEST_F(WebSocketEndpointLockManagerTest, } instance()->UnlockEndpoint(DummyEndpoint()); + RunUntilIdle(); FakeWaiter second_lock_holder; EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &second_lock_holder)); @@ -185,6 +223,7 @@ TEST_F(WebSocketEndpointLockManagerTest, RememberSocketWorks) { instance()->RememberSocket(&dummy_socket, DummyEndpoint()); instance()->UnlockSocket(&dummy_socket); + RunUntilIdle(); EXPECT_TRUE(waiters[1].called()); UnlockDummyEndpoint(1); @@ -199,6 +238,7 @@ TEST_F(WebSocketEndpointLockManagerTest, SocketAssociationForgottenOnUnlock) { EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &waiter)); instance()->RememberSocket(&dummy_socket, DummyEndpoint()); instance()->UnlockEndpoint(DummyEndpoint()); + RunUntilIdle(); EXPECT_TRUE(instance()->IsEmpty()); } @@ -213,12 +253,77 @@ TEST_F(WebSocketEndpointLockManagerTest, NextWaiterCanCallRememberSocketAgain) { instance()->RememberSocket(&dummy_sockets[0], DummyEndpoint()); instance()->UnlockEndpoint(DummyEndpoint()); + RunUntilIdle(); EXPECT_TRUE(waiters[1].called()); instance()->RememberSocket(&dummy_sockets[1], DummyEndpoint()); UnlockDummyEndpoint(1); } +// Calling UnlockSocket() after UnlockEndpoint() does nothing. +TEST_F(WebSocketEndpointLockManagerTest, + UnlockSocketAfterUnlockEndpointDoesNothing) { + FakeWaiter waiters[3]; + FakeStreamSocket dummy_socket; + + EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &waiters[0])); + EXPECT_EQ(ERR_IO_PENDING, + instance()->LockEndpoint(DummyEndpoint(), &waiters[1])); + EXPECT_EQ(ERR_IO_PENDING, + instance()->LockEndpoint(DummyEndpoint(), &waiters[2])); + instance()->RememberSocket(&dummy_socket, DummyEndpoint()); + instance()->UnlockEndpoint(DummyEndpoint()); + instance()->UnlockSocket(&dummy_socket); + RunUntilIdle(); + EXPECT_TRUE(waiters[1].called()); + EXPECT_FALSE(waiters[2].called()); + + UnlockDummyEndpoint(2); +} + +// UnlockEndpoint() should always be asynchronous. +TEST_F(WebSocketEndpointLockManagerTest, UnlockEndpointIsAsynchronous) { + FakeWaiter waiters[2]; + EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &waiters[0])); + EXPECT_EQ(ERR_IO_PENDING, + instance()->LockEndpoint(DummyEndpoint(), &waiters[1])); + + instance()->UnlockEndpoint(DummyEndpoint()); + EXPECT_FALSE(waiters[1].called()); + RunUntilIdle(); + EXPECT_TRUE(waiters[1].called()); + + UnlockDummyEndpoint(1); +} + +// UnlockEndpoint() should normally have a delay. +TEST_F(WebSocketEndpointLockManagerTest, UnlockEndpointIsDelayed) { + using base::TimeTicks; + + // This 1ms delay is too short for very slow environments (usually those + // running memory checkers). In those environments, the code takes >1ms to run + // and no delay is needed. Rather than increase the delay and slow down the + // test everywhere, the test doesn't explicitly verify that a delay has been + // applied. Instead it just verifies that the whole thing took >=1ms. 1ms is + // easily enough for normal compiles even on Android, so the fact that there + // is a delay is still checked on every platform. + const base::TimeDelta unlock_delay = base::TimeDelta::FromMilliseconds(1); + instance()->SetUnlockDelayForTesting(unlock_delay); + FakeWaiter fake_waiter; + BlockingWaiter blocking_waiter; + EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &fake_waiter)); + EXPECT_EQ(ERR_IO_PENDING, + instance()->LockEndpoint(DummyEndpoint(), &blocking_waiter)); + + TimeTicks before_unlock = TimeTicks::Now(); + instance()->UnlockEndpoint(DummyEndpoint()); + blocking_waiter.WaitForLock(); + TimeTicks after_unlock = TimeTicks::Now(); + EXPECT_GE(after_unlock - before_unlock, unlock_delay); + instance()->SetUnlockDelayForTesting(base::TimeDelta()); + UnlockDummyEndpoint(1); +} + } // namespace } // namespace net diff --git a/chromium/net/socket/websocket_transport_client_socket_pool.cc b/chromium/net/socket/websocket_transport_client_socket_pool.cc index 15ec028cb18..ce433680038 100644 --- a/chromium/net/socket/websocket_transport_client_socket_pool.cc +++ b/chromium/net/socket/websocket_transport_client_socket_pool.cc @@ -13,7 +13,7 @@ #include "base/time/time.h" #include "base/values.h" #include "net/base/net_errors.h" -#include "net/base/net_log.h" +#include "net/log/net_log.h" #include "net/socket/client_socket_handle.h" #include "net/socket/client_socket_pool_base.h" #include "net/socket/websocket_endpoint_lock_manager.h" @@ -228,18 +228,15 @@ int WebSocketTransportConnectJob::ConnectInternal() { WebSocketTransportClientSocketPool::WebSocketTransportClientSocketPool( int max_sockets, int max_sockets_per_group, - ClientSocketPoolHistograms* histograms, HostResolver* host_resolver, ClientSocketFactory* client_socket_factory, NetLog* net_log) : TransportClientSocketPool(max_sockets, max_sockets_per_group, - histograms, host_resolver, client_socket_factory, net_log), connect_job_delegate_(this), - histograms_(histograms), pool_net_log_(net_log), client_socket_factory_(client_socket_factory), host_resolver_(host_resolver), @@ -455,11 +452,6 @@ TimeDelta WebSocketTransportClientSocketPool::ConnectionTimeout() const { return TimeDelta::FromSeconds(kTransportConnectJobTimeoutInSeconds); } -ClientSocketPoolHistograms* WebSocketTransportClientSocketPool::histograms() - const { - return histograms_; -} - bool WebSocketTransportClientSocketPool::IsStalled() const { return !stalled_request_queue_.empty(); } diff --git a/chromium/net/socket/websocket_transport_client_socket_pool.h b/chromium/net/socket/websocket_transport_client_socket_pool.h index f0a94be417f..e5ddd938c59 100644 --- a/chromium/net/socket/websocket_transport_client_socket_pool.h +++ b/chromium/net/socket/websocket_transport_client_socket_pool.h @@ -17,7 +17,7 @@ #include "base/time/time.h" #include "base/timer/timer.h" #include "net/base/net_export.h" -#include "net/base/net_log.h" +#include "net/log/net_log.h" #include "net/socket/client_socket_pool.h" #include "net/socket/client_socket_pool_base.h" #include "net/socket/transport_client_socket_pool.h" @@ -25,7 +25,6 @@ namespace net { class ClientSocketFactory; -class ClientSocketPoolHistograms; class HostResolver; class NetLog; class WebSocketEndpointLockManager; @@ -118,7 +117,6 @@ class NET_EXPORT_PRIVATE WebSocketTransportClientSocketPool public: WebSocketTransportClientSocketPool(int max_sockets, int max_sockets_per_group, - ClientSocketPoolHistograms* histograms, HostResolver* host_resolver, ClientSocketFactory* client_socket_factory, NetLog* net_log); @@ -159,7 +157,6 @@ class NET_EXPORT_PRIVATE WebSocketTransportClientSocketPool const std::string& type, bool include_nested_pools) const override; base::TimeDelta ConnectionTimeout() const override; - ClientSocketPoolHistograms* histograms() const override; // HigherLayeredPool implementation. bool IsStalled() const override; @@ -228,7 +225,6 @@ class NET_EXPORT_PRIVATE WebSocketTransportClientSocketPool PendingConnectsMap pending_connects_; StalledRequestQueue stalled_request_queue_; StalledRequestMap stalled_request_map_; - ClientSocketPoolHistograms* const histograms_; NetLog* const pool_net_log_; ClientSocketFactory* const client_socket_factory_; HostResolver* const host_resolver_; diff --git a/chromium/net/socket/websocket_transport_client_socket_pool_unittest.cc b/chromium/net/socket/websocket_transport_client_socket_pool_unittest.cc index 2189181b9fc..2df66063316 100644 --- a/chromium/net/socket/websocket_transport_client_socket_pool_unittest.cc +++ b/chromium/net/socket/websocket_transport_client_socket_pool_unittest.cc @@ -15,7 +15,6 @@ #include "base/run_loop.h" #include "base/strings/stringprintf.h" #include "base/time/time.h" -#include "net/base/capturing_net_log.h" #include "net/base/ip_endpoint.h" #include "net/base/load_timing_info.h" #include "net/base/load_timing_info_test_util.h" @@ -23,8 +22,8 @@ #include "net/base/net_util.h" #include "net/base/test_completion_callback.h" #include "net/dns/mock_host_resolver.h" +#include "net/log/test_net_log.h" #include "net/socket/client_socket_handle.h" -#include "net/socket/client_socket_pool_histograms.h" #include "net/socket/socket_test_util.h" #include "net/socket/stream_socket.h" #include "net/socket/transport_client_socket_pool_test_util.h" @@ -48,7 +47,7 @@ void RunLoopForTimePeriod(base::TimeDelta period) { run_loop.Run(); } -class WebSocketTransportClientSocketPoolTest : public testing::Test { +class WebSocketTransportClientSocketPoolTest : public ::testing::Test { protected: WebSocketTransportClientSocketPoolTest() : params_(new TransportSocketParams( @@ -57,21 +56,24 @@ class WebSocketTransportClientSocketPoolTest : public testing::Test { false, OnHostResolutionCallback(), TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)), - histograms_(new ClientSocketPoolHistograms("TCPUnitTest")), host_resolver_(new MockHostResolver), client_socket_factory_(&net_log_), pool_(kMaxSockets, kMaxSocketsPerGroup, - histograms_.get(), host_resolver_.get(), &client_socket_factory_, NULL) {} ~WebSocketTransportClientSocketPoolTest() override { + RunUntilIdle(); + // ReleaseAllConnections() calls RunUntilIdle() after releasing each + // connection. ReleaseAllConnections(ClientSocketPoolTest::NO_KEEP_ALIVE); EXPECT_TRUE(WebSocketEndpointLockManager::GetInstance()->IsEmpty()); } + static void RunUntilIdle() { base::RunLoop().RunUntilIdle(); } + int StartRequest(const std::string& group_name, RequestPriority priority) { scoped_refptr<TransportSocketParams> params( new TransportSocketParams( @@ -101,13 +103,13 @@ class WebSocketTransportClientSocketPoolTest : public testing::Test { ScopedVector<TestSocketRequest>* requests() { return test_base_.requests(); } size_t completion_count() const { return test_base_.completion_count(); } - CapturingNetLog net_log_; + TestNetLog net_log_; scoped_refptr<TransportSocketParams> params_; - scoped_ptr<ClientSocketPoolHistograms> histograms_; scoped_ptr<MockHostResolver> host_resolver_; MockTransportClientSocketFactory client_socket_factory_; WebSocketTransportClientSocketPool pool_; ClientSocketPoolTest test_base_; + ScopedWebSocketEndpointZeroUnlockDelay zero_unlock_delay_; private: DISALLOW_COPY_AND_ASSIGN(WebSocketTransportClientSocketPoolTest); @@ -502,7 +504,7 @@ TEST_F(WebSocketTransportClientSocketPoolTest, LockReleasedOnHandleReset) { EXPECT_EQ(OK, request(0)->WaitForResult()); EXPECT_FALSE(request(1)->handle()->is_initialized()); request(0)->handle()->Reset(); - base::RunLoop().RunUntilIdle(); + RunUntilIdle(); EXPECT_TRUE(request(1)->handle()->is_initialized()); } @@ -518,7 +520,7 @@ TEST_F(WebSocketTransportClientSocketPoolTest, LockReleasedOnHandleDelete) { EXPECT_EQ(OK, callback.WaitForResult()); EXPECT_FALSE(request(0)->handle()->is_initialized()); handle.reset(); - base::RunLoop().RunUntilIdle(); + RunUntilIdle(); EXPECT_TRUE(request(0)->handle()->is_initialized()); } @@ -531,7 +533,7 @@ TEST_F(WebSocketTransportClientSocketPoolTest, EXPECT_EQ(OK, request(0)->WaitForResult()); EXPECT_FALSE(request(1)->handle()->is_initialized()); WebSocketTransportClientSocketPool::UnlockEndpoint(request(0)->handle()); - base::RunLoop().RunUntilIdle(); + RunUntilIdle(); EXPECT_TRUE(request(1)->handle()->is_initialized()); } @@ -548,7 +550,7 @@ TEST_F(WebSocketTransportClientSocketPoolTest, EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); - base::RunLoop().RunUntilIdle(); + RunUntilIdle(); pool_.CancelRequest("a", request(0)->handle()); EXPECT_EQ(OK, request(1)->WaitForResult()); } @@ -559,7 +561,6 @@ TEST_F(WebSocketTransportClientSocketPoolTest, IPv6FallbackSocketIPv4FinishesFirst) { WebSocketTransportClientSocketPool pool(kMaxSockets, kMaxSocketsPerGroup, - histograms_.get(), host_resolver_.get(), &client_socket_factory_, NULL); @@ -600,7 +601,6 @@ TEST_F(WebSocketTransportClientSocketPoolTest, IPv6FallbackSocketIPv6FinishesFirst) { WebSocketTransportClientSocketPool pool(kMaxSockets, kMaxSocketsPerGroup, - histograms_.get(), host_resolver_.get(), &client_socket_factory_, NULL); @@ -640,7 +640,6 @@ TEST_F(WebSocketTransportClientSocketPoolTest, IPv6NoIPv4AddressesToFallbackTo) { WebSocketTransportClientSocketPool pool(kMaxSockets, kMaxSocketsPerGroup, - histograms_.get(), host_resolver_.get(), &client_socket_factory_, NULL); @@ -672,7 +671,6 @@ TEST_F(WebSocketTransportClientSocketPoolTest, TEST_F(WebSocketTransportClientSocketPoolTest, IPv4HasNoFallback) { WebSocketTransportClientSocketPool pool(kMaxSockets, kMaxSocketsPerGroup, - histograms_.get(), host_resolver_.get(), &client_socket_factory_, NULL); @@ -705,7 +703,6 @@ TEST_F(WebSocketTransportClientSocketPoolTest, IPv4HasNoFallback) { TEST_F(WebSocketTransportClientSocketPoolTest, IPv6InstantFail) { WebSocketTransportClientSocketPool pool(kMaxSockets, kMaxSocketsPerGroup, - histograms_.get(), host_resolver_.get(), &client_socket_factory_, NULL); @@ -742,7 +739,6 @@ TEST_F(WebSocketTransportClientSocketPoolTest, IPv6InstantFail) { TEST_F(WebSocketTransportClientSocketPoolTest, IPv6RapidFail) { WebSocketTransportClientSocketPool pool(kMaxSockets, kMaxSocketsPerGroup, - histograms_.get(), host_resolver_.get(), &client_socket_factory_, NULL); @@ -769,9 +765,9 @@ TEST_F(WebSocketTransportClientSocketPoolTest, IPv6RapidFail) { EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.socket()); - base::Time start(base::Time::NowFromSystemTime()); + base::TimeTicks start(base::TimeTicks::Now()); EXPECT_EQ(OK, callback.WaitForResult()); - EXPECT_LT(base::Time::NowFromSystemTime() - start, + EXPECT_LT(base::TimeTicks::Now() - start, base::TimeDelta::FromMilliseconds( TransportConnectJobHelper::kIPv6FallbackTimerInMs)); ASSERT_TRUE(handle.socket()); @@ -787,7 +783,6 @@ TEST_F(WebSocketTransportClientSocketPoolTest, IPv6RapidFail) { TEST_F(WebSocketTransportClientSocketPoolTest, FirstSuccessWins) { WebSocketTransportClientSocketPool pool(kMaxSockets, kMaxSocketsPerGroup, - histograms_.get(), host_resolver_.get(), &client_socket_factory_, NULL); @@ -826,7 +821,6 @@ TEST_F(WebSocketTransportClientSocketPoolTest, FirstSuccessWins) { TEST_F(WebSocketTransportClientSocketPoolTest, LastFailureWins) { WebSocketTransportClientSocketPool pool(kMaxSockets, kMaxSocketsPerGroup, - histograms_.get(), host_resolver_.get(), &client_socket_factory_, NULL); @@ -853,14 +847,14 @@ TEST_F(WebSocketTransportClientSocketPoolTest, LastFailureWins) { TestCompletionCallback callback; ClientSocketHandle handle; - base::Time start(base::Time::NowFromSystemTime()); + base::TimeTicks start(base::TimeTicks::Now()); int rv = handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_EQ(ERR_CONNECTION_FAILED, callback.WaitForResult()); - EXPECT_GE(base::Time::NowFromSystemTime() - start, delay * 5); + EXPECT_GE(base::TimeTicks::Now() - start, delay * 5); } // Global timeout for all connects applies. This test is disabled by default @@ -869,7 +863,6 @@ TEST_F(WebSocketTransportClientSocketPoolTest, LastFailureWins) { TEST_F(WebSocketTransportClientSocketPoolTest, DISABLED_OverallTimeoutApplies) { WebSocketTransportClientSocketPool pool(kMaxSockets, kMaxSocketsPerGroup, - histograms_.get(), host_resolver_.get(), &client_socket_factory_, NULL); @@ -902,8 +895,9 @@ TEST_F(WebSocketTransportClientSocketPoolTest, DISABLED_OverallTimeoutApplies) { TEST_F(WebSocketTransportClientSocketPoolTest, MaxSocketsEnforced) { host_resolver_->set_synchronous_mode(true); for (int i = 0; i < kMaxSockets; ++i) { - EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + ASSERT_EQ(OK, StartRequest("a", kDefaultPriority)); WebSocketTransportClientSocketPool::UnlockEndpoint(request(i)->handle()); + RunUntilIdle(); } EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); } @@ -914,13 +908,13 @@ TEST_F(WebSocketTransportClientSocketPoolTest, MaxSocketsEnforcedWhenPending) { } // Now there are 32 sockets waiting to connect, and one stalled. for (int i = 0; i < kMaxSockets; ++i) { - base::RunLoop().RunUntilIdle(); + RunUntilIdle(); EXPECT_TRUE(request(i)->handle()->is_initialized()); EXPECT_TRUE(request(i)->handle()->socket()); WebSocketTransportClientSocketPool::UnlockEndpoint(request(i)->handle()); } // Now there are 32 sockets connected, and one stalled. - base::RunLoop().RunUntilIdle(); + RunUntilIdle(); EXPECT_FALSE(request(kMaxSockets)->handle()->is_initialized()); EXPECT_FALSE(request(kMaxSockets)->handle()->socket()); } @@ -928,8 +922,9 @@ TEST_F(WebSocketTransportClientSocketPoolTest, MaxSocketsEnforcedWhenPending) { TEST_F(WebSocketTransportClientSocketPoolTest, StalledSocketReleased) { host_resolver_->set_synchronous_mode(true); for (int i = 0; i < kMaxSockets; ++i) { - EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + ASSERT_EQ(OK, StartRequest("a", kDefaultPriority)); WebSocketTransportClientSocketPool::UnlockEndpoint(request(i)->handle()); + RunUntilIdle(); } EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); @@ -953,7 +948,7 @@ TEST_F(WebSocketTransportClientSocketPoolTest, } EXPECT_EQ(OK, request(0)->WaitForResult()); request(1)->handle()->Reset(); - base::RunLoop().RunUntilIdle(); + RunUntilIdle(); EXPECT_FALSE(pool_.IsStalled()); } @@ -972,6 +967,7 @@ TEST_F(WebSocketTransportClientSocketPoolTest, EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); } request(kMaxSockets)->handle()->Reset(); + RunUntilIdle(); EXPECT_FALSE(pool_.IsStalled()); } @@ -1114,6 +1110,7 @@ TEST_F(WebSocketTransportClientSocketPoolTest, CancelRequestReclaimsSockets) { request(0)->handle()->Reset(); // calls CancelRequest() + RunUntilIdle(); // We should now be able to create a new connection without blocking on the // endpoint lock. EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); @@ -1123,11 +1120,12 @@ TEST_F(WebSocketTransportClientSocketPoolTest, CancelRequestReclaimsSockets) { // Endpoint, not two. TEST_F(WebSocketTransportClientSocketPoolTest, EndpointLockIsOnlyReleasedOnce) { host_resolver_->set_synchronous_mode(true); - EXPECT_EQ(OK, StartRequest("a", kDefaultPriority)); + ASSERT_EQ(OK, StartRequest("a", kDefaultPriority)); EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); // First socket completes handshake. WebSocketTransportClientSocketPool::UnlockEndpoint(request(0)->handle()); + RunUntilIdle(); // First socket is closed. request(0)->handle()->Reset(); // Second socket should have been released. diff --git a/chromium/net/socket/websocket_transport_connect_sub_job.cc b/chromium/net/socket/websocket_transport_connect_sub_job.cc index fbe8bbcc82c..25c0744cfac 100644 --- a/chromium/net/socket/websocket_transport_connect_sub_job.cc +++ b/chromium/net/socket/websocket_transport_connect_sub_job.cc @@ -7,7 +7,7 @@ #include "base/logging.h" #include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" -#include "net/base/net_log.h" +#include "net/log/net_log.h" #include "net/socket/client_socket_factory.h" #include "net/socket/websocket_endpoint_lock_manager.h" |