diff options
author | Allan Sandfeld Jensen <allan.jensen@qt.io> | 2018-12-10 16:19:40 +0100 |
---|---|---|
committer | Allan Sandfeld Jensen <allan.jensen@qt.io> | 2018-12-10 16:01:50 +0000 |
commit | 51f6c2793adab2d864b3d2b360000ef8db1d3e92 (patch) | |
tree | 835b3b4446b012c75e80177cef9fbe6972cc7dbe /chromium/services | |
parent | 6036726eb981b6c4b42047513b9d3f4ac865daac (diff) | |
download | qtwebengine-chromium-51f6c2793adab2d864b3d2b360000ef8db1d3e92.tar.gz |
BASELINE: Update Chromium to 71.0.3578.93
Change-Id: I6a32086c33670e1b033f8b10e6bf1fd4da1d105d
Reviewed-by: Alexandru Croitor <alexandru.croitor@qt.io>
Diffstat (limited to 'chromium/services')
532 files changed, 23995 insertions, 3994 deletions
diff --git a/chromium/services/BUILD.gn b/chromium/services/BUILD.gn index 2e1205ac974..07666f11e4b 100644 --- a/chromium/services/BUILD.gn +++ b/chromium/services/BUILD.gn @@ -153,7 +153,7 @@ if (is_android) { "//services/device/public/mojom:mojom_java", "//services/shape_detection:shape_detection_java", "//skia/public/interfaces:interfaces_java", - "//third_party/android_tools:android_support_annotations_java", + "//third_party/android_deps:android_support_annotations_java", ] } diff --git a/chromium/services/audio/BUILD.gn b/chromium/services/audio/BUILD.gn index bd04a85de6d..58ade003eb4 100644 --- a/chromium/services/audio/BUILD.gn +++ b/chromium/services/audio/BUILD.gn @@ -120,6 +120,7 @@ source_set("lib") { public_deps += [ "//sandbox/win:sandbox" ] } configs += [ + "//build/config/compiler:wexit_time_destructors", "//media:media_config", "//media/audio:platform_config", ] diff --git a/chromium/services/audio/OWNERS b/chromium/services/audio/OWNERS index 596e00e4c80..3c9700a16d7 100644 --- a/chromium/services/audio/OWNERS +++ b/chromium/services/audio/OWNERS @@ -6,6 +6,7 @@ miu@chromium.org per-file manifest.json=set noparent per-file manifest.json=file://ipc/SECURITY_OWNERS +per-file audio_sandbox_hook_linux.*=file://sandbox/linux/OWNERS per-file audio_sandbox_win.*=file://sandbox/win/OWNERS # COMPONENT: Internals>Media>Audio diff --git a/chromium/services/audio/device_notifier.h b/chromium/services/audio/device_notifier.h index 5898e2cab13..4f5f994e91e 100644 --- a/chromium/services/audio/device_notifier.h +++ b/chromium/services/audio/device_notifier.h @@ -8,7 +8,7 @@ #include <memory> #include "base/containers/flat_map.h" -#include "base/system_monitor/system_monitor.h" +#include "base/system/system_monitor.h" #include "mojo/public/cpp/bindings/binding_set.h" #include "services/audio/public/mojom/device_notifications.mojom.h" #include "services/audio/traced_service_ref.h" diff --git a/chromium/services/audio/device_notifier_unittest.cc b/chromium/services/audio/device_notifier_unittest.cc index 86b7c4ff22d..1bee5bcdea7 100644 --- a/chromium/services/audio/device_notifier_unittest.cc +++ b/chromium/services/audio/device_notifier_unittest.cc @@ -7,7 +7,7 @@ #include <memory> #include <utility> -#include "base/system_monitor/system_monitor.h" +#include "base/system/system_monitor.h" #include "base/test/scoped_task_environment.h" #include "services/audio/public/mojom/device_notifications.mojom.h" #include "services/audio/traced_service_ref.h" diff --git a/chromium/services/audio/group_coordinator-impl.h b/chromium/services/audio/group_coordinator-impl.h index dab5633f0e0..6e10c3bbc3f 100644 --- a/chromium/services/audio/group_coordinator-impl.h +++ b/chromium/services/audio/group_coordinator-impl.h @@ -5,6 +5,21 @@ #ifndef SERVICES_AUDIO_GROUP_COORDINATOR_IMPL_H_ #define SERVICES_AUDIO_GROUP_COORDINATOR_IMPL_H_ +#include "base/compiler_specific.h" +#include "base/no_destructor.h" + +#if DCHECK_IS_ON() +#define DCHECK_INCREMENT_MUTATION_COUNT() ++mutation_count_ +#define DCHECK_REMEMBER_CURRENT_MUTATION_COUNT() \ + const auto change_number = mutation_count_ +#define DCHECK_MUTATION_COUNT_UNCHANGED() \ + DCHECK_EQ(mutation_count_, change_number) +#else +#define DCHECK_INCREMENT_MUTATION_COUNT() +#define DCHECK_REMEMBER_CURRENT_MUTATION_COUNT() +#define DCHECK_MUTATION_COUNT_UNCHANGED() +#endif + namespace audio { template <typename Member> @@ -29,9 +44,12 @@ void GroupCoordinator<Member>::RegisterMember( std::vector<Member*>& members = it->second.members; DCHECK(!base::ContainsValue(members, member)); members.push_back(member); + DCHECK_INCREMENT_MUTATION_COUNT(); + DCHECK_REMEMBER_CURRENT_MUTATION_COUNT(); for (Observer* observer : it->second.observers) { observer->OnMemberJoinedGroup(member); + DCHECK_MUTATION_COUNT_UNCHANGED(); } } @@ -47,9 +65,12 @@ void GroupCoordinator<Member>::UnregisterMember( const auto member_it = std::find(members.begin(), members.end(), member); DCHECK(member_it != members.end()); members.erase(member_it); + DCHECK_INCREMENT_MUTATION_COUNT(); + DCHECK_REMEMBER_CURRENT_MUTATION_COUNT(); for (Observer* observer : group_it->second.observers) { observer->OnMemberLeftGroup(member); + DCHECK_MUTATION_COUNT_UNCHANGED(); } MaybePruneGroupMapEntry(group_it); @@ -65,6 +86,7 @@ void GroupCoordinator<Member>::AddObserver( std::vector<Observer*>& observers = FindGroup(group_id)->second.observers; DCHECK(!base::ContainsValue(observers, observer)); observers.push_back(observer); + DCHECK_INCREMENT_MUTATION_COUNT(); } template <typename Member> @@ -79,12 +101,28 @@ void GroupCoordinator<Member>::RemoveObserver( const auto it = std::find(observers.begin(), observers.end(), observer); DCHECK(it != observers.end()); observers.erase(it); + DCHECK_INCREMENT_MUTATION_COUNT(); MaybePruneGroupMapEntry(group_it); } template <typename Member> -const std::vector<Member*>& GroupCoordinator<Member>::GetCurrentMembers( +void GroupCoordinator<Member>::ForEachMemberInGroup( + const base::UnguessableToken& group_id, + base::RepeatingCallback<void(Member*)> callback) const { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + + DCHECK_REMEMBER_CURRENT_MUTATION_COUNT(); + for (Member* member : this->GetCurrentMembersUnsafe(group_id)) { + callback.Run(member); + // Note: If this fails, then not only is there a re-entrancy problem, but + // also the iterator being used by this for-loop is no longer valid! + DCHECK_MUTATION_COUNT_UNCHANGED(); + } +} + +template <typename Member> +const std::vector<Member*>& GroupCoordinator<Member>::GetCurrentMembersUnsafe( const base::UnguessableToken& group_id) const { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); @@ -94,8 +132,8 @@ const std::vector<Member*>& GroupCoordinator<Member>::GetCurrentMembers( } } - static const std::vector<Member*> empty_set; - return empty_set; + static const base::NoDestructor<std::vector<Member*>> empty_set; + return *empty_set; } template <typename Member> @@ -111,6 +149,7 @@ GroupCoordinator<Member>::FindGroup(const base::UnguessableToken& group_id) { groups_.emplace_back(); const auto new_it = groups_.end() - 1; new_it->first = group_id; + DCHECK_INCREMENT_MUTATION_COUNT(); return new_it; } @@ -119,6 +158,7 @@ void GroupCoordinator<Member>::MaybePruneGroupMapEntry( typename GroupMap::iterator it) { if (it->second.members.empty() && it->second.observers.empty()) { groups_.erase(it); + DCHECK_INCREMENT_MUTATION_COUNT(); } } @@ -138,4 +178,10 @@ operator=(GroupCoordinator::Group&& other) = default; } // namespace audio +#if DCHECK_IS_ON() +#undef DCHECK_INCREMENT_MUTATION_COUNT +#undef DCHECK_REMEMBER_CURRENT_MUTATION_COUNT +#undef DCHECK_MUTATION_COUNT_UNCHANGED +#endif + #endif // SERVICES_AUDIO_GROUP_COORDINATOR_IMPL_H_ diff --git a/chromium/services/audio/group_coordinator.h b/chromium/services/audio/group_coordinator.h index 14f49c49306..3f8ce11ba2b 100644 --- a/chromium/services/audio/group_coordinator.h +++ b/chromium/services/audio/group_coordinator.h @@ -9,6 +9,8 @@ #include <utility> #include <vector> +#include "base/callback.h" +#include "base/logging.h" #include "base/macros.h" #include "base/sequence_checker.h" #include "base/stl_util.h" @@ -44,10 +46,16 @@ class GroupCoordinator { void RemoveObserver(const base::UnguessableToken& group_id, Observer* observer); + // Runs a |callback| for each member associated with the given |group_id|. + void ForEachMemberInGroup( + const base::UnguessableToken& group_id, + base::RepeatingCallback<void(Member*)> callback) const; + + protected: // Returns the current members in the group having the given |group_id|. Note // that the validity of the returned reference is uncertain once any of the // other non-const methods are called. - const std::vector<Member*>& GetCurrentMembers( + const std::vector<Member*>& GetCurrentMembersUnsafe( const base::UnguessableToken& group_id) const; private: @@ -75,6 +83,13 @@ class GroupCoordinator { GroupMap groups_; +#if DCHECK_IS_ON() + // Incremented with each mutation, and used to sanity-check that there aren't + // any possible re-entrancy bugs. It's okay if this rolls over, since the + // implementation is only doing DCHECK_EQ's. + size_t mutation_count_ = 0; +#endif + SEQUENCE_CHECKER(sequence_checker_); DISALLOW_COPY_AND_ASSIGN(GroupCoordinator); diff --git a/chromium/services/audio/group_coordinator_unittest.cc b/chromium/services/audio/group_coordinator_unittest.cc index 39eb6f54028..91b81c5598c 100644 --- a/chromium/services/audio/group_coordinator_unittest.cc +++ b/chromium/services/audio/group_coordinator_unittest.cc @@ -21,9 +21,17 @@ using testing::_; namespace audio { +class TestGroupCoordinator : public GroupCoordinator<MockGroupMember> { + public: + const std::vector<MockGroupMember*>& GetCurrentMembers( + const base::UnguessableToken& group_id) const { + return GetCurrentMembersUnsafe(group_id); + } +}; + namespace { -class MockGroupObserver : public GroupCoordinator<MockGroupMember>::Observer { +class MockGroupObserver : public TestGroupCoordinator::Observer { public: MockGroupObserver() = default; ~MockGroupObserver() override = default; @@ -36,7 +44,7 @@ class MockGroupObserver : public GroupCoordinator<MockGroupMember>::Observer { }; TEST(GroupCoordinatorTest, NeverUsed) { - GroupCoordinator<MockGroupMember> coordinator; + TestGroupCoordinator coordinator; } TEST(GroupCoordinatorTest, RegistersMembersInSameGroup) { @@ -56,7 +64,7 @@ TEST(GroupCoordinatorTest, RegistersMembersInSameGroup) { EXPECT_CALL(observer, OnMemberLeftGroup(&member2)) .InSequence(join_leave_sequence); - GroupCoordinator<MockGroupMember> coordinator; + TestGroupCoordinator coordinator; coordinator.AddObserver(group_id, &observer); coordinator.RegisterMember(group_id, &member1); coordinator.RegisterMember(group_id, &member2); @@ -103,7 +111,7 @@ TEST(GroupCoordinatorTest, RegistersMembersInDifferentGroups) { EXPECT_CALL(observer_b, OnMemberLeftGroup(&member_b_1)) .InSequence(join_leave_sequence_b); - GroupCoordinator<MockGroupMember> coordinator; + TestGroupCoordinator coordinator; coordinator.AddObserver(group_id_a, &observer_a); coordinator.AddObserver(group_id_b, &observer_b); coordinator.RegisterMember(group_id_a, &member_a_1); @@ -140,7 +148,7 @@ TEST(GroupCoordinatorTest, TracksMembersWithoutAnObserverPresent) { StrictMock<MockGroupMember> member1; StrictMock<MockGroupMember> member2; - GroupCoordinator<MockGroupMember> coordinator; + TestGroupCoordinator coordinator; coordinator.RegisterMember(group_id, &member1); coordinator.RegisterMember(group_id, &member2); @@ -173,7 +181,7 @@ TEST(GroupCoordinatorTest, NotifiesOnlyWhileObserving) { .InSequence(join_leave_sequence); EXPECT_CALL(observer, OnMemberLeftGroup(&member2)).Times(0); - GroupCoordinator<MockGroupMember> coordinator; + TestGroupCoordinator coordinator; coordinator.RegisterMember(group_id, &member1); EXPECT_EQ(std::vector<MockGroupMember*>({&member1}), coordinator.GetCurrentMembers(group_id)); diff --git a/chromium/services/audio/local_muter.cc b/chromium/services/audio/local_muter.cc index 290c51f5c81..7e05374ebdb 100644 --- a/chromium/services/audio/local_muter.cc +++ b/chromium/services/audio/local_muter.cc @@ -16,10 +16,10 @@ LocalMuter::LocalMuter(LoopbackCoordinator* coordinator, DCHECK(coordinator_); coordinator_->AddObserver(group_id_, this); - for (LoopbackGroupMember* member : - coordinator_->GetCurrentMembers(group_id_)) { - member->StartMuting(); - } + coordinator_->ForEachMemberInGroup( + group_id_, base::BindRepeating([](LoopbackGroupMember* member) { + member->StartMuting(); + })); bindings_.set_connection_error_handler( base::BindRepeating(&LocalMuter::OnBindingLost, base::Unretained(this))); @@ -28,11 +28,10 @@ LocalMuter::LocalMuter(LoopbackCoordinator* coordinator, LocalMuter::~LocalMuter() { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); - for (LoopbackGroupMember* member : - coordinator_->GetCurrentMembers(group_id_)) { - member->StopMuting(); - } - + coordinator_->ForEachMemberInGroup( + group_id_, base::BindRepeating([](LoopbackGroupMember* member) { + member->StopMuting(); + })); coordinator_->RemoveObserver(group_id_, this); } diff --git a/chromium/services/audio/loopback_stream.cc b/chromium/services/audio/loopback_stream.cc index 866ca18925f..26c7b7eff66 100644 --- a/chromium/services/audio/loopback_stream.cc +++ b/chromium/services/audio/loopback_stream.cc @@ -108,9 +108,8 @@ LoopbackStream::~LoopbackStream() { if (network_) { if (network_->is_started()) { coordinator_->RemoveObserver(group_id_, this); - for (LoopbackGroupMember* member : - coordinator_->GetCurrentMembers(group_id_)) { - OnMemberLeftGroup(member); + while (!snoopers_.empty()) { + OnMemberLeftGroup(snoopers_.begin()->first); } } DCHECK(snoopers_.empty()); @@ -129,10 +128,9 @@ void LoopbackStream::Record() { // Begin snooping on all group members. This will set up the mixer network // and begin accumulating audio data in the Snoopers' buffers. DCHECK(snoopers_.empty()); - for (LoopbackGroupMember* member : - coordinator_->GetCurrentMembers(group_id_)) { - OnMemberJoinedGroup(member); - } + coordinator_->ForEachMemberInGroup( + group_id_, base::BindRepeating(&LoopbackStream::OnMemberJoinedGroup, + base::Unretained(this))); coordinator_->AddObserver(group_id_, this); // Start the data flow. diff --git a/chromium/services/audio/output_controller.cc b/chromium/services/audio/output_controller.cc index a8a77a1890d..643880330fc 100644 --- a/chromium/services/audio/output_controller.cc +++ b/chromium/services/audio/output_controller.cc @@ -10,6 +10,7 @@ #include "base/bind.h" #include "base/bind_helpers.h" +#include "base/compiler_specific.h" #include "base/metrics/histogram_macros.h" #include "base/numerics/safe_conversions.h" #include "base/stl_util.h" @@ -53,6 +54,31 @@ void LogStreamCreationResult(bool for_device_change, } } +void SanitizeAudioBus(media::AudioBus* bus) { + size_t channel_size = bus->frames(); + for (int i = 0; i < bus->channels(); ++i) { + float* channel = bus->channel(i); + for (size_t j = 0; j < channel_size; ++j) { + // First check for all the invalid cases with a single conditional to + // optimize for the typical (data ok) case. Different cases are handled + // inside of the conditional. The condition is written like this to catch + // NaN. It cannot be simplified to "channel[j] < -1.f || channel[j] > + // 1.f", which isn't equivalent. + if (UNLIKELY(!(channel[j] >= -1.f && channel[j] <= 1.f))) { + // Don't just set all bad values to 0. If a value like 1.0001 is + // produced due to floating-point shenanigans, 1 will sound better than + // 0. + if (channel[j] < -1.f) { + channel[j] = -1.f; + } else { + // channel[j] > 1 or NaN. + channel[j] = 1.f; + } + } + } + } +} + } // namespace OutputController::ErrorStatisticsTracker::ErrorStatisticsTracker() @@ -101,6 +127,7 @@ OutputController::OutputController( params_(params), handler_(handler), task_runner_(audio_manager->GetTaskRunner()), + construction_time_(base::TimeTicks::Now()), output_device_id_(output_device_id), stream_(NULL), disable_local_output_(false), @@ -127,6 +154,8 @@ OutputController::~OutputController() { DCHECK_EQ(nullptr, stream_); DCHECK(snoopers_.empty()); DCHECK(should_duplicate_.IsZero()); + UMA_HISTOGRAM_LONG_TIMES("Media.AudioOutputController.LifeTime", + base::TimeTicks::Now() - construction_time_); } bool OutputController::Create(bool is_for_device_change) { @@ -178,7 +207,7 @@ bool OutputController::Create(bool is_for_device_change) { LogStreamCreationResult(is_for_device_change, STREAM_CREATION_OK); - audio_manager_->AddOutputDeviceChangeListener(this); + audio_manager_->AddOutputDeviceChangeListener(this); // We have successfully opened the stream. Set the initial volume. stream_->SetVolume(volume_); @@ -190,10 +219,13 @@ bool OutputController::Create(bool is_for_device_change) { // Ensure new monitors know that we're active. stream_monitor_coordinator_->AddObserver(processing_id_, this); // Ensure existing monitors do as well. - for (StreamMonitor* monitor : - stream_monitor_coordinator_->GetCurrentMembers(processing_id_)) { - monitor->OnStreamActive(this); - } + stream_monitor_coordinator_->ForEachMemberInGroup( + processing_id_, + base::BindRepeating( + [](OutputController* controller, StreamMonitor* monitor) { + monitor->OnStreamActive(controller); + }, + this)); } return true; @@ -319,10 +351,13 @@ int OutputController::OnMoreData(base::TimeDelta delay, const base::TimeTicks reference_time = delay_timestamp + delay; - { + if (!dest->is_bitstream_format()) { base::AutoLock lock(realtime_snooper_lock_); - for (Snooper* snooper : realtime_snoopers_) { - snooper->OnData(*dest, reference_time, volume_); + if (!realtime_snoopers_.empty()) { + SanitizeAudioBus(dest); + for (Snooper* snooper : realtime_snoopers_) { + snooper->OnData(*dest, reference_time, volume_); + } } } @@ -333,7 +368,7 @@ int OutputController::OnMoreData(base::TimeDelta delay, sync_reader_->RequestMoreData(delay, delay_timestamp, prior_frames_skipped); - if (!should_duplicate_.IsZero()) { + if (!should_duplicate_.IsZero() && !dest->is_bitstream_format()) { std::unique_ptr<media::AudioBus> copy(media::AudioBus::Create(params_)); dest->CopyTo(copy.get()); task_runner_->PostTask( @@ -412,10 +447,13 @@ void OutputController::StopCloseAndClearStream() { // Don't send out activation messages for now. stream_monitor_coordinator_->RemoveObserver(processing_id_, this); // Ensure everyone monitoring us knows we're no-longer active. - for (StreamMonitor* monitor : - stream_monitor_coordinator_->GetCurrentMembers(processing_id_)) { - monitor->OnStreamInactive(this); - } + stream_monitor_coordinator_->ForEachMemberInGroup( + processing_id_, + base::BindRepeating( + [](OutputController* controller, StreamMonitor* monitor) { + monitor->OnStreamInactive(controller); + }, + this)); } StopStream(); @@ -563,7 +601,6 @@ void OutputController::OnDeviceChange() { } } - std::pair<float, bool> OutputController::ReadCurrentPowerAndClip() { DCHECK(will_monitor_audio_levels()); return power_monitor_.ReadCurrentPowerAndClip(); diff --git a/chromium/services/audio/output_controller.h b/chromium/services/audio/output_controller.h index 3589f132950..772d71e0e57 100644 --- a/chromium/services/audio/output_controller.h +++ b/chromium/services/audio/output_controller.h @@ -242,6 +242,10 @@ class OutputController : public media::AudioOutputStream::AudioSourceCallback, // via tasks run by this TaskRunner. const scoped_refptr<base::SingleThreadTaskRunner> task_runner_; + // Time when the controller is constructed. Used to record its lifetime on + // destruction. + const base::TimeTicks construction_time_; + // Specifies the device id of the output device to open or empty for the // default output device. const std::string output_device_id_; diff --git a/chromium/services/audio/public/cpp/BUILD.gn b/chromium/services/audio/public/cpp/BUILD.gn index 55f4c71843b..d06313e2218 100644 --- a/chromium/services/audio/public/cpp/BUILD.gn +++ b/chromium/services/audio/public/cpp/BUILD.gn @@ -20,6 +20,8 @@ source_set("cpp") { "output_device.h", ] + configs += [ "//build/config/compiler:wexit_time_destructors" ] + public_deps = [ "//base", "//media", diff --git a/chromium/services/audio/public/cpp/output_device_unittest.cc b/chromium/services/audio/public/cpp/output_device_unittest.cc index 258fd47a9b0..a7186a54cfb 100644 --- a/chromium/services/audio/public/cpp/output_device_unittest.cc +++ b/chromium/services/audio/public/cpp/output_device_unittest.cc @@ -182,7 +182,13 @@ TEST_F(AudioServiceOutputDeviceTest, CreatePlayPause) { task_env_.RunUntilIdle(); } -TEST_F(AudioServiceOutputDeviceTest, VerifyDataFlow) { +// Flaky on Linux Chromium OS ASan LSan (https://crbug.com/889845) +#if defined(OS_CHROMEOS) && defined(ADDRESS_SANITIZER) +#define MAYBE_VerifyDataFlow DISABLED_VerifyDataFlow +#else +#define MAYBE_VerifyDataFlow VerifyDataFlow +#endif +TEST_F(AudioServiceOutputDeviceTest, MAYBE_VerifyDataFlow) { auto params(media::AudioParameters::UnavailableDeviceParams()); params.set_frames_per_buffer(kFrames); ASSERT_EQ(2, params.channels()); diff --git a/chromium/services/audio/service.cc b/chromium/services/audio/service.cc index 62b2bafd461..a9850d81437 100644 --- a/chromium/services/audio/service.cc +++ b/chromium/services/audio/service.cc @@ -9,7 +9,7 @@ #include "base/logging.h" #include "base/macros.h" #include "base/single_thread_task_runner.h" -#include "base/system_monitor/system_monitor.h" +#include "base/system/system_monitor.h" #include "base/time/default_tick_clock.h" #include "base/trace_event/trace_event.h" #include "media/audio/audio_manager.h" diff --git a/chromium/services/catalog/BUILD.gn b/chromium/services/catalog/BUILD.gn index 2a27fe01708..033b0446f20 100644 --- a/chromium/services/catalog/BUILD.gn +++ b/chromium/services/catalog/BUILD.gn @@ -36,6 +36,8 @@ component("lib") { "service_options.h", ] + configs += [ "//build/config/compiler:wexit_time_destructors" ] + deps = [ ":constants", "//base", diff --git a/chromium/services/catalog/entry.cc b/chromium/services/catalog/entry.cc index 682bf029ef2..d5ea10c39b2 100644 --- a/chromium/services/catalog/entry.cc +++ b/chromium/services/catalog/entry.cc @@ -184,20 +184,22 @@ std::unique_ptr<Entry> Entry::Deserialize(const base::Value& manifest_root) { << instance_sharing; } - if (const base::Value* allow_other_user_ids_value = - options->FindKey("allow_other_user_ids")) - options_struct.allow_other_user_ids = - allow_other_user_ids_value->GetBool(); - - if (const base::Value* allow_other_instance_names_value = - options->FindKey("allow_other_instance_names")) - options_struct.allow_other_instance_names = - allow_other_instance_names_value->GetBool(); - - if (const base::Value* instance_for_client_process_value = - options->FindKey("instance_for_client_process")) - options_struct.instance_for_client_process = - instance_for_client_process_value->GetBool(); + if (const base::Value* can_connect_to_other_services_as_any_user_value = + options->FindKey("can_connect_to_other_services_as_any_user")) + options_struct.can_connect_to_other_services_as_any_user = + can_connect_to_other_services_as_any_user_value->GetBool(); + + if (const base::Value* + can_connect_to_other_services_with_any_instance_name_value = + options->FindKey( + "can_connect_to_other_services_with_any_instance_name")) + options_struct.can_connect_to_other_services_with_any_instance_name = + can_connect_to_other_services_with_any_instance_name_value->GetBool(); + + if (const base::Value* can_create_other_service_instances_value = + options->FindKey("can_create_other_service_instances")) + options_struct.can_create_other_service_instances = + can_create_other_service_instances_value->GetBool(); entry->AddOptions(std::move(options_struct)); } diff --git a/chromium/services/catalog/entry_unittest.cc b/chromium/services/catalog/entry_unittest.cc index eb0429a8b02..f7ab0662502 100644 --- a/chromium/services/catalog/entry_unittest.cc +++ b/chromium/services/catalog/entry_unittest.cc @@ -75,9 +75,10 @@ TEST_F(EntryTest, Options) { EXPECT_EQ(ServiceOptions::InstanceSharingType::SINGLETON, entry->options().instance_sharing); - EXPECT_TRUE(entry->options().allow_other_user_ids); - EXPECT_TRUE(entry->options().allow_other_instance_names); - EXPECT_TRUE(entry->options().instance_for_client_process); + EXPECT_TRUE(entry->options().can_connect_to_other_services_as_any_user); + EXPECT_TRUE( + entry->options().can_connect_to_other_services_with_any_instance_name); + EXPECT_TRUE(entry->options().can_create_other_service_instances); EXPECT_EQ("", entry->sandbox_type()); } diff --git a/chromium/services/catalog/public/cpp/BUILD.gn b/chromium/services/catalog/public/cpp/BUILD.gn index febd1d72645..7eb6bdae97c 100644 --- a/chromium/services/catalog/public/cpp/BUILD.gn +++ b/chromium/services/catalog/public/cpp/BUILD.gn @@ -10,6 +10,8 @@ source_set("cpp") { "resource_loader.h", ] + configs += [ "//build/config/compiler:wexit_time_destructors" ] + deps = [ "//base", "//components/services/filesystem/public/interfaces", diff --git a/chromium/services/catalog/service_options.h b/chromium/services/catalog/service_options.h index 4499b4b3f94..a38b2e8b98f 100644 --- a/chromium/services/catalog/service_options.h +++ b/chromium/services/catalog/service_options.h @@ -15,9 +15,9 @@ struct ServiceOptions { }; InstanceSharingType instance_sharing = InstanceSharingType::NONE; - bool allow_other_user_ids = false; - bool allow_other_instance_names = false; - bool instance_for_client_process = false; + bool can_connect_to_other_services_as_any_user = false; + bool can_connect_to_other_services_with_any_instance_name = false; + bool can_create_other_service_instances = false; }; } // namespace catalog diff --git a/chromium/services/catalog/test_data/options b/chromium/services/catalog/test_data/options index 64ee077897c..203a4e6538d 100644 --- a/chromium/services/catalog/test_data/options +++ b/chromium/services/catalog/test_data/options @@ -3,9 +3,9 @@ "display_name": "Foo", "options": { "instance_sharing": "singleton", - "allow_other_user_ids": true, - "allow_other_instance_names": true, - "instance_for_client_process": true + "can_connect_to_other_services_as_any_user": true, + "can_connect_to_other_services_with_any_instance_name": true, + "can_create_other_service_instances": true }, "interface_provider_specs": { } } diff --git a/chromium/services/content/BUILD.gn b/chromium/services/content/BUILD.gn index a9b95dee876..f972bd4f278 100644 --- a/chromium/services/content/BUILD.gn +++ b/chromium/services/content/BUILD.gn @@ -26,6 +26,8 @@ source_set("impl") { "service.cc", ] + configs += [ "//build/config/compiler:wexit_time_destructors" ] + public_deps = [ "//base", "//services/content/public/cpp:buildflags", diff --git a/chromium/services/content/DEPS b/chromium/services/content/DEPS index 7c7f37f94c6..f3ff2940ecd 100644 --- a/chromium/services/content/DEPS +++ b/chromium/services/content/DEPS @@ -1,4 +1,6 @@ include_rules = [ + "+services/network/public", + "+ui/aura", "+ui/base", "+ui/gfx", diff --git a/chromium/services/content/navigable_contents_delegate.h b/chromium/services/content/navigable_contents_delegate.h index 9e8208b1b59..55ed240f791 100644 --- a/chromium/services/content/navigable_contents_delegate.h +++ b/chromium/services/content/navigable_contents_delegate.h @@ -5,6 +5,7 @@ #ifndef SERVICES_CONTENT_NAVIGABLE_CONTENTS_DELEGATE_H_ #define SERVICES_CONTENT_NAVIGABLE_CONTENTS_DELEGATE_H_ +#include "services/content/public/mojom/navigable_contents.mojom.h" #include "ui/gfx/native_widget_types.h" class GURL; @@ -31,7 +32,7 @@ class NavigableContentsDelegate { virtual gfx::NativeView GetNativeView() = 0; // Navigates the content object to a new URL. - virtual void Navigate(const GURL& url) = 0; + virtual void Navigate(const GURL& url, mojom::NavigateParamsPtr params) = 0; }; } // namespace content diff --git a/chromium/services/content/navigable_contents_impl.cc b/chromium/services/content/navigable_contents_impl.cc index e518888c827..da1f99ce4c0 100644 --- a/chromium/services/content/navigable_contents_impl.cc +++ b/chromium/services/content/navigable_contents_impl.cc @@ -35,7 +35,8 @@ NavigableContentsImpl::NavigableContentsImpl( binding_(this, std::move(request)), client_(std::move(client)), delegate_( - service_->delegate()->CreateNavigableContentsDelegate(client_.get())), + service_->delegate()->CreateNavigableContentsDelegate(*params, + client_.get())), native_content_view_(delegate_->GetNativeView()) { binding_.set_connection_error_handler(base::BindRepeating( &Service::RemoveNavigableContents, base::Unretained(service_), this)); @@ -43,12 +44,13 @@ NavigableContentsImpl::NavigableContentsImpl( NavigableContentsImpl::~NavigableContentsImpl() = default; -void NavigableContentsImpl::Navigate(const GURL& url) { +void NavigableContentsImpl::Navigate(const GURL& url, + mojom::NavigateParamsPtr params) { // Ignore non-HTTP/HTTPS requests for now. if (!url.SchemeIsHTTPOrHTTPS()) return; - delegate_->Navigate(url); + delegate_->Navigate(url, std::move(params)); } void NavigableContentsImpl::CreateView(bool in_service_process, @@ -85,10 +87,10 @@ void NavigableContentsImpl::CreateView(bool in_service_process, void NavigableContentsImpl::OnEmbedTokenReceived( CreateViewCallback callback, const base::UnguessableToken& token) { -#if defined(TOOLKIT_VIEWS) - if (native_content_view_) - native_content_view_->Show(); -#endif // defined(TOOLKIT_VIEWS) +#if defined(TOOLKIT_VIEWS) && defined(USE_AURA) + DCHECK(native_content_view_); + native_content_view_->Show(); +#endif // defined(TOOLKIT_VIEWS) && defined(USE_AURA) std::move(callback).Run(token); } #endif // BUILDFLAG(ENABLE_REMOTE_NAVIGABLE_CONTENTS_VIEW) @@ -96,13 +98,9 @@ void NavigableContentsImpl::OnEmbedTokenReceived( void NavigableContentsImpl::EmbedInProcessClientView( NavigableContentsView* view) { DCHECK(native_content_view_); -#if defined(TOOLKIT_VIEWS) - DCHECK(!local_view_host_); - local_view_host_ = std::make_unique<views::NativeViewHost>(); - local_view_host_->set_owned_by_client(); - view->view()->AddChildView(local_view_host_.get()); - view->view()->Layout(); - local_view_host_->Attach(native_content_view_); +#if defined(TOOLKIT_VIEWS) && defined(USE_AURA) + view->native_view()->AddChild(native_content_view_); + native_content_view_->Show(); #else // TODO(https://crbug.com/855092): Support embedding of other native client // views without Views + Aura. diff --git a/chromium/services/content/navigable_contents_impl.h b/chromium/services/content/navigable_contents_impl.h index 54d0a6b37de..160257a049b 100644 --- a/chromium/services/content/navigable_contents_impl.h +++ b/chromium/services/content/navigable_contents_impl.h @@ -15,7 +15,6 @@ #include "ui/gfx/native_widget_types.h" namespace views { -class NativeViewHost; class RemoteViewProvider; } @@ -38,7 +37,7 @@ class NavigableContentsImpl : public mojom::NavigableContents { private: // mojom::NavigableContents: - void Navigate(const GURL& url) override; + void Navigate(const GURL& url, mojom::NavigateParamsPtr params) override; void CreateView(bool in_service_process, CreateViewCallback callback) override; @@ -62,13 +61,6 @@ class NavigableContentsImpl : public mojom::NavigableContents { std::unique_ptr<views::RemoteViewProvider> remote_view_provider_; #endif -#if defined(TOOLKIT_VIEWS) - // Used to support local view embedding in cases where remote embedding is - // not supported and the client controlling this NavigableContents is running - // within the same process as the Content Service. - std::unique_ptr<views::NativeViewHost> local_view_host_; -#endif - base::WeakPtrFactory<NavigableContentsImpl> weak_ptr_factory_{this}; DISALLOW_COPY_AND_ASSIGN(NavigableContentsImpl); diff --git a/chromium/services/content/public/cpp/BUILD.gn b/chromium/services/content/public/cpp/BUILD.gn index 171af0cbef9..66cd4a1d50e 100644 --- a/chromium/services/content/public/cpp/BUILD.gn +++ b/chromium/services/content/public/cpp/BUILD.gn @@ -14,6 +14,7 @@ component("cpp") { public = [ "navigable_contents.h", + "navigable_contents_observer.h", "navigable_contents_view.h", ] @@ -22,12 +23,17 @@ component("cpp") { "navigable_contents_view.cc", ] + configs += [ "//build/config/compiler:wexit_time_destructors" ] + defines = [ "IS_CONTENT_SERVICE_CPP_IMPL" ] public_deps = [ ":buildflags", "//base", + "//net", "//services/content/public/mojom", + "//ui/gfx:native_widget_types", + "//ui/gfx/geometry", "//url", ] @@ -45,4 +51,8 @@ component("cpp") { ] } } + + if (use_aura) { + deps += [ "//ui/aura" ] + } } diff --git a/chromium/services/content/public/cpp/DEPS b/chromium/services/content/public/cpp/DEPS index 51a5fb7e88f..8640a7517e0 100644 --- a/chromium/services/content/public/cpp/DEPS +++ b/chromium/services/content/public/cpp/DEPS @@ -1,5 +1,8 @@ include_rules = [ + "+net/http/http_response_headers.h", + "+services/ws/public", + "+ui/aura", "+ui/views", ] diff --git a/chromium/services/content/public/cpp/navigable_contents.cc b/chromium/services/content/public/cpp/navigable_contents.cc index 3f1c8dd1ced..cfce9ca9c89 100644 --- a/chromium/services/content/public/cpp/navigable_contents.cc +++ b/chromium/services/content/public/cpp/navigable_contents.cc @@ -10,15 +10,27 @@ namespace content { NavigableContents::NavigableContents(mojom::NavigableContentsFactory* factory) + : NavigableContents(factory, mojom::NavigableContentsParams::New()) {} + +NavigableContents::NavigableContents(mojom::NavigableContentsFactory* factory, + mojom::NavigableContentsParamsPtr params) : client_binding_(this) { mojom::NavigableContentsClientPtr client; client_binding_.Bind(mojo::MakeRequest(&client)); - factory->CreateContents(mojom::NavigableContentsParams::New(), - mojo::MakeRequest(&contents_), std::move(client)); + factory->CreateContents(std::move(params), mojo::MakeRequest(&contents_), + std::move(client)); } NavigableContents::~NavigableContents() = default; +void NavigableContents::AddObserver(NavigableContentsObserver* observer) { + observers_.AddObserver(observer); +} + +void NavigableContents::RemoveObserver(NavigableContentsObserver* observer) { + observers_.RemoveObserver(observer); +} + NavigableContentsView* NavigableContents::GetView() { if (!view_) { view_ = base::WrapUnique(new NavigableContentsView); @@ -31,12 +43,33 @@ NavigableContentsView* NavigableContents::GetView() { } void NavigableContents::Navigate(const GURL& url) { - contents_->Navigate(url); + NavigateWithParams(url, mojom::NavigateParams::New()); +} + +void NavigableContents::NavigateWithParams(const GURL& url, + mojom::NavigateParamsPtr params) { + contents_->Navigate(url, std::move(params)); +} + +void NavigableContents::DidFinishNavigation( + const GURL& url, + bool is_main_frame, + bool is_error_page, + const scoped_refptr<net::HttpResponseHeaders>& response_headers) { + for (auto& observer : observers_) { + observer.DidFinishNavigation(url, is_main_frame, is_error_page, + response_headers.get()); + } } void NavigableContents::DidStopLoading() { - if (did_stop_loading_callback_) - did_stop_loading_callback_.Run(); + for (auto& observer : observers_) + observer.DidStopLoading(); +} + +void NavigableContents::DidAutoResizeView(const gfx::Size& new_size) { + for (auto& observer : observers_) + observer.DidAutoResizeView(new_size); } void NavigableContents::OnEmbedTokenReceived( diff --git a/chromium/services/content/public/cpp/navigable_contents.h b/chromium/services/content/public/cpp/navigable_contents.h index d9b5d6e68b8..7078b2d3a86 100644 --- a/chromium/services/content/public/cpp/navigable_contents.h +++ b/chromium/services/content/public/cpp/navigable_contents.h @@ -10,7 +10,9 @@ #include "base/callback.h" #include "base/component_export.h" #include "base/macros.h" +#include "base/observer_list.h" #include "mojo/public/cpp/bindings/binding.h" +#include "services/content/public/cpp/navigable_contents_observer.h" #include "services/content/public/mojom/navigable_contents.mojom.h" #include "services/content/public/mojom/navigable_contents_factory.mojom.h" @@ -28,8 +30,14 @@ class COMPONENT_EXPORT(CONTENT_SERVICE_CPP) NavigableContents public: // Constructs a new NavigableContents using |factory|. explicit NavigableContents(mojom::NavigableContentsFactory* factory); + NavigableContents(mojom::NavigableContentsFactory* factory, + mojom::NavigableContentsParamsPtr params); ~NavigableContents() override; + // These methods NavigableContentsObservers registered on this object. + void AddObserver(NavigableContentsObserver* observer); + void RemoveObserver(NavigableContentsObserver* observer); + // Returns a NavigableContentsView which renders this NavigableContents's // currently navigated contents. This widget can be parented and displayed // anywhere within the application's own window tree. @@ -42,15 +50,17 @@ class COMPONENT_EXPORT(CONTENT_SERVICE_CPP) NavigableContents // Begins an attempt to asynchronously navigate this NavigableContents to // |url|. void Navigate(const GURL& url); - - void set_did_stop_loading_callback_for_testing( - base::RepeatingClosure callback) { - did_stop_loading_callback_ = std::move(callback); - } + void NavigateWithParams(const GURL& url, mojom::NavigateParamsPtr params); private: // mojom::NavigableContentsClient: + void DidFinishNavigation( + const GURL& url, + bool is_main_frame, + bool is_error_page, + const scoped_refptr<net::HttpResponseHeaders>& response_headers) override; void DidStopLoading() override; + void DidAutoResizeView(const gfx::Size& new_size) override; void OnEmbedTokenReceived(const base::UnguessableToken& token); @@ -58,7 +68,7 @@ class COMPONENT_EXPORT(CONTENT_SERVICE_CPP) NavigableContents mojo::Binding<mojom::NavigableContentsClient> client_binding_; std::unique_ptr<NavigableContentsView> view_; - base::RepeatingClosure did_stop_loading_callback_; + base::ReentrantObserverList<NavigableContentsObserver> observers_; DISALLOW_COPY_AND_ASSIGN(NavigableContents); }; diff --git a/chromium/services/content/public/cpp/navigable_contents_observer.h b/chromium/services/content/public/cpp/navigable_contents_observer.h new file mode 100644 index 00000000000..c863de441b6 --- /dev/null +++ b/chromium/services/content/public/cpp/navigable_contents_observer.h @@ -0,0 +1,30 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef SERVICES_CONTENT_PUBLIC_CPP_NAVIGABLE_CONTENTS_OBSERVER_H_ +#define SERVICES_CONTENT_PUBLIC_CPP_NAVIGABLE_CONTENTS_OBSERVER_H_ + +#include "base/component_export.h" +#include "base/observer_list_types.h" +#include "net/http/http_response_headers.h" +#include "ui/gfx/geometry/size.h" +#include "url/gurl.h" + +namespace content { + +class COMPONENT_EXPORT(CONTENT_SERVICE_CPP) NavigableContentsObserver + : public base::CheckedObserver { + public: + virtual void DidFinishNavigation( + const GURL& url, + bool is_main_frame, + bool is_error_page, + const net::HttpResponseHeaders* response_headers) {} + virtual void DidStopLoading() {} + virtual void DidAutoResizeView(const gfx::Size& new_size) {} +}; + +} // namespace content + +#endif // SERVICES_CONTENT_PUBLIC_CPP_NAVIGABLE_CONTENTS_OBSERVER_H_ diff --git a/chromium/services/content/public/cpp/navigable_contents_view.cc b/chromium/services/content/public/cpp/navigable_contents_view.cc index fd5c1f6d651..44596695d09 100644 --- a/chromium/services/content/public/cpp/navigable_contents_view.cc +++ b/chromium/services/content/public/cpp/navigable_contents_view.cc @@ -19,10 +19,16 @@ #if BUILDFLAG(ENABLE_REMOTE_NAVIGABLE_CONTENTS_VIEW) #include "services/ws/public/mojom/window_tree_constants.mojom.h" // nogncheck #include "ui/base/ui_base_features.h" // nogncheck +#include "ui/views/controls/native/native_view_host.h" // nogncheck #include "ui/views/mus/remote_view/remote_view_host.h" // nogncheck #endif // BUILDFLAG(ENABLE_REMOTE_NAVIGABLE_CONTENTS_VIEW) #endif // defined(TOOLKIT_VIEWS) +#if defined(USE_AURA) +#include "ui/aura/layout_manager.h" // nogncheck +#include "ui/aura/window.h" // nogncheck +#endif + namespace content { namespace { @@ -41,6 +47,71 @@ base::AtomicFlag& GetInServiceProcessFlag() { return *in_service_process; } +#if BUILDFLAG(ENABLE_REMOTE_NAVIGABLE_CONTENTS_VIEW) + +std::unique_ptr<NavigableContentsView::RemoteViewManager>& +GetRemoteViewManager() { + static base::NoDestructor< + std::unique_ptr<NavigableContentsView::RemoteViewManager>> + manager; + return *manager; +} + +#endif // BUILDFLAG(ENABLE_REMOTE_NAVIGABLE_CONTENTS_VIEW) + +#if defined(TOOLKIT_VIEWS) && defined(USE_AURA) + +// Keeps child windows sized to the same bounds as the owning window. +class LocalWindowLayoutManager : public aura::LayoutManager { + public: + explicit LocalWindowLayoutManager(aura::Window* owner) : owner_(owner) {} + ~LocalWindowLayoutManager() override = default; + + // aura::LayoutManger: + void OnWindowResized() override { ResizeChildren(); } + void OnWindowAddedToLayout(aura::Window* child) override { ResizeChildren(); } + void OnWillRemoveWindowFromLayout(aura::Window* child) override {} + void OnWindowRemovedFromLayout(aura::Window* child) override {} + void OnChildWindowVisibilityChanged(aura::Window* child, + bool visible) override {} + void SetChildBounds(aura::Window* child, + const gfx::Rect& requested_bounds) override {} + + private: + void ResizeChildren() { + for (auto* child : owner_->children()) + SetChildBoundsDirect(child, owner_->bounds()); + } + + aura::Window* const owner_; + + DISALLOW_COPY_AND_ASSIGN(LocalWindowLayoutManager); +}; + +// Owns an Aura window which parents another Aura window in the same process, +// corresponding to a web contents view hosted in the process. +class LocalViewHost : public views::NativeViewHost { + public: + explicit LocalViewHost(aura::Window* window) : window_(window) { + window_->SetLayoutManager(new LocalWindowLayoutManager(window_)); + } + + ~LocalViewHost() override = default; + + // views::View: + void AddedToWidget() override { + if (!native_view()) + Attach(window_); + } + + private: + aura::Window* const window_; + + DISALLOW_COPY_AND_ASSIGN(LocalViewHost); +}; + +#endif // defined(TOOLKIT_VIEWS) && defined(USE_AURA) + } // namespace NavigableContentsView::~NavigableContentsView() = default; @@ -56,29 +127,57 @@ bool NavigableContentsView::IsClientRunningInServiceProcess() { } NavigableContentsView::NavigableContentsView() { -#if defined(TOOLKIT_VIEWS) - view_ = std::make_unique<views::View>(); - view_->set_owned_by_client(); - view_->SetLayoutManager(std::make_unique<views::FillLayout>()); +#if defined(TOOLKIT_VIEWS) && defined(USE_AURA) #if BUILDFLAG(ENABLE_REMOTE_NAVIGABLE_CONTENTS_VIEW) if (!IsClientRunningInServiceProcess()) { - DCHECK(!remote_view_host_); - remote_view_host_ = new views::RemoteViewHost; - view_->AddChildView(remote_view_host_); + RemoteViewManager* manager = GetRemoteViewManager().get(); + if (manager) + view_ = manager->CreateRemoteViewHost(); + else + view_ = std::make_unique<views::RemoteViewHost>(); + view_->set_owned_by_client(); + return; } #endif // BUILDFLAG(ENABLE_REMOTE_NAVIGABLE_CONTENTS_VIEW) -#endif // defined(TOOLKIT_VIEWS) + + window_ = std::make_unique<aura::Window>(nullptr); + window_->set_owned_by_parent(false); + window_->SetName("NavigableContentsViewWindow"); + window_->SetType(aura::client::WINDOW_TYPE_CONTROL); + window_->Init(ui::LAYER_NOT_DRAWN); + window_->Show(); + + view_ = std::make_unique<LocalViewHost>(window_.get()); + view_->set_owned_by_client(); +#endif // defined(TOOLKIT_VIEWS) && defined(USE_AURA) } +#if BUILDFLAG(ENABLE_REMOTE_NAVIGABLE_CONTENTS_VIEW) + +// static +void NavigableContentsView::SetRemoteViewManager( + std::unique_ptr<RemoteViewManager> manager) { + GetRemoteViewManager() = std::move(manager); +} + +#endif // BUILDFLAG(ENABLE_REMOTE_NAVIGABLE_CONTENTS_VIEW) + void NavigableContentsView::EmbedUsingToken( const base::UnguessableToken& token) { #if defined(TOOLKIT_VIEWS) #if BUILDFLAG(ENABLE_REMOTE_NAVIGABLE_CONTENTS_VIEW) - if (remote_view_host_) { - const uint32_t kEmbedFlags = - ws::mojom::kEmbedFlagEmbedderInterceptsEvents | - ws::mojom::kEmbedFlagEmbedderControlsVisibility; - remote_view_host_->EmbedUsingToken(token, kEmbedFlags, base::DoNothing()); + if (!IsClientRunningInServiceProcess()) { + RemoteViewManager* manager = GetRemoteViewManager().get(); + if (manager) { + manager->EmbedUsingToken(view_.get(), token); + } else { + constexpr uint32_t kEmbedFlags = + ws::mojom::kEmbedFlagEmbedderInterceptsEvents | + ws::mojom::kEmbedFlagEmbedderControlsVisibility; + static_cast<views::RemoteViewHost*>(view_.get()) + ->EmbedUsingToken(token, kEmbedFlags, base::DoNothing()); + } + return; } #endif // BUILDFLAG(ENABLE_REMOTE_NAVIGABLE_CONTENTS_VIEW) @@ -94,6 +193,9 @@ void NavigableContentsView::EmbedUsingToken( return; } + // Invoke a callback provided by the Content Service's host environment. This + // should parent a web content view to our own |view()|, as well as set + // |native_view_| to the corresponding web contents' own NativeView. auto callback = std::move(it->second); embeddings.erase(it); std::move(callback).Run(this); diff --git a/chromium/services/content/public/cpp/navigable_contents_view.h b/chromium/services/content/public/cpp/navigable_contents_view.h index fd1d7573990..4b1dab9777a 100644 --- a/chromium/services/content/public/cpp/navigable_contents_view.h +++ b/chromium/services/content/public/cpp/navigable_contents_view.h @@ -7,13 +7,23 @@ #include <memory> +#include "base/callback.h" #include "base/component_export.h" #include "base/unguessable_token.h" #include "services/content/public/cpp/buildflags.h" +#include "ui/gfx/native_widget_types.h" + +#if defined(TOOLKIT_VIEWS) && defined(USE_AURA) +#include "ui/views/controls/native/native_view_host.h" // nogncheck +#endif + +namespace aura { +class Window; +} namespace views { -class RemoteViewHost; class View; +class NativeViewHost; } // namespace views namespace content { @@ -26,9 +36,32 @@ class NavigableContentsImpl; // either Views, UIKit, AppKit, or the Android Framework. // // TODO(https://crbug.com/855092): Actually support UI frameworks other than -// Views UI. +// Views UI on Aura. class COMPONENT_EXPORT(CONTENT_SERVICE_CPP) NavigableContentsView { public: +#if BUILDFLAG(ENABLE_REMOTE_NAVIGABLE_CONTENTS_VIEW) + // May be used if the Content Service client is running within a process whose + // UI environment requires a different remote View implementation from + // the default one. For example, on Chrome OS when Ash and the Window Service + // are running in the same process, the default implementation + // (views::RemoteViewHost) will not work. + class RemoteViewManager { + public: + virtual ~RemoteViewManager() {} + + // Creates a new NativeViewHost suitable for remote embedding. + virtual std::unique_ptr<views::NativeViewHost> CreateRemoteViewHost() = 0; + + // Initiates an embedding of a remote client -- identified by |token| -- + // within |view_host|. Note that |view_host| is always an object returned by + // |CreateRemoteViewHost()| on the same RemoteViewManager. + virtual void EmbedUsingToken(views::NativeViewHost* view_host, + const base::UnguessableToken& token) = 0; + }; + + static void SetRemoteViewManager(std::unique_ptr<RemoteViewManager> manager); +#endif + ~NavigableContentsView(); // Used to set/query whether the calling process is the same process in which @@ -39,9 +72,11 @@ class COMPONENT_EXPORT(CONTENT_SERVICE_CPP) NavigableContentsView { static void SetClientRunningInServiceProcess(); static bool IsClientRunningInServiceProcess(); -#if defined(TOOLKIT_VIEWS) +#if defined(TOOLKIT_VIEWS) && defined(USE_AURA) views::View* view() const { return view_.get(); } -#endif + + gfx::NativeView native_view() const { return view_->native_view(); } +#endif // defined(TOOLKIT_VIEWS) && defined(USE_AURA) private: friend class NavigableContents; @@ -59,15 +94,11 @@ class COMPONENT_EXPORT(CONTENT_SERVICE_CPP) NavigableContentsView { const base::UnguessableToken& token, base::OnceCallback<void(NavigableContentsView*)> callback); -#if defined(TOOLKIT_VIEWS) - // This NavigableContents's View. Only initialized if |GetView()| is called, - // and only on platforms which support Views UI. - std::unique_ptr<views::View> view_; - -#if BUILDFLAG(ENABLE_REMOTE_NAVIGABLE_CONTENTS_VIEW) - views::RemoteViewHost* remote_view_host_ = nullptr; -#endif -#endif // BUILDFLAG(TOOLKIT_VIEWS) +#if defined(TOOLKIT_VIEWS) && defined(USE_AURA) + // This NavigableContents's Window and corresponding View. + std::unique_ptr<aura::Window> window_; + std::unique_ptr<views::NativeViewHost> view_; +#endif // defined(TOOLKIT_VIEWS) && defined(USE_AURA) DISALLOW_COPY_AND_ASSIGN(NavigableContentsView); }; diff --git a/chromium/services/content/public/mojom/BUILD.gn b/chromium/services/content/public/mojom/BUILD.gn index 830829538cb..58b7e0dabd6 100644 --- a/chromium/services/content/public/mojom/BUILD.gn +++ b/chromium/services/content/public/mojom/BUILD.gn @@ -19,6 +19,8 @@ mojom_component("mojom") { public_deps = [ "//mojo/public/mojom/base", + "//services/network/public/mojom:websocket_mojom", + "//ui/gfx/geometry/mojo", "//url/mojom:url_mojom_gurl", ] diff --git a/chromium/services/content/public/mojom/navigable_contents.mojom b/chromium/services/content/public/mojom/navigable_contents.mojom index 620280eec57..2797df25f11 100644 --- a/chromium/services/content/public/mojom/navigable_contents.mojom +++ b/chromium/services/content/public/mojom/navigable_contents.mojom @@ -5,15 +5,25 @@ module content.mojom; import "mojo/public/mojom/base/unguessable_token.mojom"; +import "services/network/public/mojom/network_param.mojom"; +import "ui/gfx/geometry/mojo/geometry.mojom"; import "url/mojom/url.mojom"; +// Parameters used to configure the behavior of |NavigableContents.Navigate|. +struct NavigateParams { + // Indicates that upon successful navigation, the session history should be + // cleared, resulting in the navigated page being the first and only entry in + // the session's history. + bool should_clear_session_history = false; +}; + // The primary interface an application uses to drive a top-level, navigable // content object. Typically this would correspond to e.g. a browser tab, but // it is not strictly necessary that the contents have any graphical presence // within the client application. interface NavigableContents { // Initiates a navigation to |url|. - Navigate(url.mojom.Url url); + Navigate(url.mojom.Url url, NavigateParams params); // Creates a visual representation of the navigated contents, which is // maintained by the Content Service. Responds with a |embed_token| which can @@ -34,7 +44,20 @@ interface NavigableContents { // A client interface used by the Content Service to push contents-scoped events // back to the application. interface NavigableContentsClient { + // Notifies the client that a navigation has finished. + DidFinishNavigation(url.mojom.Url url, + bool is_main_frame, + bool is_error_page, + network.mojom.HttpResponseHeaders? response_headers); + // Notifies the client that the NavigableContents has stopped loading // resources pertaining to a prior navigation request. DidStopLoading(); + + // Indicates that the navigated contents changed in such a way as to elicit + // automatic resizing of the containing view. Only fired if + // |NavigableContentsParams.enable_view_auto_resize| was set to |true| when + // creating the corresponding NavigableContents. The client may use this as a + // signal to, e.g., resize a UI element containing the content view. + DidAutoResizeView(gfx.mojom.Size new_size); }; diff --git a/chromium/services/content/public/mojom/navigable_contents_factory.mojom b/chromium/services/content/public/mojom/navigable_contents_factory.mojom index b79c633d040..acb42101eea 100644 --- a/chromium/services/content/public/mojom/navigable_contents_factory.mojom +++ b/chromium/services/content/public/mojom/navigable_contents_factory.mojom @@ -7,7 +7,12 @@ module content.mojom; import "services/content/public/mojom/navigable_contents.mojom"; // Parameters used to configure a newly created NavigableContents. -struct NavigableContentsParams {}; +struct NavigableContentsParams { + // Enables auto-resizing of any view created for this NavigableContents. If + // |true|, the corresponding NavigableContentsClient will receive + // |DidAutoResizeView()| notifications whenever such resizing happens. + bool enable_view_auto_resize = false; +}; // NavigableContentsFactory is the primary interface through which a new // NavigableContents interface is bound to a new concrete navigable contents diff --git a/chromium/services/content/service_delegate.h b/chromium/services/content/service_delegate.h index 650a2e1b6be..9749bf27218 100644 --- a/chromium/services/content/service_delegate.h +++ b/chromium/services/content/service_delegate.h @@ -6,6 +6,7 @@ #define SERVICES_CONTENT_SERVICE_DELEGATE_H_ #include "services/content/public/mojom/navigable_contents.mojom.h" +#include "services/content/public/mojom/navigable_contents_factory.mojom.h" namespace content { @@ -33,7 +34,8 @@ class ServiceDelegate { // |client| is a NavigableContentsClient interface the implementation can use // to communicate with the client of this contents. virtual std::unique_ptr<NavigableContentsDelegate> - CreateNavigableContentsDelegate(mojom::NavigableContentsClient* client) = 0; + CreateNavigableContentsDelegate(const mojom::NavigableContentsParams& params, + mojom::NavigableContentsClient* client) = 0; }; }; // namespace content diff --git a/chromium/services/content/service_unittest.cc b/chromium/services/content/service_unittest.cc index d0e08dac12e..a97e3608cde 100644 --- a/chromium/services/content/service_unittest.cc +++ b/chromium/services/content/service_unittest.cc @@ -30,7 +30,13 @@ class TestNavigableContentsClient : public mojom::NavigableContentsClient { private: // mojom::NavigableContentsClient: + void DidFinishNavigation(const GURL& url, + bool is_main_frame, + bool is_error_page, + const scoped_refptr<net::HttpResponseHeaders>& + response_headers) override {} void DidStopLoading() override {} + void DidAutoResizeView(const gfx::Size& new_size) override {} DISALLOW_COPY_AND_ASSIGN(TestNavigableContentsClient); }; @@ -47,7 +53,7 @@ class TestNavigableContentsDelegate : public NavigableContentsDelegate { } // NavigableContentsDelegate: - void Navigate(const GURL& url) override { + void Navigate(const GURL& url, mojom::NavigateParamsPtr params) override { last_navigated_url_ = url; if (navigation_callback_) navigation_callback_.Run(); @@ -76,6 +82,7 @@ class TestServiceDelegate : public ServiceDelegate { void WillDestroyServiceInstance(Service* service) override {} std::unique_ptr<NavigableContentsDelegate> CreateNavigableContentsDelegate( + const mojom::NavigableContentsParams& params, mojom::NavigableContentsClient* client) override { auto delegate = std::make_unique<TestNavigableContentsDelegate>(); if (navigable_contents_delegate_created_callback_) @@ -148,7 +155,7 @@ TEST_F(ContentServiceTest, NavigableContentsCreation) { navigation_loop.QuitClosure()); const GURL kTestUrl("https://example.com/"); - contents->Navigate(kTestUrl); + contents->Navigate(kTestUrl, mojom::NavigateParams::New()); navigation_loop.Run(); EXPECT_EQ(kTestUrl, navigable_contents_delegate->last_navigated_url()); diff --git a/chromium/services/data_decoder/BUILD.gn b/chromium/services/data_decoder/BUILD.gn index c1ed0e8ce4f..152e514b2a8 100644 --- a/chromium/services/data_decoder/BUILD.gn +++ b/chromium/services/data_decoder/BUILD.gn @@ -18,6 +18,8 @@ source_set("lib") { "xml_parser.h", ] + configs += [ "//build/config/compiler:wexit_time_destructors" ] + deps = [ "//base", "//mojo/public/cpp/bindings", diff --git a/chromium/services/data_decoder/OWNERS b/chromium/services/data_decoder/OWNERS index e86a4042186..804dc4b1eaa 100644 --- a/chromium/services/data_decoder/OWNERS +++ b/chromium/services/data_decoder/OWNERS @@ -1,6 +1,5 @@ per-file manifest.json=set noparent per-file manifest.json=file://ipc/SECURITY_OWNERS -bauerb@chromium.org jcivelli@chromium.org rsesek@chromium.org diff --git a/chromium/services/data_decoder/public/cpp/BUILD.gn b/chromium/services/data_decoder/public/cpp/BUILD.gn index fd01152275e..c4377b3fe27 100644 --- a/chromium/services/data_decoder/public/cpp/BUILD.gn +++ b/chromium/services/data_decoder/public/cpp/BUILD.gn @@ -19,6 +19,8 @@ source_set("cpp") { "safe_xml_parser.h", ] + configs += [ "//build/config/compiler:wexit_time_destructors" ] + public_deps = [ "//services/data_decoder/public/mojom", "//services/service_manager/public/cpp", diff --git a/chromium/services/data_decoder/public/cpp/android/java/src/org/chromium/services/data_decoder/JsonSanitizer.java b/chromium/services/data_decoder/public/cpp/android/java/src/org/chromium/services/data_decoder/JsonSanitizer.java new file mode 100644 index 00000000000..b3c4bfaa15f --- /dev/null +++ b/chromium/services/data_decoder/public/cpp/android/java/src/org/chromium/services/data_decoder/JsonSanitizer.java @@ -0,0 +1,195 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.services.data_decoder; + +import android.util.JsonReader; +import android.util.JsonToken; +import android.util.JsonWriter; +import android.util.MalformedJsonException; + +import org.chromium.base.StreamUtil; +import org.chromium.base.annotations.CalledByNative; +import org.chromium.base.annotations.JNINamespace; + +import java.io.IOException; +import java.io.StringReader; +import java.io.StringWriter; + +/** + * Sanitizes and normalizes a JSON string by parsing it, checking for wellformedness, and + * serializing it again. This class is meant to be used from native code. + */ +@JNINamespace("data_decoder") +public class JsonSanitizer { + // Disallow instantiating the class. + private JsonSanitizer() {} + + /** + * The maximum nesting depth to which the native JSON parser restricts input in order to avoid + * stack overflows. + */ + private static final int MAX_NESTING_DEPTH = 200; + + /** + * Validates input JSON string and returns the sanitized version of the string that's safe to + * parse. + * + * @param unsafeJson The input string to validate and sanitize. + * @return The sanitized version of the input string. + */ + public static String sanitize(String unsafeJson) throws IOException, IllegalStateException { + JsonReader reader = new JsonReader(new StringReader(unsafeJson)); + StringWriter stringWriter = new StringWriter(unsafeJson.length()); + JsonWriter writer = new JsonWriter(stringWriter); + StackChecker stackChecker = new StackChecker(); + String result = null; + try { + boolean end = false; + while (!end) { + JsonToken token = reader.peek(); + switch (token) { + case BEGIN_ARRAY: + stackChecker.increaseAndCheck(); + reader.beginArray(); + writer.beginArray(); + break; + case END_ARRAY: + stackChecker.decrease(); + reader.endArray(); + writer.endArray(); + break; + case BEGIN_OBJECT: + stackChecker.increaseAndCheck(); + reader.beginObject(); + writer.beginObject(); + break; + case END_OBJECT: + stackChecker.decrease(); + reader.endObject(); + writer.endObject(); + break; + case NAME: + writer.name(sanitizeString(reader.nextName())); + break; + case STRING: + writer.value(sanitizeString(reader.nextString())); + break; + case NUMBER: { + // Read the value as a string, then try to parse it first as a long, then as + // a double. + String value = reader.nextString(); + try { + writer.value(Long.parseLong(value)); + } catch (NumberFormatException e) { + writer.value(Double.parseDouble(value)); + } + break; + } + case BOOLEAN: + writer.value(reader.nextBoolean()); + break; + case NULL: + reader.nextNull(); + writer.nullValue(); + break; + case END_DOCUMENT: + end = true; + break; + default: + assert false : token; + } + } + result = stringWriter.toString(); + } finally { + StreamUtil.closeQuietly(reader); + StreamUtil.closeQuietly(writer); + } + return result; + } + + @CalledByNative + public static void sanitize(long nativePtr, String unsafeJson) { + String result = null; + try { + result = sanitize(unsafeJson); + } catch (IOException | IllegalStateException e) { + nativeOnError(nativePtr, e.getMessage()); + return; + } + nativeOnSuccess(nativePtr, result); + } + + /** + * Helper class to check nesting depth of JSON expressions. + */ + private static class StackChecker { + private int mStackDepth; + + public void increaseAndCheck() { + if (++mStackDepth >= MAX_NESTING_DEPTH) { + throw new IllegalStateException("Too much nesting"); + } + } + + public void decrease() { + mStackDepth--; + } + } + + private static String sanitizeString(String string) throws MalformedJsonException { + if (!checkString(string)) { + throw new MalformedJsonException("Invalid escape sequence"); + } + return string; + } + + /** + * Checks whether a given String is well-formed UTF-16, i.e. all surrogates appear in high-low + * pairs and each code point is a valid character. + * + * @param string The string to check. + * @return Whether the given string is well-formed UTF-16. + */ + private static boolean checkString(String string) { + int length = string.length(); + for (int i = 0; i < length; i++) { + char c = string.charAt(i); + // Check that surrogates only appear in pairs of a high surrogate followed by a low + // surrogate. + // A lone low surrogate is not allowed. + if (Character.isLowSurrogate(c)) return false; + + int codePoint; + if (Character.isHighSurrogate(c)) { + // A high surrogate has to be followed by a low surrogate. + char high = c; + if (++i >= length) return false; + + char low = string.charAt(i); + if (!Character.isLowSurrogate(low)) return false; + + // Decode the high-low pair into a code point. + codePoint = Character.toCodePoint(high, low); + } else { + // The code point is neither a low surrogate nor a high surrogate, so we just need + // to check that it's a valid character. + codePoint = c; + } + + if (!isUnicodeCharacter(codePoint)) return false; + } + return true; + } + + private static boolean isUnicodeCharacter(int codePoint) { + // See the native method base::IsValidCharacter(). + return codePoint < 0xD800 || (codePoint >= 0xE000 && codePoint < 0xFDD0) + || (codePoint > 0xFDEF && codePoint <= 0x10FFFF && (codePoint & 0xFFFE) != 0xFFFE); + } + + private static native void nativeOnSuccess(long id, String json); + + private static native void nativeOnError(long id, String error); +} diff --git a/chromium/services/device/BUILD.gn b/chromium/services/device/BUILD.gn index 83dce91ec81..38c9e43e45f 100644 --- a/chromium/services/device/BUILD.gn +++ b/chromium/services/device/BUILD.gn @@ -25,8 +25,12 @@ source_set("lib") { "device_service.h", ] + configs += [ "//build/config/compiler:wexit_time_destructors" ] + deps = [ "//base", + "//device/usb/mojo", + "//device/usb/public/mojom", "//services/device/fingerprint", "//services/device/generic_sensor", "//services/device/geolocation", @@ -51,7 +55,10 @@ source_set("lib") { } if (is_chromeos && use_dbus) { - deps += [ "//services/device/media_transfer_protocol" ] + deps += [ + "//services/device/bluetooth:bluetooth_system", + "//services/device/media_transfer_protocol", + ] } if (is_serial_enabled_platform) { @@ -140,6 +147,7 @@ source_set("tests") { deps += [ "//chromeos", "//dbus", + "//services/device/bluetooth:bluetooth_system_tests", "//services/device/fingerprint", "//third_party/protobuf:protobuf_lite", ] diff --git a/chromium/services/device/android/java/src/org/chromium/services/device/InterfaceRegistrar.java b/chromium/services/device/android/java/src/org/chromium/services/device/InterfaceRegistrar.java new file mode 100644 index 00000000000..7f447d8d958 --- /dev/null +++ b/chromium/services/device/android/java/src/org/chromium/services/device/InterfaceRegistrar.java @@ -0,0 +1,32 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.services.device; + +import org.chromium.base.annotations.CalledByNative; +import org.chromium.base.annotations.JNINamespace; +import org.chromium.device.battery.BatteryMonitorFactory; +import org.chromium.device.mojom.BatteryMonitor; +import org.chromium.device.mojom.NfcProvider; +import org.chromium.device.mojom.VibrationManager; +import org.chromium.device.nfc.NfcDelegate; +import org.chromium.device.nfc.NfcProviderImpl; +import org.chromium.device.vibration.VibrationManagerImpl; +import org.chromium.mojo.system.impl.CoreImpl; +import org.chromium.services.service_manager.InterfaceRegistry; + +@JNINamespace("device") +class InterfaceRegistrar { + @CalledByNative + static void createInterfaceRegistryForContext( + int nativeHandle, NfcDelegate nfcDelegate) { + // Note: The bindings code manages the lifetime of this object, so it + // is not necessary to hold on to a reference to it explicitly. + InterfaceRegistry registry = InterfaceRegistry.create( + CoreImpl.getInstance().acquireNativeHandle(nativeHandle).toMessagePipeHandle()); + registry.addInterface(BatteryMonitor.MANAGER, new BatteryMonitorFactory()); + registry.addInterface(NfcProvider.MANAGER, new NfcProviderImpl.Factory(nfcDelegate)); + registry.addInterface(VibrationManager.MANAGER, new VibrationManagerImpl.Factory()); + } +} diff --git a/chromium/services/device/battery/android/java/src/org/chromium/device/battery/BatteryMonitorFactory.java b/chromium/services/device/battery/android/java/src/org/chromium/device/battery/BatteryMonitorFactory.java new file mode 100644 index 00000000000..f90a513037f --- /dev/null +++ b/chromium/services/device/battery/android/java/src/org/chromium/device/battery/BatteryMonitorFactory.java @@ -0,0 +1,71 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.device.battery; + +import org.chromium.base.Log; +import org.chromium.base.ThreadUtils; +import org.chromium.device.battery.BatteryStatusManager.BatteryStatusCallback; +import org.chromium.device.mojom.BatteryMonitor; +import org.chromium.device.mojom.BatteryStatus; +import org.chromium.services.service_manager.InterfaceFactory; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; + +/** + * Factory that creates instances of BatteryMonitor implementations and notifies them about battery + * status changes. + */ +public class BatteryMonitorFactory implements InterfaceFactory<BatteryMonitor> { + private static final String TAG = "BattMonitorFactory"; + + // Backing source of battery information. + private final BatteryStatusManager mManager; + // Monitors currently interested in the battery status notifications. + private final HashSet<BatteryMonitorImpl> mSubscribedMonitors = + new HashSet<BatteryMonitorImpl>(); + + private final BatteryStatusCallback mCallback = new BatteryStatusCallback() { + @Override + public void onBatteryStatusChanged(BatteryStatus batteryStatus) { + ThreadUtils.assertOnUiThread(); + + List<BatteryMonitorImpl> monitors = new ArrayList<>(mSubscribedMonitors); + for (BatteryMonitorImpl monitor : monitors) { + monitor.didChange(batteryStatus); + } + } + }; + + public BatteryMonitorFactory() { + mManager = new BatteryStatusManager(mCallback); + } + + @Override + public BatteryMonitor createImpl() { + ThreadUtils.assertOnUiThread(); + + if (mSubscribedMonitors.isEmpty() && !mManager.start()) { + Log.e(TAG, "BatteryStatusManager failed to start."); + } + // TODO(ppi): record the "BatteryStatus.StartAndroid" histogram here once we have a Java API + // for UMA - http://crbug.com/442300. + + BatteryMonitorImpl monitor = new BatteryMonitorImpl(this); + mSubscribedMonitors.add(monitor); + return monitor; + } + + void unsubscribe(BatteryMonitorImpl monitor) { + ThreadUtils.assertOnUiThread(); + + assert mSubscribedMonitors.contains(monitor); + mSubscribedMonitors.remove(monitor); + if (mSubscribedMonitors.isEmpty()) { + mManager.stop(); + } + } +} diff --git a/chromium/services/device/battery/android/java/src/org/chromium/device/battery/BatteryMonitorImpl.java b/chromium/services/device/battery/android/java/src/org/chromium/device/battery/BatteryMonitorImpl.java new file mode 100644 index 00000000000..7f44e3335a4 --- /dev/null +++ b/chromium/services/device/battery/android/java/src/org/chromium/device/battery/BatteryMonitorImpl.java @@ -0,0 +1,78 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.device.battery; + +import org.chromium.base.Log; +import org.chromium.device.mojom.BatteryMonitor; +import org.chromium.device.mojom.BatteryStatus; +import org.chromium.mojo.system.MojoException; + +/** + * Android implementation of the battery monitor interface defined in + * services/device/public/mojom/battery_monitor.mojom. + */ +public class BatteryMonitorImpl implements BatteryMonitor { + private static final String TAG = "BatteryMonitorImpl"; + + // Factory that created this instance and notifies it about battery status changes. + private final BatteryMonitorFactory mFactory; + private QueryNextStatusResponse mCallback; + private BatteryStatus mStatus; + private boolean mHasStatusToReport; + private boolean mSubscribed; + + public BatteryMonitorImpl(BatteryMonitorFactory batteryMonitorFactory) { + mFactory = batteryMonitorFactory; + mHasStatusToReport = false; + mSubscribed = true; + } + + private void unsubscribe() { + if (mSubscribed) { + mFactory.unsubscribe(this); + mSubscribed = false; + } + } + + @Override + public void close() { + unsubscribe(); + } + + @Override + public void onConnectionError(MojoException e) { + unsubscribe(); + } + + @Override + public void queryNextStatus(QueryNextStatusResponse callback) { + if (mCallback != null) { + Log.e(TAG, "Overlapped call to queryNextStatus!"); + unsubscribe(); + return; + } + + mCallback = callback; + + if (mHasStatusToReport) { + reportStatus(); + } + } + + void didChange(BatteryStatus batteryStatus) { + mStatus = batteryStatus; + mHasStatusToReport = true; + + if (mCallback != null) { + reportStatus(); + } + } + + void reportStatus() { + mCallback.call(mStatus); + mCallback = null; + mHasStatusToReport = false; + } +} diff --git a/chromium/services/device/battery/android/java/src/org/chromium/device/battery/BatteryStatusManager.java b/chromium/services/device/battery/android/java/src/org/chromium/device/battery/BatteryStatusManager.java new file mode 100644 index 00000000000..36fa24f5e98 --- /dev/null +++ b/chromium/services/device/battery/android/java/src/org/chromium/device/battery/BatteryStatusManager.java @@ -0,0 +1,192 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.device.battery; + +import android.annotation.TargetApi; +import android.content.BroadcastReceiver; +import android.content.Context; +import android.content.Intent; +import android.content.IntentFilter; +import android.os.BatteryManager; +import android.os.Build; + +import org.chromium.base.ContextUtils; +import org.chromium.base.Log; +import org.chromium.base.VisibleForTesting; +import org.chromium.device.mojom.BatteryStatus; + +import javax.annotation.Nullable; + +/** + * Data source for battery status information. This class registers for battery status notifications + * from the system and calls the callback passed on construction whenever a notification is + * received. + */ +class BatteryStatusManager { + private static final String TAG = "BatteryStatusManager"; + + interface BatteryStatusCallback { + void onBatteryStatusChanged(BatteryStatus batteryStatus); + } + + private final BatteryStatusCallback mCallback; + private final IntentFilter mFilter = new IntentFilter(Intent.ACTION_BATTERY_CHANGED); + private final BroadcastReceiver mReceiver = new BroadcastReceiver() { + @Override + public void onReceive(Context context, Intent intent) { + BatteryStatusManager.this.onReceive(intent); + } + }; + + // This is to workaround a Galaxy Nexus bug, see the comment in the constructor. + private final boolean mIgnoreBatteryPresentState; + + // Only used in L (API level 21) and higher. + private AndroidBatteryManagerWrapper mAndroidBatteryManager; + + private boolean mEnabled; + + @VisibleForTesting + static class AndroidBatteryManagerWrapper { + private final BatteryManager mBatteryManager; + + protected AndroidBatteryManagerWrapper(BatteryManager batteryManager) { + mBatteryManager = batteryManager; + } + + @TargetApi(Build.VERSION_CODES.LOLLIPOP) + public int getIntProperty(int id) { + return mBatteryManager.getIntProperty(id); + } + } + + private BatteryStatusManager(BatteryStatusCallback callback, boolean ignoreBatteryPresentState, + @Nullable AndroidBatteryManagerWrapper batteryManager) { + mCallback = callback; + mIgnoreBatteryPresentState = ignoreBatteryPresentState; + mAndroidBatteryManager = batteryManager; + } + + BatteryStatusManager(BatteryStatusCallback callback) { + // BatteryManager.EXTRA_PRESENT appears to be unreliable on Galaxy Nexus, + // Android 4.2.1, it always reports false. See http://crbug.com/384348. + this(callback, Build.MODEL.equals("Galaxy Nexus"), + Build.VERSION.SDK_INT >= Build.VERSION_CODES.LOLLIPOP + ? new AndroidBatteryManagerWrapper( + (BatteryManager) ContextUtils.getApplicationContext() + .getSystemService(Context.BATTERY_SERVICE)) + : null); + } + + /** + * Creates a BatteryStatusManager without the Galaxy Nexus workaround for consistency in + * testing. + */ + static BatteryStatusManager createBatteryStatusManagerForTesting(Context context, + BatteryStatusCallback callback, @Nullable AndroidBatteryManagerWrapper batteryManager) { + return new BatteryStatusManager(callback, false, batteryManager); + } + + /** + * Starts listening for intents. + * @return True on success. + */ + boolean start() { + if (!mEnabled + && ContextUtils.getApplicationContext().registerReceiver(mReceiver, mFilter) + != null) { + // success + mEnabled = true; + } + return mEnabled; + } + + /** + * Stops listening to intents. + */ + void stop() { + if (mEnabled) { + ContextUtils.getApplicationContext().unregisterReceiver(mReceiver); + mEnabled = false; + } + } + + @VisibleForTesting + void onReceive(Intent intent) { + if (!intent.getAction().equals(Intent.ACTION_BATTERY_CHANGED)) { + Log.e(TAG, "Unexpected intent."); + return; + } + + boolean present = mIgnoreBatteryPresentState + ? true + : intent.getBooleanExtra(BatteryManager.EXTRA_PRESENT, false); + int pluggedStatus = intent.getIntExtra(BatteryManager.EXTRA_PLUGGED, -1); + + if (!present || pluggedStatus == -1) { + // No battery or no plugged status: return default values. + mCallback.onBatteryStatusChanged(new BatteryStatus()); + return; + } + + int current = intent.getIntExtra(BatteryManager.EXTRA_LEVEL, -1); + int max = intent.getIntExtra(BatteryManager.EXTRA_SCALE, -1); + double level = (double) current / (double) max; + if (level < 0 || level > 1) { + // Sanity check, assume default value in this case. + level = 1.0; + } + + // Currently Android (below L) does not provide charging/discharging time, as a work-around + // we could compute it manually based on the evolution of level delta. + // TODO(timvolodine): add proper projection for chargingTime, dischargingTime + // (see crbug.com/401553). + boolean charging = pluggedStatus != 0; + int status = intent.getIntExtra(BatteryManager.EXTRA_STATUS, -1); + boolean batteryFull = status == BatteryManager.BATTERY_STATUS_FULL; + double chargingTimeSeconds = (charging && batteryFull) ? 0 : Double.POSITIVE_INFINITY; + double dischargingTimeSeconds = Double.POSITIVE_INFINITY; + + BatteryStatus batteryStatus = new BatteryStatus(); + batteryStatus.charging = charging; + batteryStatus.chargingTime = chargingTimeSeconds; + batteryStatus.dischargingTime = dischargingTimeSeconds; + batteryStatus.level = level; + + if (mAndroidBatteryManager != null) { + updateBatteryStatusForLollipop(batteryStatus); + } + + mCallback.onBatteryStatusChanged(batteryStatus); + } + + private void updateBatteryStatusForLollipop(BatteryStatus batteryStatus) { + assert mAndroidBatteryManager != null; + + // On Lollipop we can provide a better estimate for chargingTime and dischargingTime. + double remainingCapacityRatio = + mAndroidBatteryManager.getIntProperty(BatteryManager.BATTERY_PROPERTY_CAPACITY) + / 100.0; + double batteryCapacityMicroAh = mAndroidBatteryManager.getIntProperty( + BatteryManager.BATTERY_PROPERTY_CHARGE_COUNTER); + double averageCurrentMicroA = mAndroidBatteryManager.getIntProperty( + BatteryManager.BATTERY_PROPERTY_CURRENT_AVERAGE); + + if (batteryStatus.charging) { + if (batteryStatus.chargingTime == Double.POSITIVE_INFINITY + && averageCurrentMicroA > 0) { + double chargeFromEmptyHours = batteryCapacityMicroAh / averageCurrentMicroA; + batteryStatus.chargingTime = + Math.ceil((1 - remainingCapacityRatio) * chargeFromEmptyHours * 3600.0); + } + } else { + if (averageCurrentMicroA < 0) { + double dischargeFromFullHours = batteryCapacityMicroAh / -averageCurrentMicroA; + batteryStatus.dischargingTime = + Math.floor(remainingCapacityRatio * dischargeFromFullHours * 3600.0); + } + } + } +} diff --git a/chromium/services/device/battery/android/javatests/src/org/chromium/device/battery/BatteryStatusManagerTest.java b/chromium/services/device/battery/android/javatests/src/org/chromium/device/battery/BatteryStatusManagerTest.java new file mode 100644 index 00000000000..03bd067132e --- /dev/null +++ b/chromium/services/device/battery/android/javatests/src/org/chromium/device/battery/BatteryStatusManagerTest.java @@ -0,0 +1,252 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.device.battery; + +import android.content.Intent; +import android.os.BatteryManager; +import android.os.Build; +import android.support.test.InstrumentationRegistry; +import android.support.test.filters.SmallTest; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.chromium.base.test.BaseJUnit4ClassRunner; +import org.chromium.device.mojom.BatteryStatus; + +/** + * Test suite for BatteryStatusManager. + */ +@RunWith(BaseJUnit4ClassRunner.class) +public class BatteryStatusManagerTest { + // Values reported in the most recent callback from |mManager|. + private boolean mCharging = false; + private double mChargingTime = 0; + private double mDischargingTime = 0; + private double mLevel = 0; + + private BatteryStatusManager.BatteryStatusCallback mCallback = + new BatteryStatusManager.BatteryStatusCallback() { + @Override + public void onBatteryStatusChanged(BatteryStatus batteryStatus) { + mCharging = batteryStatus.charging; + mChargingTime = batteryStatus.chargingTime; + mDischargingTime = batteryStatus.dischargingTime; + mLevel = batteryStatus.level; + } + }; + + private BatteryStatusManager mManager; + + private void verifyValues( + boolean charging, double chargingTime, double dischargingTime, double level) { + Assert.assertEquals(charging, mCharging); + Assert.assertEquals(chargingTime, mChargingTime); + Assert.assertEquals(dischargingTime, mDischargingTime); + Assert.assertEquals(level, mLevel); + } + + private static class FakeAndroidBatteryManager + extends BatteryStatusManager.AndroidBatteryManagerWrapper { + private int mChargeCounter; + private int mCapacity; + private int mAverageCurrent; + + private FakeAndroidBatteryManager() { + super(null); + } + + @Override + public int getIntProperty(int id) { + switch (id) { + case BatteryManager.BATTERY_PROPERTY_CHARGE_COUNTER: + return mChargeCounter; + case BatteryManager.BATTERY_PROPERTY_CAPACITY: + return mCapacity; + case BatteryManager.BATTERY_PROPERTY_CURRENT_AVERAGE: + return mAverageCurrent; + } + Assert.fail(); + return 0; + } + + public FakeAndroidBatteryManager setIntProperty(int id, int value) { + switch (id) { + case BatteryManager.BATTERY_PROPERTY_CHARGE_COUNTER: + mChargeCounter = value; + return this; + case BatteryManager.BATTERY_PROPERTY_CAPACITY: + mCapacity = value; + return this; + case BatteryManager.BATTERY_PROPERTY_CURRENT_AVERAGE: + mAverageCurrent = value; + return this; + } + Assert.fail(); + return this; + } + } + + @Before + public void setUp() throws Exception { + initializeBatteryManager(null); + } + + public void initializeBatteryManager(FakeAndroidBatteryManager managerForTesting) { + mManager = BatteryStatusManager.createBatteryStatusManagerForTesting( + InstrumentationRegistry.getContext(), mCallback, managerForTesting); + } + + @Test + @SmallTest + public void testOnReceiveBatteryNotPluggedIn() { + Intent intent = new Intent(Intent.ACTION_BATTERY_CHANGED); + intent.putExtra(BatteryManager.EXTRA_PRESENT, true); + intent.putExtra(BatteryManager.EXTRA_PLUGGED, 0); + intent.putExtra(BatteryManager.EXTRA_LEVEL, 10); + intent.putExtra(BatteryManager.EXTRA_SCALE, 100); + intent.putExtra(BatteryManager.EXTRA_STATUS, BatteryManager.BATTERY_STATUS_NOT_CHARGING); + + mManager.onReceive(intent); + verifyValues(false, Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY, 0.1); + } + + @Test + @SmallTest + public void testOnReceiveBatteryPluggedInACCharging() { + Intent intent = new Intent(Intent.ACTION_BATTERY_CHANGED); + intent.putExtra(BatteryManager.EXTRA_PRESENT, true); + intent.putExtra(BatteryManager.EXTRA_PLUGGED, BatteryManager.BATTERY_PLUGGED_AC); + intent.putExtra(BatteryManager.EXTRA_LEVEL, 50); + intent.putExtra(BatteryManager.EXTRA_SCALE, 100); + intent.putExtra(BatteryManager.EXTRA_STATUS, BatteryManager.BATTERY_STATUS_CHARGING); + + mManager.onReceive(intent); + verifyValues(true, Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY, 0.5); + } + + @Test + @SmallTest + public void testOnReceiveBatteryPluggedInACNotCharging() { + Intent intent = new Intent(Intent.ACTION_BATTERY_CHANGED); + intent.putExtra(BatteryManager.EXTRA_PRESENT, true); + intent.putExtra(BatteryManager.EXTRA_PLUGGED, BatteryManager.BATTERY_PLUGGED_AC); + intent.putExtra(BatteryManager.EXTRA_LEVEL, 50); + intent.putExtra(BatteryManager.EXTRA_SCALE, 100); + intent.putExtra(BatteryManager.EXTRA_STATUS, BatteryManager.BATTERY_STATUS_NOT_CHARGING); + + mManager.onReceive(intent); + verifyValues(true, Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY, 0.5); + } + + @Test + @SmallTest + public void testOnReceiveBatteryPluggedInUSBFull() { + Intent intent = new Intent(Intent.ACTION_BATTERY_CHANGED); + intent.putExtra(BatteryManager.EXTRA_PRESENT, true); + intent.putExtra(BatteryManager.EXTRA_PLUGGED, BatteryManager.BATTERY_PLUGGED_USB); + intent.putExtra(BatteryManager.EXTRA_LEVEL, 100); + intent.putExtra(BatteryManager.EXTRA_SCALE, 100); + intent.putExtra(BatteryManager.EXTRA_STATUS, BatteryManager.BATTERY_STATUS_FULL); + + mManager.onReceive(intent); + verifyValues(true, 0, Double.POSITIVE_INFINITY, 1); + } + + @Test + @SmallTest + public void testOnReceiveNoBattery() { + Intent intent = new Intent(Intent.ACTION_BATTERY_CHANGED); + intent.putExtra(BatteryManager.EXTRA_PRESENT, false); + intent.putExtra(BatteryManager.EXTRA_PLUGGED, BatteryManager.BATTERY_PLUGGED_USB); + + mManager.onReceive(intent); + verifyValues(true, 0, Double.POSITIVE_INFINITY, 1); + } + + @Test + @SmallTest + public void testOnReceiveNoPluggedStatus() { + Intent intent = new Intent(Intent.ACTION_BATTERY_CHANGED); + intent.putExtra(BatteryManager.EXTRA_PRESENT, true); + + mManager.onReceive(intent); + verifyValues(true, 0, Double.POSITIVE_INFINITY, 1); + } + + @Test + @SmallTest + public void testStartStopSucceeds() { + Assert.assertTrue(mManager.start()); + mManager.stop(); + } + + @Test + @SmallTest + public void testLollipopChargingTimeEstimate() { + if (Build.VERSION.SDK_INT < Build.VERSION_CODES.LOLLIPOP) return; + + Intent intent = new Intent(Intent.ACTION_BATTERY_CHANGED); + intent.putExtra(BatteryManager.EXTRA_PRESENT, true); + intent.putExtra(BatteryManager.EXTRA_PLUGGED, BatteryManager.BATTERY_PLUGGED_USB); + intent.putExtra(BatteryManager.EXTRA_LEVEL, 50); + intent.putExtra(BatteryManager.EXTRA_SCALE, 100); + + initializeBatteryManager( + new FakeAndroidBatteryManager() + .setIntProperty(BatteryManager.BATTERY_PROPERTY_CHARGE_COUNTER, 1000) + .setIntProperty(BatteryManager.BATTERY_PROPERTY_CAPACITY, 50) + .setIntProperty(BatteryManager.BATTERY_PROPERTY_CURRENT_AVERAGE, 100)); + + mManager.onReceive(intent); + verifyValues(true, 0.5 * 10 * 3600, Double.POSITIVE_INFINITY, 0.5); + } + + @Test + @SmallTest + public void testLollipopDischargingTimeEstimate() { + if (Build.VERSION.SDK_INT < Build.VERSION_CODES.LOLLIPOP) return; + + Intent intent = new Intent(Intent.ACTION_BATTERY_CHANGED); + intent.putExtra(BatteryManager.EXTRA_PRESENT, true); + intent.putExtra(BatteryManager.EXTRA_PLUGGED, 0); + intent.putExtra(BatteryManager.EXTRA_LEVEL, 60); + intent.putExtra(BatteryManager.EXTRA_SCALE, 100); + intent.putExtra(BatteryManager.EXTRA_STATUS, BatteryManager.BATTERY_STATUS_NOT_CHARGING); + + initializeBatteryManager( + new FakeAndroidBatteryManager() + .setIntProperty(BatteryManager.BATTERY_PROPERTY_CHARGE_COUNTER, 1000) + .setIntProperty(BatteryManager.BATTERY_PROPERTY_CAPACITY, 60) + .setIntProperty(BatteryManager.BATTERY_PROPERTY_CURRENT_AVERAGE, -100)); + + mManager.onReceive(intent); + verifyValues(false, Double.POSITIVE_INFINITY, 0.6 * 10 * 3600, 0.6); + } + + @Test + @SmallTest + public void testLollipopDischargingTimeEstimateRounding() { + if (Build.VERSION.SDK_INT < Build.VERSION_CODES.LOLLIPOP) return; + + Intent intent = new Intent(Intent.ACTION_BATTERY_CHANGED); + intent.putExtra(BatteryManager.EXTRA_PRESENT, true); + intent.putExtra(BatteryManager.EXTRA_PLUGGED, 0); + intent.putExtra(BatteryManager.EXTRA_LEVEL, 90); + intent.putExtra(BatteryManager.EXTRA_SCALE, 100); + intent.putExtra(BatteryManager.EXTRA_STATUS, BatteryManager.BATTERY_STATUS_NOT_CHARGING); + + initializeBatteryManager( + new FakeAndroidBatteryManager() + .setIntProperty(BatteryManager.BATTERY_PROPERTY_CHARGE_COUNTER, 1999) + .setIntProperty(BatteryManager.BATTERY_PROPERTY_CAPACITY, 90) + .setIntProperty(BatteryManager.BATTERY_PROPERTY_CURRENT_AVERAGE, -1000)); + + mManager.onReceive(intent); + verifyValues(false, Double.POSITIVE_INFINITY, Math.floor(0.9 * 1.999 * 3600), 0.9); + } +} diff --git a/chromium/services/device/bluetooth/BUILD.gn b/chromium/services/device/bluetooth/BUILD.gn new file mode 100644 index 00000000000..e598c5f51c3 --- /dev/null +++ b/chromium/services/device/bluetooth/BUILD.gn @@ -0,0 +1,46 @@ +# Copyright 2018 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +import("//build/config/features.gni") + +source_set("bluetooth_system") { + visibility = [ + "//services/device:lib", + "//services/device/bluetooth:bluetooth_system_tests", + ] + + sources = [ + "bluetooth_system.cc", + "bluetooth_system.h", + "bluetooth_system_factory.cc", + "bluetooth_system_factory.h", + ] + + public_deps = [ + "//services/device/public/mojom", + ] + + deps = [ + "//base", + "//dbus", + "//device/bluetooth", + ] +} + +source_set("bluetooth_system_tests") { + testonly = true + + sources = [ + "bluetooth_system_unittest.cc", + ] + + deps = [ + ":bluetooth_system", + "//dbus", + "//device/bluetooth", + "//net", + "//services/device:test_support", + "//testing/gtest", + ] +} diff --git a/chromium/services/device/bluetooth/DEPS b/chromium/services/device/bluetooth/DEPS new file mode 100644 index 00000000000..18af175ef38 --- /dev/null +++ b/chromium/services/device/bluetooth/DEPS @@ -0,0 +1,9 @@ +include_rules = [ + "+dbus", +] + +specific_include_rules = { + "bluetooth_system_unittest.cc": [ + "+third_party/cros_system_api/dbus/service_constants.h" + ], +} diff --git a/chromium/services/device/bluetooth/bluetooth_system.cc b/chromium/services/device/bluetooth/bluetooth_system.cc new file mode 100644 index 00000000000..f9464ed5e79 --- /dev/null +++ b/chromium/services/device/bluetooth/bluetooth_system.cc @@ -0,0 +1,110 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "services/device/bluetooth/bluetooth_system.h" + +#include <memory> +#include <utility> +#include <vector> + +#include "dbus/object_path.h" +#include "device/bluetooth/dbus/bluetooth_adapter_client.h" +#include "device/bluetooth/dbus/bluez_dbus_manager.h" +#include "mojo/public/cpp/bindings/strong_binding.h" + +namespace device { + +void BluetoothSystem::Create(mojom::BluetoothSystemRequest request, + mojom::BluetoothSystemClientPtr client) { + mojo::MakeStrongBinding(std::make_unique<BluetoothSystem>(std::move(client)), + std::move(request)); +} + +BluetoothSystem::BluetoothSystem(mojom::BluetoothSystemClientPtr client) { + client_ptr_ = std::move(client); + GetBluetoothAdapterClient()->AddObserver(this); + + std::vector<dbus::ObjectPath> object_paths = + GetBluetoothAdapterClient()->GetAdapters(); + if (object_paths.empty()) + return; + + active_adapter_ = object_paths[0]; + auto* properties = + GetBluetoothAdapterClient()->GetProperties(active_adapter_.value()); + state_ = properties->powered.value() ? State::kPoweredOn : State::kPoweredOff; +} + +BluetoothSystem::~BluetoothSystem() = default; + +void BluetoothSystem::AdapterAdded(const dbus::ObjectPath& object_path) { + if (active_adapter_) + return; + + active_adapter_ = object_path; + UpdateStateAndNotifyIfNecessary(); +} + +void BluetoothSystem::AdapterRemoved(const dbus::ObjectPath& object_path) { + DCHECK(active_adapter_); + + if (active_adapter_.value() != object_path) + return; + + active_adapter_ = base::nullopt; + + std::vector<dbus::ObjectPath> object_paths = + GetBluetoothAdapterClient()->GetAdapters(); + for (const auto& new_object_path : object_paths) { + // The removed adapter is still included in GetAdapters(). + if (new_object_path == object_path) + continue; + + active_adapter_ = new_object_path; + break; + } + + UpdateStateAndNotifyIfNecessary(); +} + +void BluetoothSystem::AdapterPropertyChanged( + const dbus::ObjectPath& object_path, + const std::string& property_name) { + DCHECK(active_adapter_); + if (active_adapter_.value() != object_path) + return; + + auto* properties = + GetBluetoothAdapterClient()->GetProperties(active_adapter_.value()); + + if (properties->powered.name() == property_name) + UpdateStateAndNotifyIfNecessary(); +} + +void BluetoothSystem::GetState(GetStateCallback callback) { + std::move(callback).Run(state_); +} + +bluez::BluetoothAdapterClient* BluetoothSystem::GetBluetoothAdapterClient() { + // Use AlternateBluetoothAdapterClient to avoid interfering with users of the + // regular BluetoothAdapterClient. + return bluez::BluezDBusManager::Get()->GetAlternateBluetoothAdapterClient(); +} + +void BluetoothSystem::UpdateStateAndNotifyIfNecessary() { + State old_state = state_; + if (active_adapter_) { + auto* properties = + GetBluetoothAdapterClient()->GetProperties(active_adapter_.value()); + state_ = + properties->powered.value() ? State::kPoweredOn : State::kPoweredOff; + } else { + state_ = State::kUnavailable; + } + + if (old_state != state_) + client_ptr_->OnStateChanged(state_); +} + +} // namespace device diff --git a/chromium/services/device/bluetooth/bluetooth_system.h b/chromium/services/device/bluetooth/bluetooth_system.h new file mode 100644 index 00000000000..546ab10e745 --- /dev/null +++ b/chromium/services/device/bluetooth/bluetooth_system.h @@ -0,0 +1,58 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef SERVICES_DEVICE_BLUETOOTH_BLUETOOTH_SYSTEM_H_ +#define SERVICES_DEVICE_BLUETOOTH_BLUETOOTH_SYSTEM_H_ + +#include "base/macros.h" +#include "base/optional.h" +#include "dbus/object_path.h" +#include "device/bluetooth/dbus/bluetooth_adapter_client.h" +#include "services/device/public/mojom/bluetooth_system.mojom.h" + +namespace bluez { +class BluetoothAdapterClient; +} + +namespace device { + +class BluetoothSystem : public mojom::BluetoothSystem, + public bluez::BluetoothAdapterClient::Observer { + public: + static void Create(mojom::BluetoothSystemRequest request, + mojom::BluetoothSystemClientPtr client); + + explicit BluetoothSystem(mojom::BluetoothSystemClientPtr client); + ~BluetoothSystem() override; + + // bluez::BluetoothAdapterClient::Observer + void AdapterAdded(const dbus::ObjectPath& object_path) override; + void AdapterRemoved(const dbus::ObjectPath& object_path) override; + void AdapterPropertyChanged(const dbus::ObjectPath& object_path, + const std::string& property_name) override; + + // mojom::BluetoothSystem + void GetState(GetStateCallback callback) override; + + private: + bluez::BluetoothAdapterClient* GetBluetoothAdapterClient(); + + void UpdateStateAndNotifyIfNecessary(); + + mojom::BluetoothSystemClientPtr client_ptr_; + + // The ObjectPath of the adapter being used. Updated as BT adapters are + // added and removed. nullopt if there is no adapter. + base::Optional<dbus::ObjectPath> active_adapter_; + + // State of |active_adapter_| or kUnavailable if there is no + // |active_adapter_|. + State state_ = State::kUnavailable; + + DISALLOW_COPY_AND_ASSIGN(BluetoothSystem); +}; + +} // namespace device + +#endif // SERVICES_DEVICE_BLUETOOTH_BLUETOOTH_SYSTEM_H_ diff --git a/chromium/services/device/bluetooth/bluetooth_system_factory.cc b/chromium/services/device/bluetooth/bluetooth_system_factory.cc new file mode 100644 index 00000000000..47f731c9730 --- /dev/null +++ b/chromium/services/device/bluetooth/bluetooth_system_factory.cc @@ -0,0 +1,31 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "services/device/bluetooth/bluetooth_system_factory.h" + +#include <memory> +#include <utility> + +#include "mojo/public/cpp/bindings/strong_binding.h" +#include "services/device/bluetooth/bluetooth_system.h" + +namespace device { + +void BluetoothSystemFactory::CreateFactory( + mojom::BluetoothSystemFactoryRequest request) { + mojo::MakeStrongBinding(std::make_unique<BluetoothSystemFactory>(), + std::move(request)); +} + +BluetoothSystemFactory::BluetoothSystemFactory() = default; + +BluetoothSystemFactory::~BluetoothSystemFactory() = default; + +void BluetoothSystemFactory::Create( + mojom::BluetoothSystemRequest system_request, + mojom::BluetoothSystemClientPtr system_client) { + BluetoothSystem::Create(std::move(system_request), std::move(system_client)); +} + +} // namespace device diff --git a/chromium/services/device/bluetooth/bluetooth_system_factory.h b/chromium/services/device/bluetooth/bluetooth_system_factory.h new file mode 100644 index 00000000000..05c4b8216cc --- /dev/null +++ b/chromium/services/device/bluetooth/bluetooth_system_factory.h @@ -0,0 +1,30 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef SERVICES_DEVICE_BLUETOOTH_BLUETOOTH_SYSTEM_FACTORY_H_ +#define SERVICES_DEVICE_BLUETOOTH_BLUETOOTH_SYSTEM_FACTORY_H_ + +#include "base/macros.h" +#include "services/device/public/mojom/bluetooth_system.mojom.h" + +namespace device { + +class BluetoothSystemFactory : public mojom::BluetoothSystemFactory { + public: + static void CreateFactory(mojom::BluetoothSystemFactoryRequest request); + + BluetoothSystemFactory(); + ~BluetoothSystemFactory() override; + + // mojom::BluetoothSystemFactory + void Create(mojom::BluetoothSystemRequest system_request, + mojom::BluetoothSystemClientPtr system_client) override; + + private: + DISALLOW_COPY_AND_ASSIGN(BluetoothSystemFactory); +}; + +} // namespace device + +#endif // SERVICES_DEVICE_BLUETOOTH_BLUETOOTH_SYSTEM_FACTORY_H_ diff --git a/chromium/services/device/bluetooth/bluetooth_system_unittest.cc b/chromium/services/device/bluetooth/bluetooth_system_unittest.cc new file mode 100644 index 00000000000..c9617d73fd4 --- /dev/null +++ b/chromium/services/device/bluetooth/bluetooth_system_unittest.cc @@ -0,0 +1,670 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "services/device/bluetooth/bluetooth_system.h" + +#include <utility> + +#include "base/observer_list.h" +#include "base/run_loop.h" +#include "base/strings/stringprintf.h" +#include "device/bluetooth/dbus/bluetooth_adapter_client.h" +#include "device/bluetooth/dbus/bluez_dbus_manager.h" +#include "mojo/public/cpp/bindings/binding.h" +#include "services/device/device_service_test_base.h" +#include "services/device/public/mojom/bluetooth_system.mojom.h" +#include "services/device/public/mojom/constants.mojom.h" +#include "third_party/cros_system_api/dbus/service_constants.h" + +namespace device { + +constexpr const char kFooObjectPathStr[] = "fake/hci0"; +constexpr const char kBarObjectPathStr[] = "fake/hci1"; + +namespace { + +// Exposes high-level methods to simulate Bluetooth events e.g. a new adapter +// was added, adapter power state changed, etc. +// +// As opposed to FakeBluetoothAdapterClient, the other fake implementation of +// BluetoothAdapterClient, this class does not have any built-in behavior +// e.g. it won't start triggering device discovery events when StartDiscovery is +// called. It's up to its users to call the relevant Simulate*() method to +// trigger each event. +class DEVICE_BLUETOOTH_EXPORT TestBluetoothAdapterClient + : public bluez::BluetoothAdapterClient { + public: + struct Properties : public bluez::BluetoothAdapterClient::Properties { + explicit Properties(const PropertyChangedCallback& callback) + : BluetoothAdapterClient::Properties( + nullptr, /* object_proxy */ + bluetooth_adapter::kBluetoothAdapterInterface, + callback) {} + ~Properties() override = default; + + // dbus::PropertySet override + void Get(dbus::PropertyBase* property, + dbus::PropertySet::GetCallback callback) override { + DVLOG(1) << "Get " << property->name(); + NOTIMPLEMENTED(); + } + + void GetAll() override { + DVLOG(1) << "GetAll"; + NOTIMPLEMENTED(); + } + + void Set(dbus::PropertyBase* property, + dbus::PropertySet::SetCallback callback) override { + DVLOG(1) << "Set " << property->name(); + NOTIMPLEMENTED(); + } + }; + + TestBluetoothAdapterClient() = default; + ~TestBluetoothAdapterClient() override = default; + + // Simulates a new adapter with |object_path_str|. Its properties are empty, + // 0, or false. + void SimulateAdapterAdded(const std::string& object_path_str) { + dbus::ObjectPath object_path(object_path_str); + + ObjectPathToProperties::iterator it; + bool was_inserted; + std::tie(it, was_inserted) = adapter_object_paths_to_properties_.emplace( + object_path, + base::BindRepeating(&TestBluetoothAdapterClient::OnPropertyChanged, + base::Unretained(this), object_path)); + DCHECK(was_inserted); + + for (auto& observer : observers_) + observer.AdapterAdded(object_path); + } + + // Simulates the adapter at |object_path_str| being removed. + void SimulateAdapterRemoved(const std::string& object_path_str) { + dbus::ObjectPath object_path(object_path_str); + + // Properties are set to empty, 0, or false right before AdapterRemoved is + // called. + GetProperties(object_path)->powered.ReplaceValue(false); + + // When BlueZ calls into AdapterRemoved, the adapter is still exposed + // through GetAdapters() and its properties are still accessible. + for (auto& observer : observers_) + observer.AdapterRemoved(object_path); + + size_t removed = adapter_object_paths_to_properties_.erase(object_path); + DCHECK_EQ(1u, removed); + } + + // Simulates adapter at |object_path_str| changing its powered state to + // |powered|. + void SimulateAdapterPowerStateChanged(const std::string& object_path_str, + bool powered) { + GetProperties(dbus::ObjectPath(object_path_str)) + ->powered.ReplaceValue(powered); + } + + // BluetoothAdapterClient: + void Init(dbus::Bus* bus, + const std::string& bluetooth_service_name) override {} + + void AddObserver(Observer* observer) override { + observers_.AddObserver(observer); + } + + void RemoveObserver(Observer* observer) override { + observers_.RemoveObserver(observer); + } + + std::vector<dbus::ObjectPath> GetAdapters() override { + std::vector<dbus::ObjectPath> object_paths; + for (const auto& object_path_to_property : + adapter_object_paths_to_properties_) { + object_paths.push_back(object_path_to_property.first); + } + return object_paths; + } + + Properties* GetProperties(const dbus::ObjectPath& object_path) override { + auto it = adapter_object_paths_to_properties_.find(object_path); + if (it == adapter_object_paths_to_properties_.end()) + return nullptr; + return &(it->second); + } + + void StartDiscovery(const dbus::ObjectPath& object_path, + const base::Closure& callback, + ErrorCallback error_callback) override { + NOTIMPLEMENTED(); + } + + void StopDiscovery(const dbus::ObjectPath& object_path, + const base::Closure& callback, + ErrorCallback error_callback) override { + NOTIMPLEMENTED(); + } + + void PauseDiscovery(const dbus::ObjectPath& object_path, + const base::Closure& callback, + ErrorCallback error_callback) override { + NOTIMPLEMENTED(); + } + + void UnpauseDiscovery(const dbus::ObjectPath& object_path, + const base::Closure& callback, + ErrorCallback error_callback) override { + NOTIMPLEMENTED(); + } + + void RemoveDevice(const dbus::ObjectPath& object_path, + const dbus::ObjectPath& device_path, + const base::Closure& callback, + ErrorCallback error_callback) override { + NOTIMPLEMENTED(); + } + + void SetDiscoveryFilter(const dbus::ObjectPath& object_path, + const DiscoveryFilter& discovery_filter, + const base::Closure& callback, + ErrorCallback error_callback) override { + NOTIMPLEMENTED(); + } + + void CreateServiceRecord(const dbus::ObjectPath& object_path, + const bluez::BluetoothServiceRecordBlueZ& record, + const ServiceRecordCallback& callback, + ErrorCallback error_callback) override { + NOTIMPLEMENTED(); + } + + void RemoveServiceRecord(const dbus::ObjectPath& object_path, + uint32_t handle, + const base::Closure& callback, + ErrorCallback error_callback) override { + NOTIMPLEMENTED(); + } + + private: + void OnPropertyChanged(const dbus::ObjectPath& object_path, + const std::string& property_name) { + for (auto& observer : observers_) { + observer.AdapterPropertyChanged(object_path, property_name); + } + } + + using ObjectPathToProperties = std::map<dbus::ObjectPath, Properties>; + ObjectPathToProperties adapter_object_paths_to_properties_; + + base::ObserverList<Observer>::Unchecked observers_; +}; + +} // namespace + +class BluetoothSystemTest : public DeviceServiceTestBase, + public mojom::BluetoothSystemClient { + public: + BluetoothSystemTest() = default; + ~BluetoothSystemTest() override = default; + + void SetUp() override { + DeviceServiceTestBase::SetUp(); + connector()->BindInterface(mojom::kServiceName, &system_factory_); + + auto test_bluetooth_adapter_client = + std::make_unique<TestBluetoothAdapterClient>(); + test_bluetooth_adapter_client_ = test_bluetooth_adapter_client.get(); + + std::unique_ptr<bluez::BluezDBusManagerSetter> dbus_setter = + bluez::BluezDBusManager::GetSetterForTesting(); + dbus_setter->SetAlternateBluetoothAdapterClient( + std::move(test_bluetooth_adapter_client)); + } + + void StateCallback(base::OnceClosure quit_closure, + mojom::BluetoothSystem::State state) { + get_state_result_ = state; + std::move(quit_closure).Run(); + } + + // mojom::BluetoothSystemClient + void OnStateChanged(mojom::BluetoothSystem::State state) override { + on_state_changed_states_.push_back(state); + } + + protected: + mojom::BluetoothSystemPtr CreateBluetoothSystem() { + mojom::BluetoothSystemClientPtr client_ptr; + system_client_binding_.Bind(mojo::MakeRequest(&client_ptr)); + + mojom::BluetoothSystemPtr system_ptr; + system_factory_->Create(mojo::MakeRequest(&system_ptr), + std::move(client_ptr)); + return system_ptr; + } + + void ResetResults() { + get_state_result_.reset(); + on_state_changed_states_.clear(); + } + + // Saves the last state passed to StateCallback. + base::Optional<mojom::BluetoothSystem::State> get_state_result_; + + // Saves the states passed to OnStateChanged. + using StateVector = std::vector<mojom::BluetoothSystem::State>; + StateVector on_state_changed_states_; + + mojom::BluetoothSystemFactoryPtr system_factory_; + + TestBluetoothAdapterClient* test_bluetooth_adapter_client_; + + mojo::Binding<mojom::BluetoothSystemClient> system_client_binding_{this}; + + private: + DISALLOW_COPY_AND_ASSIGN(BluetoothSystemTest); +}; + +// Tests that the Create method for BluetoothSystemFactory works. +TEST_F(BluetoothSystemTest, FactoryCreate) { + mojom::BluetoothSystemPtr system_ptr; + mojo::Binding<mojom::BluetoothSystemClient> client_binding(this); + + mojom::BluetoothSystemClientPtr client_ptr; + client_binding.Bind(mojo::MakeRequest(&client_ptr)); + + EXPECT_FALSE(system_ptr.is_bound()); + + system_factory_->Create(mojo::MakeRequest(&system_ptr), + std::move(client_ptr)); + base::RunLoop run_loop; + system_ptr.FlushAsyncForTesting(run_loop.QuitClosure()); + run_loop.Run(); + + EXPECT_TRUE(system_ptr.is_bound()); +} + +// Tests that the state is 'Unavailable' when there is no Bluetooth adapter +// present. +TEST_F(BluetoothSystemTest, State_NoAdapter) { + auto system = CreateBluetoothSystem(); + + base::RunLoop run_loop; + system->GetState(base::BindOnce(&BluetoothSystemTest::StateCallback, + base::Unretained(this), + run_loop.QuitClosure())); + run_loop.Run(); + + EXPECT_EQ(mojom::BluetoothSystem::State::kUnavailable, + get_state_result_.value()); + EXPECT_TRUE(on_state_changed_states_.empty()); +} + +// Tests that the state is "Off" when the Bluetooth adapter is powered off. +TEST_F(BluetoothSystemTest, State_PoweredOffAdapter) { + test_bluetooth_adapter_client_->SimulateAdapterAdded(kFooObjectPathStr); + // Added adapters are Off by default. + + auto system = CreateBluetoothSystem(); + + base::RunLoop run_loop; + system->GetState(base::BindOnce(&BluetoothSystemTest::StateCallback, + base::Unretained(this), + run_loop.QuitClosure())); + run_loop.Run(); + + EXPECT_EQ(mojom::BluetoothSystem::State::kPoweredOff, + get_state_result_.value()); + EXPECT_TRUE(on_state_changed_states_.empty()); +} + +// Tests that the state is "On" when the Bluetooth adapter is powered on. +TEST_F(BluetoothSystemTest, State_PoweredOnAdapter) { + test_bluetooth_adapter_client_->SimulateAdapterAdded(kFooObjectPathStr); + test_bluetooth_adapter_client_->SimulateAdapterPowerStateChanged( + kFooObjectPathStr, true); + + auto system = CreateBluetoothSystem(); + + base::RunLoop run_loop; + system->GetState(base::BindOnce(&BluetoothSystemTest::StateCallback, + base::Unretained(this), + run_loop.QuitClosure())); + run_loop.Run(); + + EXPECT_EQ(mojom::BluetoothSystem::State::kPoweredOn, + get_state_result_.value()); + EXPECT_TRUE(on_state_changed_states_.empty()); +} + +// Tests that the state changes to On when the adapter turns on and then changes +// to Off when the adapter turns off. +TEST_F(BluetoothSystemTest, State_PoweredOnThenOff) { + test_bluetooth_adapter_client_->SimulateAdapterAdded(kFooObjectPathStr); + + auto system = CreateBluetoothSystem(); + + { + // The adapter is initially powered off. + base::RunLoop run_loop; + system->GetState(base::BindOnce(&BluetoothSystemTest::StateCallback, + base::Unretained(this), + run_loop.QuitClosure())); + run_loop.Run(); + + EXPECT_EQ(mojom::BluetoothSystem::State::kPoweredOff, + get_state_result_.value()); + EXPECT_TRUE(on_state_changed_states_.empty()); + ResetResults(); + } + + { + // Turn adapter on. + test_bluetooth_adapter_client_->SimulateAdapterPowerStateChanged( + kFooObjectPathStr, true); + + base::RunLoop run_loop; + system->GetState(base::BindOnce(&BluetoothSystemTest::StateCallback, + base::Unretained(this), + run_loop.QuitClosure())); + run_loop.Run(); + + EXPECT_EQ(mojom::BluetoothSystem::State::kPoweredOn, + get_state_result_.value()); + EXPECT_EQ(StateVector({mojom::BluetoothSystem::State::kPoweredOn}), + on_state_changed_states_); + ResetResults(); + } + + { + // Turn adapter off. + test_bluetooth_adapter_client_->SimulateAdapterPowerStateChanged( + kFooObjectPathStr, false); + + base::RunLoop run_loop; + system->GetState(base::BindOnce(&BluetoothSystemTest::StateCallback, + base::Unretained(this), + run_loop.QuitClosure())); + run_loop.Run(); + + EXPECT_EQ(mojom::BluetoothSystem::State::kPoweredOff, + get_state_result_.value()); + EXPECT_EQ(StateVector({mojom::BluetoothSystem::State::kPoweredOff}), + on_state_changed_states_); + } +} + +// Tests that the state is updated as expected when removing and re-adding the +// same adapter. +TEST_F(BluetoothSystemTest, State_AdapterRemoved) { + test_bluetooth_adapter_client_->SimulateAdapterAdded(kFooObjectPathStr); + test_bluetooth_adapter_client_->SimulateAdapterPowerStateChanged( + kFooObjectPathStr, true); + + auto system = CreateBluetoothSystem(); + + { + // The adapter is initially powered on. + base::RunLoop run_loop; + system->GetState(base::BindOnce(&BluetoothSystemTest::StateCallback, + base::Unretained(this), + run_loop.QuitClosure())); + run_loop.Run(); + + EXPECT_EQ(mojom::BluetoothSystem::State::kPoweredOn, + get_state_result_.value()); + EXPECT_TRUE(on_state_changed_states_.empty()); + ResetResults(); + } + + { + // Remove the adapter. The state should change to Unavailable. + test_bluetooth_adapter_client_->SimulateAdapterRemoved(kFooObjectPathStr); + base::RunLoop run_loop; + system->GetState(base::BindOnce(&BluetoothSystemTest::StateCallback, + base::Unretained(this), + run_loop.QuitClosure())); + run_loop.Run(); + + EXPECT_EQ(mojom::BluetoothSystem::State::kUnavailable, + get_state_result_.value()); + EXPECT_EQ(StateVector({mojom::BluetoothSystem::State::kPoweredOff, + mojom::BluetoothSystem::State::kUnavailable}), + on_state_changed_states_); + ResetResults(); + } + + { + // Add the adapter again; it's off by default. + test_bluetooth_adapter_client_->SimulateAdapterAdded(kFooObjectPathStr); + base::RunLoop run_loop; + system->GetState(base::BindOnce(&BluetoothSystemTest::StateCallback, + base::Unretained(this), + run_loop.QuitClosure())); + run_loop.Run(); + + EXPECT_EQ(mojom::BluetoothSystem::State::kPoweredOff, + get_state_result_.value()); + EXPECT_EQ(StateVector({mojom::BluetoothSystem::State::kPoweredOff}), + on_state_changed_states_); + } +} + +// Tests that the state is updated as expected when replacing the adapter with a +// different adapter. +TEST_F(BluetoothSystemTest, State_AdapterReplaced) { + // Start with a powered on adapter. + test_bluetooth_adapter_client_->SimulateAdapterAdded(kFooObjectPathStr); + test_bluetooth_adapter_client_->SimulateAdapterPowerStateChanged( + kFooObjectPathStr, true); + + auto system = CreateBluetoothSystem(); + + { + // The adapter is initially powered on. + base::RunLoop run_loop; + system->GetState(base::BindOnce(&BluetoothSystemTest::StateCallback, + base::Unretained(this), + run_loop.QuitClosure())); + run_loop.Run(); + + EXPECT_EQ(mojom::BluetoothSystem::State::kPoweredOn, + get_state_result_.value()); + EXPECT_TRUE(on_state_changed_states_.empty()); + ResetResults(); + } + + { + // Remove the adapter. The state should change to Unavailable. + test_bluetooth_adapter_client_->SimulateAdapterRemoved(kFooObjectPathStr); + base::RunLoop run_loop; + system->GetState(base::BindOnce(&BluetoothSystemTest::StateCallback, + base::Unretained(this), + run_loop.QuitClosure())); + run_loop.Run(); + + EXPECT_EQ(mojom::BluetoothSystem::State::kUnavailable, + get_state_result_.value()); + EXPECT_EQ(StateVector({mojom::BluetoothSystem::State::kPoweredOff, + mojom::BluetoothSystem::State::kUnavailable}), + on_state_changed_states_); + ResetResults(); + } + + { + // Add a different adapter. it's off by default. + test_bluetooth_adapter_client_->SimulateAdapterAdded(kBarObjectPathStr); + base::RunLoop run_loop; + system->GetState(base::BindOnce(&BluetoothSystemTest::StateCallback, + base::Unretained(this), + run_loop.QuitClosure())); + run_loop.Run(); + + EXPECT_EQ(mojom::BluetoothSystem::State::kPoweredOff, + get_state_result_.value()); + EXPECT_EQ(StateVector({mojom::BluetoothSystem::State::kPoweredOff}), + on_state_changed_states_); + } +} + +// Tests that the state is correctly updated when adding and removing multiple +// adapters. +TEST_F(BluetoothSystemTest, State_AddAndRemoveMultipleAdapters) { + // Start with a powered on "foo" adapter. + test_bluetooth_adapter_client_->SimulateAdapterAdded(kFooObjectPathStr); + test_bluetooth_adapter_client_->SimulateAdapterPowerStateChanged( + kFooObjectPathStr, true); + + auto system = CreateBluetoothSystem(); + + { + // The "foo" adapter is initially powered on. + base::RunLoop run_loop; + system->GetState(base::BindOnce(&BluetoothSystemTest::StateCallback, + base::Unretained(this), + run_loop.QuitClosure())); + run_loop.Run(); + + EXPECT_EQ(mojom::BluetoothSystem::State::kPoweredOn, + get_state_result_.value()); + EXPECT_TRUE(on_state_changed_states_.empty()); + ResetResults(); + } + + { + // Add an extra "bar" adapter. The state should not change. + test_bluetooth_adapter_client_->SimulateAdapterAdded(kBarObjectPathStr); + base::RunLoop run_loop; + system->GetState(base::BindOnce(&BluetoothSystemTest::StateCallback, + base::Unretained(this), + run_loop.QuitClosure())); + run_loop.Run(); + + EXPECT_EQ(mojom::BluetoothSystem::State::kPoweredOn, + get_state_result_.value()); + EXPECT_TRUE(on_state_changed_states_.empty()); + ResetResults(); + } + + { + // Remove "foo". We should retrieve the state from "bar". + test_bluetooth_adapter_client_->SimulateAdapterRemoved(kFooObjectPathStr); + base::RunLoop run_loop; + system->GetState(base::BindOnce(&BluetoothSystemTest::StateCallback, + base::Unretained(this), + run_loop.QuitClosure())); + run_loop.Run(); + + EXPECT_EQ(mojom::BluetoothSystem::State::kPoweredOff, + get_state_result_.value()); + EXPECT_EQ(StateVector({mojom::BluetoothSystem::State::kPoweredOff}), + on_state_changed_states_); + ResetResults(); + } + + { + // Change "bar"'s state to On. + test_bluetooth_adapter_client_->SimulateAdapterPowerStateChanged( + kBarObjectPathStr, true); + base::RunLoop run_loop; + system->GetState(base::BindOnce(&BluetoothSystemTest::StateCallback, + base::Unretained(this), + run_loop.QuitClosure())); + run_loop.Run(); + + EXPECT_EQ(mojom::BluetoothSystem::State::kPoweredOn, + get_state_result_.value()); + EXPECT_EQ(StateVector({mojom::BluetoothSystem::State::kPoweredOn}), + on_state_changed_states_); + ResetResults(); + } + + { + // Add "foo" again. We should still retrieve the state from "bar". + test_bluetooth_adapter_client_->SimulateAdapterAdded(kFooObjectPathStr); + base::RunLoop run_loop; + system->GetState(base::BindOnce(&BluetoothSystemTest::StateCallback, + base::Unretained(this), + run_loop.QuitClosure())); + run_loop.Run(); + + EXPECT_EQ(mojom::BluetoothSystem::State::kPoweredOn, + get_state_result_.value()); + EXPECT_TRUE(on_state_changed_states_.empty()); + } +} + +// Tests that an extra adapter changing state does not interfer with the state. +TEST_F(BluetoothSystemTest, State_ChangeStateMultipleAdapters) { + // Start with a powered on "foo" adapter. + test_bluetooth_adapter_client_->SimulateAdapterAdded(kFooObjectPathStr); + test_bluetooth_adapter_client_->SimulateAdapterPowerStateChanged( + kFooObjectPathStr, true); + + auto system = CreateBluetoothSystem(); + + { + // The "foo" adapter is initially powered on. + base::RunLoop run_loop; + system->GetState(base::BindOnce(&BluetoothSystemTest::StateCallback, + base::Unretained(this), + run_loop.QuitClosure())); + run_loop.Run(); + + EXPECT_EQ(mojom::BluetoothSystem::State::kPoweredOn, + get_state_result_.value()); + EXPECT_TRUE(on_state_changed_states_.empty()); + ResetResults(); + } + + { + // Add an extra "bar" adapter. The state should not change. + test_bluetooth_adapter_client_->SimulateAdapterAdded(kBarObjectPathStr); + base::RunLoop run_loop; + system->GetState(base::BindOnce(&BluetoothSystemTest::StateCallback, + base::Unretained(this), + run_loop.QuitClosure())); + run_loop.Run(); + + EXPECT_EQ(mojom::BluetoothSystem::State::kPoweredOn, + get_state_result_.value()); + EXPECT_TRUE(on_state_changed_states_.empty()); + ResetResults(); + } + + { + // Turn "bar" on. The state should not change. + test_bluetooth_adapter_client_->SimulateAdapterPowerStateChanged( + kBarObjectPathStr, true); + base::RunLoop run_loop; + system->GetState(base::BindOnce(&BluetoothSystemTest::StateCallback, + base::Unretained(this), + run_loop.QuitClosure())); + run_loop.Run(); + + EXPECT_EQ(mojom::BluetoothSystem::State::kPoweredOn, + get_state_result_.value()); + EXPECT_TRUE(on_state_changed_states_.empty()); + ResetResults(); + } + + { + // Turn "bar" off. The state should not change. + test_bluetooth_adapter_client_->SimulateAdapterPowerStateChanged( + kBarObjectPathStr, false); + base::RunLoop run_loop; + system->GetState(base::BindOnce(&BluetoothSystemTest::StateCallback, + base::Unretained(this), + run_loop.QuitClosure())); + run_loop.Run(); + + EXPECT_EQ(mojom::BluetoothSystem::State::kPoweredOn, + get_state_result_.value()); + EXPECT_TRUE(on_state_changed_states_.empty()); + ResetResults(); + } +} + +} // namespace device diff --git a/chromium/services/device/device_service.cc b/chromium/services/device/device_service.cc index 7be23bcfc44..34416635ab3 100644 --- a/chromium/services/device/device_service.cc +++ b/chromium/services/device/device_service.cc @@ -11,7 +11,9 @@ #include "base/single_thread_task_runner.h" #include "base/threading/thread_task_runner_handle.h" #include "build/build_config.h" +#include "device/usb/mojo/device_manager_impl.h" #include "mojo/public/cpp/system/message_pipe.h" +#include "services/device/bluetooth/bluetooth_system_factory.h" #include "services/device/fingerprint/fingerprint.h" #include "services/device/generic_sensor/sensor_provider_impl.h" #include "services/device/geolocation/geolocation_config.h" @@ -140,6 +142,8 @@ void DeviceService::OnStart() { base::Unretained(this))); registry_.AddInterface<mojom::SerialIoHandler>(base::Bind( &DeviceService::BindSerialIoHandlerRequest, base::Unretained(this))); + registry_.AddInterface<mojom::UsbDeviceManager>(base::Bind( + &DeviceService::BindUsbDeviceManagerRequest, base::Unretained(this))); #if defined(OS_ANDROID) registry_.AddInterface(GetJavaInterfaceProvider() @@ -161,6 +165,9 @@ void DeviceService::OnStart() { #endif #if defined(OS_CHROMEOS) + registry_.AddInterface<mojom::BluetoothSystemFactory>( + base::BindRepeating(&DeviceService::BindBluetoothSystemFactoryRequest, + base::Unretained(this))); registry_.AddInterface<mojom::MtpManager>(base::BindRepeating( &DeviceService::BindMtpManagerRequest, base::Unretained(this))); #endif @@ -202,6 +209,11 @@ void DeviceService::BindVibrationManagerRequest( #endif #if defined(OS_CHROMEOS) +void DeviceService::BindBluetoothSystemFactoryRequest( + mojom::BluetoothSystemFactoryRequest request) { + BluetoothSystemFactory::CreateFactory(std::move(request)); +} + void DeviceService::BindMtpManagerRequest(mojom::MtpManagerRequest request) { if (!mtp_device_manager_) mtp_device_manager_ = MtpDeviceManager::Initialize(); @@ -311,6 +323,14 @@ void DeviceService::BindSerialIoHandlerRequest( #endif } +void DeviceService::BindUsbDeviceManagerRequest( + mojom::UsbDeviceManagerRequest request) { + if (!usb_device_manager_) + usb_device_manager_ = std::make_unique<usb::DeviceManagerImpl>(); + + usb_device_manager_->AddBinding(std::move(request)); +} + #if defined(OS_ANDROID) service_manager::InterfaceProvider* DeviceService::GetJavaInterfaceProvider() { if (!java_interface_provider_initialized_) { diff --git a/chromium/services/device/device_service.h b/chromium/services/device/device_service.h index 2eae9a431de..b7c9d260aeb 100644 --- a/chromium/services/device/device_service.h +++ b/chromium/services/device/device_service.h @@ -10,6 +10,8 @@ #include "base/memory/ref_counted.h" #include "build/build_config.h" +#include "device/usb/mojo/device_manager_impl.h" +#include "device/usb/public/mojom/device_manager.mojom.h" #include "mojo/public/cpp/bindings/binding_set.h" #include "services/device/geolocation/geolocation_provider.h" #include "services/device/geolocation/geolocation_provider_impl.h" @@ -41,6 +43,7 @@ #if defined(OS_CHROMEOS) #include "services/device/media_transfer_protocol/mtp_device_manager.h" +#include "services/device/public/mojom/bluetooth_system.mojom.h" #endif #if defined(OS_LINUX) && defined(USE_UDEV) @@ -130,6 +133,8 @@ class DeviceService : public service_manager::Service { #endif #if defined(OS_CHROMEOS) + void BindBluetoothSystemFactoryRequest( + mojom::BluetoothSystemFactoryRequest request); void BindMtpManagerRequest(mojom::MtpManagerRequest request); #endif @@ -152,11 +157,14 @@ class DeviceService : public service_manager::Service { void BindSerialIoHandlerRequest(mojom::SerialIoHandlerRequest request); + void BindUsbDeviceManagerRequest(mojom::UsbDeviceManagerRequest request); + std::unique_ptr<PowerMonitorMessageBroadcaster> power_monitor_message_broadcaster_; std::unique_ptr<PublicIpAddressGeolocationProvider> public_ip_address_geolocation_provider_; std::unique_ptr<TimeZoneMonitor> time_zone_monitor_; + std::unique_ptr<usb::DeviceManagerImpl> usb_device_manager_; scoped_refptr<base::SingleThreadTaskRunner> file_task_runner_; scoped_refptr<base::SingleThreadTaskRunner> io_task_runner_; scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory_; diff --git a/chromium/services/device/fingerprint/fingerprint_chromeos.cc b/chromium/services/device/fingerprint/fingerprint_chromeos.cc index 3d13d43c6b3..1b085c7b8c5 100644 --- a/chromium/services/device/fingerprint/fingerprint_chromeos.cc +++ b/chromium/services/device/fingerprint/fingerprint_chromeos.cc @@ -14,8 +14,6 @@ namespace device { namespace { -constexpr int64_t kFingerprintSessionTimeoutMs = 150; - chromeos::BiodClient* GetBiodClient() { return chromeos::DBusThreadManager::Get()->GetBiodClient(); } @@ -75,13 +73,7 @@ void FingerprintChromeOS::OnCloseAuthSessionForEnroll( if (!result) return; - // TODO(xiaoyinh@): Timeout should be removed after we resolve - // crbug.com/715302. - base::ThreadTaskRunnerHandle::Get()->PostDelayedTask( - FROM_HERE, - base::Bind(&FingerprintChromeOS::ScheduleStartEnroll, - weak_ptr_factory_.GetWeakPtr(), user_id, label), - base::TimeDelta::FromMilliseconds(kFingerprintSessionTimeoutMs)); + ScheduleStartEnroll(user_id, label); } void FingerprintChromeOS::ScheduleStartEnroll(const std::string& user_id, @@ -136,13 +128,7 @@ void FingerprintChromeOS::OnCloseEnrollSessionForAuth(bool result) { if (!result) return; - // TODO(xiaoyinh@): Timeout should be removed after we resolve - // crbug.com/715302. - base::ThreadTaskRunnerHandle::Get()->PostDelayedTask( - FROM_HERE, - base::Bind(&FingerprintChromeOS::ScheduleStartAuth, - weak_ptr_factory_.GetWeakPtr()), - base::TimeDelta::FromMilliseconds(kFingerprintSessionTimeoutMs)); + ScheduleStartAuth(); } void FingerprintChromeOS::ScheduleStartAuth() { diff --git a/chromium/services/device/generic_sensor/OWNERS b/chromium/services/device/generic_sensor/OWNERS index db5ecbc8113..b083d335711 100644 --- a/chromium/services/device/generic_sensor/OWNERS +++ b/chromium/services/device/generic_sensor/OWNERS @@ -1,4 +1,3 @@ -alexander.shalamov@intel.com juncai@chromium.org timvolodine@chromium.org diff --git a/chromium/services/device/generic_sensor/android/java/src/org/chromium/device/sensors/PlatformSensor.java b/chromium/services/device/generic_sensor/android/java/src/org/chromium/device/sensors/PlatformSensor.java new file mode 100644 index 00000000000..85d7c7d058d --- /dev/null +++ b/chromium/services/device/generic_sensor/android/java/src/org/chromium/device/sensors/PlatformSensor.java @@ -0,0 +1,270 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.device.sensors; + +import android.hardware.Sensor; +import android.hardware.SensorEvent; +import android.hardware.SensorEventListener; +import android.os.Build; + +import org.chromium.base.Log; +import org.chromium.base.annotations.CalledByNative; +import org.chromium.base.annotations.JNINamespace; +import org.chromium.device.mojom.ReportingMode; + +import java.util.List; + +/** + * Implementation of PlatformSensor that uses Android Sensor Framework. Lifetime is controlled by + * the device::PlatformSensorAndroid. + */ +@JNINamespace("device") +public class PlatformSensor implements SensorEventListener { + private static final double MICROSECONDS_IN_SECOND = 1000000; + private static final double SECONDS_IN_MICROSECOND = 0.000001d; + private static final double SECONDS_IN_NANOSECOND = 0.000000001d; + private static final String TAG = "PlatformSensor"; + + /** + * The SENSOR_FREQUENCY_NORMAL is defined as 5Hz which corresponds to a polling delay + * @see android.hardware.SensorManager.SENSOR_DELAY_NORMAL value that is defined as 200000 + * microseconds. + */ + private static final double SENSOR_FREQUENCY_NORMAL = 5.0d; + + /** + * Identifier of device::PlatformSensorAndroid instance. + */ + private long mNativePlatformSensorAndroid; + + /** + * Used for fetching sensor reading values and obtaining information about the sensor. + * @see android.hardware.Sensor + */ + private final Sensor mSensor; + + /** + * The minimum delay between two readings in microseconds that is supported by the sensor. + */ + private final int mMinDelayUsec; + + /** + * The number of sensor reading values required from the sensor. + */ + private final int mReadingCount; + + /** + * Frequncy that is currently used by the sensor for polling. + */ + private double mCurrentPollingFrequency; + + /** + * Provides shared SensorManager and event processing thread Handler to PlatformSensor objects. + */ + private final PlatformSensorProvider mProvider; + + /** + * Creates new PlatformSensor. + * + * @param sensorType type of the sensor to be constructed. @see android.hardware.Sensor.TYPE_* + * @param readingCount number of sensor reading values required from the sensor. + * @param provider object that shares SensorManager and polling thread Handler with sensors. + */ + public static PlatformSensor create( + int sensorType, int readingCount, PlatformSensorProvider provider) { + List<Sensor> sensors = provider.getSensorManager().getSensorList(sensorType); + if (sensors.isEmpty()) return null; + return new PlatformSensor(sensors.get(0), readingCount, provider); + } + + /** + * Constructor. + */ + protected PlatformSensor(Sensor sensor, int readingCount, PlatformSensorProvider provider) { + mReadingCount = readingCount; + mProvider = provider; + mSensor = sensor; + mMinDelayUsec = mSensor.getMinDelay(); + } + + /** + * Initializes PlatformSensor, called by native code. + * + * @param nativePlatformSensorAndroid identifier of device::PlatformSensorAndroid instance. + * @param buffer shared buffer that is used to return data to the client. + */ + @CalledByNative + protected void initPlatformSensorAndroid(long nativePlatformSensorAndroid) { + assert nativePlatformSensorAndroid != 0; + mNativePlatformSensorAndroid = nativePlatformSensorAndroid; + } + + /** + * Returns reporting mode supported by the sensor. + * + * @return ReportingMode reporting mode. + */ + @CalledByNative + protected int getReportingMode() { + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.LOLLIPOP) { + return mSensor.getReportingMode() == Sensor.REPORTING_MODE_CONTINUOUS + ? ReportingMode.CONTINUOUS + : ReportingMode.ON_CHANGE; + } + return ReportingMode.CONTINUOUS; + } + + /** + * Returns default configuration supported by the sensor. Currently only frequency is supported. + * + * @return double frequency. + */ + @CalledByNative + protected double getDefaultConfiguration() { + return SENSOR_FREQUENCY_NORMAL; + } + + /** + * Returns maximum sampling frequency supported by the sensor. + * + * @return double frequency in Hz. + */ + @CalledByNative + protected double getMaximumSupportedFrequency() { + if (mMinDelayUsec == 0) return getDefaultConfiguration(); + return 1 / (mMinDelayUsec * SECONDS_IN_MICROSECOND); + } + + /** + * Requests sensor to start polling for data. + * + * @return boolean true if successful, false otherwise. + */ + @CalledByNative + protected boolean startSensor(double frequency) { + // If we already polling hw with same frequency, do not restart the sensor. + if (mCurrentPollingFrequency == frequency) return true; + + // Unregister old listener if polling frequency has changed. + unregisterListener(); + + mProvider.sensorStarted(this); + boolean sensorStarted; + try { + sensorStarted = mProvider.getSensorManager().registerListener( + this, mSensor, getSamplingPeriod(frequency), mProvider.getHandler()); + } catch (RuntimeException e) { + // This can fail due to internal framework errors. https://crbug.com/884190 + Log.w(TAG, "Failed to register sensor listener.", e); + sensorStarted = false; + } + + if (!sensorStarted) { + stopSensor(); + return sensorStarted; + } + + mCurrentPollingFrequency = frequency; + return sensorStarted; + } + + private void unregisterListener() { + // Do not unregister if current polling frequency is 0, not polling for data. + if (mCurrentPollingFrequency == 0) return; + mProvider.getSensorManager().unregisterListener(this, mSensor); + } + + /** + * Requests sensor to stop polling for data. + */ + @CalledByNative + protected void stopSensor() { + unregisterListener(); + mProvider.sensorStopped(this); + mCurrentPollingFrequency = 0; + } + + /** + * Checks whether configuration is supported by the sensor. Currently only frequency is + * supported. + * + * @return boolean true if configuration is supported, false otherwise. + */ + @CalledByNative + protected boolean checkSensorConfiguration(double frequency) { + return mMinDelayUsec <= getSamplingPeriod(frequency); + } + + /** + * Called from device::PlatformSensorAndroid destructor, so that this instance would be + * notified not to deliver any updates about new sensor readings or errors. + */ + @CalledByNative + protected void sensorDestroyed() { + stopSensor(); + mNativePlatformSensorAndroid = 0; + } + + /** + * Converts frequency to sampling period in microseconds. + */ + private int getSamplingPeriod(double frequency) { + return (int) ((1 / frequency) * MICROSECONDS_IN_SECOND); + } + + /** + * Notifies native device::PlatformSensorAndroid when there is an error. + */ + protected void sensorError() { + nativeNotifyPlatformSensorError(mNativePlatformSensorAndroid); + } + + /** + * Updates reading at native device::PlatformSensorAndroid. + */ + protected void updateSensorReading( + double timestamp, double value1, double value2, double value3, double value4) { + nativeUpdatePlatformSensorReading( + mNativePlatformSensorAndroid, timestamp, value1, value2, value3, value4); + } + + @Override + public void onAccuracyChanged(Sensor sensor, int accuracy) {} + + @Override + public void onSensorChanged(SensorEvent event) { + if (mNativePlatformSensorAndroid == 0) { + Log.w(TAG, "Should not get sensor events after PlatformSensorAndroid is destroyed."); + return; + } + + if (event.values.length < mReadingCount) { + sensorError(); + stopSensor(); + return; + } + + double timestamp = event.timestamp * SECONDS_IN_NANOSECOND; + switch (event.values.length) { + case 1: + updateSensorReading(timestamp, event.values[0], 0.0, 0.0, 0.0); + break; + case 2: + updateSensorReading(timestamp, event.values[0], event.values[1], 0.0, 0.0); + break; + case 3: + updateSensorReading( + timestamp, event.values[0], event.values[1], event.values[2], 0.0); + break; + default: + updateSensorReading(timestamp, event.values[0], event.values[1], event.values[2], + event.values[3]); + } + } + + private native void nativeNotifyPlatformSensorError(long nativePlatformSensorAndroid); + private native void nativeUpdatePlatformSensorReading(long nativePlatformSensorAndroid, + double timestamp, double value1, double value2, double value3, double value4); +} diff --git a/chromium/services/device/generic_sensor/android/java/src/org/chromium/device/sensors/PlatformSensorProvider.java b/chromium/services/device/generic_sensor/android/java/src/org/chromium/device/sensors/PlatformSensorProvider.java new file mode 100644 index 00000000000..cd730e14d50 --- /dev/null +++ b/chromium/services/device/generic_sensor/android/java/src/org/chromium/device/sensors/PlatformSensorProvider.java @@ -0,0 +1,223 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.device.sensors; + +import android.content.Context; +import android.hardware.Sensor; +import android.hardware.SensorManager; +import android.os.Build; +import android.os.Handler; +import android.os.HandlerThread; + +import org.chromium.base.ContextUtils; +import org.chromium.base.annotations.CalledByNative; +import org.chromium.base.annotations.JNINamespace; +import org.chromium.device.mojom.SensorType; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** + * Lifetime is controlled by device::PlatformSensorProviderAndroid. + */ +@JNINamespace("device") +class PlatformSensorProvider { + /** + * SensorManager that is shared among PlatformSensor objects. It is used for Sensor object + * creation and @see android.hardware.SensorEventListener registration. + * @see android.hardware.SensorManager + */ + private SensorManager mSensorManager; + + /** + * Thread that is handling all sensor events. + */ + private HandlerThread mSensorsThread; + + /** + * Processes messages on #mSensorsThread message queue. Provided to #mSensorManager when + * sensor should start polling for data. + */ + private Handler mHandler; + + /** + * Set of currently active PlatformSensor objects. + */ + private final Set<PlatformSensor> mActiveSensors = new HashSet<PlatformSensor>(); + + /** + * Returns shared thread Handler. + * + * @return Handler thread handler. + */ + public Handler getHandler() { + return mHandler; + } + + /** + * Returns shared SensorManager. + * + * @return SensorManager sensor manager. + */ + public SensorManager getSensorManager() { + return mSensorManager; + } + + /** + * Notifies PlatformSensorProvider that sensor started polling for data. Adds sensor to + * a set of active sensors, creates and starts new thread if needed. + */ + public void sensorStarted(PlatformSensor sensor) { + synchronized (mActiveSensors) { + if (mActiveSensors.isEmpty()) startSensorThread(); + mActiveSensors.add(sensor); + } + } + + /** + * Notifies PlatformSensorProvider that sensor is no longer polling for data. When + * #mActiveSensors becomes empty thread is stopped. + */ + public void sensorStopped(PlatformSensor sensor) { + synchronized (mActiveSensors) { + mActiveSensors.remove(sensor); + if (mActiveSensors.isEmpty()) stopSensorThread(); + } + } + + /** + * Starts sensor handler thread. + */ + protected void startSensorThread() { + if (mSensorsThread == null) { + mSensorsThread = new HandlerThread("SensorsHandlerThread"); + mSensorsThread.start(); + mHandler = new Handler(mSensorsThread.getLooper()); + } + } + + /** + * Stops sensor handler thread. + */ + protected void stopSensorThread() { + if (mSensorsThread != null) { + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.JELLY_BEAN_MR2) { + mSensorsThread.quitSafely(); + } else { + mSensorsThread.quit(); + } + mSensorsThread = null; + mHandler = null; + } + } + + /** + * Constructor. + */ + protected PlatformSensorProvider(Context context) { + mSensorManager = (SensorManager) context.getSystemService(Context.SENSOR_SERVICE); + } + + /** + * Creates PlatformSensorProvider instance. + * + * @return PlatformSensorProvider new PlatformSensorProvider instance. + */ + protected static PlatformSensorProvider createForTest(Context context) { + return new PlatformSensorProvider(context); + } + + /** + * Creates PlatformSensorProvider instance. + * + * @return PlatformSensorProvider new PlatformSensorProvider instance. + */ + @CalledByNative + protected static PlatformSensorProvider create() { + return new PlatformSensorProvider(ContextUtils.getApplicationContext()); + } + + /** + * Sets |mSensorManager| to null for testing purposes. + */ + @CalledByNative + protected void setSensorManagerToNullForTesting() { + mSensorManager = null; + } + + /** + * Checks if |type| sensor is available. + * + * @param type type of a sensor. + * @return If |type| sensor is available, returns true; otherwise returns false. + */ + @CalledByNative + protected boolean hasSensorType(int type) { + if (mSensorManager == null) return false; + + // Type of the sensor to be constructed. @see android.hardware.Sensor.TYPE_* + int sensorType; + + switch (type) { + case SensorType.AMBIENT_LIGHT: + sensorType = Sensor.TYPE_LIGHT; + break; + case SensorType.ACCELEROMETER: + sensorType = Sensor.TYPE_ACCELEROMETER; + break; + case SensorType.LINEAR_ACCELERATION: + sensorType = Sensor.TYPE_LINEAR_ACCELERATION; + break; + case SensorType.GYROSCOPE: + sensorType = Sensor.TYPE_GYROSCOPE; + break; + case SensorType.MAGNETOMETER: + sensorType = Sensor.TYPE_MAGNETIC_FIELD; + break; + case SensorType.ABSOLUTE_ORIENTATION_QUATERNION: + sensorType = Sensor.TYPE_ROTATION_VECTOR; + break; + case SensorType.RELATIVE_ORIENTATION_QUATERNION: + sensorType = Sensor.TYPE_GAME_ROTATION_VECTOR; + break; + default: + return false; + } + + List<Sensor> sensors = mSensorManager.getSensorList(sensorType); + return !sensors.isEmpty(); + } + + /** + * Creates PlatformSensor instance. + * + * @param type type of a sensor. + * @return PlatformSensor new PlatformSensor instance or null if sensor cannot be created. + */ + @CalledByNative + protected PlatformSensor createSensor(int type) { + if (mSensorManager == null) return null; + + switch (type) { + case SensorType.AMBIENT_LIGHT: + return PlatformSensor.create(Sensor.TYPE_LIGHT, 1, this); + case SensorType.ACCELEROMETER: + return PlatformSensor.create(Sensor.TYPE_ACCELEROMETER, 3, this); + case SensorType.LINEAR_ACCELERATION: + return PlatformSensor.create(Sensor.TYPE_LINEAR_ACCELERATION, 3, this); + case SensorType.GYROSCOPE: + return PlatformSensor.create(Sensor.TYPE_GYROSCOPE, 3, this); + case SensorType.MAGNETOMETER: + return PlatformSensor.create(Sensor.TYPE_MAGNETIC_FIELD, 3, this); + case SensorType.ABSOLUTE_ORIENTATION_QUATERNION: + return PlatformSensor.create(Sensor.TYPE_ROTATION_VECTOR, 4, this); + case SensorType.RELATIVE_ORIENTATION_QUATERNION: + return PlatformSensor.create(Sensor.TYPE_GAME_ROTATION_VECTOR, 4, this); + default: + return null; + } + } +} diff --git a/chromium/services/device/generic_sensor/android/junit/src/org/chromium/device/sensors/PlatformSensorAndProviderTest.java b/chromium/services/device/generic_sensor/android/junit/src/org/chromium/device/sensors/PlatformSensorAndProviderTest.java new file mode 100644 index 00000000000..f70c7a3d772 --- /dev/null +++ b/chromium/services/device/generic_sensor/android/junit/src/org/chromium/device/sensors/PlatformSensorAndProviderTest.java @@ -0,0 +1,464 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.device.sensors; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import android.content.Context; +import android.hardware.Sensor; +import android.hardware.SensorEvent; +import android.hardware.SensorEventListener; +import android.hardware.SensorManager; +import android.os.Handler; +import android.util.SparseArray; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import org.robolectric.annotation.Config; + +import org.chromium.base.test.BaseRobolectricTestRunner; +import org.chromium.base.test.util.Feature; +import org.chromium.device.mojom.ReportingMode; +import org.chromium.device.mojom.SensorType; + +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.util.ArrayList; +import java.util.List; + +/** + * Unit tests for PlatformSensor and PlatformSensorProvider. + */ +@RunWith(BaseRobolectricTestRunner.class) +@Config(manifest = Config.NONE) +public class PlatformSensorAndProviderTest { + @Mock + private Context mContext; + @Mock + private SensorManager mSensorManager; + @Mock + private PlatformSensorProvider mPlatformSensorProvider; + private final SparseArray<List<Sensor>> mMockSensors = new SparseArray<>(); + private static final long PLATFORM_SENSOR_ANDROID = 123456789L; + private static final long PLATFORM_SENSOR_TIMESTAMP = 314159265358979L; + private static final double SECONDS_IN_NANOSECOND = 0.000000001d; + + /** + * Class that overrides thread management callbacks for testing purposes. + */ + private static class TestPlatformSensorProvider extends PlatformSensorProvider { + public TestPlatformSensorProvider(Context context) { + super(context); + } + + @Override + public Handler getHandler() { + return new Handler(); + } + + @Override + protected void startSensorThread() {} + + @Override + protected void stopSensorThread() {} + } + + /** + * Class that overrides native callbacks for testing purposes. + */ + private static class TestPlatformSensor extends PlatformSensor { + public TestPlatformSensor( + Sensor sensor, int readingCount, PlatformSensorProvider provider) { + super(sensor, readingCount, provider); + } + + @Override + protected void updateSensorReading( + double timestamp, double value1, double value2, double value3, double value4) {} + @Override + protected void sensorError() {} + } + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + // Remove all mock sensors before the test. + mMockSensors.clear(); + doReturn(mSensorManager).when(mContext).getSystemService(Context.SENSOR_SERVICE); + doAnswer(new Answer<List<Sensor>>() { + @Override + public List<Sensor> answer(final InvocationOnMock invocation) throws Throwable { + return getMockSensors((int) (Integer) (invocation.getArguments())[0]); + } + }) + .when(mSensorManager) + .getSensorList(anyInt()); + doReturn(mSensorManager).when(mPlatformSensorProvider).getSensorManager(); + doReturn(new Handler()).when(mPlatformSensorProvider).getHandler(); + // By default, allow successful registration of SensorEventListeners. + doReturn(true) + .when(mSensorManager) + .registerListener(any(SensorEventListener.class), any(Sensor.class), anyInt(), + any(Handler.class)); + } + + /** + * Test that PlatformSensorProvider cannot create sensors if sensor manager is null. + */ + @Test + @Feature({"PlatformSensorProvider"}) + public void testNullSensorManager() { + doReturn(null).when(mContext).getSystemService(Context.SENSOR_SERVICE); + PlatformSensorProvider provider = PlatformSensorProvider.createForTest(mContext); + PlatformSensor sensor = provider.createSensor(SensorType.AMBIENT_LIGHT); + assertNull(sensor); + } + + /** + * Test that PlatformSensorProvider cannot create sensors that are not supported. + */ + @Test + @Feature({"PlatformSensorProvider"}) + public void testSensorNotSupported() { + PlatformSensorProvider provider = PlatformSensorProvider.createForTest(mContext); + PlatformSensor sensor = provider.createSensor(SensorType.AMBIENT_LIGHT); + assertNull(sensor); + } + + /** + * Test that PlatformSensorProvider maps device::SensorType to android.hardware.Sensor.TYPE_*. + */ + @Test + @Feature({"PlatformSensorProvider"}) + public void testSensorTypeMappings() { + PlatformSensorProvider provider = PlatformSensorProvider.createForTest(mContext); + provider.createSensor(SensorType.AMBIENT_LIGHT); + verify(mSensorManager).getSensorList(Sensor.TYPE_LIGHT); + provider.createSensor(SensorType.ACCELEROMETER); + verify(mSensorManager).getSensorList(Sensor.TYPE_ACCELEROMETER); + provider.createSensor(SensorType.LINEAR_ACCELERATION); + verify(mSensorManager).getSensorList(Sensor.TYPE_LINEAR_ACCELERATION); + provider.createSensor(SensorType.GYROSCOPE); + verify(mSensorManager).getSensorList(Sensor.TYPE_GYROSCOPE); + provider.createSensor(SensorType.MAGNETOMETER); + verify(mSensorManager).getSensorList(Sensor.TYPE_MAGNETIC_FIELD); + provider.createSensor(SensorType.ABSOLUTE_ORIENTATION_QUATERNION); + verify(mSensorManager).getSensorList(Sensor.TYPE_ROTATION_VECTOR); + provider.createSensor(SensorType.RELATIVE_ORIENTATION_QUATERNION); + verify(mSensorManager).getSensorList(Sensor.TYPE_GAME_ROTATION_VECTOR); + } + + /** + * Test that PlatformSensorProvider can create sensors that are supported. + */ + @Test + @Feature({"PlatformSensorProvider"}) + public void testSensorSupported() { + PlatformSensor sensor = createPlatformSensor(50000, Sensor.TYPE_LIGHT, + SensorType.AMBIENT_LIGHT, Sensor.REPORTING_MODE_ON_CHANGE); + assertNotNull(sensor); + } + + /** + * Test that PlatformSensor notifies PlatformSensorProvider when it starts (stops) polling, + * and SensorEventListener is registered (unregistered) to sensor manager. + */ + @Test + @Feature({"PlatformSensor"}) + public void testSensorStartStop() { + addMockSensor(50000, Sensor.TYPE_ACCELEROMETER, Sensor.REPORTING_MODE_CONTINUOUS); + PlatformSensor sensor = + PlatformSensor.create(Sensor.TYPE_ACCELEROMETER, 3, mPlatformSensorProvider); + assertNotNull(sensor); + + sensor.startSensor(5); + sensor.stopSensor(); + + // Multiple start invocations. + sensor.startSensor(1); + sensor.startSensor(2); + sensor.startSensor(3); + // Same frequency, should not restart sensor + sensor.startSensor(3); + + // Started polling with 5, 1, 2 and 3 Hz frequency. + verify(mPlatformSensorProvider, times(4)).getHandler(); + verify(mPlatformSensorProvider, times(4)).sensorStarted(sensor); + verify(mSensorManager, times(4)) + .registerListener(any(SensorEventListener.class), any(Sensor.class), anyInt(), + any(Handler.class)); + + sensor.stopSensor(); + sensor.stopSensor(); + verify(mPlatformSensorProvider, times(3)).sensorStopped(sensor); + verify(mSensorManager, times(4)) + .unregisterListener(any(SensorEventListener.class), any(Sensor.class)); + } + + /** + * Test that PlatformSensorProvider is notified when PlatformSensor starts and in case of + * failure, tells PlatformSensorProvider that the sensor is stopped, so that polling thread + * can be stopped. + */ + @Test + @Feature({"PlatformSensor"}) + public void testSensorStartFails() { + addMockSensor(50000, Sensor.TYPE_ACCELEROMETER, Sensor.REPORTING_MODE_CONTINUOUS); + PlatformSensor sensor = + PlatformSensor.create(Sensor.TYPE_ACCELEROMETER, 3, mPlatformSensorProvider); + assertNotNull(sensor); + + doReturn(false) + .when(mSensorManager) + .registerListener(any(SensorEventListener.class), any(Sensor.class), anyInt(), + any(Handler.class)); + + sensor.startSensor(5); + verify(mPlatformSensorProvider, times(1)).sensorStarted(sensor); + verify(mPlatformSensorProvider, times(1)).sensorStopped(sensor); + verify(mPlatformSensorProvider, times(1)).getHandler(); + } + + /** + * Same as the above except instead of a clean failure an exception is thrown. + */ + @Test + @Feature({"PlatformSensor"}) + public void testSensorStartFailsWithException() { + addMockSensor(50000, Sensor.TYPE_ACCELEROMETER, Sensor.REPORTING_MODE_CONTINUOUS); + PlatformSensor sensor = + PlatformSensor.create(Sensor.TYPE_ACCELEROMETER, 3, mPlatformSensorProvider); + assertNotNull(sensor); + + when(mSensorManager.registerListener(any(SensorEventListener.class), any(Sensor.class), + anyInt(), any(Handler.class))) + .thenThrow(RuntimeException.class); + + sensor.startSensor(5); + verify(mPlatformSensorProvider, times(1)).sensorStarted(sensor); + verify(mPlatformSensorProvider, times(1)).sensorStopped(sensor); + verify(mPlatformSensorProvider, times(1)).getHandler(); + } + + /** + * Test that PlatformSensor correctly checks supported configuration. + */ + @Test + @Feature({"PlatformSensor"}) + public void testSensorConfiguration() { + // 5Hz min delay + PlatformSensor sensor = createPlatformSensor(200000, Sensor.TYPE_ACCELEROMETER, + SensorType.ACCELEROMETER, Sensor.REPORTING_MODE_CONTINUOUS); + assertTrue(sensor.checkSensorConfiguration(5)); + assertFalse(sensor.checkSensorConfiguration(6)); + } + + /** + * Test that PlatformSensor correctly returns its reporting mode. + */ + @Test + @Feature({"PlatformSensor"}) + public void testSensorOnChangeReportingMode() { + PlatformSensor sensor = createPlatformSensor(50000, Sensor.TYPE_LIGHT, + SensorType.AMBIENT_LIGHT, Sensor.REPORTING_MODE_ON_CHANGE); + assertEquals(ReportingMode.ON_CHANGE, sensor.getReportingMode()); + } + + /** + * Test that PlatformSensor correctly returns its maximum supported frequency. + */ + @Test + @Feature({"PlatformSensor"}) + public void testSensorMaximumSupportedFrequency() { + PlatformSensor sensor = createPlatformSensor(50000, Sensor.TYPE_LIGHT, + SensorType.AMBIENT_LIGHT, Sensor.REPORTING_MODE_ON_CHANGE); + assertEquals(20, sensor.getMaximumSupportedFrequency(), 0.001); + + sensor = createPlatformSensor( + 0, Sensor.TYPE_LIGHT, SensorType.AMBIENT_LIGHT, Sensor.REPORTING_MODE_ON_CHANGE); + assertEquals( + sensor.getDefaultConfiguration(), sensor.getMaximumSupportedFrequency(), 0.001); + } + + /** + * Test that shared buffer is correctly populated from SensorEvent. + */ + @Test + @Feature({"PlatformSensor"}) + public void testSensorReadingFromEvent() { + TestPlatformSensor sensor = createTestPlatformSensor( + 50000, Sensor.TYPE_LIGHT, 1, Sensor.REPORTING_MODE_ON_CHANGE); + initPlatformSensor(sensor); + TestPlatformSensor spySensor = spy(sensor); + SensorEvent event = createFakeEvent(1); + assertNotNull(event); + spySensor.onSensorChanged(event); + + double timestamp = PLATFORM_SENSOR_TIMESTAMP * SECONDS_IN_NANOSECOND; + + verify(spySensor, times(1)) + .updateSensorReading(timestamp, getFakeReadingValue(1), 0.0, 0.0, 0.0); + } + + /** + * Test that shared buffer is correctly populated from SensorEvent for sensors with more + * than one value. + */ + @Test + @Feature({"PlatformSensor"}) + public void testSensorReadingFromEventMoreValues() { + TestPlatformSensor sensor = createTestPlatformSensor( + 50000, Sensor.TYPE_ROTATION_VECTOR, 4, Sensor.REPORTING_MODE_ON_CHANGE); + initPlatformSensor(sensor); + TestPlatformSensor spySensor = spy(sensor); + SensorEvent event = createFakeEvent(4); + assertNotNull(event); + spySensor.onSensorChanged(event); + + double timestamp = PLATFORM_SENSOR_TIMESTAMP * SECONDS_IN_NANOSECOND; + + verify(spySensor, times(1)) + .updateSensorReading(timestamp, getFakeReadingValue(1), getFakeReadingValue(2), + getFakeReadingValue(3), getFakeReadingValue(4)); + } + + /** + * Test that PlatformSensor notifies client when there is an error. + */ + @Test + @Feature({"PlatformSensor"}) + public void testSensorInvalidReadingSize() { + TestPlatformSensor sensor = createTestPlatformSensor( + 50000, Sensor.TYPE_ACCELEROMETER, 3, Sensor.REPORTING_MODE_CONTINUOUS); + initPlatformSensor(sensor); + TestPlatformSensor spySensor = spy(sensor); + // Accelerometer requires 3 reading values x,y and z, create fake event with 1 reading + // value. + SensorEvent event = createFakeEvent(1); + assertNotNull(event); + spySensor.onSensorChanged(event); + verify(spySensor, times(1)).sensorError(); + } + + /** + * Test that multiple PlatformSensor instances correctly register (unregister) to + * sensor manager and notify PlatformSensorProvider when they start (stop) polling for data. + */ + @Test + @Feature({"PlatformSensor"}) + public void testMultipleSensorTypeInstances() { + addMockSensor(200000, Sensor.TYPE_LIGHT, Sensor.REPORTING_MODE_ON_CHANGE); + addMockSensor(50000, Sensor.TYPE_ACCELEROMETER, Sensor.REPORTING_MODE_CONTINUOUS); + + TestPlatformSensorProvider spyProvider = spy(new TestPlatformSensorProvider(mContext)); + PlatformSensor lightSensor = PlatformSensor.create(Sensor.TYPE_LIGHT, 1, spyProvider); + assertNotNull(lightSensor); + + PlatformSensor accelerometerSensor = + PlatformSensor.create(Sensor.TYPE_ACCELEROMETER, 3, spyProvider); + assertNotNull(accelerometerSensor); + + lightSensor.startSensor(3); + accelerometerSensor.startSensor(10); + lightSensor.stopSensor(); + accelerometerSensor.stopSensor(); + + verify(spyProvider, times(2)).getHandler(); + verify(spyProvider, times(1)).sensorStarted(lightSensor); + verify(spyProvider, times(1)).sensorStarted(accelerometerSensor); + verify(spyProvider, times(1)).sensorStopped(lightSensor); + verify(spyProvider, times(1)).sensorStopped(accelerometerSensor); + verify(spyProvider, times(1)).startSensorThread(); + verify(spyProvider, times(1)).stopSensorThread(); + verify(mSensorManager, times(2)) + .registerListener(any(SensorEventListener.class), any(Sensor.class), anyInt(), + any(Handler.class)); + verify(mSensorManager, times(2)) + .unregisterListener(any(SensorEventListener.class), any(Sensor.class)); + } + + /** + * Creates fake event. The SensorEvent constructor is not accessible outside android.hardware + * package, therefore, java reflection is used to make constructor accessible to construct + * SensorEvent instance. + */ + private SensorEvent createFakeEvent(int readingValuesNum) { + try { + Constructor<SensorEvent> sensorEventConstructor = + SensorEvent.class.getDeclaredConstructor(Integer.TYPE); + sensorEventConstructor.setAccessible(true); + SensorEvent event = sensorEventConstructor.newInstance(readingValuesNum); + event.timestamp = PLATFORM_SENSOR_TIMESTAMP; + for (int i = 0; i < readingValuesNum; ++i) { + event.values[i] = getFakeReadingValue(i + 1); + } + return event; + } catch (InvocationTargetException | NoSuchMethodException | InstantiationException + | IllegalAccessException e) { + return null; + } + } + + private void initPlatformSensor(PlatformSensor sensor) { + sensor.initPlatformSensorAndroid(PLATFORM_SENSOR_ANDROID); + } + + private void addMockSensor(long minDelayUsec, int sensorType, int reportingMode) { + List<Sensor> mockSensorList = new ArrayList<Sensor>(); + mockSensorList.add(createMockSensor(minDelayUsec, sensorType, reportingMode)); + mMockSensors.put(sensorType, mockSensorList); + } + + private Sensor createMockSensor(long minDelayUsec, int sensorType, int reportingMode) { + Sensor mockSensor = mock(Sensor.class); + doReturn((int) minDelayUsec).when(mockSensor).getMinDelay(); + doReturn(reportingMode).when(mockSensor).getReportingMode(); + doReturn(sensorType).when(mockSensor).getType(); + return mockSensor; + } + + private List<Sensor> getMockSensors(int sensorType) { + if (mMockSensors.indexOfKey(sensorType) >= 0) { + return mMockSensors.get(sensorType); + } + return new ArrayList<Sensor>(); + } + + private PlatformSensor createPlatformSensor( + long minDelayUsec, int androidSensorType, int mojoSensorType, int reportingMode) { + addMockSensor(minDelayUsec, androidSensorType, reportingMode); + PlatformSensorProvider provider = PlatformSensorProvider.createForTest(mContext); + return provider.createSensor(mojoSensorType); + } + + private TestPlatformSensor createTestPlatformSensor( + long minDelayUsec, int androidSensorType, int readingCount, int reportingMode) { + return new TestPlatformSensor( + createMockSensor(minDelayUsec, androidSensorType, reportingMode), readingCount, + mPlatformSensorProvider); + } + + private float getFakeReadingValue(int valueNum) { + return (float) (valueNum + SECONDS_IN_NANOSECOND); + } +} diff --git a/chromium/services/device/generic_sensor/fake_platform_sensor_and_provider.cc b/chromium/services/device/generic_sensor/fake_platform_sensor_and_provider.cc index 0d53e3bc906..a9badf7ad8f 100644 --- a/chromium/services/device/generic_sensor/fake_platform_sensor_and_provider.cc +++ b/chromium/services/device/generic_sensor/fake_platform_sensor_and_provider.cc @@ -80,7 +80,8 @@ void FakePlatformSensorProvider::CreateSensorInternal( mojom::SensorType type, SensorReadingSharedBuffer* reading_buffer, const CreateSensorCallback& callback) { - DCHECK(type >= mojom::SensorType::FIRST && type <= mojom::SensorType::LAST); + DCHECK(type >= mojom::SensorType::kMinValue && + type <= mojom::SensorType::kMaxValue); auto sensor = base::MakeRefCounted<FakePlatformSensor>(type, reading_buffer, this); DoCreateSensorInternal(type, std::move(sensor), callback); diff --git a/chromium/services/device/generic_sensor/linux/sensor_device_manager.cc b/chromium/services/device/generic_sensor/linux/sensor_device_manager.cc index 526b4b341eb..b13c8a830e0 100644 --- a/chromium/services/device/generic_sensor/linux/sensor_device_manager.cc +++ b/chromium/services/device/generic_sensor/linux/sensor_device_manager.cc @@ -81,9 +81,9 @@ void SensorDeviceManager::OnDeviceAdded(udev_device* dev) { if (device_node.empty()) return; - const uint32_t first = static_cast<uint32_t>(mojom::SensorType::FIRST); - const uint32_t last = static_cast<uint32_t>(mojom::SensorType::LAST); - for (uint32_t i = first; i < last; ++i) { + const uint32_t first = static_cast<uint32_t>(mojom::SensorType::kMinValue); + const uint32_t last = static_cast<uint32_t>(mojom::SensorType::kMaxValue); + for (uint32_t i = first; i <= last; ++i) { SensorPathsLinux data; mojom::SensorType type = static_cast<mojom::SensorType>(i); if (!InitSensorData(type, &data)) diff --git a/chromium/services/device/generic_sensor/platform_sensor_and_provider_unittest_win.cc b/chromium/services/device/generic_sensor/platform_sensor_and_provider_unittest_win.cc index e9ff66361f6..2554eb4eb35 100644 --- a/chromium/services/device/generic_sensor/platform_sensor_and_provider_unittest_win.cc +++ b/chromium/services/device/generic_sensor/platform_sensor_and_provider_unittest_win.cc @@ -9,6 +9,7 @@ #include "base/bind.h" #include "base/message_loop/message_loop.h" #include "base/run_loop.h" +#include "base/test/scoped_task_environment.h" #include "base/win/iunknown_impl.h" #include "base/win/propvarutil.h" #include "base/win/scoped_propvariant.h" @@ -189,7 +190,12 @@ class MockISensorDataReport : public MockCOMInterface<ISensorDataReport> { // data in OnDataUpdated event. class PlatformSensorAndProviderTestWin : public ::testing::Test { public: + PlatformSensorAndProviderTestWin() + : scoped_task_environment_( + base::test::ScopedTaskEnvironment::MainThreadType::IO) {} + void SetUp() override { + EXPECT_EQ(S_OK, CoInitialize(nullptr)); sensor_ = new NiceMock<MockISensor>(); sensor_collection_ = new NiceMock<MockISensorCollection>(); sensor_manager_ = new NiceMock<MockISensorManager>(); @@ -244,7 +250,7 @@ class PlatformSensorAndProviderTestWin : public ::testing::Test { void SetUnsupportedSensor(REFSENSOR_TYPE_ID sensor) { EXPECT_CALL(*sensor_manager_, GetSensorsByType(sensor, _)) - .WillOnce(Invoke( + .WillRepeatedly(Invoke( [this](REFSENSOR_TYPE_ID type, ISensorCollection** collection) { return HRESULT_FROM_WIN32(ERROR_NOT_FOUND); })); @@ -288,7 +294,7 @@ class PlatformSensorAndProviderTestWin : public ::testing::Test { events->AddRef(); sensor_events_.Attach(events); if (this->run_loop_) { - message_loop_.task_runner()->PostTask( + scoped_task_environment_.GetMainThreadTaskRunner()->PostTask( FROM_HERE, base::Bind(&PlatformSensorAndProviderTestWin::QuitInnerLoop, base::Unretained(this))); @@ -302,7 +308,7 @@ class PlatformSensorAndProviderTestWin : public ::testing::Test { .WillByDefault(Invoke([this](ISensorEvents* events) { sensor_events_.Reset(); if (this->run_loop_) { - message_loop_.task_runner()->PostTask( + scoped_task_environment_.GetMainThreadTaskRunner()->PostTask( FROM_HERE, base::Bind(&PlatformSensorAndProviderTestWin::QuitInnerLoop, base::Unretained(this))); @@ -382,11 +388,11 @@ class PlatformSensorAndProviderTestWin : public ::testing::Test { sensor_events_->OnDataUpdated(sensor_.get(), data_report.Get()); } + base::test::ScopedTaskEnvironment scoped_task_environment_; scoped_refptr<MockISensorManager> sensor_manager_; scoped_refptr<MockISensorCollection> sensor_collection_; scoped_refptr<MockISensor> sensor_; Microsoft::WRL::ComPtr<ISensorEvents> sensor_events_; - base::MessageLoop message_loop_; scoped_refptr<PlatformSensor> platform_sensor_; // Inner run loop used to wait for async sensor creation callback. std::unique_ptr<base::RunLoop> run_loop_; diff --git a/chromium/services/device/generic_sensor/platform_sensor_provider_base.cc b/chromium/services/device/generic_sensor/platform_sensor_provider_base.cc index 011e6e70810..f5cb25cf916 100644 --- a/chromium/services/device/generic_sensor/platform_sensor_provider_base.cc +++ b/chromium/services/device/generic_sensor/platform_sensor_provider_base.cc @@ -15,7 +15,8 @@ namespace { const uint64_t kReadingBufferSize = sizeof(SensorReadingSharedBuffer); const uint64_t kSharedBufferSizeInBytes = - kReadingBufferSize * static_cast<uint64_t>(mojom::SensorType::LAST); + kReadingBufferSize * + (static_cast<uint64_t>(mojom::SensorType::kMaxValue) + 1); } // namespace diff --git a/chromium/services/device/generic_sensor/platform_sensor_provider_win.cc b/chromium/services/device/generic_sensor/platform_sensor_provider_win.cc index 346ba7d417a..6aab3386fdc 100644 --- a/chromium/services/device/generic_sensor/platform_sensor_provider_win.cc +++ b/chromium/services/device/generic_sensor/platform_sensor_provider_win.cc @@ -10,6 +10,8 @@ #include <iomanip> #include "base/memory/singleton.h" +#include "base/task/post_task.h" +#include "base/task/task_traits.h" #include "base/task_runner_util.h" #include "base/threading/thread.h" #include "services/device/generic_sensor/linear_acceleration_fusion_algorithm_using_accelerometer.h" @@ -19,44 +21,6 @@ namespace device { -class PlatformSensorProviderWin::SensorThread final : public base::Thread { - public: - SensorThread() : base::Thread("Sensor thread") { init_com_with_mta(true); } - - void SetSensorManagerForTesting( - Microsoft::WRL::ComPtr<ISensorManager> sensor_manager) { - sensor_manager_ = sensor_manager; - } - - const Microsoft::WRL::ComPtr<ISensorManager>& sensor_manager() const { - return sensor_manager_; - } - - protected: - void Init() override { - if (sensor_manager_) - return; - HRESULT hr = ::CoCreateInstance(CLSID_SensorManager, nullptr, CLSCTX_ALL, - IID_PPV_ARGS(&sensor_manager_)); - if (FAILED(hr)) { - // Only log this error the first time. - static bool logged_failure = false; - if (!logged_failure) { - LOG(ERROR) << "Unable to create instance of SensorManager: " - << _com_error(hr).ErrorMessage() << " (0x" << std::hex - << std::uppercase << std::setfill('0') << std::setw(8) << hr - << ")"; - logged_failure = true; - } - } - } - - void CleanUp() override { sensor_manager_.Reset(); } - - private: - Microsoft::WRL::ComPtr<ISensorManager> sensor_manager_; -}; - // static PlatformSensorProviderWin* PlatformSensorProviderWin::GetInstance() { return base::Singleton< @@ -66,11 +30,13 @@ PlatformSensorProviderWin* PlatformSensorProviderWin::GetInstance() { void PlatformSensorProviderWin::SetSensorManagerForTesting( Microsoft::WRL::ComPtr<ISensorManager> sensor_manager) { - CreateSensorThread(); - sensor_thread_->SetSensorManagerForTesting(sensor_manager); + sensor_manager_ = sensor_manager; } -PlatformSensorProviderWin::PlatformSensorProviderWin() = default; +PlatformSensorProviderWin::PlatformSensorProviderWin() + : com_sta_task_runner_(base::CreateCOMSTATaskRunnerWithTraits( + base::TaskPriority::USER_VISIBLE)) {} + PlatformSensorProviderWin::~PlatformSensorProviderWin() = default; void PlatformSensorProviderWin::CreateSensorInternal( @@ -78,7 +44,43 @@ void PlatformSensorProviderWin::CreateSensorInternal( SensorReadingSharedBuffer* reading_buffer, const CreateSensorCallback& callback) { DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); - if (!StartSensorThread()) { + if (sensor_manager_) { + OnInitSensorManager(type, reading_buffer, callback); + } else { + com_sta_task_runner_->PostTaskAndReply( + FROM_HERE, + base::Bind(&PlatformSensorProviderWin::InitSensorManager, + base::Unretained(this)), + base::Bind(&PlatformSensorProviderWin::OnInitSensorManager, + base::Unretained(this), type, reading_buffer, callback)); + } +} + +void PlatformSensorProviderWin::InitSensorManager() { + DCHECK(com_sta_task_runner_->RunsTasksInCurrentSequence()); + + HRESULT hr = ::CoCreateInstance(CLSID_SensorManager, nullptr, CLSCTX_ALL, + IID_PPV_ARGS(&sensor_manager_)); + if (FAILED(hr)) { + // Only log this error the first time. + static bool logged_failure = false; + if (!logged_failure) { + LOG(ERROR) << "Unable to create instance of SensorManager: " + << _com_error(hr).ErrorMessage() << " (0x" << std::hex + << std::uppercase << std::setfill('0') << std::setw(8) << hr + << ")"; + logged_failure = true; + } + } +} + +void PlatformSensorProviderWin::OnInitSensorManager( + mojom::SensorType type, + SensorReadingSharedBuffer* reading_buffer, + const CreateSensorCallback& callback) { + DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); + + if (!sensor_manager_) { callback.Run(nullptr); return; } @@ -99,7 +101,7 @@ void PlatformSensorProviderWin::CreateSensorInternal( // Try to create low-level sensors by default. default: { base::PostTaskAndReplyWithResult( - sensor_thread_->task_runner().get(), FROM_HERE, + com_sta_task_runner_.get(), FROM_HERE, base::Bind(&PlatformSensorProviderWin::CreateSensorReader, base::Unretained(this), type), base::Bind(&PlatformSensorProviderWin::SensorReaderCreated, @@ -109,27 +111,6 @@ void PlatformSensorProviderWin::CreateSensorInternal( } } -void PlatformSensorProviderWin::FreeResources() { - StopSensorThread(); -} - -void PlatformSensorProviderWin::CreateSensorThread() { - if (!sensor_thread_) - sensor_thread_ = std::make_unique<SensorThread>(); -} - -bool PlatformSensorProviderWin::StartSensorThread() { - CreateSensorThread(); - if (!sensor_thread_->IsRunning()) - return sensor_thread_->Start(); - return true; -} - -void PlatformSensorProviderWin::StopSensorThread() { - if (sensor_thread_ && sensor_thread_->IsRunning()) - sensor_thread_->Stop(); -} - void PlatformSensorProviderWin::SensorReaderCreated( mojom::SensorType type, SensorReadingSharedBuffer* reading_buffer, @@ -155,19 +136,18 @@ void PlatformSensorProviderWin::SensorReaderCreated( } } - scoped_refptr<PlatformSensor> sensor = new PlatformSensorWin( - type, reading_buffer, this, sensor_thread_->task_runner(), - std::move(sensor_reader)); + scoped_refptr<PlatformSensor> sensor = + new PlatformSensorWin(type, reading_buffer, this, com_sta_task_runner_, + std::move(sensor_reader)); callback.Run(sensor); } std::unique_ptr<PlatformSensorReaderWin> PlatformSensorProviderWin::CreateSensorReader(mojom::SensorType type) { - DCHECK(sensor_thread_->task_runner()->BelongsToCurrentThread()); - if (!sensor_thread_->sensor_manager()) + DCHECK(com_sta_task_runner_->RunsTasksInCurrentSequence()); + if (!sensor_manager_) return nullptr; - return PlatformSensorReaderWin::Create(type, - sensor_thread_->sensor_manager()); + return PlatformSensorReaderWin::Create(type, sensor_manager_); } } // namespace device diff --git a/chromium/services/device/generic_sensor/platform_sensor_provider_win.h b/chromium/services/device/generic_sensor/platform_sensor_provider_win.h index 4b7e1bd2344..9958ceba547 100644 --- a/chromium/services/device/generic_sensor/platform_sensor_provider_win.h +++ b/chromium/services/device/generic_sensor/platform_sensor_provider_win.h @@ -37,7 +37,6 @@ class PlatformSensorProviderWin final : public PlatformSensorProvider { ~PlatformSensorProviderWin() override; // PlatformSensorProvider interface implementation. - void FreeResources() override; void CreateSensorInternal(mojom::SensorType type, SensorReadingSharedBuffer* reading_buffer, const CreateSensorCallback& callback) override; @@ -45,13 +44,12 @@ class PlatformSensorProviderWin final : public PlatformSensorProvider { private: friend struct base::DefaultSingletonTraits<PlatformSensorProviderWin>; - class SensorThread; - PlatformSensorProviderWin(); - void CreateSensorThread(); - bool StartSensorThread(); - void StopSensorThread(); + void InitSensorManager(); + void OnInitSensorManager(mojom::SensorType type, + SensorReadingSharedBuffer* reading_buffer, + const CreateSensorCallback& callback); std::unique_ptr<PlatformSensorReaderWin> CreateSensorReader( mojom::SensorType type); void SensorReaderCreated( @@ -60,7 +58,8 @@ class PlatformSensorProviderWin final : public PlatformSensorProvider { const CreateSensorCallback& callback, std::unique_ptr<PlatformSensorReaderWin> sensor_reader); - std::unique_ptr<SensorThread> sensor_thread_; + scoped_refptr<base::SingleThreadTaskRunner> com_sta_task_runner_; + Microsoft::WRL::ComPtr<ISensorManager> sensor_manager_; DISALLOW_COPY_AND_ASSIGN(PlatformSensorProviderWin); }; diff --git a/chromium/services/device/geolocation/BUILD.gn b/chromium/services/device/geolocation/BUILD.gn index 0a139b7b01c..fb39939f7fd 100644 --- a/chromium/services/device/geolocation/BUILD.gn +++ b/chromium/services/device/geolocation/BUILD.gn @@ -141,6 +141,7 @@ if (is_android) { "$google_play_services_package:google_play_services_basement_java", "$google_play_services_package:google_play_services_location_java", "//base:base_java", + "//components/location/android:location_java", "//services/device/public/java:geolocation_java", ] } diff --git a/chromium/services/device/geolocation/DEPS b/chromium/services/device/geolocation/DEPS index 0687f47a704..5bcd2371210 100644 --- a/chromium/services/device/geolocation/DEPS +++ b/chromium/services/device/geolocation/DEPS @@ -1,5 +1,6 @@ include_rules = [ "+chromeos", + "+components/location", "+dbus", "+jni", "+net/base", diff --git a/chromium/services/device/geolocation/android/java/src/org/chromium/device/geolocation/LocationProviderAdapter.java b/chromium/services/device/geolocation/android/java/src/org/chromium/device/geolocation/LocationProviderAdapter.java new file mode 100644 index 00000000000..442ec6412a9 --- /dev/null +++ b/chromium/services/device/geolocation/android/java/src/org/chromium/device/geolocation/LocationProviderAdapter.java @@ -0,0 +1,95 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.device.geolocation; + +import android.location.Location; + +import org.chromium.base.Log; +import org.chromium.base.ThreadUtils; +import org.chromium.base.VisibleForTesting; +import org.chromium.base.annotations.CalledByNative; + +import java.util.concurrent.FutureTask; + +/** + * Implements the Java side of LocationProviderAndroid. + * Delegates all real functionality to the implementation + * returned from LocationProviderFactory. + * See detailed documentation on + * content/browser/geolocation/location_api_adapter_android.h. + * Based on android.webkit.GeolocationService.java + */ +@VisibleForTesting +public class LocationProviderAdapter { + private static final String TAG = "cr_LocationProvider"; + + // Delegate handling the real work in the main thread. + private LocationProvider mImpl; + + private LocationProviderAdapter() { + mImpl = LocationProviderFactory.create(); + } + + @CalledByNative + public static LocationProviderAdapter create() { + return new LocationProviderAdapter(); + } + + /** + * Start listening for location updates until we're told to quit. May be called in any thread. + * @param enableHighAccuracy Whether or not to enable high accuracy location providers. + */ + @CalledByNative + public void start(final boolean enableHighAccuracy) { + FutureTask<Void> task = new FutureTask<Void>(new Runnable() { + @Override + public void run() { + mImpl.start(enableHighAccuracy); + } + }, null); + ThreadUtils.runOnUiThread(task); + } + + /** + * Stop listening for location updates. May be called in any thread. + */ + @CalledByNative + public void stop() { + FutureTask<Void> task = new FutureTask<Void>(new Runnable() { + @Override + public void run() { + mImpl.stop(); + } + }, null); + ThreadUtils.runOnUiThread(task); + } + + /** + * Returns true if we are currently listening for location updates, false if not. + * Must be called only in the UI thread. + */ + public boolean isRunning() { + assert ThreadUtils.runningOnUiThread(); + return mImpl.isRunning(); + } + + public static void onNewLocationAvailable(Location location) { + nativeNewLocationAvailable(location.getLatitude(), location.getLongitude(), + location.getTime() / 1000.0, location.hasAltitude(), location.getAltitude(), + location.hasAccuracy(), location.getAccuracy(), location.hasBearing(), + location.getBearing(), location.hasSpeed(), location.getSpeed()); + } + + public static void newErrorAvailable(String message) { + Log.e(TAG, "newErrorAvailable %s", message); + nativeNewErrorAvailable(message); + } + + // Native functions + private static native void nativeNewLocationAvailable(double latitude, double longitude, + double timeStamp, boolean hasAltitude, double altitude, boolean hasAccuracy, + double accuracy, boolean hasHeading, double heading, boolean hasSpeed, double speed); + private static native void nativeNewErrorAvailable(String message); +} diff --git a/chromium/services/device/geolocation/android/java/src/org/chromium/device/geolocation/LocationProviderAndroid.java b/chromium/services/device/geolocation/android/java/src/org/chromium/device/geolocation/LocationProviderAndroid.java new file mode 100644 index 00000000000..01fac25d842 --- /dev/null +++ b/chromium/services/device/geolocation/android/java/src/org/chromium/device/geolocation/LocationProviderAndroid.java @@ -0,0 +1,161 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.device.geolocation; + +import android.content.Context; +import android.location.Criteria; +import android.location.Location; +import android.location.LocationListener; +import android.location.LocationManager; +import android.os.Bundle; + +import org.chromium.base.ContextUtils; +import org.chromium.base.Log; +import org.chromium.base.ThreadUtils; +import org.chromium.base.VisibleForTesting; + +import java.util.List; + +/** + * This is a LocationProvider using Android APIs [1]. It is a separate class for clarity + * so that it can manage all processing completely on the UI thread. The container class + * ensures that the start/stop calls into this class are done on the UI thread. + * + * [1] https://developer.android.com/reference/android/location/package-summary.html + */ +public class LocationProviderAndroid implements LocationListener, LocationProvider { + private static final String TAG = "cr_LocationProvider"; + + private LocationManager mLocationManager; + private boolean mIsRunning; + + LocationProviderAndroid() {} + + @Override + public void start(boolean enableHighAccuracy) { + ThreadUtils.assertOnUiThread(); + unregisterFromLocationUpdates(); + registerForLocationUpdates(enableHighAccuracy); + } + + @Override + public void stop() { + ThreadUtils.assertOnUiThread(); + unregisterFromLocationUpdates(); + } + + @Override + public boolean isRunning() { + ThreadUtils.assertOnUiThread(); + return mIsRunning; + } + + @Override + public void onLocationChanged(Location location) { + // Callbacks from the system location service are queued to this thread, so it's + // possible that we receive callbacks after unregistering. At this point, the + // native object will no longer exist. + if (mIsRunning) { + LocationProviderAdapter.onNewLocationAvailable(location); + } + } + + @Override + public void onStatusChanged(String provider, int status, Bundle extras) {} + + @Override + public void onProviderEnabled(String provider) {} + + @Override + public void onProviderDisabled(String provider) {} + + @VisibleForTesting + public void setLocationManagerForTesting(LocationManager manager) { + mLocationManager = manager; + } + + private void createLocationManagerIfNeeded() { + if (mLocationManager != null) return; + mLocationManager = (LocationManager) ContextUtils.getApplicationContext().getSystemService( + Context.LOCATION_SERVICE); + if (mLocationManager == null) { + Log.e(TAG, "Could not get location manager."); + } + } + + /** + * Registers this object with the location service. + */ + private void registerForLocationUpdates(boolean enableHighAccuracy) { + createLocationManagerIfNeeded(); + if (usePassiveOneShotLocation()) return; + + assert !mIsRunning; + mIsRunning = true; + + // We're running on the main thread. The C++ side is responsible to + // bounce notifications to the Geolocation thread as they arrive in the mainLooper. + try { + Criteria criteria = new Criteria(); + if (enableHighAccuracy) criteria.setAccuracy(Criteria.ACCURACY_FINE); + mLocationManager.requestLocationUpdates( + 0, 0, criteria, this, ThreadUtils.getUiThreadLooper()); + } catch (SecurityException e) { + Log.e(TAG, + "Caught security exception while registering for location updates " + + "from the system. The application does not have sufficient " + + "geolocation permissions."); + unregisterFromLocationUpdates(); + // Propagate an error to JavaScript, this can happen in case of WebView + // when the embedding app does not have sufficient permissions. + LocationProviderAdapter.newErrorAvailable( + "application does not have sufficient geolocation permissions."); + } catch (IllegalArgumentException e) { + Log.e(TAG, "Caught IllegalArgumentException registering for location updates."); + unregisterFromLocationUpdates(); + assert false; + } + } + + /** + * Unregisters this object from the location service. + */ + private void unregisterFromLocationUpdates() { + if (!mIsRunning) return; + mIsRunning = false; + mLocationManager.removeUpdates(this); + } + + private boolean usePassiveOneShotLocation() { + if (!isOnlyPassiveLocationProviderEnabled()) { + return false; + } + + // Do not request a location update if the only available location provider is + // the passive one. Make use of the last known location and call + // onNewLocationAvailable directly. + final Location location = + mLocationManager.getLastKnownLocation(LocationManager.PASSIVE_PROVIDER); + if (location != null) { + ThreadUtils.runOnUiThread(new Runnable() { + @Override + public void run() { + LocationProviderAdapter.onNewLocationAvailable(location); + } + }); + } + return true; + } + + /* + * Checks if the passive location provider is the only provider available + * in the system. + */ + private boolean isOnlyPassiveLocationProviderEnabled() { + final List<String> providers = mLocationManager.getProviders(true); + return providers != null && providers.size() == 1 + && providers.get(0).equals(LocationManager.PASSIVE_PROVIDER); + } +} diff --git a/chromium/services/device/geolocation/android/java/src/org/chromium/device/geolocation/LocationProviderFactory.java b/chromium/services/device/geolocation/android/java/src/org/chromium/device/geolocation/LocationProviderFactory.java new file mode 100644 index 00000000000..c3d58656625 --- /dev/null +++ b/chromium/services/device/geolocation/android/java/src/org/chromium/device/geolocation/LocationProviderFactory.java @@ -0,0 +1,44 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.device.geolocation; + +import org.chromium.base.ContextUtils; +import org.chromium.base.VisibleForTesting; +import org.chromium.base.annotations.CalledByNative; +import org.chromium.base.annotations.JNINamespace; + +/** + * Factory to create a LocationProvider to allow us to inject a mock for tests. + */ +@JNINamespace("device") +public class LocationProviderFactory { + private static LocationProvider sProviderImpl; + private static boolean sUseGmsCoreLocationProvider; + + private LocationProviderFactory() {} + + @VisibleForTesting + public static void setLocationProviderImpl(LocationProvider provider) { + sProviderImpl = provider; + } + + @CalledByNative + public static void useGmsCoreLocationProvider() { + sUseGmsCoreLocationProvider = true; + } + + public static LocationProvider create() { + if (sProviderImpl != null) return sProviderImpl; + + if (sUseGmsCoreLocationProvider + && LocationProviderGmsCore.isGooglePlayServicesAvailable( + ContextUtils.getApplicationContext())) { + sProviderImpl = new LocationProviderGmsCore(ContextUtils.getApplicationContext()); + } else { + sProviderImpl = new LocationProviderAndroid(); + } + return sProviderImpl; + } +} diff --git a/chromium/services/device/geolocation/android/java/src/org/chromium/device/geolocation/LocationProviderGmsCore.java b/chromium/services/device/geolocation/android/java/src/org/chromium/device/geolocation/LocationProviderGmsCore.java new file mode 100644 index 00000000000..9f88da9d421 --- /dev/null +++ b/chromium/services/device/geolocation/android/java/src/org/chromium/device/geolocation/LocationProviderGmsCore.java @@ -0,0 +1,152 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.device.geolocation; + +import android.content.Context; +import android.location.Location; +import android.os.Bundle; + +import com.google.android.gms.common.ConnectionResult; +import com.google.android.gms.common.GoogleApiAvailability; +import com.google.android.gms.common.api.GoogleApiClient; +import com.google.android.gms.common.api.GoogleApiClient.ConnectionCallbacks; +import com.google.android.gms.common.api.GoogleApiClient.OnConnectionFailedListener; +import com.google.android.gms.location.FusedLocationProviderApi; +import com.google.android.gms.location.LocationListener; +import com.google.android.gms.location.LocationRequest; +import com.google.android.gms.location.LocationServices; + +import org.chromium.base.Log; +import org.chromium.base.ThreadUtils; +import org.chromium.components.location.LocationUtils; + +/** + * This is a LocationProvider using Google Play Services. + * + * https://developers.google.com/android/reference/com/google/android/gms/location/package-summary + */ +public class LocationProviderGmsCore implements ConnectionCallbacks, OnConnectionFailedListener, + LocationListener, LocationProvider { + private static final String TAG = "cr_LocationProvider"; + + // Values for the LocationRequest's setInterval for normal and high accuracy, respectively. + private static final long UPDATE_INTERVAL_MS = 1000; + private static final long UPDATE_INTERVAL_FAST_MS = 500; + + private final GoogleApiClient mGoogleApiClient; + private FusedLocationProviderApi mLocationProviderApi = LocationServices.FusedLocationApi; + + private boolean mEnablehighAccuracy; + private LocationRequest mLocationRequest; + + public static boolean isGooglePlayServicesAvailable(Context context) { + return GoogleApiAvailability.getInstance().isGooglePlayServicesAvailable(context) + == ConnectionResult.SUCCESS; + } + + LocationProviderGmsCore(Context context) { + Log.i(TAG, "Google Play Services"); + mGoogleApiClient = new GoogleApiClient.Builder(context) + .addApi(LocationServices.API) + .addConnectionCallbacks(this) + .addOnConnectionFailedListener(this) + .build(); + assert mGoogleApiClient != null; + } + + LocationProviderGmsCore(GoogleApiClient client, FusedLocationProviderApi locationApi) { + mGoogleApiClient = client; + mLocationProviderApi = locationApi; + } + + // ConnectionCallbacks implementation + @Override + public void onConnected(Bundle connectionHint) { + ThreadUtils.assertOnUiThread(); + + mLocationRequest = LocationRequest.create(); + if (mEnablehighAccuracy) { + // With enableHighAccuracy, request a faster update interval and configure the provider + // for high accuracy mode. + mLocationRequest.setPriority(LocationRequest.PRIORITY_HIGH_ACCURACY) + .setInterval(UPDATE_INTERVAL_FAST_MS); + } else { + // Use balanced mode by default. In this mode, the API will prefer the network provider + // but may use sensor data (for instance, GPS) if high accuracy is requested by another + // app. + // + // If location is configured for sensors-only then elevate the priority to ensure GPS + // and other sensors are used. + if (LocationUtils.getInstance().isSystemLocationSettingSensorsOnly()) { + mLocationRequest.setPriority(LocationRequest.PRIORITY_HIGH_ACCURACY); + } else { + mLocationRequest.setPriority(LocationRequest.PRIORITY_BALANCED_POWER_ACCURACY); + } + mLocationRequest.setInterval(UPDATE_INTERVAL_MS); + } + + final Location location = mLocationProviderApi.getLastLocation(mGoogleApiClient); + if (location != null) { + LocationProviderAdapter.onNewLocationAvailable(location); + } + + try { + // Request updates on UI Thread replicating LocationProviderAndroid's behaviour. + mLocationProviderApi.requestLocationUpdates( + mGoogleApiClient, mLocationRequest, this, ThreadUtils.getUiThreadLooper()); + } catch (IllegalStateException | SecurityException e) { + // IllegalStateException is thrown "If this method is executed in a thread that has not + // called Looper.prepare()". SecurityException is thrown if there is no permission, see + // https://crbug.com/731271. + Log.e(TAG, " mLocationProviderApi.requestLocationUpdates() " + e); + LocationProviderAdapter.newErrorAvailable( + "Failed to request location updates: " + e.toString()); + assert false; + } + } + + @Override + public void onConnectionSuspended(int cause) {} + + // OnConnectionFailedListener implementation + @Override + public void onConnectionFailed(ConnectionResult result) { + LocationProviderAdapter.newErrorAvailable( + "Failed to connect to Google Play Services: " + result.toString()); + } + + // LocationProvider implementation + @Override + public void start(boolean enableHighAccuracy) { + ThreadUtils.assertOnUiThread(); + if (mGoogleApiClient.isConnected()) mGoogleApiClient.disconnect(); + + mEnablehighAccuracy = enableHighAccuracy; + mGoogleApiClient.connect(); // Should return via onConnected(). + } + + @Override + public void stop() { + ThreadUtils.assertOnUiThread(); + if (!mGoogleApiClient.isConnected()) return; + + mLocationProviderApi.removeLocationUpdates(mGoogleApiClient, this); + + mGoogleApiClient.disconnect(); + } + + @Override + public boolean isRunning() { + assert ThreadUtils.runningOnUiThread(); + if (mGoogleApiClient == null) return false; + return mGoogleApiClient.isConnecting() || mGoogleApiClient.isConnected(); + } + + // LocationListener implementation + @Override + public void onLocationChanged(Location location) { + LocationProviderAdapter.onNewLocationAvailable(location); + } +} diff --git a/chromium/services/device/geolocation/android/junit/src/org/chromium/device/geolocation/LocationProviderTest.java b/chromium/services/device/geolocation/android/junit/src/org/chromium/device/geolocation/LocationProviderTest.java new file mode 100644 index 00000000000..035f83681e5 --- /dev/null +++ b/chromium/services/device/geolocation/android/junit/src/org/chromium/device/geolocation/LocationProviderTest.java @@ -0,0 +1,197 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.device.geolocation; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.doAnswer; + +import android.content.Context; +import android.location.LocationManager; + +import com.google.android.gms.common.api.GoogleApiClient; +import com.google.android.gms.location.FusedLocationProviderApi; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import org.robolectric.ParameterizedRobolectricTestRunner; +import org.robolectric.ParameterizedRobolectricTestRunner.Parameters; +import org.robolectric.Shadows; +import org.robolectric.annotation.Config; +import org.robolectric.shadow.api.Shadow; +import org.robolectric.shadows.ShadowLocationManager; +import org.robolectric.shadows.ShadowLog; // remove me ? + +import org.chromium.base.ThreadUtils; +import org.chromium.base.test.util.Feature; + +import java.util.Arrays; +import java.util.Collection; + +/** + * Test suite for Java Geolocation. + */ +@RunWith(ParameterizedRobolectricTestRunner.class) +@Config(sdk = 21, manifest = Config.NONE) +public class LocationProviderTest { + public static enum LocationProviderType { MOCK, ANDROID, GMS_CORE } + + @Parameters + public static Collection<Object[]> data() { + return Arrays.asList(new Object[][] {{LocationProviderType.MOCK}, + {LocationProviderType.ANDROID}, {LocationProviderType.GMS_CORE}}); + } + + @Mock + private Context mContext; + + // Member variables for LocationProviderType.GMS_CORE case. + @Mock + private GoogleApiClient mGoogleApiClient; + private boolean mGoogleApiClientIsConnected; + + // Member variables for LocationProviderType.ANDROID case. + private LocationManager mLocationManager; + private ShadowLocationManager mShadowLocationManager; + + private LocationProviderAdapter mLocationProviderAdapter; + + private final LocationProviderType mApi; + + public LocationProviderTest(LocationProviderType api) { + mApi = api; + } + + @Before + public void setUp() { + ShadowLog.stream = System.out; + MockitoAnnotations.initMocks(this); + + mContext = Mockito.mock(Context.class); + } + + /** + * Verify a normal start/stop call pair with the given LocationProvider. + */ + @Test + @Feature({"Location"}) + public void testStartStop() { + assertTrue("Should be on UI thread", ThreadUtils.runningOnUiThread()); + + setLocationProvider(); + + createLocationProviderAdapter(); + startLocationProviderAdapter(false); + stopLocationProviderAdapter(); + } + + /** + * Verify a start/upgrade/stop call sequencewith the given LocationProvider. + */ + @Test + @Feature({"Location"}) + public void testStartUpgradeStop() { + assertTrue("Should be on UI thread", ThreadUtils.runningOnUiThread()); + + setLocationProvider(); + + createLocationProviderAdapter(); + startLocationProviderAdapter(false); + startLocationProviderAdapter(true); + stopLocationProviderAdapter(); + } + + private void createLocationProviderAdapter() { + mLocationProviderAdapter = LocationProviderAdapter.create(); + assertNotNull("LocationProvider", mLocationProviderAdapter); + } + + private void setLocationProvider() { + if (mApi == LocationProviderType.MOCK) { + setLocationProviderMock(); + } else if (mApi == LocationProviderType.ANDROID) { + setLocationProviderAndroid(); + } else if (mApi == LocationProviderType.GMS_CORE) { + setLocationProviderGmsCore(); + } else { + assert false; + } + } + + private void setLocationProviderMock() { + LocationProviderFactory.setLocationProviderImpl(new MockLocationProvider()); + } + + private void setLocationProviderAndroid() { + LocationProviderAndroid locationProviderAndroid = new LocationProviderAndroid(); + + // Robolectric has a ShadowLocationManager class that mocks the behaviour of the real + // class very closely. Use it here. + mLocationManager = Shadow.newInstanceOf(LocationManager.class); + mShadowLocationManager = Shadows.shadowOf(mLocationManager); + locationProviderAndroid.setLocationManagerForTesting(mLocationManager); + LocationProviderFactory.setLocationProviderImpl(locationProviderAndroid); + } + + private void setLocationProviderGmsCore() { + // Robolectric has a ShadowGoogleApiClientBuilder class that mocks the behaviour of the real + // class very closely, but it's not available in our build + mGoogleApiClient = Mockito.mock(GoogleApiClient.class); + mGoogleApiClientIsConnected = false; + doAnswer(new Answer<Boolean>() { + @Override + public Boolean answer(InvocationOnMock invocation) { + return mGoogleApiClientIsConnected; + } + }) + .when(mGoogleApiClient) + .isConnected(); + + doAnswer(new Answer<Void>() { + @Override + public Void answer(InvocationOnMock invocation) { + mGoogleApiClientIsConnected = true; + return null; + } + }) + .when(mGoogleApiClient) + .connect(); + + doAnswer(new Answer<Void>() { + @Override + public Void answer(InvocationOnMock invocation) { + mGoogleApiClientIsConnected = false; + return null; + } + }) + .when(mGoogleApiClient) + .disconnect(); + + FusedLocationProviderApi fusedLocationProviderApi = + Mockito.mock(FusedLocationProviderApi.class); + + LocationProviderGmsCore locationProviderGmsCore = + new LocationProviderGmsCore(mGoogleApiClient, fusedLocationProviderApi); + + LocationProviderFactory.setLocationProviderImpl(locationProviderGmsCore); + } + + private void startLocationProviderAdapter(boolean highResolution) { + mLocationProviderAdapter.start(highResolution); + assertTrue("Should be running", mLocationProviderAdapter.isRunning()); + } + + private void stopLocationProviderAdapter() { + mLocationProviderAdapter.stop(); + assertFalse("Should have stopped", mLocationProviderAdapter.isRunning()); + } +} diff --git a/chromium/services/device/manifest.json b/chromium/services/device/manifest.json index 254f9ef5725..beac967a3b5 100644 --- a/chromium/services/device/manifest.json +++ b/chromium/services/device/manifest.json @@ -8,6 +8,7 @@ "service_manager:connector": { "provides": { "device:battery_monitor": [ "device.mojom.BatteryMonitor" ], + "device:bluetooth_system": [ "device.mojom.BluetoothSystemFactory" ], "device:fingerprint": [ "device.mojom.Fingerprint" ], "device:generic_sensor": [ "device.mojom.SensorProvider" ], "device:geolocation": [ "device.mojom.GeolocationContext" ], @@ -25,6 +26,7 @@ "device.mojom.SerialIoHandler" ], "device:time_zone_monitor": [ "device.mojom.TimeZoneMonitor" ], + "device:usb": [ "device.mojom.UsbDeviceManager" ], "device:vibration": [ "device.mojom.VibrationManager" ], "device:wake_lock": [ "device.mojom.WakeLockProvider" ] } diff --git a/chromium/services/device/nfc/android/java/src/org/chromium/device/nfc/InvalidNfcMessageException.java b/chromium/services/device/nfc/android/java/src/org/chromium/device/nfc/InvalidNfcMessageException.java new file mode 100644 index 00000000000..b1289122c35 --- /dev/null +++ b/chromium/services/device/nfc/android/java/src/org/chromium/device/nfc/InvalidNfcMessageException.java @@ -0,0 +1,10 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.device.nfc; + +/** + * Exception that raised when NfcMessage is found to be invalid during conversion to NdefMessage. + */ +public final class InvalidNfcMessageException extends Exception {}
\ No newline at end of file diff --git a/chromium/services/device/nfc/android/java/src/org/chromium/device/nfc/NfcImpl.java b/chromium/services/device/nfc/android/java/src/org/chromium/device/nfc/NfcImpl.java new file mode 100644 index 00000000000..90958063226 --- /dev/null +++ b/chromium/services/device/nfc/android/java/src/org/chromium/device/nfc/NfcImpl.java @@ -0,0 +1,698 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.device.nfc; + +import android.Manifest; +import android.annotation.TargetApi; +import android.app.Activity; +import android.content.Context; +import android.content.pm.PackageManager; +import android.nfc.FormatException; +import android.nfc.NdefMessage; +import android.nfc.NfcAdapter; +import android.nfc.NfcAdapter.ReaderCallback; +import android.nfc.NfcManager; +import android.nfc.Tag; +import android.nfc.TagLostException; +import android.os.Build; +import android.os.Handler; +import android.os.Process; +import android.util.SparseArray; + +import org.chromium.base.Callback; +import org.chromium.base.ContextUtils; +import org.chromium.base.Log; +import org.chromium.device.mojom.Nfc; +import org.chromium.device.mojom.NfcClient; +import org.chromium.device.mojom.NfcError; +import org.chromium.device.mojom.NfcErrorType; +import org.chromium.device.mojom.NfcMessage; +import org.chromium.device.mojom.NfcPushOptions; +import org.chromium.device.mojom.NfcPushTarget; +import org.chromium.device.mojom.NfcWatchMode; +import org.chromium.device.mojom.NfcWatchOptions; +import org.chromium.mojo.bindings.Callbacks; +import org.chromium.mojo.system.MojoException; + +import java.io.IOException; +import java.io.UnsupportedEncodingException; +import java.net.MalformedURLException; +import java.net.URL; +import java.util.ArrayList; +import java.util.List; + +/** Android implementation of the NFC mojo service defined in device/nfc/nfc.mojom. + */ +public class NfcImpl implements Nfc { + private static final String TAG = "NfcImpl"; + private static final String ANY_PATH = "/*"; + + private final int mHostId; + + private final NfcDelegate mDelegate; + + /** + * Used to get instance of NFC adapter, @see android.nfc.NfcManager + */ + private final NfcManager mNfcManager; + + /** + * NFC adapter. @see android.nfc.NfcAdapter + */ + private final NfcAdapter mNfcAdapter; + + /** + * Activity that is in foreground and is used to enable / disable NFC reader mode operations. + * Can be updated when activity associated with web page is changed. @see #setActivity + */ + private Activity mActivity; + + /** + * Flag that indicates whether NFC permission is granted. + */ + private final boolean mHasPermission; + + /** + * Implementation of android.nfc.NfcAdapter.ReaderCallback. @see ReaderCallbackHandler + */ + private ReaderCallbackHandler mReaderCallbackHandler; + + /** + * Object that contains data that was passed to method + * #push(NfcMessage message, NfcPushOptions options, PushResponse callback) + * @see PendingPushOperation + */ + private PendingPushOperation mPendingPushOperation; + + /** + * Utility that provides I/O operations for a Tag. Created on demand when Tag is found. + * @see NfcTagHandler + */ + private NfcTagHandler mTagHandler; + + /** + * Client interface used to deliver NFCMessages for registered watch operations. + * @see #watch + */ + private NfcClient mClient; + + /** + * Watcher id that is incremented for each #watch call. + */ + private int mWatcherId; + + /** + * Map of watchId <-> NfcWatchOptions. All NfcWatchOptions are matched against tag that is in + * proximity, when match algorithm (@see #matchesWatchOptions) returns true, watcher with + * corresponding ID would be notified using NfcClient interface. + * @see NfcClient#onWatch(int[] id, NfcMessage message) + */ + private final SparseArray<NfcWatchOptions> mWatchers = new SparseArray<>(); + + /** + * Handler that runs delayed push timeout task. + */ + private final Handler mPushTimeoutHandler = new Handler(); + + /** + * Runnable responsible for cancelling push operation after specified timeout. + */ + private Runnable mPushTimeoutRunnable; + + public NfcImpl(int hostId, NfcDelegate delegate) { + mHostId = hostId; + mDelegate = delegate; + int permission = ContextUtils.getApplicationContext().checkPermission( + Manifest.permission.NFC, Process.myPid(), Process.myUid()); + mHasPermission = permission == PackageManager.PERMISSION_GRANTED; + Callback<Activity> onActivityUpdatedCallback = new Callback<Activity>() { + @Override + public void onResult(Activity activity) { + setActivity(activity); + } + }; + + mDelegate.trackActivityForHost(mHostId, onActivityUpdatedCallback); + + if (!mHasPermission || Build.VERSION.SDK_INT < Build.VERSION_CODES.KITKAT) { + Log.w(TAG, "NFC operations are not permitted."); + mNfcAdapter = null; + mNfcManager = null; + } else { + mNfcManager = (NfcManager) ContextUtils.getApplicationContext().getSystemService( + Context.NFC_SERVICE); + if (mNfcManager == null) { + Log.w(TAG, "NFC is not supported."); + mNfcAdapter = null; + } else { + mNfcAdapter = mNfcManager.getDefaultAdapter(); + } + } + } + + /** + * Sets Activity that is used to enable / disable NFC reader mode. When Activity is set, + * reader mode is disabled for old Activity and enabled for the new Activity. + */ + protected void setActivity(Activity activity) { + disableReaderMode(); + mActivity = activity; + enableReaderModeIfNeeded(); + } + + /** + * Sets NfcClient. NfcClient interface is used to notify mojo NFC service client when NFC + * device is in proximity and has NfcMessage that matches NfcWatchOptions criteria. + * @see Nfc#watch(NfcWatchOptions options, WatchResponse callback) + * + * @param client @see NfcClient + */ + @Override + public void setClient(NfcClient client) { + mClient = client; + } + + /** + * Pushes NfcMessage to Tag or Peer, whenever NFC device is in proximity. At the moment, only + * passive NFC devices are supported (NfcPushTarget.TAG). + * + * @param message that should be pushed to NFC device. + * @param options that contain information about timeout and target device type. + * @param callback that is used to notify when push operation is completed. + */ + @Override + public void push(NfcMessage message, NfcPushOptions options, PushResponse callback) { + if (!checkIfReady(callback)) return; + + if (!NfcMessageValidator.isValid(message)) { + callback.call(createError(NfcErrorType.INVALID_MESSAGE)); + return; + } + + // Check NfcPushOptions that are not supported by Android platform. + if (options.target == NfcPushTarget.PEER || options.timeout < 0 + || (options.timeout > Long.MAX_VALUE && !Double.isInfinite(options.timeout))) { + callback.call(createError(NfcErrorType.NOT_SUPPORTED)); + return; + } + + // If previous pending push operation is not completed, cancel it. + if (mPendingPushOperation != null) { + mPendingPushOperation.complete(createError(NfcErrorType.OPERATION_CANCELLED)); + cancelPushTimeoutTask(); + } + + mPendingPushOperation = new PendingPushOperation(message, options, callback); + + // Schedule push timeout task for new #mPendingPushOperation. + schedulePushTimeoutTask(options); + enableReaderModeIfNeeded(); + processPendingPushOperation(); + } + + /** + * Cancels pending push operation. + * At the moment, only passive NFC devices are supported (NfcPushTarget.TAG). + * + * @param target @see NfcPushTarget + * @param callback that is used to notify caller when cancelPush() is completed. + */ + @Override + public void cancelPush(int target, CancelPushResponse callback) { + if (!checkIfReady(callback)) return; + + if (target == NfcPushTarget.PEER) { + callback.call(createError(NfcErrorType.NOT_SUPPORTED)); + return; + } + + if (mPendingPushOperation == null) { + callback.call(createError(NfcErrorType.NOT_FOUND)); + } else { + completePendingPushOperation(createError(NfcErrorType.OPERATION_CANCELLED)); + callback.call(null); + } + } + + /** + * Watch method allows to set filtering criteria for NfcMessages that are found when NFC device + * is within proximity. On success, watch ID is returned to caller through WatchResponse + * callback. When NfcMessage that matches NfcWatchOptions is found, it is passed to NfcClient + * interface together with corresponding watch ID. + * @see NfcClient#onWatch(int[] id, NfcMessage message) + * + * @param options used to filter NfcMessages, @see NfcWatchOptions. + * @param callback that is used to notify caller when watch() is completed and return watch ID. + */ + @Override + public void watch(NfcWatchOptions options, WatchResponse callback) { + if (!checkIfReady(callback)) return; + int watcherId = ++mWatcherId; + mWatchers.put(watcherId, options); + callback.call(watcherId, null); + enableReaderModeIfNeeded(); + processPendingWatchOperations(); + } + + /** + * Cancels NFC watch operation. + * + * @param id of watch operation. + * @param callback that is used to notify caller when cancelWatch() is completed. + */ + @Override + public void cancelWatch(int id, CancelWatchResponse callback) { + if (!checkIfReady(callback)) return; + + if (mWatchers.indexOfKey(id) < 0) { + callback.call(createError(NfcErrorType.NOT_FOUND)); + } else { + mWatchers.remove(id); + callback.call(null); + disableReaderModeIfNeeded(); + } + } + + /** + * Cancels all NFC watch operations. + * + * @param callback that is used to notify caller when cancelAllWatches() is completed. + */ + @Override + public void cancelAllWatches(CancelAllWatchesResponse callback) { + if (!checkIfReady(callback)) return; + + if (mWatchers.size() == 0) { + callback.call(createError(NfcErrorType.NOT_FOUND)); + } else { + mWatchers.clear(); + callback.call(null); + disableReaderModeIfNeeded(); + } + } + + /** + * Suspends all pending operations. Should be called when web page visibility is lost. + */ + @Override + public void suspendNfcOperations() { + disableReaderMode(); + } + + /** + * Resumes all pending watch / push operations. Should be called when web page becomes visible. + */ + @Override + public void resumeNfcOperations() { + enableReaderModeIfNeeded(); + } + + @Override + public void close() { + mDelegate.stopTrackingActivityForHost(mHostId); + disableReaderMode(); + } + + @Override + public void onConnectionError(MojoException e) { + close(); + } + + /** + * Holds information about pending push operation. + */ + private static class PendingPushOperation { + public final NfcMessage nfcMessage; + public final NfcPushOptions nfcPushOptions; + private final PushResponse mPushResponseCallback; + + public PendingPushOperation( + NfcMessage message, NfcPushOptions options, PushResponse callback) { + nfcMessage = message; + nfcPushOptions = options; + mPushResponseCallback = callback; + } + + /** + * Completes pending push operation. + * + * @param error should be null when operation is completed successfully, otherwise, + * error object with corresponding NfcErrorType should be provided. + */ + public void complete(NfcError error) { + if (mPushResponseCallback != null) mPushResponseCallback.call(error); + } + } + + /** + * Helper method that creates NfcError object from NfcErrorType. + */ + private NfcError createError(int errorType) { + NfcError error = new NfcError(); + error.errorType = errorType; + return error; + } + + /** + * Checks if NFC funcionality can be used by the mojo service. If permission to use NFC is + * granted and hardware is enabled, returns null. + */ + private NfcError checkIfReady() { + if (!mHasPermission || mActivity == null) { + return createError(NfcErrorType.SECURITY); + } else if (mNfcManager == null || mNfcAdapter == null) { + return createError(NfcErrorType.NOT_SUPPORTED); + } else if (!mNfcAdapter.isEnabled()) { + return createError(NfcErrorType.DEVICE_DISABLED); + } + return null; + } + + /** + * Uses checkIfReady() method and if NFC cannot be used, calls mojo callback with NfcError. + * + * @param WatchResponse Callback that is provided to watch() method. + * @return boolean true if NFC functionality can be used, false otherwise. + */ + private boolean checkIfReady(WatchResponse callback) { + NfcError error = checkIfReady(); + if (error == null) return true; + + callback.call(0, error); + return false; + } + + /** + * Uses checkIfReady() method and if NFC cannot be used, calls mojo callback with NfcError. + * + * @param callback Generic callback that is provided to push(), cancelPush(), + * cancelWatch() and cancelAllWatches() methods. + * @return boolean true if NFC functionality can be used, false otherwise. + */ + private boolean checkIfReady(Callbacks.Callback1<NfcError> callback) { + NfcError error = checkIfReady(); + if (error == null) return true; + + callback.call(error); + return false; + } + + /** + * Implementation of android.nfc.NfcAdapter.ReaderCallback. Callback is called when NFC tag is + * discovered, Tag object is delegated to mojo service implementation method + * NfcImpl.onTagDiscovered(). + */ + @TargetApi(Build.VERSION_CODES.KITKAT) + private static class ReaderCallbackHandler implements ReaderCallback { + private final NfcImpl mNfcImpl; + + public ReaderCallbackHandler(NfcImpl impl) { + mNfcImpl = impl; + } + + @Override + public void onTagDiscovered(Tag tag) { + mNfcImpl.onTagDiscovered(tag); + } + } + + /** + * Enables reader mode, allowing NFC device to read / write NFC tags. + * @see android.nfc.NfcAdapter#enableReaderMode + */ + private void enableReaderModeIfNeeded() { + if (Build.VERSION.SDK_INT < Build.VERSION_CODES.KITKAT) return; + + if (mReaderCallbackHandler != null || mActivity == null || mNfcAdapter == null) return; + + // Do not enable reader mode, if there are no active push / watch operations. + if (mPendingPushOperation == null && mWatchers.size() == 0) return; + + mReaderCallbackHandler = new ReaderCallbackHandler(this); + mNfcAdapter.enableReaderMode(mActivity, mReaderCallbackHandler, + NfcAdapter.FLAG_READER_NFC_A | NfcAdapter.FLAG_READER_NFC_B + | NfcAdapter.FLAG_READER_NFC_F | NfcAdapter.FLAG_READER_NFC_V, + null); + } + + /** + * Disables reader mode. + * @see android.nfc.NfcAdapter#disableReaderMode + */ + @TargetApi(Build.VERSION_CODES.KITKAT) + private void disableReaderMode() { + if (Build.VERSION.SDK_INT < Build.VERSION_CODES.KITKAT) return; + + // There is no API that could query whether reader mode is enabled for adapter. + // If mReaderCallbackHandler is null, reader mode is not enabled. + if (mReaderCallbackHandler == null) return; + + mReaderCallbackHandler = null; + if (mActivity != null && mNfcAdapter != null && !mActivity.isDestroyed()) { + mNfcAdapter.disableReaderMode(mActivity); + } + } + + /** + * Checks if there are pending push / watch operations and disables readre mode + * whenever necessary. + */ + private void disableReaderModeIfNeeded() { + if (mPendingPushOperation == null && mWatchers.size() == 0) { + disableReaderMode(); + } + } + + /** + * Handles completion of pending push operation, cancels timeout task and completes push + * operation. On error, invalidates #mTagHandler. + */ + private void pendingPushOperationCompleted(NfcError error) { + completePendingPushOperation(error); + if (error != null) mTagHandler = null; + } + + /** + * Completes pending push operation and disables reader mode if needed. + */ + private void completePendingPushOperation(NfcError error) { + if (mPendingPushOperation == null) return; + + cancelPushTimeoutTask(); + mPendingPushOperation.complete(error); + mPendingPushOperation = null; + disableReaderModeIfNeeded(); + } + + /** + * Checks whether there is a #mPendingPushOperation and writes data to NFC tag. In case of + * exception calls pendingPushOperationCompleted() with appropriate error object. + */ + private void processPendingPushOperation() { + if (mTagHandler == null || mPendingPushOperation == null) return; + + if (mTagHandler.isTagOutOfRange()) { + mTagHandler = null; + return; + } + + try { + mTagHandler.connect(); + mTagHandler.write(NfcTypeConverter.toNdefMessage(mPendingPushOperation.nfcMessage)); + pendingPushOperationCompleted(null); + } catch (InvalidNfcMessageException e) { + Log.w(TAG, "Cannot write data to NFC tag. Invalid NfcMessage."); + pendingPushOperationCompleted(createError(NfcErrorType.INVALID_MESSAGE)); + } catch (TagLostException e) { + Log.w(TAG, "Cannot write data to NFC tag. Tag is lost."); + pendingPushOperationCompleted(createError(NfcErrorType.IO_ERROR)); + } catch (FormatException | IllegalStateException | IOException e) { + Log.w(TAG, "Cannot write data to NFC tag. IO_ERROR."); + pendingPushOperationCompleted(createError(NfcErrorType.IO_ERROR)); + } + } + + /** + * Reads NfcMessage from a tag and forwards message to matching method. + */ + private void processPendingWatchOperations() { + if (mTagHandler == null || mClient == null || mWatchers.size() == 0) return; + + // Skip reading if there is a pending push operation and ignoreRead flag is set. + if (mPendingPushOperation != null && mPendingPushOperation.nfcPushOptions.ignoreRead) { + return; + } + + if (mTagHandler.isTagOutOfRange()) { + mTagHandler = null; + return; + } + + NdefMessage message = null; + + try { + mTagHandler.connect(); + message = mTagHandler.read(); + if (message.getByteArrayLength() > NfcMessage.MAX_SIZE) { + Log.w(TAG, "Cannot read data from NFC tag. NfcMessage exceeds allowed size."); + return; + } + } catch (TagLostException e) { + Log.w(TAG, "Cannot read data from NFC tag. Tag is lost."); + } catch (FormatException | IllegalStateException | IOException e) { + Log.w(TAG, "Cannot read data from NFC tag. IO_ERROR."); + } + + if (message != null) notifyMatchingWatchers(message); + } + + /** + * Iterates through active watchers and if any of those match NfcWatchOptions criteria, + * delivers NfcMessage to the client. + */ + private void notifyMatchingWatchers(NdefMessage message) { + try { + NfcMessage nfcMessage = NfcTypeConverter.toNfcMessage(message); + List<Integer> watchIds = new ArrayList<Integer>(); + for (int i = 0; i < mWatchers.size(); i++) { + NfcWatchOptions options = mWatchers.valueAt(i); + if (matchesWatchOptions(nfcMessage, options)) watchIds.add(mWatchers.keyAt(i)); + } + + if (watchIds.size() != 0) { + int[] ids = new int[watchIds.size()]; + for (int i = 0; i < watchIds.size(); ++i) { + ids[i] = watchIds.get(i).intValue(); + } + mClient.onWatch(ids, nfcMessage); + } + } catch (UnsupportedEncodingException e) { + Log.w(TAG, "Cannot convert NdefMessage to NfcMessage."); + } + } + + /** + * Implements matching algorithm. + */ + private boolean matchesWatchOptions(NfcMessage message, NfcWatchOptions options) { + // Valid WebNFC message must have non-empty url. + if (options.mode == NfcWatchMode.WEBNFC_ONLY + && (message.url == null || message.url.isEmpty())) { + return false; + } + + // Filter by WebNfc watch Id. + if (!matchesWebNfcId(message.url, options.url)) return false; + + // Matches any record / media type. + if ((options.mediaType == null || options.mediaType.isEmpty()) + && options.recordFilter == null) { + return true; + } + + // Filter by mediaType and recordType + for (int i = 0; i < message.data.length; i++) { + boolean matchedMediaType; + boolean matchedRecordType; + + if (options.mediaType == null || options.mediaType.isEmpty()) { + // If media type for the watch options is empty, match all media types. + matchedMediaType = true; + } else { + matchedMediaType = options.mediaType.equals(message.data[i].mediaType); + } + + if (options.recordFilter == null) { + // If record type filter for the watch options is null, match all record types. + matchedRecordType = true; + } else { + matchedRecordType = options.recordFilter.recordType == message.data[i].recordType; + } + + if (matchedMediaType && matchedRecordType) return true; + } + + return false; + } + + /** + * WebNfc Id match algorithm. + * https://w3c.github.io/web-nfc/#url-pattern-match-algorithm + */ + private boolean matchesWebNfcId(String id, String pattern) { + if (id != null && !id.isEmpty() && pattern != null && !pattern.isEmpty()) { + try { + URL id_url = new URL(id); + URL pattern_url = new URL(pattern); + + if (!id_url.getProtocol().equals(pattern_url.getProtocol())) return false; + if (!id_url.getHost().endsWith("." + pattern_url.getHost()) + && !id_url.getHost().equals(pattern_url.getHost())) { + return false; + } + if (pattern_url.getPath().equals(ANY_PATH)) return true; + if (id_url.getPath().startsWith(pattern_url.getPath())) return true; + return false; + + } catch (MalformedURLException e) { + return false; + } + } + + return true; + } + + /** + * Called by ReaderCallbackHandler when NFC tag is in proximity. + */ + public void onTagDiscovered(Tag tag) { + processPendingOperations(NfcTagHandler.create(tag)); + } + + /** + * Processes pending operation when NFC tag is in proximity. + */ + protected void processPendingOperations(NfcTagHandler tagHandler) { + mTagHandler = tagHandler; + processPendingWatchOperations(); + processPendingPushOperation(); + if (mTagHandler != null && mTagHandler.isConnected()) { + try { + mTagHandler.close(); + } catch (IOException e) { + Log.w(TAG, "Cannot close NFC tag connection."); + } + } + } + + /** + * Schedules task that is executed after timeout and cancels pending push operation. + */ + private void schedulePushTimeoutTask(NfcPushOptions options) { + assert mPushTimeoutRunnable == null; + // Default timeout value. + if (Double.isInfinite(options.timeout)) return; + + // Create and schedule timeout. + mPushTimeoutRunnable = new Runnable() { + @Override + public void run() { + completePendingPushOperation(createError(NfcErrorType.TIMER_EXPIRED)); + } + }; + + mPushTimeoutHandler.postDelayed(mPushTimeoutRunnable, (long) options.timeout); + } + + /** + * Cancels push timeout task. + */ + void cancelPushTimeoutTask() { + if (mPushTimeoutRunnable == null) return; + + mPushTimeoutHandler.removeCallbacks(mPushTimeoutRunnable); + mPushTimeoutRunnable = null; + } +} diff --git a/chromium/services/device/nfc/android/java/src/org/chromium/device/nfc/NfcMessageValidator.java b/chromium/services/device/nfc/android/java/src/org/chromium/device/nfc/NfcMessageValidator.java new file mode 100644 index 00000000000..bb73d38c5ec --- /dev/null +++ b/chromium/services/device/nfc/android/java/src/org/chromium/device/nfc/NfcMessageValidator.java @@ -0,0 +1,41 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.device.nfc; + +import org.chromium.device.mojom.NfcMessage; +import org.chromium.device.mojom.NfcRecord; +import org.chromium.device.mojom.NfcRecordType; + +/** + * Utility class that provides validation of NfcMessage. + */ +public final class NfcMessageValidator { + /** + * Validates NfcMessage. + * + * @param message to be validated. + * @return true if message is valid, false otherwise. + */ + public static boolean isValid(NfcMessage message) { + if (message == null || message.data == null || message.data.length == 0) { + return false; + } + + for (int i = 0; i < message.data.length; ++i) { + if (!isValid(message.data[i])) return false; + } + return true; + } + + /** + * Checks that NfcRecord#data and NfcRecord#mediaType fields are valid. NfcRecord#data and + * NfcRecord#mediaType fields are omitted for the record with EMPTY type. + */ + private static boolean isValid(NfcRecord record) { + if (record == null) return false; + if (record.recordType == NfcRecordType.EMPTY) return true; + return record.data != null && record.mediaType != null && !record.mediaType.isEmpty(); + } +} diff --git a/chromium/services/device/nfc/android/java/src/org/chromium/device/nfc/NfcProviderImpl.java b/chromium/services/device/nfc/android/java/src/org/chromium/device/nfc/NfcProviderImpl.java new file mode 100644 index 00000000000..f7f9e3ec684 --- /dev/null +++ b/chromium/services/device/nfc/android/java/src/org/chromium/device/nfc/NfcProviderImpl.java @@ -0,0 +1,50 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.device.nfc; + +import org.chromium.device.mojom.Nfc; +import org.chromium.device.mojom.NfcProvider; +import org.chromium.mojo.bindings.InterfaceRequest; +import org.chromium.mojo.system.MojoException; +import org.chromium.services.service_manager.InterfaceFactory; + +/** + * Android implementation of the NfcProvider Mojo interface. + */ +public class NfcProviderImpl implements NfcProvider { + private static final String TAG = "NfcProviderImpl"; + private NfcDelegate mDelegate; + + public NfcProviderImpl(NfcDelegate delegate) { + mDelegate = delegate; + } + + @Override + public void close() {} + + @Override + public void onConnectionError(MojoException e) {} + + @Override + public void getNfcForHost(int hostId, InterfaceRequest<Nfc> request) { + Nfc.MANAGER.bind(new NfcImpl(hostId, mDelegate), request); + } + + /** + * A factory for implementations of the NfcProvider interface. + */ + public static class Factory implements InterfaceFactory<NfcProvider> { + private NfcDelegate mDelegate; + + public Factory(NfcDelegate delegate) { + mDelegate = delegate; + } + + @Override + public NfcProvider createImpl() { + return new NfcProviderImpl(mDelegate); + } + } +} diff --git a/chromium/services/device/nfc/android/java/src/org/chromium/device/nfc/NfcTagHandler.java b/chromium/services/device/nfc/android/java/src/org/chromium/device/nfc/NfcTagHandler.java new file mode 100644 index 00000000000..9e8b9c465a9 --- /dev/null +++ b/chromium/services/device/nfc/android/java/src/org/chromium/device/nfc/NfcTagHandler.java @@ -0,0 +1,157 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.device.nfc; + +import android.nfc.FormatException; +import android.nfc.NdefMessage; +import android.nfc.Tag; +import android.nfc.TagLostException; +import android.nfc.tech.Ndef; +import android.nfc.tech.NdefFormatable; +import android.nfc.tech.TagTechnology; + +import java.io.IOException; + +/** + * Utility class that provides I/O operations for NFC tags. + */ +public class NfcTagHandler { + private final TagTechnology mTech; + private final TagTechnologyHandler mTechHandler; + private boolean mWasConnected; + + /** + * Factory method that creates NfcTagHandler for a given NFC Tag. + * + * @param tag @see android.nfc.Tag + * @return NfcTagHandler or null when unsupported Tag is provided. + */ + public static NfcTagHandler create(Tag tag) { + if (tag == null) return null; + + Ndef ndef = Ndef.get(tag); + if (ndef != null) return new NfcTagHandler(ndef, new NdefHandler(ndef)); + + NdefFormatable formattable = NdefFormatable.get(tag); + if (formattable != null) { + return new NfcTagHandler(formattable, new NdefFormattableHandler(formattable)); + } + + return null; + } + + /** + * NdefFormatable and Ndef interfaces have different signatures for operating with NFC tags. + * This interface provides generic methods. + */ + private interface TagTechnologyHandler { + public void write(NdefMessage message) + throws IOException, TagLostException, FormatException, IllegalStateException; + public NdefMessage read() + throws IOException, TagLostException, FormatException, IllegalStateException; + } + + /** + * Implementation of TagTechnologyHandler that uses Ndef tag technology. + * @see android.nfc.tech.Ndef + */ + private static class NdefHandler implements TagTechnologyHandler { + private final Ndef mNdef; + + NdefHandler(Ndef ndef) { + mNdef = ndef; + } + + @Override + public void write(NdefMessage message) + throws IOException, TagLostException, FormatException, IllegalStateException { + mNdef.writeNdefMessage(message); + } + + @Override + public NdefMessage read() + throws IOException, TagLostException, FormatException, IllegalStateException { + return mNdef.getNdefMessage(); + } + } + + /** + * Implementation of TagTechnologyHandler that uses NdefFormatable tag technology. + * @see android.nfc.tech.NdefFormatable + */ + private static class NdefFormattableHandler implements TagTechnologyHandler { + private final NdefFormatable mNdefFormattable; + + NdefFormattableHandler(NdefFormatable ndefFormattable) { + mNdefFormattable = ndefFormattable; + } + + @Override + public void write(NdefMessage message) + throws IOException, TagLostException, FormatException, IllegalStateException { + mNdefFormattable.format(message); + } + + @Override + public NdefMessage read() throws FormatException { + return NfcTypeConverter.emptyNdefMessage(); + } + } + + protected NfcTagHandler(TagTechnology tech, TagTechnologyHandler handler) { + mTech = tech; + mTechHandler = handler; + } + + /** + * Connects to NFC tag. + */ + public void connect() throws IOException, TagLostException { + if (!mTech.isConnected()) { + mTech.connect(); + mWasConnected = true; + } + } + + /** + * Checks if NFC tag is connected. + */ + public boolean isConnected() { + return mTech.isConnected(); + } + + /** + * Closes connection. + */ + public void close() throws IOException { + mTech.close(); + } + + /** + * Writes NdefMessage to NFC tag. + */ + public void write(NdefMessage message) + throws IOException, TagLostException, FormatException, IllegalStateException { + mTechHandler.write(message); + } + + public NdefMessage read() + throws IOException, TagLostException, FormatException, IllegalStateException { + return mTechHandler.read(); + } + + /** + * If tag was previously connected and subsequent connection to the same tag fails, consider + * tag to be out of range. + */ + public boolean isTagOutOfRange() { + try { + connect(); + } catch (IOException e) { + return mWasConnected; + } + return false; + } +} diff --git a/chromium/services/device/nfc/android/java/src/org/chromium/device/nfc/NfcTypeConverter.java b/chromium/services/device/nfc/android/java/src/org/chromium/device/nfc/NfcTypeConverter.java new file mode 100644 index 00000000000..f1001d82fe2 --- /dev/null +++ b/chromium/services/device/nfc/android/java/src/org/chromium/device/nfc/NfcTypeConverter.java @@ -0,0 +1,230 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.device.nfc; + +import android.net.Uri; +import android.nfc.NdefMessage; +import android.nfc.NdefRecord; +import android.os.Build; + +import org.chromium.base.ApiCompatibilityUtils; +import org.chromium.base.Log; +import org.chromium.device.mojom.NfcMessage; +import org.chromium.device.mojom.NfcRecord; +import org.chromium.device.mojom.NfcRecordType; + +import java.io.UnsupportedEncodingException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * Utility class that provides convesion between Android NdefMessage + * and mojo NfcMessage data structures. + */ +public final class NfcTypeConverter { + private static final String TAG = "NfcTypeConverter"; + private static final String DOMAIN = "w3.org"; + private static final String TYPE = "webnfc"; + private static final String WEBNFC_URN = DOMAIN + ":" + TYPE; + private static final String TEXT_MIME = "text/plain"; + private static final String JSON_MIME = "application/json"; + private static final String CHARSET_UTF8 = ";charset=UTF-8"; + private static final String CHARSET_UTF16 = ";charset=UTF-16"; + + /** + * Converts mojo NfcMessage to android.nfc.NdefMessage + */ + public static NdefMessage toNdefMessage(NfcMessage message) throws InvalidNfcMessageException { + try { + List<NdefRecord> records = new ArrayList<NdefRecord>(); + for (int i = 0; i < message.data.length; ++i) { + records.add(toNdefRecord(message.data[i])); + } + records.add(NdefRecord.createExternal( + DOMAIN, TYPE, ApiCompatibilityUtils.getBytesUtf8(message.url))); + NdefRecord[] ndefRecords = new NdefRecord[records.size()]; + records.toArray(ndefRecords); + return new NdefMessage(ndefRecords); + } catch (UnsupportedEncodingException | InvalidNfcMessageException + | IllegalArgumentException e) { + throw new InvalidNfcMessageException(); + } + } + + /** + * Converts android.nfc.NdefMessage to mojo NfcMessage + */ + public static NfcMessage toNfcMessage(NdefMessage ndefMessage) + throws UnsupportedEncodingException { + NdefRecord[] ndefRecords = ndefMessage.getRecords(); + NfcMessage nfcMessage = new NfcMessage(); + List<NfcRecord> nfcRecords = new ArrayList<NfcRecord>(); + + for (int i = 0; i < ndefRecords.length; i++) { + if ((ndefRecords[i].getTnf() == NdefRecord.TNF_EXTERNAL_TYPE) + && (Arrays.equals(ndefRecords[i].getType(), + ApiCompatibilityUtils.getBytesUtf8(WEBNFC_URN)))) { + nfcMessage.url = new String(ndefRecords[i].getPayload(), "UTF-8"); + continue; + } + + NfcRecord nfcRecord = toNfcRecord(ndefRecords[i]); + if (nfcRecord != null) nfcRecords.add(nfcRecord); + } + + nfcMessage.data = new NfcRecord[nfcRecords.size()]; + nfcRecords.toArray(nfcMessage.data); + return nfcMessage; + } + + /** + * Returns charset of mojo NfcRecord. Only applicable for URL and TEXT records. + * If charset cannot be determined, UTF-8 charset is used by default. + */ + private static String getCharset(NfcRecord record) { + if (record.mediaType.endsWith(CHARSET_UTF8)) return "UTF-8"; + + // When 16bit WTF::String data is converted to bytearray, it is in LE byte order, without + // BOM. By default, Android interprets UTF-16 charset without BOM as UTF-16BE, thus, use + // UTF-16LE as encoding for text data. + + if (record.mediaType.endsWith(CHARSET_UTF16)) return "UTF-16LE"; + + Log.w(TAG, "Unknown charset, defaulting to UTF-8."); + return "UTF-8"; + } + + /** + * Converts mojo NfcRecord to android.nfc.NdefRecord + */ + private static NdefRecord toNdefRecord(NfcRecord record) throws InvalidNfcMessageException, + IllegalArgumentException, + UnsupportedEncodingException { + switch (record.recordType) { + case NfcRecordType.URL: + return NdefRecord.createUri(new String(record.data, getCharset(record))); + case NfcRecordType.TEXT: + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.LOLLIPOP) { + return NdefRecord.createTextRecord( + "en-US", new String(record.data, getCharset(record))); + } else { + return NdefRecord.createMime(TEXT_MIME, record.data); + } + case NfcRecordType.JSON: + case NfcRecordType.OPAQUE_RECORD: + return NdefRecord.createMime(record.mediaType, record.data); + case NfcRecordType.EMPTY: + return new NdefRecord(NdefRecord.TNF_EMPTY, null, null, null); + default: + throw new InvalidNfcMessageException(); + } + } + + /** + * Converts android.nfc.NdefRecord to mojo NfcRecord + */ + private static NfcRecord toNfcRecord(NdefRecord ndefRecord) + throws UnsupportedEncodingException { + switch (ndefRecord.getTnf()) { + case NdefRecord.TNF_EMPTY: + return createEmptyRecord(); + case NdefRecord.TNF_MIME_MEDIA: + return createMIMERecord( + new String(ndefRecord.getType(), "UTF-8"), ndefRecord.getPayload()); + case NdefRecord.TNF_ABSOLUTE_URI: + return createURLRecord(ndefRecord.toUri()); + case NdefRecord.TNF_WELL_KNOWN: + return createWellKnownRecord(ndefRecord); + } + return null; + } + + /** + * Constructs empty NdefMessage + */ + public static NdefMessage emptyNdefMessage() { + return new NdefMessage(new NdefRecord(NdefRecord.TNF_EMPTY, null, null, null)); + } + + /** + * Constructs empty NfcRecord + */ + private static NfcRecord createEmptyRecord() { + NfcRecord nfcRecord = new NfcRecord(); + nfcRecord.recordType = NfcRecordType.EMPTY; + nfcRecord.mediaType = ""; + nfcRecord.data = new byte[0]; + return nfcRecord; + } + + /** + * Constructs URL NfcRecord + */ + private static NfcRecord createURLRecord(Uri uri) { + if (uri == null) return null; + NfcRecord nfcRecord = new NfcRecord(); + nfcRecord.recordType = NfcRecordType.URL; + nfcRecord.mediaType = TEXT_MIME; + nfcRecord.data = ApiCompatibilityUtils.getBytesUtf8(uri.toString()); + return nfcRecord; + } + + /** + * Constructs MIME or JSON NfcRecord + */ + private static NfcRecord createMIMERecord(String mediaType, byte[] payload) { + NfcRecord nfcRecord = new NfcRecord(); + if (mediaType.equals(JSON_MIME)) { + nfcRecord.recordType = NfcRecordType.JSON; + } else { + nfcRecord.recordType = NfcRecordType.OPAQUE_RECORD; + } + nfcRecord.mediaType = mediaType; + nfcRecord.data = payload; + return nfcRecord; + } + + /** + * Constructs TEXT NfcRecord + */ + private static NfcRecord createTextRecord(byte[] text) { + // Check that text byte array is not empty. + if (text.length == 0) { + return null; + } + + NfcRecord nfcRecord = new NfcRecord(); + nfcRecord.recordType = NfcRecordType.TEXT; + nfcRecord.mediaType = TEXT_MIME; + // According to NFCForum-TS-RTD_Text_1.0 specification, section 3.2.1 Syntax. + // First byte of the payload is status byte, defined in Table 3: Status Byte Encodings. + // 0-5: lang code length + // 6 : must be zero + // 8 : 0 - text is in UTF-8 encoding, 1 - text is in UTF-16 encoding. + int langCodeLength = (text[0] & (byte) 0x3F); + int textBodyStartPos = langCodeLength + 1; + if (textBodyStartPos > text.length) { + return null; + } + nfcRecord.data = Arrays.copyOfRange(text, textBodyStartPos, text.length); + return nfcRecord; + } + + /** + * Constructs well known type (TEXT or URI) NfcRecord + */ + private static NfcRecord createWellKnownRecord(NdefRecord record) { + if (Arrays.equals(record.getType(), NdefRecord.RTD_URI)) { + return createURLRecord(record.toUri()); + } + + if (Arrays.equals(record.getType(), NdefRecord.RTD_TEXT)) { + return createTextRecord(record.getPayload()); + } + + return null; + } +} diff --git a/chromium/services/device/nfc/android/junit/src/org/chromium/device/nfc/NFCTest.java b/chromium/services/device/nfc/android/junit/src/org/chromium/device/nfc/NFCTest.java new file mode 100644 index 00000000000..d2104c688f6 --- /dev/null +++ b/chromium/services/device/nfc/android/junit/src/org/chromium/device/nfc/NFCTest.java @@ -0,0 +1,1124 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.device.nfc; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.isNull; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import android.app.Activity; +import android.content.Context; +import android.content.pm.PackageManager; +import android.nfc.FormatException; +import android.nfc.NdefMessage; +import android.nfc.NdefRecord; +import android.nfc.NfcAdapter; +import android.nfc.NfcAdapter.ReaderCallback; +import android.nfc.NfcManager; +import android.os.Bundle; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.robolectric.annotation.Config; +import org.robolectric.shadows.ShadowLooper; + +import org.chromium.base.ApiCompatibilityUtils; +import org.chromium.base.Callback; +import org.chromium.base.ContextUtils; +import org.chromium.base.test.util.Feature; +import org.chromium.device.mojom.Nfc.CancelAllWatchesResponse; +import org.chromium.device.mojom.Nfc.CancelPushResponse; +import org.chromium.device.mojom.Nfc.CancelWatchResponse; +import org.chromium.device.mojom.Nfc.PushResponse; +import org.chromium.device.mojom.Nfc.WatchResponse; +import org.chromium.device.mojom.NfcClient; +import org.chromium.device.mojom.NfcError; +import org.chromium.device.mojom.NfcErrorType; +import org.chromium.device.mojom.NfcMessage; +import org.chromium.device.mojom.NfcPushOptions; +import org.chromium.device.mojom.NfcPushTarget; +import org.chromium.device.mojom.NfcRecord; +import org.chromium.device.mojom.NfcRecordType; +import org.chromium.device.mojom.NfcRecordTypeFilter; +import org.chromium.device.mojom.NfcWatchMode; +import org.chromium.device.mojom.NfcWatchOptions; +import org.chromium.testing.local.LocalRobolectricTestRunner; + +import java.io.IOException; +import java.io.UnsupportedEncodingException; + +/** + * Unit tests for NfcImpl and NfcTypeConverter classes. + */ +@RunWith(LocalRobolectricTestRunner.class) +@Config(manifest = Config.NONE) +public class NFCTest { + private TestNfcDelegate mDelegate; + @Mock + private Context mContext; + @Mock + private NfcManager mNfcManager; + @Mock + private NfcAdapter mNfcAdapter; + @Mock + private Activity mActivity; + @Mock + private NfcClient mNfcClient; + @Mock + private NfcTagHandler mNfcTagHandler; + @Captor + private ArgumentCaptor<NfcError> mErrorCaptor; + @Captor + private ArgumentCaptor<Integer> mWatchCaptor; + @Captor + private ArgumentCaptor<int[]> mOnWatchCallbackCaptor; + + // Constants used for the test. + private static final String TEST_TEXT = "test"; + private static final String TEST_URL = "https://google.com"; + private static final String TEST_JSON = "{\"key1\":\"value1\",\"key2\":2}"; + private static final String DOMAIN = "w3.org"; + private static final String TYPE = "webnfc"; + private static final String TEXT_MIME = "text/plain"; + private static final String JSON_MIME = "application/json"; + private static final String CHARSET_UTF8 = ";charset=UTF-8"; + private static final String CHARSET_UTF16 = ";charset=UTF-16"; + private static final String LANG_EN_US = "en-US"; + + /** + * Class that is used test NfcImpl implementation + */ + private static class TestNfcImpl extends NfcImpl { + public TestNfcImpl(Context context, NfcDelegate delegate) { + super(0, delegate); + } + + public void processPendingOperationsForTesting(NfcTagHandler handler) { + super.processPendingOperations(handler); + } + } + + private static class TestNfcDelegate implements NfcDelegate { + Activity mActivity; + Callback<Activity> mCallback; + + public TestNfcDelegate(Activity activity) { + mActivity = activity; + } + @Override + public void trackActivityForHost(int hostId, Callback<Activity> callback) { + mCallback = callback; + } + + public void invokeCallback() { + mCallback.onResult(mActivity); + } + + @Override + public void stopTrackingActivityForHost(int hostId) {} + } + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + mDelegate = new TestNfcDelegate(mActivity); + doReturn(mNfcManager).when(mContext).getSystemService(Context.NFC_SERVICE); + doReturn(mNfcAdapter).when(mNfcManager).getDefaultAdapter(); + doReturn(true).when(mNfcAdapter).isEnabled(); + doReturn(PackageManager.PERMISSION_GRANTED) + .when(mContext) + .checkPermission(anyString(), anyInt(), anyInt()); + doNothing() + .when(mNfcAdapter) + .enableReaderMode(any(Activity.class), any(ReaderCallback.class), anyInt(), + (Bundle) isNull()); + doNothing().when(mNfcAdapter).disableReaderMode(any(Activity.class)); + // Tag handler overrides used to mock connected tag. + doReturn(true).when(mNfcTagHandler).isConnected(); + doReturn(false).when(mNfcTagHandler).isTagOutOfRange(); + try { + doNothing().when(mNfcTagHandler).connect(); + doNothing().when(mNfcTagHandler).write(any(NdefMessage.class)); + doReturn(createUrlWebNFCNdefMessage(TEST_URL)).when(mNfcTagHandler).read(); + doNothing().when(mNfcTagHandler).close(); + } catch (IOException | FormatException e) { + } + ContextUtils.initApplicationContextForTests(mContext); + } + + /** + * Test that error with type NOT_SUPPORTED is returned if NFC is not supported. + */ + @Test + @Feature({"NFCTest"}) + public void testNFCNotSupported() { + doReturn(null).when(mNfcManager).getDefaultAdapter(); + TestNfcImpl nfc = new TestNfcImpl(mContext, mDelegate); + mDelegate.invokeCallback(); + CancelAllWatchesResponse mockCallback = mock(CancelAllWatchesResponse.class); + nfc.cancelAllWatches(mockCallback); + verify(mockCallback).call(mErrorCaptor.capture()); + assertEquals(NfcErrorType.NOT_SUPPORTED, mErrorCaptor.getValue().errorType); + } + + /** + * Test that error with type SECURITY is returned if permission to use NFC is not granted. + */ + @Test + @Feature({"NFCTest"}) + public void testNFCIsNotPermitted() { + doReturn(PackageManager.PERMISSION_DENIED) + .when(mContext) + .checkPermission(anyString(), anyInt(), anyInt()); + TestNfcImpl nfc = new TestNfcImpl(mContext, mDelegate); + CancelAllWatchesResponse mockCallback = mock(CancelAllWatchesResponse.class); + nfc.cancelAllWatches(mockCallback); + verify(mockCallback).call(mErrorCaptor.capture()); + assertEquals(NfcErrorType.SECURITY, mErrorCaptor.getValue().errorType); + } + + /** + * Test that method can be invoked successfully if NFC is supported and adapter is enabled. + */ + @Test + @Feature({"NFCTest"}) + public void testNFCIsSupported() { + TestNfcImpl nfc = new TestNfcImpl(mContext, mDelegate); + mDelegate.invokeCallback(); + WatchResponse mockCallback = mock(WatchResponse.class); + nfc.watch(createNfcWatchOptions(), mockCallback); + verify(mockCallback).call(anyInt(), mErrorCaptor.capture()); + assertNull(mErrorCaptor.getValue()); + } + + /** + * Test conversion from NdefMessage to mojo NfcMessage. + */ + @Test + @Feature({"NFCTest"}) + public void testNdefToMojoConversion() throws UnsupportedEncodingException { + // Test EMPTY record conversion. + NdefMessage emptyNdefMessage = + new NdefMessage(new NdefRecord(NdefRecord.TNF_EMPTY, null, null, null)); + NfcMessage emptyNfcMessage = NfcTypeConverter.toNfcMessage(emptyNdefMessage); + assertNull(emptyNfcMessage.url); + assertEquals(1, emptyNfcMessage.data.length); + assertEquals(NfcRecordType.EMPTY, emptyNfcMessage.data[0].recordType); + assertEquals(true, emptyNfcMessage.data[0].mediaType.isEmpty()); + assertEquals(0, emptyNfcMessage.data[0].data.length); + + // Test URL record conversion. + NdefMessage urlNdefMessage = new NdefMessage(NdefRecord.createUri(TEST_URL)); + NfcMessage urlNfcMessage = NfcTypeConverter.toNfcMessage(urlNdefMessage); + assertNull(urlNfcMessage.url); + assertEquals(1, urlNfcMessage.data.length); + assertEquals(NfcRecordType.URL, urlNfcMessage.data[0].recordType); + assertEquals(TEXT_MIME, urlNfcMessage.data[0].mediaType); + assertEquals(TEST_URL, new String(urlNfcMessage.data[0].data)); + + // Test TEXT record conversion. + NdefMessage textNdefMessage = + new NdefMessage(NdefRecord.createTextRecord(LANG_EN_US, TEST_TEXT)); + NfcMessage textNfcMessage = NfcTypeConverter.toNfcMessage(textNdefMessage); + assertNull(textNfcMessage.url); + assertEquals(1, textNfcMessage.data.length); + assertEquals(NfcRecordType.TEXT, textNfcMessage.data[0].recordType); + assertEquals(TEXT_MIME, textNfcMessage.data[0].mediaType); + assertEquals(TEST_TEXT, new String(textNfcMessage.data[0].data)); + + // Test MIME record conversion. + NdefMessage mimeNdefMessage = new NdefMessage( + NdefRecord.createMime(TEXT_MIME, ApiCompatibilityUtils.getBytesUtf8(TEST_TEXT))); + NfcMessage mimeNfcMessage = NfcTypeConverter.toNfcMessage(mimeNdefMessage); + assertNull(mimeNfcMessage.url); + assertEquals(1, mimeNfcMessage.data.length); + assertEquals(NfcRecordType.OPAQUE_RECORD, mimeNfcMessage.data[0].recordType); + assertEquals(TEXT_MIME, textNfcMessage.data[0].mediaType); + assertEquals(TEST_TEXT, new String(textNfcMessage.data[0].data)); + + // Test JSON record conversion. + NdefMessage jsonNdefMessage = new NdefMessage( + NdefRecord.createMime(JSON_MIME, ApiCompatibilityUtils.getBytesUtf8(TEST_JSON))); + NfcMessage jsonNfcMessage = NfcTypeConverter.toNfcMessage(jsonNdefMessage); + assertNull(jsonNfcMessage.url); + assertEquals(1, jsonNfcMessage.data.length); + assertEquals(NfcRecordType.JSON, jsonNfcMessage.data[0].recordType); + assertEquals(JSON_MIME, jsonNfcMessage.data[0].mediaType); + assertEquals(TEST_JSON, new String(jsonNfcMessage.data[0].data)); + + // Test NfcMessage with WebNFC external type. + NdefRecord jsonNdefRecord = + NdefRecord.createMime(JSON_MIME, ApiCompatibilityUtils.getBytesUtf8(TEST_JSON)); + NdefRecord extNdefRecord = NdefRecord.createExternal( + DOMAIN, TYPE, ApiCompatibilityUtils.getBytesUtf8(TEST_URL)); + NdefMessage webNdefMessage = new NdefMessage(jsonNdefRecord, extNdefRecord); + NfcMessage webNfcMessage = NfcTypeConverter.toNfcMessage(webNdefMessage); + assertEquals(TEST_URL, webNfcMessage.url); + assertEquals(1, webNfcMessage.data.length); + assertEquals(NfcRecordType.JSON, webNfcMessage.data[0].recordType); + assertEquals(JSON_MIME, webNfcMessage.data[0].mediaType); + assertEquals(TEST_JSON, new String(webNfcMessage.data[0].data)); + } + + /** + * Test conversion from mojo NfcMessage to android NdefMessage. + */ + @Test + @Feature({"NFCTest"}) + public void testMojoToNdefConversion() throws InvalidNfcMessageException { + // Test URL record conversion. + NdefMessage urlNdefMessage = createUrlWebNFCNdefMessage(TEST_URL); + assertEquals(2, urlNdefMessage.getRecords().length); + assertEquals(NdefRecord.TNF_WELL_KNOWN, urlNdefMessage.getRecords()[0].getTnf()); + assertEquals(TEST_URL, urlNdefMessage.getRecords()[0].toUri().toString()); + assertEquals(NdefRecord.TNF_EXTERNAL_TYPE, urlNdefMessage.getRecords()[1].getTnf()); + assertEquals(DOMAIN + ":" + TYPE, new String(urlNdefMessage.getRecords()[1].getType())); + + // Test TEXT record conversion. + NfcRecord textNfcRecord = new NfcRecord(); + textNfcRecord.recordType = NfcRecordType.TEXT; + textNfcRecord.mediaType = TEXT_MIME; + textNfcRecord.data = ApiCompatibilityUtils.getBytesUtf8(TEST_TEXT); + NfcMessage textNfcMessage = createNfcMessage(TEST_URL, textNfcRecord); + NdefMessage textNdefMessage = NfcTypeConverter.toNdefMessage(textNfcMessage); + assertEquals(2, textNdefMessage.getRecords().length); + short tnf = textNdefMessage.getRecords()[0].getTnf(); + boolean isWellKnownOrMime = + (tnf == NdefRecord.TNF_WELL_KNOWN || tnf == NdefRecord.TNF_MIME_MEDIA); + assertEquals(true, isWellKnownOrMime); + assertEquals(NdefRecord.TNF_EXTERNAL_TYPE, textNdefMessage.getRecords()[1].getTnf()); + + // Test MIME record conversion. + NfcRecord mimeNfcRecord = new NfcRecord(); + mimeNfcRecord.recordType = NfcRecordType.OPAQUE_RECORD; + mimeNfcRecord.mediaType = TEXT_MIME; + mimeNfcRecord.data = ApiCompatibilityUtils.getBytesUtf8(TEST_TEXT); + NfcMessage mimeNfcMessage = createNfcMessage(TEST_URL, mimeNfcRecord); + NdefMessage mimeNdefMessage = NfcTypeConverter.toNdefMessage(mimeNfcMessage); + assertEquals(2, mimeNdefMessage.getRecords().length); + assertEquals(NdefRecord.TNF_MIME_MEDIA, mimeNdefMessage.getRecords()[0].getTnf()); + assertEquals(TEXT_MIME, mimeNdefMessage.getRecords()[0].toMimeType()); + assertEquals(TEST_TEXT, new String(mimeNdefMessage.getRecords()[0].getPayload())); + assertEquals(NdefRecord.TNF_EXTERNAL_TYPE, mimeNdefMessage.getRecords()[1].getTnf()); + + // Test JSON record conversion. + NfcRecord jsonNfcRecord = new NfcRecord(); + jsonNfcRecord.recordType = NfcRecordType.OPAQUE_RECORD; + jsonNfcRecord.mediaType = JSON_MIME; + jsonNfcRecord.data = ApiCompatibilityUtils.getBytesUtf8(TEST_JSON); + NfcMessage jsonNfcMessage = createNfcMessage(TEST_URL, jsonNfcRecord); + NdefMessage jsonNdefMessage = NfcTypeConverter.toNdefMessage(jsonNfcMessage); + assertEquals(2, jsonNdefMessage.getRecords().length); + assertEquals(NdefRecord.TNF_MIME_MEDIA, jsonNdefMessage.getRecords()[0].getTnf()); + assertEquals(JSON_MIME, jsonNdefMessage.getRecords()[0].toMimeType()); + assertEquals(TEST_JSON, new String(jsonNdefMessage.getRecords()[0].getPayload())); + assertEquals(NdefRecord.TNF_EXTERNAL_TYPE, jsonNdefMessage.getRecords()[1].getTnf()); + + // Test EMPTY record conversion. + NfcRecord emptyNfcRecord = new NfcRecord(); + emptyNfcRecord.recordType = NfcRecordType.EMPTY; + NfcMessage emptyNfcMessage = createNfcMessage(TEST_URL, emptyNfcRecord); + NdefMessage emptyNdefMessage = NfcTypeConverter.toNdefMessage(emptyNfcMessage); + assertEquals(2, emptyNdefMessage.getRecords().length); + assertEquals(NdefRecord.TNF_EMPTY, emptyNdefMessage.getRecords()[0].getTnf()); + assertEquals(NdefRecord.TNF_EXTERNAL_TYPE, emptyNdefMessage.getRecords()[1].getTnf()); + } + + /** + * Test that invalid NfcMessage is rejected with INVALID_MESSAGE error code. + */ + @Test + @Feature({"NFCTest"}) + public void testInvalidNfcMessage() { + TestNfcImpl nfc = new TestNfcImpl(mContext, mDelegate); + mDelegate.invokeCallback(); + PushResponse mockCallback = mock(PushResponse.class); + nfc.push(new NfcMessage(), createNfcPushOptions(), mockCallback); + nfc.processPendingOperationsForTesting(mNfcTagHandler); + verify(mockCallback).call(mErrorCaptor.capture()); + assertEquals(NfcErrorType.INVALID_MESSAGE, mErrorCaptor.getValue().errorType); + } + + /** + * Test that Nfc.suspendNfcOperations() and Nfc.resumeNfcOperations() work correctly. + */ + @Test + @Feature({"NFCTest"}) + public void testResumeSuspend() { + TestNfcImpl nfc = new TestNfcImpl(mContext, mDelegate); + // No activity / client or active pending operations + nfc.suspendNfcOperations(); + nfc.resumeNfcOperations(); + + mDelegate.invokeCallback(); + nfc.setClient(mNfcClient); + WatchResponse mockCallback = mock(WatchResponse.class); + nfc.watch(createNfcWatchOptions(), mockCallback); + nfc.suspendNfcOperations(); + verify(mNfcAdapter, times(1)).disableReaderMode(mActivity); + nfc.resumeNfcOperations(); + // 1. Enable after watch is called, 2. after resumeNfcOperations is called. + verify(mNfcAdapter, times(2)) + .enableReaderMode(any(Activity.class), any(ReaderCallback.class), anyInt(), + (Bundle) isNull()); + + nfc.processPendingOperationsForTesting(mNfcTagHandler); + // Check that watch request was completed successfully. + verify(mockCallback).call(mWatchCaptor.capture(), mErrorCaptor.capture()); + assertNull(mErrorCaptor.getValue()); + + // Check that client was notified and watch with correct id was triggered. + verify(mNfcClient, times(1)) + .onWatch(mOnWatchCallbackCaptor.capture(), any(NfcMessage.class)); + assertEquals(mWatchCaptor.getValue().intValue(), mOnWatchCallbackCaptor.getValue()[0]); + } + + /** + * Test that Nfc.push() successful when NFC tag is connected. + */ + @Test + @Feature({"NFCTest"}) + public void testPush() { + TestNfcImpl nfc = new TestNfcImpl(mContext, mDelegate); + mDelegate.invokeCallback(); + PushResponse mockCallback = mock(PushResponse.class); + nfc.push(createNfcMessage(), createNfcPushOptions(), mockCallback); + nfc.processPendingOperationsForTesting(mNfcTagHandler); + verify(mockCallback).call(mErrorCaptor.capture()); + assertNull(mErrorCaptor.getValue()); + } + + /** + * Test that Nfc.cancelPush() cancels pending push opration and completes successfully. + */ + @Test + @Feature({"NFCTest"}) + public void testCancelPush() { + TestNfcImpl nfc = new TestNfcImpl(mContext, mDelegate); + mDelegate.invokeCallback(); + PushResponse mockPushCallback = mock(PushResponse.class); + CancelPushResponse mockCancelPushCallback = mock(CancelPushResponse.class); + nfc.push(createNfcMessage(), createNfcPushOptions(), mockPushCallback); + nfc.cancelPush(NfcPushTarget.ANY, mockCancelPushCallback); + + // Check that push request was cancelled with OPERATION_CANCELLED. + verify(mockPushCallback).call(mErrorCaptor.capture()); + assertEquals(NfcErrorType.OPERATION_CANCELLED, mErrorCaptor.getValue().errorType); + + // Check that cancel request was successfuly completed. + verify(mockCancelPushCallback).call(mErrorCaptor.capture()); + assertNull(mErrorCaptor.getValue()); + } + + /** + * Test that Nfc.watch() works correctly and client is notified. + */ + @Test + @Feature({"NFCTest"}) + public void testWatch() { + TestNfcImpl nfc = new TestNfcImpl(mContext, mDelegate); + mDelegate.invokeCallback(); + nfc.setClient(mNfcClient); + WatchResponse mockWatchCallback1 = mock(WatchResponse.class); + nfc.watch(createNfcWatchOptions(), mockWatchCallback1); + + // Check that watch requests were completed successfully. + verify(mockWatchCallback1).call(mWatchCaptor.capture(), mErrorCaptor.capture()); + assertNull(mErrorCaptor.getValue()); + int watchId1 = mWatchCaptor.getValue().intValue(); + + WatchResponse mockWatchCallback2 = mock(WatchResponse.class); + nfc.watch(createNfcWatchOptions(), mockWatchCallback2); + verify(mockWatchCallback2).call(mWatchCaptor.capture(), mErrorCaptor.capture()); + assertNull(mErrorCaptor.getValue()); + int watchId2 = mWatchCaptor.getValue().intValue(); + // Check that each watch operation is associated with unique id. + assertNotEquals(watchId1, watchId2); + + // Mocks 'NFC tag found' event. + nfc.processPendingOperationsForTesting(mNfcTagHandler); + + // Check that client was notified and correct watch ids were provided. + verify(mNfcClient, times(1)) + .onWatch(mOnWatchCallbackCaptor.capture(), any(NfcMessage.class)); + assertEquals(watchId1, mOnWatchCallbackCaptor.getValue()[0]); + assertEquals(watchId2, mOnWatchCallbackCaptor.getValue()[1]); + } + + /** + * Test that Nfc.watch() matching function works correctly. + */ + @Test + @Feature({"NFCTest"}) + public void testWatchMatching() { + TestNfcImpl nfc = new TestNfcImpl(mContext, mDelegate); + mDelegate.invokeCallback(); + nfc.setClient(mNfcClient); + + // Should match by WebNFC Id (exact match). + NfcWatchOptions options1 = createNfcWatchOptions(); + options1.mode = NfcWatchMode.WEBNFC_ONLY; + options1.url = TEST_URL; + WatchResponse mockWatchCallback1 = mock(WatchResponse.class); + nfc.watch(options1, mockWatchCallback1); + verify(mockWatchCallback1).call(mWatchCaptor.capture(), mErrorCaptor.capture()); + assertNull(mErrorCaptor.getValue()); + int watchId1 = mWatchCaptor.getValue().intValue(); + + // Should match by media type. + NfcWatchOptions options2 = createNfcWatchOptions(); + options2.mode = NfcWatchMode.ANY; + options2.mediaType = TEXT_MIME; + WatchResponse mockWatchCallback2 = mock(WatchResponse.class); + nfc.watch(options2, mockWatchCallback2); + verify(mockWatchCallback2).call(mWatchCaptor.capture(), mErrorCaptor.capture()); + assertNull(mErrorCaptor.getValue()); + int watchId2 = mWatchCaptor.getValue().intValue(); + + // Should match by record type. + NfcWatchOptions options3 = createNfcWatchOptions(); + options3.mode = NfcWatchMode.ANY; + NfcRecordTypeFilter typeFilter = new NfcRecordTypeFilter(); + typeFilter.recordType = NfcRecordType.URL; + options3.recordFilter = typeFilter; + WatchResponse mockWatchCallback3 = mock(WatchResponse.class); + nfc.watch(options3, mockWatchCallback3); + verify(mockWatchCallback3).call(mWatchCaptor.capture(), mErrorCaptor.capture()); + assertNull(mErrorCaptor.getValue()); + int watchId3 = mWatchCaptor.getValue().intValue(); + + // Should not match + NfcWatchOptions options4 = createNfcWatchOptions(); + options4.mode = NfcWatchMode.WEBNFC_ONLY; + options4.url = DOMAIN; + WatchResponse mockWatchCallback4 = mock(WatchResponse.class); + nfc.watch(options4, mockWatchCallback4); + verify(mockWatchCallback4).call(mWatchCaptor.capture(), mErrorCaptor.capture()); + assertNull(mErrorCaptor.getValue()); + int watchId4 = mWatchCaptor.getValue().intValue(); + + nfc.processPendingOperationsForTesting(mNfcTagHandler); + + // Check that client was notified and watch with correct id was triggered. + verify(mNfcClient, times(1)) + .onWatch(mOnWatchCallbackCaptor.capture(), any(NfcMessage.class)); + assertEquals(3, mOnWatchCallbackCaptor.getValue().length); + assertEquals(watchId1, mOnWatchCallbackCaptor.getValue()[0]); + assertEquals(watchId2, mOnWatchCallbackCaptor.getValue()[1]); + assertEquals(watchId3, mOnWatchCallbackCaptor.getValue()[2]); + } + + /** + * Test that Nfc.watch() can be cancelled with Nfc.cancelWatch(). + */ + @Test + @Feature({"NFCTest"}) + public void testCancelWatch() { + TestNfcImpl nfc = new TestNfcImpl(mContext, mDelegate); + mDelegate.invokeCallback(); + WatchResponse mockWatchCallback = mock(WatchResponse.class); + nfc.watch(createNfcWatchOptions(), mockWatchCallback); + + verify(mockWatchCallback).call(mWatchCaptor.capture(), mErrorCaptor.capture()); + assertNull(mErrorCaptor.getValue()); + + CancelWatchResponse mockCancelWatchCallback = mock(CancelWatchResponse.class); + nfc.cancelWatch(mWatchCaptor.getValue().intValue(), mockCancelWatchCallback); + + // Check that cancel request was successfuly completed. + verify(mockCancelWatchCallback).call(mErrorCaptor.capture()); + assertNull(mErrorCaptor.getValue()); + + // Check that watch is not triggered when NFC tag is in proximity. + nfc.processPendingOperationsForTesting(mNfcTagHandler); + verify(mNfcClient, times(0)).onWatch(any(int[].class), any(NfcMessage.class)); + } + + /** + * Test that Nfc.cancelAllWatches() cancels all pending watch operations. + */ + @Test + @Feature({"NFCTest"}) + public void testCancelAllWatches() { + TestNfcImpl nfc = new TestNfcImpl(mContext, mDelegate); + mDelegate.invokeCallback(); + WatchResponse mockWatchCallback1 = mock(WatchResponse.class); + WatchResponse mockWatchCallback2 = mock(WatchResponse.class); + nfc.watch(createNfcWatchOptions(), mockWatchCallback1); + verify(mockWatchCallback1).call(mWatchCaptor.capture(), mErrorCaptor.capture()); + assertNull(mErrorCaptor.getValue()); + + nfc.watch(createNfcWatchOptions(), mockWatchCallback2); + verify(mockWatchCallback2).call(mWatchCaptor.capture(), mErrorCaptor.capture()); + assertNull(mErrorCaptor.getValue()); + + CancelAllWatchesResponse mockCallback = mock(CancelAllWatchesResponse.class); + nfc.cancelAllWatches(mockCallback); + + // Check that cancel request was successfuly completed. + verify(mockCallback).call(mErrorCaptor.capture()); + assertNull(mErrorCaptor.getValue()); + } + + /** + * Test that Nfc.cancelWatch() with invalid id is failing with NOT_FOUND error. + */ + @Test + @Feature({"NFCTest"}) + public void testCancelWatchInvalidId() { + TestNfcImpl nfc = new TestNfcImpl(mContext, mDelegate); + mDelegate.invokeCallback(); + WatchResponse mockWatchCallback = mock(WatchResponse.class); + nfc.watch(createNfcWatchOptions(), mockWatchCallback); + + verify(mockWatchCallback).call(mWatchCaptor.capture(), mErrorCaptor.capture()); + assertNull(mErrorCaptor.getValue()); + + CancelWatchResponse mockCancelWatchCallback = mock(CancelWatchResponse.class); + nfc.cancelWatch(mWatchCaptor.getValue().intValue() + 1, mockCancelWatchCallback); + + verify(mockCancelWatchCallback).call(mErrorCaptor.capture()); + assertEquals(NfcErrorType.NOT_FOUND, mErrorCaptor.getValue().errorType); + } + + /** + * Test that Nfc.cancelAllWatches() is failing with NOT_FOUND error if there are no active + * watch opeartions. + */ + @Test + @Feature({"NFCTest"}) + public void testCancelAllWatchesWithNoWathcers() { + TestNfcImpl nfc = new TestNfcImpl(mContext, mDelegate); + mDelegate.invokeCallback(); + CancelAllWatchesResponse mockCallback = mock(CancelAllWatchesResponse.class); + nfc.cancelAllWatches(mockCallback); + + verify(mockCallback).call(mErrorCaptor.capture()); + assertEquals(NfcErrorType.NOT_FOUND, mErrorCaptor.getValue().errorType); + } + + /** + * Test that when tag is disconnected during read operation, IllegalStateException is handled. + */ + @Test + @Feature({"NFCTest"}) + public void testTagDisconnectedDuringRead() throws IOException, FormatException { + TestNfcImpl nfc = new TestNfcImpl(mContext, mDelegate); + mDelegate.invokeCallback(); + nfc.setClient(mNfcClient); + WatchResponse mockWatchCallback = mock(WatchResponse.class); + nfc.watch(createNfcWatchOptions(), mockWatchCallback); + + // Force read operation to fail + doThrow(IllegalStateException.class).when(mNfcTagHandler).read(); + + // Mocks 'NFC tag found' event. + nfc.processPendingOperationsForTesting(mNfcTagHandler); + + // Check that client was not notified. + verify(mNfcClient, times(0)) + .onWatch(mOnWatchCallbackCaptor.capture(), any(NfcMessage.class)); + } + + /** + * Test that when tag is disconnected during write operation, IllegalStateException is handled. + */ + @Test + @Feature({"NFCTest"}) + public void testTagDisconnectedDuringWrite() throws IOException, FormatException { + TestNfcImpl nfc = new TestNfcImpl(mContext, mDelegate); + mDelegate.invokeCallback(); + PushResponse mockCallback = mock(PushResponse.class); + + // Force write operation to fail + doThrow(IllegalStateException.class).when(mNfcTagHandler).write(any(NdefMessage.class)); + nfc.push(createNfcMessage(), createNfcPushOptions(), mockCallback); + nfc.processPendingOperationsForTesting(mNfcTagHandler); + verify(mockCallback).call(mErrorCaptor.capture()); + + // Test that correct error is returned. + assertNotNull(mErrorCaptor.getValue()); + assertEquals(NfcErrorType.IO_ERROR, mErrorCaptor.getValue().errorType); + } + + /** + * Test that push operation completes with TIMER_EXPIRED error when operation times-out. + */ + @Test + @Feature({"NFCTest"}) + public void testPushTimeout() { + TestNfcImpl nfc = new TestNfcImpl(mContext, mDelegate); + mDelegate.invokeCallback(); + PushResponse mockCallback = mock(PushResponse.class); + + // Set 1 millisecond timeout. + nfc.push(createNfcMessage(), createNfcPushOptions(1), mockCallback); + + // Wait for timeout. + ShadowLooper.runUiThreadTasksIncludingDelayedTasks(); + + // Test that correct error is returned. + verify(mockCallback).call(mErrorCaptor.capture()); + assertNotNull(mErrorCaptor.getValue()); + assertEquals(NfcErrorType.TIMER_EXPIRED, mErrorCaptor.getValue().errorType); + } + + /** + * Test that multiple Nfc.push() invocations do not disable reader mode. + */ + @Test + @Feature({"NFCTest"}) + public void testPushMultipleInvocations() { + TestNfcImpl nfc = new TestNfcImpl(mContext, mDelegate); + mDelegate.invokeCallback(); + + PushResponse mockCallback1 = mock(PushResponse.class); + PushResponse mockCallback2 = mock(PushResponse.class); + nfc.push(createNfcMessage(), createNfcPushOptions(), mockCallback1); + nfc.push(createNfcMessage(), createNfcPushOptions(), mockCallback2); + + verify(mNfcAdapter, times(1)) + .enableReaderMode(any(Activity.class), any(ReaderCallback.class), anyInt(), + (Bundle) isNull()); + verify(mNfcAdapter, times(0)).disableReaderMode(mActivity); + + verify(mockCallback1).call(mErrorCaptor.capture()); + assertNotNull(mErrorCaptor.getValue()); + assertEquals(NfcErrorType.OPERATION_CANCELLED, mErrorCaptor.getValue().errorType); + } + + /** + * Test that reader mode is disabled after push operation timeout is expired. + */ + @Test + @Feature({"NFCTest"}) + public void testReaderModeIsDisabledAfterTimeout() { + TestNfcImpl nfc = new TestNfcImpl(mContext, mDelegate); + mDelegate.invokeCallback(); + PushResponse mockCallback = mock(PushResponse.class); + + // Set 1 millisecond timeout. + nfc.push(createNfcMessage(), createNfcPushOptions(1), mockCallback); + + verify(mNfcAdapter, times(1)) + .enableReaderMode(any(Activity.class), any(ReaderCallback.class), anyInt(), + (Bundle) isNull()); + + // Wait for timeout. + ShadowLooper.runUiThreadTasksIncludingDelayedTasks(); + + // Reader mode is disabled + verify(mNfcAdapter, times(1)).disableReaderMode(mActivity); + + // Test that correct error is returned. + verify(mockCallback).call(mErrorCaptor.capture()); + assertNotNull(mErrorCaptor.getValue()); + assertEquals(NfcErrorType.TIMER_EXPIRED, mErrorCaptor.getValue().errorType); + } + + /** + * Test that reader mode is disabled and two push operations are cancelled with correct + * error code. + */ + @Test + @Feature({"NFCTest"}) + public void testTwoPushInvocationsWithCancel() { + TestNfcImpl nfc = new TestNfcImpl(mContext, mDelegate); + mDelegate.invokeCallback(); + + PushResponse mockCallback1 = mock(PushResponse.class); + + // First push without timeout, must be completed with OPERATION_CANCELLED. + nfc.push(createNfcMessage(), createNfcPushOptions(), mockCallback1); + + PushResponse mockCallback2 = mock(PushResponse.class); + + // Second push with 1 millisecond timeout, should be cancelled before timer expires, + // thus, operation must be completed with OPERATION_CANCELLED. + nfc.push(createNfcMessage(), createNfcPushOptions(1), mockCallback2); + + verify(mNfcAdapter, times(1)) + .enableReaderMode(any(Activity.class), any(ReaderCallback.class), anyInt(), + (Bundle) isNull()); + verify(mockCallback1).call(mErrorCaptor.capture()); + assertNotNull(mErrorCaptor.getValue()); + assertEquals(NfcErrorType.OPERATION_CANCELLED, mErrorCaptor.getValue().errorType); + + CancelPushResponse mockCancelPushCallback = mock(CancelPushResponse.class); + nfc.cancelPush(NfcPushTarget.ANY, mockCancelPushCallback); + + // Reader mode is disabled after cancelPush is invoked. + verify(mNfcAdapter, times(1)).disableReaderMode(mActivity); + + // Wait for delayed tasks to complete. + ShadowLooper.runUiThreadTasksIncludingDelayedTasks(); + + // The disableReaderMode is not called after delayed tasks are processed. + verify(mNfcAdapter, times(1)).disableReaderMode(mActivity); + + // Test that correct error is returned. + verify(mockCallback2).call(mErrorCaptor.capture()); + assertNotNull(mErrorCaptor.getValue()); + assertEquals(NfcErrorType.OPERATION_CANCELLED, mErrorCaptor.getValue().errorType); + } + + /** + * Test that reader mode is disabled and two push operations with timeout are cancelled + * with correct error codes. + */ + @Test + @Feature({"NFCTest"}) + public void testTwoPushInvocationsWithTimeout() { + TestNfcImpl nfc = new TestNfcImpl(mContext, mDelegate); + mDelegate.invokeCallback(); + + PushResponse mockCallback1 = mock(PushResponse.class); + + // First push without timeout, must be completed with OPERATION_CANCELLED. + nfc.push(createNfcMessage(), createNfcPushOptions(1), mockCallback1); + + PushResponse mockCallback2 = mock(PushResponse.class); + + // Second push with 1 millisecond timeout, should be cancelled with TIMER_EXPIRED. + nfc.push(createNfcMessage(), createNfcPushOptions(1), mockCallback2); + + verify(mNfcAdapter, times(1)) + .enableReaderMode(any(Activity.class), any(ReaderCallback.class), anyInt(), + (Bundle) isNull()); + + // Reader mode enabled for the duration of timeout. + verify(mNfcAdapter, times(0)).disableReaderMode(mActivity); + + verify(mockCallback1).call(mErrorCaptor.capture()); + assertNotNull(mErrorCaptor.getValue()); + assertEquals(NfcErrorType.OPERATION_CANCELLED, mErrorCaptor.getValue().errorType); + + // Wait for delayed tasks to complete. + ShadowLooper.runUiThreadTasksIncludingDelayedTasks(); + + // Reader mode is disabled + verify(mNfcAdapter, times(1)).disableReaderMode(mActivity); + + // Test that correct error is returned for second push. + verify(mockCallback2).call(mErrorCaptor.capture()); + assertNotNull(mErrorCaptor.getValue()); + assertEquals(NfcErrorType.TIMER_EXPIRED, mErrorCaptor.getValue().errorType); + } + + /** + * Test that reader mode is not disabled when there is an active watch operation and push + * operation timer is expired. + */ + @Test + @Feature({"NFCTest"}) + public void testTimeoutDontDisableReaderMode() { + TestNfcImpl nfc = new TestNfcImpl(mContext, mDelegate); + mDelegate.invokeCallback(); + WatchResponse mockWatchCallback = mock(WatchResponse.class); + nfc.watch(createNfcWatchOptions(), mockWatchCallback); + + PushResponse mockPushCallback = mock(PushResponse.class); + // Should be cancelled with TIMER_EXPIRED. + nfc.push(createNfcMessage(), createNfcPushOptions(1), mockPushCallback); + + verify(mNfcAdapter, times(1)) + .enableReaderMode(any(Activity.class), any(ReaderCallback.class), anyInt(), + (Bundle) isNull()); + + // Wait for delayed tasks to complete. + ShadowLooper.runUiThreadTasksIncludingDelayedTasks(); + + // Push was cancelled with TIMER_EXPIRED. + verify(mockPushCallback).call(mErrorCaptor.capture()); + assertNotNull(mErrorCaptor.getValue()); + assertEquals(NfcErrorType.TIMER_EXPIRED, mErrorCaptor.getValue().errorType); + + verify(mNfcAdapter, times(0)).disableReaderMode(mActivity); + + CancelAllWatchesResponse mockCancelCallback = mock(CancelAllWatchesResponse.class); + nfc.cancelAllWatches(mockCancelCallback); + + // Check that cancel request was successfuly completed. + verify(mockCancelCallback).call(mErrorCaptor.capture()); + assertNull(mErrorCaptor.getValue()); + + // Reader mode is disabled when there are no pending push / watch operations. + verify(mNfcAdapter, times(1)).disableReaderMode(mActivity); + } + + /** + * Test timeout overflow and negative timeout + */ + @Test + @Feature({"NFCTest"}) + public void testInvalidPushOptions() { + TestNfcImpl nfc = new TestNfcImpl(mContext, mDelegate); + mDelegate.invokeCallback(); + PushResponse mockCallback = mock(PushResponse.class); + + // Long overflow + nfc.push(createNfcMessage(), createNfcPushOptions(Long.MAX_VALUE + 1), mockCallback); + + verify(mockCallback).call(mErrorCaptor.capture()); + assertNotNull(mErrorCaptor.getValue()); + assertEquals(NfcErrorType.NOT_SUPPORTED, mErrorCaptor.getValue().errorType); + + // Test negative timeout + PushResponse mockCallback2 = mock(PushResponse.class); + nfc.push(createNfcMessage(), createNfcPushOptions(-1), mockCallback2); + + verify(mockCallback2).call(mErrorCaptor.capture()); + assertNotNull(mErrorCaptor.getValue()); + assertEquals(NfcErrorType.NOT_SUPPORTED, mErrorCaptor.getValue().errorType); + } + + /** + * Test that Nfc.watch() WebNFC Id pattern matching works correctly. + */ + @Test + @Feature({"NFCTest"}) + public void testWatchPatternMatching() throws IOException, FormatException { + TestNfcImpl nfc = new TestNfcImpl(mContext, mDelegate); + mDelegate.invokeCallback(); + nfc.setClient(mNfcClient); + + // Should match. + { + NfcWatchOptions options = createNfcWatchOptions(); + options.mode = NfcWatchMode.WEBNFC_ONLY; + options.url = "https://test.com/*"; + WatchResponse mockWatchCallback = mock(WatchResponse.class); + nfc.watch(options, mockWatchCallback); + verify(mockWatchCallback).call(mWatchCaptor.capture(), mErrorCaptor.capture()); + assertNull(mErrorCaptor.getValue()); + } + int watchId1 = mWatchCaptor.getValue().intValue(); + + // Should match. + { + NfcWatchOptions options = createNfcWatchOptions(); + options.mode = NfcWatchMode.WEBNFC_ONLY; + options.url = "https://test.com/contact/42"; + WatchResponse mockWatchCallback = mock(WatchResponse.class); + nfc.watch(options, mockWatchCallback); + verify(mockWatchCallback).call(mWatchCaptor.capture(), mErrorCaptor.capture()); + assertNull(mErrorCaptor.getValue()); + } + int watchId2 = mWatchCaptor.getValue().intValue(); + + // Should match. + { + NfcWatchOptions options = createNfcWatchOptions(); + options.mode = NfcWatchMode.WEBNFC_ONLY; + options.url = "https://subdomain.test.com/*"; + WatchResponse mockWatchCallback = mock(WatchResponse.class); + nfc.watch(options, mockWatchCallback); + verify(mockWatchCallback).call(mWatchCaptor.capture(), mErrorCaptor.capture()); + assertNull(mErrorCaptor.getValue()); + } + int watchId3 = mWatchCaptor.getValue().intValue(); + + // Should match. + { + NfcWatchOptions options = createNfcWatchOptions(); + options.mode = NfcWatchMode.WEBNFC_ONLY; + options.url = "https://subdomain.test.com/contact"; + WatchResponse mockWatchCallback = mock(WatchResponse.class); + nfc.watch(options, mockWatchCallback); + verify(mockWatchCallback).call(mWatchCaptor.capture(), mErrorCaptor.capture()); + assertNull(mErrorCaptor.getValue()); + } + int watchId4 = mWatchCaptor.getValue().intValue(); + + // Should not match. + { + NfcWatchOptions options = createNfcWatchOptions(); + options.mode = NfcWatchMode.WEBNFC_ONLY; + options.url = "https://www.test.com/*"; + WatchResponse mockWatchCallback = mock(WatchResponse.class); + nfc.watch(options, mockWatchCallback); + verify(mockWatchCallback).call(mWatchCaptor.capture(), mErrorCaptor.capture()); + assertNull(mErrorCaptor.getValue()); + } + + // Should not match. + { + NfcWatchOptions options = createNfcWatchOptions(); + options.mode = NfcWatchMode.WEBNFC_ONLY; + options.url = "http://test.com/*"; + WatchResponse mockWatchCallback = mock(WatchResponse.class); + nfc.watch(options, mockWatchCallback); + verify(mockWatchCallback).call(mWatchCaptor.capture(), mErrorCaptor.capture()); + assertNull(mErrorCaptor.getValue()); + } + + // Should not match. + { + NfcWatchOptions options = createNfcWatchOptions(); + options.mode = NfcWatchMode.WEBNFC_ONLY; + options.url = "invalid pattern url"; + WatchResponse mockWatchCallback = mock(WatchResponse.class); + nfc.watch(options, mockWatchCallback); + verify(mockWatchCallback).call(mWatchCaptor.capture(), mErrorCaptor.capture()); + assertNull(mErrorCaptor.getValue()); + } + + doReturn(createUrlWebNFCNdefMessage("https://subdomain.test.com/contact/42")) + .when(mNfcTagHandler) + .read(); + nfc.processPendingOperationsForTesting(mNfcTagHandler); + + // None of the watches should match NFCMessage with this WebNFC Id. + doReturn(createUrlWebNFCNdefMessage("https://notest.com/foo")).when(mNfcTagHandler).read(); + nfc.processPendingOperationsForTesting(mNfcTagHandler); + + // Check that client was notified and watch with correct id was triggered. + verify(mNfcClient, times(1)) + .onWatch(mOnWatchCallbackCaptor.capture(), any(NfcMessage.class)); + assertEquals(4, mOnWatchCallbackCaptor.getValue().length); + assertEquals(watchId1, mOnWatchCallbackCaptor.getValue()[0]); + assertEquals(watchId2, mOnWatchCallbackCaptor.getValue()[1]); + assertEquals(watchId3, mOnWatchCallbackCaptor.getValue()[2]); + assertEquals(watchId4, mOnWatchCallbackCaptor.getValue()[3]); + } + + /** + * Test that Nfc.watch() WebNFC Id pattern matching works correctly for invalid WebNFC Ids. + */ + @Test + @Feature({"NFCTest"}) + public void testWatchPatternMatchingInvalidId() throws IOException, FormatException { + TestNfcImpl nfc = new TestNfcImpl(mContext, mDelegate); + mDelegate.invokeCallback(); + nfc.setClient(mNfcClient); + + // Should not match when invalid WebNFC Id is received. + NfcWatchOptions options = createNfcWatchOptions(); + options.mode = NfcWatchMode.WEBNFC_ONLY; + options.url = "https://test.com/*"; + WatchResponse mockWatchCallback = mock(WatchResponse.class); + nfc.watch(options, mockWatchCallback); + verify(mockWatchCallback).call(mWatchCaptor.capture(), mErrorCaptor.capture()); + assertNull(mErrorCaptor.getValue()); + + doReturn(createUrlWebNFCNdefMessage("http://subdomain.test.com/contact/42")) + .when(mNfcTagHandler) + .read(); + nfc.processPendingOperationsForTesting(mNfcTagHandler); + + doReturn(createUrlWebNFCNdefMessage("ftp://subdomain.test.com/contact/42")) + .when(mNfcTagHandler) + .read(); + nfc.processPendingOperationsForTesting(mNfcTagHandler); + + doReturn(createUrlWebNFCNdefMessage("invalid url")).when(mNfcTagHandler).read(); + nfc.processPendingOperationsForTesting(mNfcTagHandler); + + verify(mNfcClient, times(0)) + .onWatch(mOnWatchCallbackCaptor.capture(), any(NfcMessage.class)); + } + + /** + * Test that Nfc.push() succeeds for NFC messages with EMPTY records. + */ + @Test + @Feature({"NFCTest"}) + public void testPushWithEmptyRecord() { + TestNfcImpl nfc = new TestNfcImpl(mContext, mDelegate); + mDelegate.invokeCallback(); + PushResponse mockCallback = mock(PushResponse.class); + + // Create message with empty record. + NfcRecord emptyNfcRecord = new NfcRecord(); + emptyNfcRecord.recordType = NfcRecordType.EMPTY; + NfcMessage nfcMessage = createNfcMessage(TEST_URL, emptyNfcRecord); + + nfc.push(nfcMessage, createNfcPushOptions(), mockCallback); + nfc.processPendingOperationsForTesting(mNfcTagHandler); + verify(mockCallback).call(mErrorCaptor.capture()); + assertNull(mErrorCaptor.getValue()); + } + + /** + * Creates NfcPushOptions with default values. + */ + private NfcPushOptions createNfcPushOptions() { + NfcPushOptions pushOptions = new NfcPushOptions(); + pushOptions.target = NfcPushTarget.ANY; + pushOptions.timeout = Double.POSITIVE_INFINITY; + pushOptions.ignoreRead = false; + return pushOptions; + } + + /** + * Creates NfcPushOptions with specified timeout. + */ + private NfcPushOptions createNfcPushOptions(double timeout) { + NfcPushOptions pushOptions = new NfcPushOptions(); + pushOptions.target = NfcPushTarget.ANY; + pushOptions.timeout = timeout; + pushOptions.ignoreRead = false; + return pushOptions; + } + + private NfcWatchOptions createNfcWatchOptions() { + NfcWatchOptions options = new NfcWatchOptions(); + options.url = ""; + options.mediaType = ""; + options.mode = NfcWatchMode.ANY; + options.recordFilter = null; + return options; + } + + private NfcMessage createNfcMessage() { + NfcMessage message = new NfcMessage(); + message.url = ""; + message.data = new NfcRecord[1]; + + NfcRecord nfcRecord = new NfcRecord(); + nfcRecord.recordType = NfcRecordType.TEXT; + nfcRecord.mediaType = TEXT_MIME; + nfcRecord.data = ApiCompatibilityUtils.getBytesUtf8(TEST_TEXT); + message.data[0] = nfcRecord; + return message; + } + + private NfcMessage createNfcMessage(String url, NfcRecord record) { + NfcMessage message = new NfcMessage(); + message.url = url; + message.data = new NfcRecord[1]; + message.data[0] = record; + return message; + } + + private NdefMessage createUrlWebNFCNdefMessage(String webNfcId) { + NfcRecord urlNfcRecord = new NfcRecord(); + urlNfcRecord.recordType = NfcRecordType.URL; + urlNfcRecord.mediaType = TEXT_MIME; + urlNfcRecord.data = ApiCompatibilityUtils.getBytesUtf8(TEST_URL); + NfcMessage urlNfcMessage = createNfcMessage(webNfcId, urlNfcRecord); + try { + return NfcTypeConverter.toNdefMessage(urlNfcMessage); + } catch (InvalidNfcMessageException e) { + return null; + } + } +} diff --git a/chromium/services/device/public/cpp/BUILD.gn b/chromium/services/device/public/cpp/BUILD.gn index 22453cbe98d..abc3fdd6d5f 100644 --- a/chromium/services/device/public/cpp/BUILD.gn +++ b/chromium/services/device/public/cpp/BUILD.gn @@ -12,6 +12,7 @@ component("device_features") { "device_features.cc", "device_features_export.h", ] + configs += [ "//build/config/compiler:wexit_time_destructors" ] public_deps = [ "//base", ] diff --git a/chromium/services/device/public/cpp/generic_sensor/sensor_reading.cc b/chromium/services/device/public/cpp/generic_sensor/sensor_reading.cc index 9d9b5cf676e..bc98c33ccf7 100644 --- a/chromium/services/device/public/cpp/generic_sensor/sensor_reading.cc +++ b/chromium/services/device/public/cpp/generic_sensor/sensor_reading.cc @@ -40,9 +40,7 @@ SensorReading& SensorReading::operator=(const SensorReading& other) { // static uint64_t SensorReadingSharedBuffer::GetOffset(mojom::SensorType type) { - return (static_cast<uint64_t>(mojom::SensorType::LAST) - - static_cast<uint64_t>(type)) * - sizeof(SensorReadingSharedBuffer); + return static_cast<uint64_t>(type) * sizeof(SensorReadingSharedBuffer); } } // namespace device diff --git a/chromium/services/device/public/cpp/generic_sensor/sensor_traits.cc b/chromium/services/device/public/cpp/generic_sensor/sensor_traits.cc index 849b1dcdbe7..a191142c47b 100644 --- a/chromium/services/device/public/cpp/generic_sensor/sensor_traits.cc +++ b/chromium/services/device/public/cpp/generic_sensor/sensor_traits.cc @@ -40,7 +40,7 @@ double GetSensorMaxAllowedFrequency(SensorType type) { // No default so the compiler will warn us if a new type is added. } NOTREACHED() << "Unknown sensor type " << type; - return SensorTraits<SensorType::LAST>::kMaxAllowedFrequency; + return SensorTraits<SensorType::kMaxValue>::kMaxAllowedFrequency; } double GetSensorDefaultFrequency(mojom::SensorType type) { @@ -74,7 +74,7 @@ double GetSensorDefaultFrequency(mojom::SensorType type) { // No default so the compiler will warn us if a new type is added. } NOTREACHED() << "Unknown sensor type " << type; - return SensorTraits<SensorType::LAST>::kDefaultFrequency; + return SensorTraits<SensorType::kMaxValue>::kDefaultFrequency; } } // namespace device diff --git a/chromium/services/device/public/java/src/org/chromium/device/geolocation/LocationProvider.java b/chromium/services/device/public/java/src/org/chromium/device/geolocation/LocationProvider.java new file mode 100644 index 00000000000..4c6d0bf9d4f --- /dev/null +++ b/chromium/services/device/public/java/src/org/chromium/device/geolocation/LocationProvider.java @@ -0,0 +1,27 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.device.geolocation; + +/** + * LocationProvider interface. + */ +public interface LocationProvider { + /** + * Start listening for location updates. Calling several times before stop() is interpreted + * as restart. + * @param enableHighAccuracy Whether or not to enable high accuracy location. + */ + public void start(boolean enableHighAccuracy); + + /** + * Stop listening for location updates. + */ + public void stop(); + + /** + * Returns true if we are currently listening for location updates, false if not. + */ + public boolean isRunning(); +} diff --git a/chromium/services/device/public/java/src/org/chromium/device/geolocation/LocationProviderOverrider.java b/chromium/services/device/public/java/src/org/chromium/device/geolocation/LocationProviderOverrider.java new file mode 100644 index 00000000000..a1cfab1c1f7 --- /dev/null +++ b/chromium/services/device/public/java/src/org/chromium/device/geolocation/LocationProviderOverrider.java @@ -0,0 +1,16 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.device.geolocation; + +/** + * Set the MockLocationProvider to LocationProviderFactory. Used for test only. + */ +final public class LocationProviderOverrider { + public static void setLocationProviderImpl(LocationProvider provider) { + LocationProviderFactory.setLocationProviderImpl(provider); + } + + private LocationProviderOverrider() {} +}; diff --git a/chromium/services/device/public/java/src/org/chromium/device/geolocation/MockLocationProvider.java b/chromium/services/device/public/java/src/org/chromium/device/geolocation/MockLocationProvider.java new file mode 100644 index 00000000000..6b85d437be9 --- /dev/null +++ b/chromium/services/device/public/java/src/org/chromium/device/geolocation/MockLocationProvider.java @@ -0,0 +1,86 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.device.geolocation; + +import android.location.Location; +import android.os.Handler; +import android.os.HandlerThread; +import android.os.Message; + +/** + * A mock location provider. When started, runs a background thread that periodically + * posts location updates. This does not involve any system Location APIs and thus + * does not require any special permissions in the test app or on the device. + */ +public class MockLocationProvider implements LocationProvider { + private boolean mIsRunning; + private Handler mHandler; + private HandlerThread mHandlerThread; + private final Object mLock = new Object(); + + private static final int UPDATE_LOCATION_MSG = 100; + + public MockLocationProvider() {} + + public void stopUpdates() { + if (mHandlerThread != null) { + mHandlerThread.quit(); + } + } + + @Override + public void start(boolean enableHighAccuracy) { + if (mIsRunning) return; + + if (mHandlerThread == null) { + startMockLocationProviderThread(); + } + + mIsRunning = true; + synchronized (mLock) { + mHandler.sendEmptyMessage(UPDATE_LOCATION_MSG); + } + } + + @Override + public void stop() { + if (!mIsRunning) return; + mIsRunning = false; + synchronized (mLock) { + mHandler.removeMessages(UPDATE_LOCATION_MSG); + } + } + + @Override + public boolean isRunning() { + return mIsRunning; + } + + private void startMockLocationProviderThread() { + assert mHandlerThread == null; + assert mHandler == null; + + mHandlerThread = new HandlerThread("MockLocationProviderImpl"); + mHandlerThread.start(); + mHandler = new Handler(mHandlerThread.getLooper()) { + @Override + public void handleMessage(Message msg) { + synchronized (mLock) { + if (msg.what == UPDATE_LOCATION_MSG) { + newLocation(); + sendEmptyMessageDelayed(UPDATE_LOCATION_MSG, 250); + } + } + } + }; + } + + private void newLocation() { + Location location = new Location("MockLocationProvider"); + location.setTime(System.currentTimeMillis()); + location.setAccuracy(0.5f); + LocationProviderAdapter.onNewLocationAvailable(location); + } +}; diff --git a/chromium/services/device/public/java/src/org/chromium/device/nfc/NfcDelegate.java b/chromium/services/device/public/java/src/org/chromium/device/nfc/NfcDelegate.java new file mode 100644 index 00000000000..9b144dc3d4c --- /dev/null +++ b/chromium/services/device/public/java/src/org/chromium/device/nfc/NfcDelegate.java @@ -0,0 +1,24 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.device.nfc; + +import android.app.Activity; + +import org.chromium.base.Callback; + +/** Interface that allows the NFC implementation to access the Activity associated with a given + * client. |hostId| is the same ID passed in NFCProvider::GetNFCForHost(). + */ +public interface NfcDelegate { + /** Calls |callback| with the Activity associated with |hostId|, and subsequently calls + * |callback| again whenever the Activity associated with |hostId| changes. + */ + void trackActivityForHost(int hostId, Callback<Activity> callback); + + /** Called when the NFC implementation no longer needs to track the Activity associated with + * |hostId|. + */ + void stopTrackingActivityForHost(int hostId); +} diff --git a/chromium/services/device/public/mojom/BUILD.gn b/chromium/services/device/public/mojom/BUILD.gn index 212fb1be442..4a786dbf0ba 100644 --- a/chromium/services/device/public/mojom/BUILD.gn +++ b/chromium/services/device/public/mojom/BUILD.gn @@ -9,6 +9,7 @@ mojom("mojom") { sources = [ "battery_monitor.mojom", "battery_status.mojom", + "bluetooth_system.mojom", "fingerprint.mojom", "geolocation.mojom", "geolocation_config.mojom", diff --git a/chromium/services/device/public/mojom/bluetooth_system.mojom b/chromium/services/device/public/mojom/bluetooth_system.mojom new file mode 100644 index 00000000000..bc522525b28 --- /dev/null +++ b/chromium/services/device/public/mojom/bluetooth_system.mojom @@ -0,0 +1,44 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +module device.mojom; + +// Factory to get an instance of the BluetoothSystem interface. +interface BluetoothSystemFactory { + Create(BluetoothSystem& system, BluetoothSystemClient system_client); +}; + +// High level interface targeted towards UI level components that: +// - Show the BT Radio state and allow users to change it. +// - Show a list of nearby, connected and paired BT Devices. +// - Start and stop BT scans. +// - Connect to and pair with BT devices. +// +// This interface is implemented only on Chrome OS and lives in the Device +// Service. +interface BluetoothSystem { + + // State of Bluetooth. + enum State { + // The platform does not support Bluetooth. + kUnsupported, + // The platform supports Bluetooth but we can’t use it right now e.g. a BT + // radio is not present. + kUnavailable, + // Bluetooth Radio is off. + kPoweredOff, + // State is transitioning between PoweredOff and PoweredOn or vice versa. + kTransitioning, + // Bluetooth Radio is on. + kPoweredOn, + }; + + GetState() => (State state); +}; + +// Interface used by clients of BluetoothSystem to get notified of events +// like Bluetooth State changes. +interface BluetoothSystemClient { + OnStateChanged(BluetoothSystem.State new_state); +}; diff --git a/chromium/services/device/public/mojom/sensor.mojom b/chromium/services/device/public/mojom/sensor.mojom index 19e9fe8eed7..ed8bd83d09e 100644 --- a/chromium/services/device/public/mojom/sensor.mojom +++ b/chromium/services/device/public/mojom/sensor.mojom @@ -8,8 +8,7 @@ module device.mojom; // When adding new sensor type, update the documentation of sensor data values // in SensorReading struct at sensor_reading.h file. enum SensorType { - FIRST = 1, - AMBIENT_LIGHT = FIRST, + AMBIENT_LIGHT, PROXIMITY, ACCELEROMETER, LINEAR_ACCELERATION, @@ -24,7 +23,6 @@ enum SensorType { RELATIVE_ORIENTATION_EULER_ANGLES, // Recommended for new code. RELATIVE_ORIENTATION_QUATERNION, - LAST = RELATIVE_ORIENTATION_QUATERNION // Note: LAST is also equal to the types count. }; // Reporting mode supported by the Sensor. diff --git a/chromium/services/device/screen_orientation/android/java/src/org/chromium/device/screen_orientation/ScreenOrientationListener.java b/chromium/services/device/screen_orientation/android/java/src/org/chromium/device/screen_orientation/ScreenOrientationListener.java new file mode 100644 index 00000000000..f6f9c9a3c30 --- /dev/null +++ b/chromium/services/device/screen_orientation/android/java/src/org/chromium/device/screen_orientation/ScreenOrientationListener.java @@ -0,0 +1,46 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.device.screen_orientation; + +import android.provider.Settings; + +import org.chromium.base.ContextUtils; +import org.chromium.base.ThreadUtils; +import org.chromium.base.annotations.CalledByNative; +import org.chromium.base.annotations.JNINamespace; +import org.chromium.ui.display.DisplayAndroid; + +/** + * Android implementation details for device::ScreenOrientationListenerAndroid. + */ +@JNINamespace("device") +class ScreenOrientationListener { + @CalledByNative + static void startAccurateListening() { + ThreadUtils.runOnUiThread(new Runnable() { + @Override + public void run() { + DisplayAndroid.startAccurateListening(); + } + }); + } + + @CalledByNative + static void stopAccurateListening() { + ThreadUtils.runOnUiThread(new Runnable() { + @Override + public void run() { + DisplayAndroid.stopAccurateListening(); + } + }); + } + + @CalledByNative + static boolean isAutoRotateEnabledByUser() { + return Settings.System.getInt(ContextUtils.getApplicationContext().getContentResolver(), + Settings.System.ACCELEROMETER_ROTATION, 0) + == 1; + } +} diff --git a/chromium/services/device/time_zone_monitor/android/java/src/org/chromium/device/time_zone_monitor/TimeZoneMonitor.java b/chromium/services/device/time_zone_monitor/android/java/src/org/chromium/device/time_zone_monitor/TimeZoneMonitor.java new file mode 100644 index 00000000000..f927cd8195f --- /dev/null +++ b/chromium/services/device/time_zone_monitor/android/java/src/org/chromium/device/time_zone_monitor/TimeZoneMonitor.java @@ -0,0 +1,67 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.device.time_zone_monitor; + +import android.content.BroadcastReceiver; +import android.content.Context; +import android.content.Intent; +import android.content.IntentFilter; + +import org.chromium.base.ContextUtils; +import org.chromium.base.Log; +import org.chromium.base.annotations.CalledByNative; +import org.chromium.base.annotations.JNINamespace; + +/** + * Android implementation details for device::TimeZoneMonitorAndroid. + */ +@JNINamespace("device") +class TimeZoneMonitor { + private static final String TAG = "cr_TimeZoneMonitor"; + + private final IntentFilter mFilter = new IntentFilter(Intent.ACTION_TIMEZONE_CHANGED); + private final BroadcastReceiver mBroadcastReceiver = new BroadcastReceiver() { + @Override + public void onReceive(Context context, Intent intent) { + if (!intent.getAction().equals(Intent.ACTION_TIMEZONE_CHANGED)) { + Log.e(TAG, "unexpected intent"); + return; + } + + nativeTimeZoneChangedFromJava(mNativePtr); + } + }; + + private long mNativePtr; + + /** + * Start listening for intents. + * @param nativePtr The native device::TimeZoneMonitorAndroid to notify of time zone changes. + */ + private TimeZoneMonitor(long nativePtr) { + mNativePtr = nativePtr; + ContextUtils.getApplicationContext().registerReceiver(mBroadcastReceiver, mFilter); + } + + @CalledByNative + static TimeZoneMonitor getInstance(long nativePtr) { + return new TimeZoneMonitor(nativePtr); + } + + /** + * Stop listening for intents. + */ + @CalledByNative + void stop() { + ContextUtils.getApplicationContext().unregisterReceiver(mBroadcastReceiver); + mNativePtr = 0; + } + + /** + * Native JNI call to device::TimeZoneMonitorAndroid::TimeZoneChanged. + * See device/time_zone_monitor/time_zone_monitor_android.cc. + */ + private native void nativeTimeZoneChangedFromJava(long nativeTimeZoneMonitorAndroid); +} diff --git a/chromium/services/device/unittest_manifest.json b/chromium/services/device/unittest_manifest.json index b6889ea27fb..ce96a0ee9a1 100644 --- a/chromium/services/device/unittest_manifest.json +++ b/chromium/services/device/unittest_manifest.json @@ -10,6 +10,7 @@ }, "requires": { "device": [ + "device:bluetooth_system", "device:battery_monitor", "device:generic_sensor", "device:geolocation_config", diff --git a/chromium/services/device/vibration/android/java/src/org/chromium/device/vibration/VibrationManagerImpl.java b/chromium/services/device/vibration/android/java/src/org/chromium/device/vibration/VibrationManagerImpl.java new file mode 100644 index 00000000000..0c5be4df7a7 --- /dev/null +++ b/chromium/services/device/vibration/android/java/src/org/chromium/device/vibration/VibrationManagerImpl.java @@ -0,0 +1,110 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.device.vibration; + +import android.content.Context; +import android.content.pm.PackageManager; +import android.media.AudioManager; +import android.os.Vibrator; + +import org.chromium.base.ContextUtils; +import org.chromium.base.Log; +import org.chromium.base.annotations.CalledByNative; +import org.chromium.base.annotations.JNINamespace; +import org.chromium.device.mojom.VibrationManager; +import org.chromium.mojo.system.MojoException; +import org.chromium.services.service_manager.InterfaceFactory; + +/** + * Android implementation of the VibrationManager interface defined in + * services/device/public/mojom/vibration_manager.mojom. + */ +@JNINamespace("device") +public class VibrationManagerImpl implements VibrationManager { + private static final String TAG = "VibrationManagerImpl"; + + private static final long MINIMUM_VIBRATION_DURATION_MS = 1; // 1 millisecond + private static final long MAXIMUM_VIBRATION_DURATION_MS = 10000; // 10 seconds + + private final AudioManager mAudioManager; + private final Vibrator mVibrator; + private final boolean mHasVibratePermission; + + private static long sVibrateMilliSecondsForTesting = -1; + private static boolean sVibrateCancelledForTesting; + + public VibrationManagerImpl() { + Context appContext = ContextUtils.getApplicationContext(); + mAudioManager = (AudioManager) appContext.getSystemService(Context.AUDIO_SERVICE); + mVibrator = (Vibrator) appContext.getSystemService(Context.VIBRATOR_SERVICE); + // TODO(mvanouwerkerk): What happens if permission is revoked? Handle this better. + mHasVibratePermission = + appContext.checkCallingOrSelfPermission(android.Manifest.permission.VIBRATE) + == PackageManager.PERMISSION_GRANTED; + if (!mHasVibratePermission) { + Log.w(TAG, "Failed to use vibrate API, requires VIBRATE permission."); + } + } + + @Override + public void close() {} + + @Override + public void onConnectionError(MojoException e) {} + + @Override + public void vibrate(long milliseconds, VibrateResponse callback) { + // Though the Blink implementation already sanitizes vibration times, don't + // trust any values passed from the client. + long sanitizedMilliseconds = Math.max(MINIMUM_VIBRATION_DURATION_MS, + Math.min(milliseconds, MAXIMUM_VIBRATION_DURATION_MS)); + + if (mAudioManager.getRingerMode() != AudioManager.RINGER_MODE_SILENT + && mHasVibratePermission) { + mVibrator.vibrate(sanitizedMilliseconds); + } + setVibrateMilliSecondsForTesting(sanitizedMilliseconds); + callback.call(); + } + + @Override + public void cancel(CancelResponse callback) { + if (mHasVibratePermission) { + mVibrator.cancel(); + } + setVibrateCancelledForTesting(true); + callback.call(); + } + + /** + * A factory for implementations of the VibrationManager interface. + */ + public static class Factory implements InterfaceFactory<VibrationManager> { + public Factory() {} + + @Override + public VibrationManager createImpl() { + return new VibrationManagerImpl(); + } + } + + static void setVibrateMilliSecondsForTesting(long milliseconds) { + sVibrateMilliSecondsForTesting = milliseconds; + } + + static void setVibrateCancelledForTesting(boolean cancelled) { + sVibrateCancelledForTesting = cancelled; + } + + @CalledByNative + static long getVibrateMilliSecondsForTesting() { + return sVibrateMilliSecondsForTesting; + } + + @CalledByNative + static boolean getVibrateCancelledForTesting() { + return sVibrateCancelledForTesting; + } +} diff --git a/chromium/services/device/wake_lock/power_save_blocker/android/java/src/org/chromium/device/power_save_blocker/PowerSaveBlocker.java b/chromium/services/device/wake_lock/power_save_blocker/android/java/src/org/chromium/device/power_save_blocker/PowerSaveBlocker.java new file mode 100644 index 00000000000..f83cb43fa2a --- /dev/null +++ b/chromium/services/device/wake_lock/power_save_blocker/android/java/src/org/chromium/device/power_save_blocker/PowerSaveBlocker.java @@ -0,0 +1,44 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.device.power_save_blocker; + +import android.view.View; + +import org.chromium.base.annotations.CalledByNative; +import org.chromium.base.annotations.JNINamespace; + +import java.lang.ref.WeakReference; + +@JNINamespace("device") +class PowerSaveBlocker { + // WeakReference to prevent leaks in Android WebView. + private WeakReference<View> mKeepScreenOnView; + + @CalledByNative + private static PowerSaveBlocker create() { + return new PowerSaveBlocker(); + } + + private PowerSaveBlocker() {} + + @CalledByNative + private void applyBlock(View view) { + assert mKeepScreenOnView == null; + mKeepScreenOnView = new WeakReference<>(view); + view.setKeepScreenOn(true); + } + + @CalledByNative + private void removeBlock() { + // mKeepScreenOnView may be null since it's possible that |applyBlock()| was + // not invoked due to having failed to get a view to call |setKeepScrenOn| on. + if (mKeepScreenOnView == null) return; + View view = mKeepScreenOnView.get(); + mKeepScreenOnView = null; + if (view == null) return; + + view.setKeepScreenOn(false); + } +} diff --git a/chromium/services/file/BUILD.gn b/chromium/services/file/BUILD.gn index 2fe2ea25f06..ff7ff53e148 100644 --- a/chromium/services/file/BUILD.gn +++ b/chromium/services/file/BUILD.gn @@ -15,6 +15,8 @@ source_set("lib") { "user_id_map.h", ] + configs += [ "//build/config/compiler:wexit_time_destructors" ] + deps = [ "//base", "//components/services/filesystem:lib", diff --git a/chromium/services/identity/BUILD.gn b/chromium/services/identity/BUILD.gn index 045e94b4b43..938c3bf42ec 100644 --- a/chromium/services/identity/BUILD.gn +++ b/chromium/services/identity/BUILD.gn @@ -14,6 +14,8 @@ source_set("lib") { "identity_service.h", ] + configs += [ "//build/config/compiler:wexit_time_destructors" ] + deps = [ "//base", "//components/signin/core/browser", diff --git a/chromium/services/identity/identity_manager_impl_unittest.cc b/chromium/services/identity/identity_manager_impl_unittest.cc index 0e9bb8923ac..5f14f535987 100644 --- a/chromium/services/identity/identity_manager_impl_unittest.cc +++ b/chromium/services/identity/identity_manager_impl_unittest.cc @@ -101,6 +101,7 @@ class IdentityManagerImplTest : public service_manager::test::ServiceTest { nullptr) { #endif AccountTrackerService::RegisterPrefs(pref_service_.registry()); + ProfileOAuth2TokenService::RegisterProfilePrefs(pref_service_.registry()); SigninManagerBase::RegisterProfilePrefs(pref_service_.registry()); SigninManagerBase::RegisterPrefs(pref_service_.registry()); diff --git a/chromium/services/identity/public/cpp/BUILD.gn b/chromium/services/identity/public/cpp/BUILD.gn index 0207855b2fc..da86b1dffe4 100644 --- a/chromium/services/identity/public/cpp/BUILD.gn +++ b/chromium/services/identity/public/cpp/BUILD.gn @@ -14,6 +14,8 @@ source_set("cpp") { "primary_account_access_token_fetcher.h", ] + configs += [ "//build/config/compiler:wexit_time_destructors" ] + public_deps = [ "//components/signin/core/browser", ] @@ -29,6 +31,8 @@ source_set("cpp_types") { "account_state.h", "scope_set.h", ] + + configs += [ "//build/config/compiler:wexit_time_destructors" ] } source_set("test_support") { diff --git a/chromium/services/identity/public/cpp/DEPS b/chromium/services/identity/public/cpp/DEPS index 6efe08bc7b4..a972c566bfd 100644 --- a/chromium/services/identity/public/cpp/DEPS +++ b/chromium/services/identity/public/cpp/DEPS @@ -8,5 +8,6 @@ include_rules = [ "+components/signin/core/browser/signin_switches.h", "+google_apis/gaia/gaia_auth_util.h", "+google_apis/gaia/google_service_auth_error.h", + "+google_apis/gaia/oauth2_access_token_consumer.h", "+google_apis/gaia/oauth2_token_service.h", ] diff --git a/chromium/services/identity/public/cpp/README.md b/chromium/services/identity/public/cpp/README.md index 0bba2301b56..71aa75c2ee8 100644 --- a/chromium/services/identity/public/cpp/README.md +++ b/chromium/services/identity/public/cpp/README.md @@ -7,8 +7,14 @@ over the bare Identity Service Mojo interfaces such as: - Synchronous access to the information of the primary account (via caching) -A cheat sheet for developers migrating from usage of //components/signin and -//google_apis/gaia: +Documentation on the mapping between usage of legacy signin +classes (notably SigninManager(Base) and ProfileOAuth2TokenService) and usage of +IdentityManager is available here: + +https://docs.google.com/document/d/14f3qqkDM9IE4Ff_l6wuXvCMeHfSC9TxKezXTCyeaPUY/edit# + +A quick inline cheat sheet for developers migrating from usage of //components/ +signin and //google_apis/gaia: - "Primary account" in IdentityManager refers to what is called the "authenticated account" in SigninManager, i.e., the account that has been diff --git a/chromium/services/identity/public/cpp/access_token_fetcher.cc b/chromium/services/identity/public/cpp/access_token_fetcher.cc index 7069df9c59a..2e4043b3004 100644 --- a/chromium/services/identity/public/cpp/access_token_fetcher.cc +++ b/chromium/services/identity/public/cpp/access_token_fetcher.cc @@ -81,9 +81,10 @@ void AccessTokenFetcher::OnGetTokenSuccess( std::unique_ptr<OAuth2TokenService::Request> request_deleter( std::move(access_token_request_)); - RunCallbackAndMaybeDie(GoogleServiceAuthError::AuthErrorNone(), - AccessTokenInfo(token_response.access_token, - token_response.expiration_time)); + RunCallbackAndMaybeDie( + GoogleServiceAuthError::AuthErrorNone(), + AccessTokenInfo(token_response.access_token, + token_response.expiration_time, token_response.id_token)); // Potentially dead after the above invocation; nothing to do except return. } diff --git a/chromium/services/identity/public/cpp/access_token_fetcher_unittest.cc b/chromium/services/identity/public/cpp/access_token_fetcher_unittest.cc index 9862f5338ed..f1059406592 100644 --- a/chromium/services/identity/public/cpp/access_token_fetcher_unittest.cc +++ b/chromium/services/identity/public/cpp/access_token_fetcher_unittest.cc @@ -16,6 +16,7 @@ #include "components/signin/core/browser/fake_profile_oauth2_token_service.h" #include "components/signin/core/browser/test_signin_client.h" #include "components/sync_preferences/testing_pref_service_syncable.h" +#include "google_apis/gaia/oauth2_access_token_consumer.h" #include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h" @@ -33,6 +34,11 @@ const char kTestGaiaId2[] = "dummyId2"; const char kTestEmail[] = "me@gmail.com"; const char kTestEmail2[] = "me2@gmail.com"; +// Used just to check that the id_token is passed along. +const char kIdTokenEmptyServices[] = + "dummy-header." + "eyAic2VydmljZXMiOiBbXSB9" // payload: { "services": [] } + ".dummy-signature"; } // namespace class AccessTokenFetcherTest : public testing::Test, @@ -45,7 +51,8 @@ class AccessTokenFetcherTest : public testing::Test, : signin_client_(&pref_service_), token_service_(&pref_service_), access_token_info_("access token", - base::Time::Now() + base::TimeDelta::FromHours(1)) { + base::Time::Now() + base::TimeDelta::FromHours(1), + std::string(kIdTokenEmptyServices)) { AccountTrackerService::RegisterPrefs(pref_service_.registry()); account_tracker_ = std::make_unique<AccountTrackerService>(); @@ -126,8 +133,10 @@ TEST_F(AccessTokenFetcherTest, OneShotShouldCallBackOnFulfilledRequest) { access_token_info())); token_service()->IssueAllTokensForAccount( - account_id, access_token_info().token, - access_token_info().expiration_time); + account_id, + OAuth2AccessTokenConsumer::TokenResponse( + access_token_info().token, access_token_info().expiration_time, + access_token_info().id_token)); } TEST_F(AccessTokenFetcherTest, @@ -154,8 +163,10 @@ TEST_F(AccessTokenFetcherTest, access_token_info())); token_service()->IssueAllTokensForAccount( - account_id, access_token_info().token, - access_token_info().expiration_time); + account_id, + OAuth2AccessTokenConsumer::TokenResponse( + access_token_info().token, access_token_info().expiration_time, + access_token_info().id_token)); } TEST_F(AccessTokenFetcherTest, @@ -176,8 +187,10 @@ TEST_F(AccessTokenFetcherTest, // Before the refresh token is available, the callback shouldn't get called. EXPECT_CALL(callback, Run(_, _)).Times(0); token_service()->IssueAllTokensForAccount( - account_id, access_token_info().token, - access_token_info().expiration_time); + account_id, + OAuth2AccessTokenConsumer::TokenResponse( + access_token_info().token, access_token_info().expiration_time, + access_token_info().id_token)); // Once the refresh token becomes available, we should get an access token // request. @@ -191,8 +204,10 @@ TEST_F(AccessTokenFetcherTest, access_token_info())); token_service()->IssueAllTokensForAccount( - account_id, access_token_info().token, - access_token_info().expiration_time); + account_id, + OAuth2AccessTokenConsumer::TokenResponse( + access_token_info().token, access_token_info().expiration_time, + access_token_info().id_token)); } TEST_F(AccessTokenFetcherTest, @@ -239,8 +254,10 @@ TEST_F(AccessTokenFetcherTest, ShouldNotReplyIfDestroyed) { // Now fulfilling the access token request should have no effect. token_service()->IssueAllTokensForAccount( - account_id, access_token_info().token, - access_token_info().expiration_time); + account_id, + OAuth2AccessTokenConsumer::TokenResponse( + access_token_info().token, access_token_info().expiration_time, + access_token_info().id_token)); } TEST_F(AccessTokenFetcherTest, ReturnsErrorWhenAccountIsUnknown) { @@ -391,8 +408,10 @@ TEST_F(AccessTokenFetcherTest, MultipleRequestsForSameAccountFulfilled) { EXPECT_CALL(callback2, Run(GoogleServiceAuthError::AuthErrorNone(), access_token_info())); token_service()->IssueAllTokensForAccount( - account_id, access_token_info().token, - access_token_info().expiration_time); + account_id, + OAuth2AccessTokenConsumer::TokenResponse( + access_token_info().token, access_token_info().expiration_time, + access_token_info().id_token)); } TEST_F(AccessTokenFetcherTest, MultipleRequestsForDifferentAccountsFulfilled) { @@ -425,16 +444,20 @@ TEST_F(AccessTokenFetcherTest, MultipleRequestsForDifferentAccountsFulfilled) { EXPECT_CALL(callback, Run(GoogleServiceAuthError::AuthErrorNone(), access_token_info())); token_service()->IssueAllTokensForAccount( - account_id, access_token_info().token, - access_token_info().expiration_time); + account_id, + OAuth2AccessTokenConsumer::TokenResponse( + access_token_info().token, access_token_info().expiration_time, + access_token_info().id_token)); // Once the second access token request is fulfilled, it should get // called back with the access token. EXPECT_CALL(callback2, Run(GoogleServiceAuthError::AuthErrorNone(), access_token_info())); token_service()->IssueAllTokensForAccount( - account_id2, access_token_info().token, - access_token_info().expiration_time); + account_id2, + OAuth2AccessTokenConsumer::TokenResponse( + access_token_info().token, access_token_info().expiration_time, + access_token_info().id_token)); } TEST_F(AccessTokenFetcherTest, @@ -486,8 +509,10 @@ TEST_F(AccessTokenFetcherTest, Run(GoogleServiceAuthError::AuthErrorNone(), access_token_info())) .WillOnce(testing::InvokeWithoutArgs(&run_loop4, &base::RunLoop::Quit)); token_service()->IssueAllTokensForAccount( - account_id2, access_token_info().token, - access_token_info().expiration_time); + account_id2, + OAuth2AccessTokenConsumer::TokenResponse( + access_token_info().token, access_token_info().expiration_time, + access_token_info().id_token)); run_loop4.Run(); } diff --git a/chromium/services/identity/public/cpp/access_token_info.cc b/chromium/services/identity/public/cpp/access_token_info.cc index d4c68091ed1..f29bebd1881 100644 --- a/chromium/services/identity/public/cpp/access_token_info.cc +++ b/chromium/services/identity/public/cpp/access_token_info.cc @@ -8,7 +8,8 @@ namespace identity { bool operator==(const AccessTokenInfo& lhs, const AccessTokenInfo& rhs) { return (lhs.token == rhs.token) && - (lhs.expiration_time == rhs.expiration_time); + (lhs.expiration_time == rhs.expiration_time) && + (lhs.id_token == rhs.id_token); } } // namespace identity diff --git a/chromium/services/identity/public/cpp/access_token_info.h b/chromium/services/identity/public/cpp/access_token_info.h index 0a615f9c325..4d12757e520 100644 --- a/chromium/services/identity/public/cpp/access_token_info.h +++ b/chromium/services/identity/public/cpp/access_token_info.h @@ -20,10 +20,18 @@ struct AccessTokenInfo { // The time at which this access token will expire. base::Time expiration_time; + // Contains extra information regarding the user's currently registered + // services. It is uncommon for consumers to need to interact with this field. + // To interact with it, first parse it via gaia::ParseServiceFlags(). + std::string id_token; + AccessTokenInfo() = default; AccessTokenInfo(const std::string& token_param, - const base::Time& expiration_time_param) - : token(token_param), expiration_time(expiration_time_param) {} + const base::Time& expiration_time_param, + const std::string& id_token) + : token(token_param), + expiration_time(expiration_time_param), + id_token(id_token) {} }; // Defined for testing purposes only. diff --git a/chromium/services/identity/public/cpp/identity_manager.cc b/chromium/services/identity/public/cpp/identity_manager.cc index 6bfbe198ecc..ce70d354cdb 100644 --- a/chromium/services/identity/public/cpp/identity_manager.cc +++ b/chromium/services/identity/public/cpp/identity_manager.cc @@ -51,107 +51,29 @@ IdentityManager::IdentityManager( token_service_(token_service), account_tracker_service_(account_tracker_service), gaia_cookie_manager_service_(gaia_cookie_manager_service) { - // Initialize the state of the primary account. - primary_account_info_ = signin_manager_->GetAuthenticatedAccountInfo(); - - // Initialize the state of accounts with refresh tokens. - // |account_id| is moved into |accounts_with_refresh_tokens_|. - // Do not change this to "const std::string&". - for (std::string account_id : token_service->GetAccounts()) { - AccountInfo account_info = - account_tracker_service_->GetAccountInfo(account_id); - - // In the context of supervised users, the ProfileOAuth2TokenService is used - // without the AccountTrackerService being used. This is the only case in - // which the AccountTrackerService will potentially not know about the - // account. In this context, |account_id| is always set to - // kSupervisedUserPseudoEmail. - // TODO(860492): Remove this special case once supervised user support is - // removed. - DCHECK(!account_info.IsEmpty() || account_id == kSupervisedUserPseudoEmail); - - if (account_id == kSupervisedUserPseudoEmail && account_info.IsEmpty()) { - // Populate the information manually to maintain the invariant that the - // account ID, gaia ID, and email are always set. - account_info.account_id = account_id; - account_info.email = kSupervisedUserPseudoEmail; - account_info.gaia = kSupervisedUserPseudoGaiaID; - } - - accounts_with_refresh_tokens_.emplace(std::move(account_id), - std::move(account_info)); - } - signin_manager_->AddObserver(this); -#if !defined(OS_CHROMEOS) - SigninManager::FromSigninManagerBase(signin_manager_) - ->set_diagnostics_client(this); -#endif token_service_->AddDiagnosticsObserver(this); token_service_->AddObserver(this); - token_service_->set_diagnostics_client(this); gaia_cookie_manager_service_->AddObserver(this); } IdentityManager::~IdentityManager() { signin_manager_->RemoveObserver(this); -#if !defined(OS_CHROMEOS) - SigninManager::FromSigninManagerBase(signin_manager_) - ->set_diagnostics_client(nullptr); -#endif token_service_->RemoveObserver(this); token_service_->RemoveDiagnosticsObserver(this); - token_service_->set_diagnostics_client(nullptr); gaia_cookie_manager_service_->RemoveObserver(this); } AccountInfo IdentityManager::GetPrimaryAccountInfo() const { -#if defined(OS_CHROMEOS) - // On ChromeOS in production, the authenticated account is set very early in - // startup and never changed. Hence, the information held by the - // IdentityManager should always correspond to that held by SigninManager. - // NOTE: the above invariant is not guaranteed to hold in tests. If you - // are seeing this DCHECK go off in a testing context, it means that you need - // to set the IdentityManager's primary account info in the test at the place - // where you are setting the authenticated account info in the SigninManager. - // TODO(blundell): Add the API to do this once we hit the first case and - // document the API to use here. - DCHECK_EQ(signin_manager_->GetAuthenticatedAccountId(), - primary_account_info_.account_id); - - // Note: If the primary account's refresh token gets revoked, then the account - // gets removed from AccountTrackerService (via - // AccountFetcherService::OnRefreshTokenRevoked), and so SigninManager's - // GetAuthenticatedAccountInfo is empty (even though - // GetAuthenticatedAccountId is NOT empty). - if (!signin_manager_->GetAuthenticatedAccountInfo().account_id.empty()) { - DCHECK_EQ(signin_manager_->GetAuthenticatedAccountInfo().account_id, - primary_account_info_.account_id); - DCHECK_EQ(signin_manager_->GetAuthenticatedAccountInfo().gaia, - primary_account_info_.gaia); - - // TODO(842670): As described in the bug, AccountTrackerService's email - // address can be updated after it is initially set on ChromeOS. Figure out - // right long-term solution for this problem. - if (signin_manager_->GetAuthenticatedAccountInfo().email != - primary_account_info_.email) { - // This update should only be to move it from normalized form to the form - // in which the user entered the email when creating the account. The - // below check verifies that the normalized forms of the two email - // addresses are identical. - DCHECK(gaia::AreEmailsSame( - signin_manager_->GetAuthenticatedAccountInfo().email, - primary_account_info_.email)); - primary_account_info_.email = - signin_manager_->GetAuthenticatedAccountInfo().email; - } - } -#endif // defined(OS_CHROMEOS) - return primary_account_info_; + return signin_manager_->GetAuthenticatedAccountInfo(); +} + +const std::string& IdentityManager::GetPrimaryAccountId() const { + return signin_manager_->GetAuthenticatedAccountId(); } bool IdentityManager::HasPrimaryAccount() const { - return !primary_account_info_.account_id.empty(); + return signin_manager_->IsAuthenticated(); } #if !defined(OS_CHROMEOS) @@ -176,20 +98,19 @@ void IdentityManager::ClearPrimaryAccount( break; } - // NOTE: |primary_account_| member is cleared in WillFireGoogleSignedOut() - // and IdentityManager::Observers are notified in GoogleSignedOut(); + // NOTE: IdentityManager::Observers are notified in GoogleSignedOut(). } #endif // defined(OS_CHROMEOS) std::vector<AccountInfo> IdentityManager::GetAccountsWithRefreshTokens() const { - // TODO(blundell): It seems wasteful to construct this vector every time this - // method is called, but it also seems bad to maintain the vector as an ivar - // along the map. + std::vector<std::string> account_ids_with_tokens = + token_service_->GetAccounts(); + std::vector<AccountInfo> accounts; - accounts.reserve(accounts_with_refresh_tokens_.size()); + accounts.reserve(account_ids_with_tokens.size()); - for (const auto& pair : accounts_with_refresh_tokens_) { - accounts.push_back(pair.second); + for (const std::string& account_id : account_ids_with_tokens) { + accounts.push_back(GetAccountInfoForAccountWithRefreshToken(account_id)); } return accounts; @@ -208,11 +129,21 @@ std::vector<AccountInfo> IdentityManager::GetAccountsInCookieJar( bool IdentityManager::HasAccountWithRefreshToken( const std::string& account_id) const { - return base::ContainsKey(accounts_with_refresh_tokens_, account_id); + return token_service_->RefreshTokenIsAvailable(account_id); +} + +bool IdentityManager::HasAccountWithRefreshTokenInPersistentErrorState( + const std::string& account_id) const { + return GetErrorStateOfRefreshTokenForAccount(account_id).IsPersistentError(); +} + +GoogleServiceAuthError IdentityManager::GetErrorStateOfRefreshTokenForAccount( + const std::string& account_id) const { + return token_service_->GetAuthError(account_id); } bool IdentityManager::HasPrimaryAccountWithRefreshToken() const { - return HasAccountWithRefreshToken(GetPrimaryAccountInfo().account_id); + return HasAccountWithRefreshToken(GetPrimaryAccountId()); } std::unique_ptr<AccessTokenFetcher> @@ -269,54 +200,18 @@ void IdentityManager::SetPrimaryAccountSynchronously( const std::string& email_address, const std::string& refresh_token) { signin_manager_->SetAuthenticatedAccountInfo(gaia_id, email_address); - primary_account_info_ = signin_manager_->GetAuthenticatedAccountInfo(); if (!refresh_token.empty()) { - token_service_->UpdateCredentials(primary_account_info_.account_id, - refresh_token); - } -} - -#if !defined(OS_CHROMEOS) -void IdentityManager::WillFireGoogleSigninSucceeded( - const AccountInfo& account_info) { - // TODO(843510): Consider setting this info and notifying observers - // asynchronously in response to GoogleSigninSucceeded() once there are no - // direct clients of SigninManager. - primary_account_info_ = account_info; -} - -void IdentityManager::WillFireGoogleSignedOut(const AccountInfo& account_info) { - // TODO(843510): Consider setting this info and notifying observers - // asynchronously in response to GoogleSigninSucceeded() once there are no - // direct clients of SigninManager. - DCHECK_EQ(account_info.account_id, primary_account_info_.account_id); - DCHECK_EQ(account_info.gaia, primary_account_info_.gaia); - DCHECK(gaia::AreEmailsSame(account_info.email, primary_account_info_.email)); - primary_account_info_ = AccountInfo(); -} -#endif - -void IdentityManager::GoogleSigninSucceeded(const AccountInfo& account_info) { - DCHECK(account_info.account_id == primary_account_info_.account_id); - DCHECK(account_info.gaia == primary_account_info_.gaia); - DCHECK(account_info.email == primary_account_info_.email); - for (auto& observer : observer_list_) { - observer.OnPrimaryAccountSet(account_info); + token_service_->UpdateCredentials(GetPrimaryAccountId(), refresh_token); } } -void IdentityManager::GoogleSignedOut(const AccountInfo& account_info) { - DCHECK(!HasPrimaryAccount()); - for (auto& observer : observer_list_) { - observer.OnPrimaryAccountCleared(account_info); - } -} +// Populates and returns an AccountInfo object corresponding to |account_id|, +// which must be an account with a refresh token. +AccountInfo IdentityManager::GetAccountInfoForAccountWithRefreshToken( + std::string account_id) const { + DCHECK(HasAccountWithRefreshToken(account_id)); -void IdentityManager::WillFireOnRefreshTokenAvailable( - const std::string& account_id, - bool is_valid) { - DCHECK(!pending_token_available_state_); AccountInfo account_info = account_tracker_service_->GetAccountInfo(account_id); @@ -324,122 +219,71 @@ void IdentityManager::WillFireOnRefreshTokenAvailable( // without the AccountTrackerService being used. This is the only case in // which the AccountTrackerService will potentially not know about the // account. In this context, |account_id| is always set to - // kSupervisedUserPseudoEmail. + // kSupervisedUserPseudoEmail. Populate the information manually in this case + // to maintain the invariant that the account ID, gaia ID, and email are + // always set. // TODO(860492): Remove this special case once supervised user support is // removed. DCHECK(!account_info.IsEmpty() || account_id == kSupervisedUserPseudoEmail); if (account_id == kSupervisedUserPseudoEmail && account_info.IsEmpty()) { - // Populate the information manually to maintain the invariant that the - // account ID, gaia ID, and email are always set. account_info.account_id = account_id; account_info.email = kSupervisedUserPseudoEmail; account_info.gaia = kSupervisedUserPseudoGaiaID; } - // The account might have already been present (e.g., this method can fire on - // updating an invalid token to a valid one or vice versa); in this case we - // sanity-check that the cached account info has the expected values. - auto iterator = accounts_with_refresh_tokens_.find(account_id); - if (iterator != accounts_with_refresh_tokens_.end()) { - DCHECK_EQ(account_info.gaia, iterator->second.gaia); - DCHECK(gaia::AreEmailsSame(account_info.email, iterator->second.email)); - } else { - auto insertion_result = accounts_with_refresh_tokens_.emplace( - account_id, std::move(account_info)); - DCHECK(insertion_result.second); - iterator = insertion_result.first; - } - - PendingTokenAvailableState pending_token_available_state; - pending_token_available_state.account_info = iterator->second; - pending_token_available_state.refresh_token_is_valid = is_valid; - pending_token_available_state_ = std::move(pending_token_available_state); + return account_info; } -void IdentityManager::WillFireOnRefreshTokenRevoked( - const std::string& account_id) { - DCHECK(!pending_token_revoked_info_); - - auto iterator = accounts_with_refresh_tokens_.find(account_id); - if (iterator == accounts_with_refresh_tokens_.end()) { - // A corner case exists wherein the token service revokes tokens while - // loading tokens during initial startup. This is the only case in which it - // is expected that we could receive this notification without having - // previously received a notification that this account was available. In - // this case, we simply do not forward on the notification, for the - // following reasons: (1) We may not have a fully-populated |account_info| - // to send as the argument. (2) Sending the notification breaks clients' - // expectations that IdentityManager will only fire RefreshTokenRemoved - // notifications for accounts that it previously knew about. - DCHECK(!token_service_->AreAllCredentialsLoaded()); - return; +void IdentityManager::GoogleSigninSucceeded(const AccountInfo& account_info) { + for (auto& observer : observer_list_) { + observer.OnPrimaryAccountSet(account_info); } +} - accounts_with_refresh_tokens_.erase(iterator); - - pending_token_revoked_info_ = - account_tracker_service_->GetAccountInfo(account_id); - - // In the context of supervised users, the ProfileOAuth2TokenService is used - // without the AccountTrackerService being used. This is the only case in - // which the AccountTrackerService will potentially not know about the - // account. In this context, |account_id| is always set to - // kSupervisedUserPseudoEmail. - - // TODO(860492): Remove this special case once supervised user support is - // removed. - DCHECK(!pending_token_revoked_info_->IsEmpty() || - account_id == kSupervisedUserPseudoEmail); - if (account_id == kSupervisedUserPseudoEmail && - pending_token_revoked_info_->IsEmpty()) { - // Populate the information manually to maintain the invariant that the - // account ID, gaia ID, and email are always set. - pending_token_revoked_info_->account_id = account_id; - pending_token_revoked_info_->email = account_id; - pending_token_revoked_info_->gaia = kSupervisedUserPseudoGaiaID; +void IdentityManager::GoogleSignedOut(const AccountInfo& account_info) { + DCHECK(!HasPrimaryAccount()); + for (auto& observer : observer_list_) { + observer.OnPrimaryAccountCleared(account_info); } } -void IdentityManager::OnRefreshTokenAvailable(const std::string& account_id) { - DCHECK(pending_token_available_state_); - DCHECK_EQ(pending_token_available_state_->account_info.account_id, - account_id); +void IdentityManager::GoogleSigninFailed(const GoogleServiceAuthError& error) { + for (auto& observer : observer_list_) + observer.OnPrimaryAccountSigninFailed(error); +} - // Move the state out of |pending_token_available_state_| in case any observer - // callbacks fired below result in mutations of refresh tokens. +void IdentityManager::OnRefreshTokenAvailable(const std::string& account_id) { AccountInfo account_info = - std::move(pending_token_available_state_->account_info); - bool refresh_token_is_valid = - pending_token_available_state_->refresh_token_is_valid; - - pending_token_available_state_.reset(); + GetAccountInfoForAccountWithRefreshToken(account_id); + + // Compute the validity of the new refresh token: PO2TS sets an account's + // refresh token to be invalid (error CREDENTIALS_REJECTED_BY_CLIENT) if the + // user signs out of that account on the web. + // TODO(blundell): Hide this logic inside PO2TS. + bool is_valid = true; + GoogleServiceAuthError token_error = token_service_->GetAuthError(account_id); + if (token_error == GoogleServiceAuthError::FromInvalidGaiaCredentialsReason( + GoogleServiceAuthError::InvalidGaiaCredentialsReason:: + CREDENTIALS_REJECTED_BY_CLIENT)) { + is_valid = false; + } for (auto& observer : observer_list_) { - observer.OnRefreshTokenUpdatedForAccount(account_info, - refresh_token_is_valid); + observer.OnRefreshTokenUpdatedForAccount(account_info, is_valid); } } void IdentityManager::OnRefreshTokenRevoked(const std::string& account_id) { - // NOTE: It is possible for |pending_token_revoked_info_| to be null in the - // corner case of tokens being revoked during initial startup (see - // WillFireOnRefreshTokenRevoked() above). - if (!pending_token_revoked_info_) { - return; - } - - DCHECK_EQ(pending_token_revoked_info_->account_id, account_id); - - // Move the state out of |pending_token_revoked_info_| in case any observer - // callbacks fired below result in mutations of refresh tokens. - AccountInfo account_info = pending_token_revoked_info_.value(); - pending_token_revoked_info_.reset(); - for (auto& observer : observer_list_) { - observer.OnRefreshTokenRemovedForAccount(account_info); + observer.OnRefreshTokenRemovedForAccount(account_id); } } +void IdentityManager::OnRefreshTokensLoaded() { + for (auto& observer : observer_list_) + observer.OnRefreshTokensLoaded(); +} + void IdentityManager::OnGaiaAccountsInCookieUpdated( const std::vector<gaia::ListedAccount>& accounts, const std::vector<gaia::ListedAccount>& signed_out_accounts, diff --git a/chromium/services/identity/public/cpp/identity_manager.h b/chromium/services/identity/public/cpp/identity_manager.h index 2c7e02f5cef..3d101dc6a3e 100644 --- a/chromium/services/identity/public/cpp/identity_manager.h +++ b/chromium/services/identity/public/cpp/identity_manager.h @@ -43,10 +43,6 @@ namespace identity { // Gives access to information about the user's Google identities. See // ./README.md for detailed documentation. class IdentityManager : public SigninManagerBase::Observer, -#if !defined(OS_CHROMEOS) - public SigninManager::DiagnosticsClient, -#endif - public ProfileOAuth2TokenService::DiagnosticsClient, public OAuth2TokenService::DiagnosticsObserver, public OAuth2TokenService::Observer, public GaiaCookieManagerService::Observer { @@ -68,8 +64,10 @@ class IdentityManager : public SigninManagerBase::Observer, virtual void OnPrimaryAccountCleared( const AccountInfo& previous_primary_account_info) {} - // TODO(https://crbug/869418): Eventually we might need a callback for - // failure to log in to the primary account. + // Called when the user attempts but fails to set their primary + // account. |error| gives the reason for the failure. + virtual void OnPrimaryAccountSigninFailed( + const GoogleServiceAuthError& error) {} // Called when a new refresh token is associated with |account_info|. // |is_valid| indicates whether the new refresh token is valid. @@ -82,13 +80,25 @@ class IdentityManager : public SigninManagerBase::Observer, const AccountInfo& account_info, bool is_valid) {} - // Called when the refresh token previously associated with |account_info| - // has been removed. + // Called when the refresh token previously associated with |account_id| + // has been removed. At the time that this callback is invoked, there is + // no longer guaranteed to be any AccountInfo associated with + // |account_id|. + // NOTE: It is not guaranteed that a call to + // OnRefreshTokenUpdatedForAccount() has previously occurred for this + // account due to corner cases. + // TODO(https://crbug.com/884731): Eliminate these corner cases. // NOTE: On a signout event, the ordering of this callback wrt the // OnPrimaryAccountCleared() callback is undefined.If this lack of ordering // is problematic for your use case, please contact blundell@chromium.org. virtual void OnRefreshTokenRemovedForAccount( - const AccountInfo& account_info) {} + const std::string& account_id) {} + + // Called after refresh tokens are loaded. + // CAVEAT: On ChromeOS, this callback is not invoked during + // startup in all cases. See https://crbug.com/749535, which + // details the cases where it's not invoked. + virtual void OnRefreshTokensLoaded() {} // Called whenever the list of Gaia accounts in the cookie jar has changed. // |accounts| is ordered by the order of the accounts in the cookie. @@ -119,13 +129,23 @@ class IdentityManager : public SigninManagerBase::Observer, GaiaCookieManagerService* gaia_cookie_manager_service); ~IdentityManager() override; - // Provides access to the latest cached information of the user's primary - // account. + // Provides access to the extended information of the user's primary account. + // Returns an empty struct if no such info is available, either because there + // is no primary account or because the extended information for the primary + // account has been removed (this happens when the refresh token is revoked, + // for example). AccountInfo GetPrimaryAccountInfo() const; - // Returns whether the primary account is available, according to the latest - // cached information. Simple convenience wrapper over checking whether the - // primary account info has a valid account ID. + // Provides access to the account ID of the user's primary account. Note that + // this may return a valid string even in cases where GetPrimaryAccountInfo() + // returns an empty struct, as the extended information for the primary + // account is removed on certain events (e.g., when its refresh token is + // revoked). + const std::string& GetPrimaryAccountId() const; + + // Returns whether the primary account is available. Simple convenience + // wrapper over checking whether GetPrimaryAccountId() returns a non-empty + // string. bool HasPrimaryAccount() const; // For ChromeOS, mutation of primary account state is not managed externally. @@ -177,6 +197,21 @@ class IdentityManager : public SigninManagerBase::Observer, // Returns true if a refresh token exists for |account_id|. bool HasAccountWithRefreshToken(const std::string& account_id) const; + // Returns true if (a) a refresh token exists for |account_id|, and (b) the + // refresh token is in a persistent error state (defined as + // GoogleServiceAuthError::IsPersistentError() returning true for the error + // returned by GetErrorStateOfRefreshTokenForAccount(account_id)). + bool HasAccountWithRefreshTokenInPersistentErrorState( + const std::string& account_id) const; + + // Returns the error state of the refresh token associated with |account_id|. + // In particular: Returns GoogleServiceAuthError::AuthErrorNone() if either + // (a) no refresh token exists for |account_id|, or (b) the refresh token is + // not in a persistent error state. Otherwise, returns the last persistent + // error that was detected when using the refresh token. + GoogleServiceAuthError GetErrorStateOfRefreshTokenForAccount( + const std::string& account_id) const; + // Returns true if (a) the primary account exists, and (b) a refresh token // exists for the primary account. bool HasPrimaryAccountWithRefreshToken() const; @@ -204,11 +239,6 @@ class IdentityManager : public SigninManagerBase::Observer, void RemoveDiagnosticsObserver(DiagnosticsObserver* observer); private: - struct PendingTokenAvailableState { - AccountInfo account_info; - bool refresh_token_is_valid = false; - }; - // These clients need to call SetPrimaryAccountSynchronouslyForTests(). friend AccountInfo SetPrimaryAccount(SigninManagerBase* signin_manager, IdentityManager* identity_manager, @@ -240,28 +270,20 @@ class IdentityManager : public SigninManagerBase::Observer, const std::string& email_address, const std::string& refresh_token); + // Populates and returns an AccountInfo object corresponding to |account_id|, + // which must be an account with a refresh token. + AccountInfo GetAccountInfoForAccountWithRefreshToken( + std::string account_id) const; + // SigninManagerBase::Observer: void GoogleSigninSucceeded(const AccountInfo& account_info) override; void GoogleSignedOut(const AccountInfo& account_info) override; - - // ProfileOAuth2TokenService::DiagnosticsClient: - void WillFireOnRefreshTokenAvailable(const std::string& account_id, - bool is_valid) override; - void WillFireOnRefreshTokenRevoked(const std::string& account_id) override; + void GoogleSigninFailed(const GoogleServiceAuthError& error) override; // OAuth2TokenService::Observer: void OnRefreshTokenAvailable(const std::string& account_id) override; void OnRefreshTokenRevoked(const std::string& account_id) override; - -#if !defined(OS_CHROMEOS) - // SigninManager::DiagnosticsClient: - // Override these to update |primary_account_info_| before any observers of - // SigninManager are notified of the signin state change, ensuring that any - // such observer flows that eventually interact with IdentityManager observe - // its state as being consistent with that of SigninManager. - void WillFireGoogleSigninSucceeded(const AccountInfo& account_info) override; - void WillFireGoogleSignedOut(const AccountInfo& account_info) override; -#endif + void OnRefreshTokensLoaded() override; // GaiaCookieManagerService::Observer: void OnGaiaAccountsInCookieUpdated( @@ -284,27 +306,6 @@ class IdentityManager : public SigninManagerBase::Observer, AccountTrackerService* account_tracker_service_; GaiaCookieManagerService* gaia_cookie_manager_service_; - // The latest (cached) value of the primary account. -#if defined(OS_CHROMEOS) - // On ChromeOS the primary account's email address needs to be modified from - // within GetPrimaryAccountInfo(). TODO(842670): Remove this field being - // mutable if possible as part of solving the larger issue. - mutable AccountInfo primary_account_info_; -#else - AccountInfo primary_account_info_; -#endif - - // The latest (cached) value of the accounts with refresh tokens. - using AccountIDToAccountInfoMap = std::map<std::string, AccountInfo>; - AccountIDToAccountInfoMap accounts_with_refresh_tokens_; - - // Info that is cached from the PO2TS::DiagnosticsClient callbacks in order to - // forward on to the observers of this class in the corresponding - // O2TS::Observer callbacks (the information is not directly available at the - // time of receiving the O2TS::Observer callbacks). - base::Optional<PendingTokenAvailableState> pending_token_available_state_; - base::Optional<AccountInfo> pending_token_revoked_info_; - // Lists of observers. // Makes sure lists are empty on destruction. base::ObserverList<Observer, true>::Unchecked observer_list_; diff --git a/chromium/services/identity/public/cpp/identity_manager_unittest.cc b/chromium/services/identity/public/cpp/identity_manager_unittest.cc index 256a2dd95f5..6746c46535e 100644 --- a/chromium/services/identity/public/cpp/identity_manager_unittest.cc +++ b/chromium/services/identity/public/cpp/identity_manager_unittest.cc @@ -104,9 +104,6 @@ class TestSigninManagerObserver : public SigninManagerBase::Observer { void set_on_google_signin_succeeded_callback(base::OnceClosure callback) { on_google_signin_succeeded_callback_ = std::move(callback); } - void set_on_google_signin_failed_callback(base::OnceClosure callback) { - on_google_signin_failed_callback_ = std::move(callback); - } void set_on_google_signed_out_callback(base::OnceClosure callback) { on_google_signed_out_callback_ = std::move(callback); } @@ -114,9 +111,6 @@ class TestSigninManagerObserver : public SigninManagerBase::Observer { const AccountInfo& primary_account_from_signin_callback() const { return primary_account_from_signin_callback_; } - const GoogleServiceAuthError& error_from_signin_failed_callback() const { - return google_signin_failed_error_; - } const AccountInfo& primary_account_from_signout_callback() const { return primary_account_from_signout_callback_; } @@ -130,11 +124,6 @@ class TestSigninManagerObserver : public SigninManagerBase::Observer { if (on_google_signin_succeeded_callback_) std::move(on_google_signin_succeeded_callback_).Run(); } - void GoogleSigninFailed(const GoogleServiceAuthError& error) override { - google_signin_failed_error_ = error; - if (on_google_signin_failed_callback_) - std::move(on_google_signin_failed_callback_).Run(); - } void GoogleSignedOut(const AccountInfo& account_info) override { ASSERT_TRUE(identity_manager_); primary_account_from_signout_callback_ = @@ -150,7 +139,6 @@ class TestSigninManagerObserver : public SigninManagerBase::Observer { base::OnceClosure on_google_signed_out_callback_; AccountInfo primary_account_from_signin_callback_; AccountInfo primary_account_from_signout_callback_; - GoogleServiceAuthError google_signin_failed_error_; }; // Class that observes updates from ProfileOAuth2TokenService and and verifies @@ -215,6 +203,10 @@ class TestIdentityManagerObserver : IdentityManager::Observer { void set_on_primary_account_cleared_callback(base::OnceClosure callback) { on_primary_account_cleared_callback_ = std::move(callback); } + void set_on_primary_account_signin_failed_callback( + base::OnceClosure callback) { + on_primary_account_signin_failed_callback_ = std::move(callback); + } const AccountInfo& primary_account_from_set_callback() { return primary_account_from_set_callback_; @@ -229,17 +221,21 @@ class TestIdentityManagerObserver : IdentityManager::Observer { // This method uses a RepeatingCallback to simplify verification of multiple // removed tokens. void set_on_refresh_token_removed_callback( - base::RepeatingCallback<void(const AccountInfo&)> callback) { + base::RepeatingCallback<void(const std::string&)> callback) { on_refresh_token_removed_callback_ = std::move(callback); } + void set_on_refresh_tokens_loaded_callback(base::OnceClosure callback) { + on_refresh_tokens_loaded_callback_ = std::move(callback); + } + const AccountInfo& account_from_refresh_token_updated_callback() { return account_from_refresh_token_updated_callback_; } bool validity_from_refresh_token_updated_callback() { return validity_from_refresh_token_updated_callback_; } - const AccountInfo& account_from_refresh_token_removed_callback() { + const std::string& account_from_refresh_token_removed_callback() { return account_from_refresh_token_removed_callback_; } @@ -251,6 +247,10 @@ class TestIdentityManagerObserver : IdentityManager::Observer { return accounts_from_cookie_change_callback_; } + const GoogleServiceAuthError& error_from_signin_failed_callback() const { + return google_signin_failed_error_; + } + private: // IdentityManager::Observer: void OnPrimaryAccountSet(const AccountInfo& primary_account_info) override { @@ -264,6 +264,12 @@ class TestIdentityManagerObserver : IdentityManager::Observer { if (on_primary_account_cleared_callback_) std::move(on_primary_account_cleared_callback_).Run(); } + void OnPrimaryAccountSigninFailed( + const GoogleServiceAuthError& error) override { + google_signin_failed_error_ = error; + if (on_primary_account_signin_failed_callback_) + std::move(on_primary_account_signin_failed_callback_).Run(); + } void OnRefreshTokenUpdatedForAccount(const AccountInfo& account_info, bool is_valid) override { account_from_refresh_token_updated_callback_ = account_info; @@ -271,11 +277,14 @@ class TestIdentityManagerObserver : IdentityManager::Observer { if (on_refresh_token_updated_callback_) std::move(on_refresh_token_updated_callback_).Run(); } - void OnRefreshTokenRemovedForAccount( - const AccountInfo& account_info) override { - account_from_refresh_token_removed_callback_ = account_info; + void OnRefreshTokenRemovedForAccount(const std::string& account_id) override { + account_from_refresh_token_removed_callback_ = account_id; if (on_refresh_token_removed_callback_) - on_refresh_token_removed_callback_.Run(account_info); + on_refresh_token_removed_callback_.Run(account_id); + } + void OnRefreshTokensLoaded() override { + if (on_refresh_tokens_loaded_callback_) + std::move(on_refresh_tokens_loaded_callback_).Run(); } void OnAccountsInCookieUpdated( const std::vector<AccountInfo>& accounts) override { @@ -287,16 +296,19 @@ class TestIdentityManagerObserver : IdentityManager::Observer { IdentityManager* identity_manager_; base::OnceClosure on_primary_account_set_callback_; base::OnceClosure on_primary_account_cleared_callback_; + base::OnceClosure on_primary_account_signin_failed_callback_; base::OnceClosure on_refresh_token_updated_callback_; - base::RepeatingCallback<void(const AccountInfo&)> + base::RepeatingCallback<void(const std::string&)> on_refresh_token_removed_callback_; + base::OnceClosure on_refresh_tokens_loaded_callback_; base::OnceClosure on_accounts_in_cookie_updated_callback_; AccountInfo primary_account_from_set_callback_; AccountInfo primary_account_from_cleared_callback_; AccountInfo account_from_refresh_token_updated_callback_; bool validity_from_refresh_token_updated_callback_; - AccountInfo account_from_refresh_token_removed_callback_; + std::string account_from_refresh_token_removed_callback_; std::vector<AccountInfo> accounts_from_cookie_change_callback_; + GoogleServiceAuthError google_signin_failed_error_; }; class TestIdentityManagerDiagnosticsObserver @@ -357,6 +369,7 @@ class IdentityManagerTest : public testing::Test { "identity_manager_unittest", &signin_client_) { AccountTrackerService::RegisterPrefs(pref_service_.registry()); + ProfileOAuth2TokenService::RegisterProfilePrefs(pref_service_.registry()); SigninManagerBase::RegisterProfilePrefs(pref_service_.registry()); SigninManagerBase::RegisterPrefs(pref_service_.registry()); @@ -405,6 +418,11 @@ class IdentityManagerTest : public testing::Test { identity_manager_diagnostics_observer_.reset(); identity_manager_.reset(); + if (signin_manager_) { + signin_manager_->Shutdown(); + signin_manager_.reset(); + } + #if defined(OS_CHROMEOS) DCHECK_EQ(account_consistency, signin::AccountConsistencyMethod::kDisabled) << "AccountConsistency is not used by SigninManagerBase"; @@ -483,8 +501,8 @@ class IdentityManagerTest : public testing::Test { identity_manager_observer()->set_on_refresh_token_removed_callback( base::BindRepeating( [](base::flat_set<std::string>* observed_removals, - const AccountInfo& removed_account) { - observed_removals->insert(removed_account.email); + const std::string& removed_account) { + observed_removals->insert(removed_account); }, &observed_removals)); @@ -517,16 +535,20 @@ class IdentityManagerTest : public testing::Test { former_primary_account.account_id)); EXPECT_TRUE(identity_manager()->HasAccountWithRefreshToken( secondary_account_info.account_id)); - EXPECT_TRUE(base::ContainsKey(observed_removals, kTestEmail)); - EXPECT_FALSE(base::ContainsKey(observed_removals, kTestEmail2)); + EXPECT_TRUE(base::ContainsKey(observed_removals, + former_primary_account.account_id)); + EXPECT_FALSE(base::ContainsKey(observed_removals, + secondary_account_info.account_id)); break; case RemoveTokenExpectation::kRemoveAll: EXPECT_FALSE(identity_manager()->HasAccountWithRefreshToken( former_primary_account.account_id)); EXPECT_FALSE(identity_manager()->HasAccountWithRefreshToken( secondary_account_info.account_id)); - EXPECT_TRUE(base::ContainsKey(observed_removals, kTestEmail)); - EXPECT_TRUE(base::ContainsKey(observed_removals, kTestEmail2)); + EXPECT_TRUE(base::ContainsKey(observed_removals, + former_primary_account.account_id)); + EXPECT_TRUE(base::ContainsKey(observed_removals, + secondary_account_info.account_id)); break; } } @@ -578,6 +600,10 @@ TEST_F(IdentityManagerTest, PrimaryAccountInfoAfterSignin) { identity_manager()->GetPrimaryAccountInfo(); EXPECT_EQ(kTestGaiaId, primary_account_info.gaia); EXPECT_EQ(kTestEmail, primary_account_info.email); + + std::string primary_account_id = identity_manager()->GetPrimaryAccountId(); + EXPECT_EQ(primary_account_id, kTestGaiaId); + EXPECT_EQ(primary_account_id, primary_account_info.account_id); } TEST_F(IdentityManagerTest, ClearPrimaryAccount_RemoveAll) { @@ -644,7 +670,7 @@ TEST_F(IdentityManagerTest, // Set primary account to have authentication error. SetRefreshTokenForPrimaryAccount(token_service(), identity_manager()); token_service()->UpdateAuthErrorForTesting( - identity_manager()->GetPrimaryAccountInfo().account_id, + identity_manager()->GetPrimaryAccountId(), GoogleServiceAuthError( GoogleServiceAuthError::State::INVALID_GAIA_CREDENTIALS)); @@ -674,24 +700,13 @@ TEST_F(IdentityManagerTest, ClearPrimaryAccount_AuthInProgress) { secondary_account_info.account_id)); // Observe that in-progress authentication is *canceled* and quit the RunLoop. - // TODO(https://crbug/869418): Determine if signin failed notifications should - // be part of the IdentityManager::Observer interface. base::RunLoop run_loop; - GoogleServiceAuthError::State observed_error = - GoogleServiceAuthError::State::NONE; - TestSigninManagerObserver signin_manager_observer(signin_manager()); - signin_manager_observer.set_identity_manager(identity_manager()); - signin_manager_observer.set_on_google_signin_failed_callback(base::BindOnce( - [](TestSigninManagerObserver* observer, - GoogleServiceAuthError::State* error, base::OnceClosure callback) { - *error = observer->error_from_signin_failed_callback().state(); - std::move(callback).Run(); - }, - &signin_manager_observer, &observed_error, run_loop.QuitClosure())); + identity_manager_observer()->set_on_primary_account_signin_failed_callback( + run_loop.QuitClosure()); // Observer should not be notified of any token removals. identity_manager_observer()->set_on_refresh_token_removed_callback( - base::BindRepeating([](const AccountInfo&) { EXPECT_TRUE(false); })); + base::BindRepeating([](const std::string&) { EXPECT_TRUE(false); })); // No primary account to "clear", so no callback. identity_manager_observer()->set_on_primary_account_cleared_callback( @@ -704,7 +719,9 @@ TEST_F(IdentityManagerTest, ClearPrimaryAccount_AuthInProgress) { run_loop.Run(); // Verify in-progress authentication was canceled. - EXPECT_EQ(observed_error, GoogleServiceAuthError::State::REQUEST_CANCELED); + EXPECT_EQ( + identity_manager_observer()->error_from_signin_failed_callback().state(), + GoogleServiceAuthError::State::REQUEST_CANCELED); EXPECT_FALSE(signin_manager()->AuthInProgress()); // We didn't have a primary account to start with, we shouldn't have one now @@ -746,13 +763,50 @@ TEST_F(IdentityManagerTest, PrimaryAccountInfoAfterSigninAndSignout) { identity_manager()->GetPrimaryAccountInfo(); EXPECT_EQ("", primary_account_info.gaia); EXPECT_EQ("", primary_account_info.email); + + std::string primary_account_id = identity_manager()->GetPrimaryAccountId(); + EXPECT_EQ("", primary_account_id); + EXPECT_EQ(primary_account_id, primary_account_info.account_id); +} + +// Test that the primary account's ID remains tracked by the IdentityManager +// after signing in even after having removed the account without signing out. +TEST_F(IdentityManagerTest, PrimaryAccountInfoAfterSigninAndAccountRemoval) { + // First ensure that the user is signed in from the POV of the + // IdentityManager. + base::RunLoop run_loop; + identity_manager_observer()->set_on_primary_account_set_callback( + run_loop.QuitClosure()); + signin_manager()->SignIn(kTestGaiaId, kTestEmail, "password"); + run_loop.Run(); + + // Remove the account from the AccountTrackerService and check that + // the returned AccountInfo won't have a valid ID anymore, even if + // the IdentityManager is still storing the primary account's ID. + account_tracker()->RemoveAccount(kTestGaiaId); + + AccountInfo primary_account_info = + identity_manager()->GetPrimaryAccountInfo(); + EXPECT_EQ("", primary_account_info.gaia); + EXPECT_EQ("", primary_account_info.email); + EXPECT_EQ("", primary_account_info.account_id); + + std::string primary_account_id = identity_manager()->GetPrimaryAccountId(); + EXPECT_EQ(primary_account_id, kTestGaiaId); } #endif // !defined(OS_CHROMEOS) TEST_F(IdentityManagerTest, HasPrimaryAccount) { EXPECT_TRUE(identity_manager()->HasPrimaryAccount()); + // Removing the account from the AccountTrackerService should not cause + // IdentityManager to think that there is no longer a primary account. + account_tracker()->RemoveAccount(identity_manager()->GetPrimaryAccountId()); + EXPECT_TRUE(identity_manager()->HasPrimaryAccount()); + #if !defined(OS_CHROMEOS) + // Signing out should cause IdentityManager to recognize that there is no + // longer a primary account. base::RunLoop run_loop; identity_manager_observer()->set_on_primary_account_cleared_callback( run_loop.QuitClosure()); @@ -1152,6 +1206,117 @@ TEST_F( identity_manager()->HasAccountWithRefreshToken(account_info2.account_id)); } +TEST_F(IdentityManagerTest, GetErrorStateOfRefreshTokenForAccount) { + AccountInfo primary_account_info = + identity_manager()->GetPrimaryAccountInfo(); + std::string primary_account_id = primary_account_info.account_id; + + // A primary account without a refresh token should not be in an error + // state, and setting a refresh token should not affect that. + EXPECT_EQ(GoogleServiceAuthError::AuthErrorNone(), + identity_manager()->GetErrorStateOfRefreshTokenForAccount( + primary_account_id)); + EXPECT_FALSE( + identity_manager()->HasAccountWithRefreshTokenInPersistentErrorState( + primary_account_id)); + + SetRefreshTokenForPrimaryAccount(token_service(), identity_manager()); + EXPECT_EQ(GoogleServiceAuthError::AuthErrorNone(), + identity_manager()->GetErrorStateOfRefreshTokenForAccount( + primary_account_id)); + EXPECT_FALSE( + identity_manager()->HasAccountWithRefreshTokenInPersistentErrorState( + primary_account_id)); + + // A secondary account without a refresh token should not be in an error + // state, and setting a refresh token should not affect that. + account_tracker()->SeedAccountInfo(kTestGaiaId2, kTestEmail2); + AccountInfo account_info2 = + account_tracker()->FindAccountInfoByGaiaId(kTestGaiaId2); + std::string account_id2 = account_info2.account_id; + EXPECT_EQ( + GoogleServiceAuthError::AuthErrorNone(), + identity_manager()->GetErrorStateOfRefreshTokenForAccount(account_id2)); + EXPECT_FALSE( + identity_manager()->HasAccountWithRefreshTokenInPersistentErrorState( + account_id2)); + + SetRefreshTokenForAccount(token_service(), identity_manager(), account_id2); + EXPECT_EQ( + GoogleServiceAuthError::AuthErrorNone(), + identity_manager()->GetErrorStateOfRefreshTokenForAccount(account_id2)); + EXPECT_FALSE( + identity_manager()->HasAccountWithRefreshTokenInPersistentErrorState( + account_id2)); + + GoogleServiceAuthError account_deleted_error = + GoogleServiceAuthError(GoogleServiceAuthError::State::ACCOUNT_DELETED); + GoogleServiceAuthError account_disabled_error = + GoogleServiceAuthError(GoogleServiceAuthError::State::ACCOUNT_DISABLED); + GoogleServiceAuthError transient_error = GoogleServiceAuthError( + GoogleServiceAuthError::State::SERVICE_UNAVAILABLE); + + // Set a persistent error for |account_id2| and check that it's reflected. + token_service()->UpdateAuthErrorForTesting(account_id2, + account_deleted_error); + EXPECT_EQ( + account_deleted_error, + identity_manager()->GetErrorStateOfRefreshTokenForAccount(account_id2)); + EXPECT_TRUE( + identity_manager()->HasAccountWithRefreshTokenInPersistentErrorState( + account_id2)); + EXPECT_EQ(GoogleServiceAuthError::AuthErrorNone(), + identity_manager()->GetErrorStateOfRefreshTokenForAccount( + primary_account_id)); + EXPECT_FALSE( + identity_manager()->HasAccountWithRefreshTokenInPersistentErrorState( + primary_account_id)); + + // A transient error should cause no change in the error state. + token_service()->UpdateAuthErrorForTesting(primary_account_id, + transient_error); + EXPECT_EQ(GoogleServiceAuthError::AuthErrorNone(), + identity_manager()->GetErrorStateOfRefreshTokenForAccount( + primary_account_id)); + EXPECT_FALSE( + identity_manager()->HasAccountWithRefreshTokenInPersistentErrorState( + primary_account_id)); + + // Set a different persistent error for the primary account and check that + // it's reflected. + token_service()->UpdateAuthErrorForTesting(primary_account_id, + account_disabled_error); + EXPECT_EQ( + account_deleted_error, + identity_manager()->GetErrorStateOfRefreshTokenForAccount(account_id2)); + EXPECT_TRUE( + identity_manager()->HasAccountWithRefreshTokenInPersistentErrorState( + account_id2)); + EXPECT_EQ(account_disabled_error, + identity_manager()->GetErrorStateOfRefreshTokenForAccount( + primary_account_id)); + EXPECT_TRUE( + identity_manager()->HasAccountWithRefreshTokenInPersistentErrorState( + primary_account_id)); + + // Remove the token for account2 and check that it goes back to having no + // error. + RemoveRefreshTokenForAccount(token_service(), identity_manager(), + account_id2); + EXPECT_EQ( + GoogleServiceAuthError::AuthErrorNone(), + identity_manager()->GetErrorStateOfRefreshTokenForAccount(account_id2)); + EXPECT_FALSE( + identity_manager()->HasAccountWithRefreshTokenInPersistentErrorState( + account_id2)); + EXPECT_EQ(account_disabled_error, + identity_manager()->GetErrorStateOfRefreshTokenForAccount( + primary_account_id)); + EXPECT_TRUE( + identity_manager()->HasAccountWithRefreshTokenInPersistentErrorState( + primary_account_id)); +} + TEST_F(IdentityManagerTest, RemoveAccessTokenFromCache) { std::set<std::string> scopes{"scope"}; std::string access_token = "access_token"; @@ -1176,9 +1341,8 @@ TEST_F(IdentityManagerTest, CreateAccessTokenFetcher) { [](GoogleServiceAuthError error, AccessTokenInfo access_token_info) {}); std::unique_ptr<AccessTokenFetcher> token_fetcher = identity_manager()->CreateAccessTokenFetcherForAccount( - identity_manager()->GetPrimaryAccountInfo().account_id, - "dummy_consumer", scopes, std::move(callback), - AccessTokenFetcher::Mode::kImmediate); + identity_manager()->GetPrimaryAccountId(), "dummy_consumer", scopes, + std::move(callback), AccessTokenFetcher::Mode::kImmediate); EXPECT_TRUE(token_fetcher); } @@ -1196,9 +1360,8 @@ TEST_F(IdentityManagerTest, ObserveAccessTokenFetch) { [](GoogleServiceAuthError error, AccessTokenInfo access_token_info) {}); std::unique_ptr<AccessTokenFetcher> token_fetcher = identity_manager()->CreateAccessTokenFetcherForAccount( - identity_manager()->GetPrimaryAccountInfo().account_id, - "dummy_consumer", scopes, std::move(callback), - AccessTokenFetcher::Mode::kImmediate); + identity_manager()->GetPrimaryAccountId(), "dummy_consumer", scopes, + std::move(callback), AccessTokenFetcher::Mode::kImmediate); run_loop.Run(); @@ -1213,8 +1376,9 @@ TEST_F(IdentityManagerTest, ObserveAccessTokenFetch) { } #if !defined(OS_CHROMEOS) -TEST_F(IdentityManagerTest, - IdentityManagerGetsSignInEventBeforeSigninManagerObserver) { +TEST_F( + IdentityManagerTest, + IdentityManagerGivesConsistentValuesFromSigninManagerObserverNotificationOfSignIn) { signin_manager()->ForceSignOut(); base::RunLoop run_loop; @@ -1239,8 +1403,9 @@ TEST_F(IdentityManagerTest, EXPECT_EQ(kTestEmail, primary_account_from_signin_callback.email); } -TEST_F(IdentityManagerTest, - IdentityManagerGetsSignOutEventBeforeSigninManagerObserver) { +TEST_F( + IdentityManagerTest, + IdentityManagerGivesConsistentValuesFromSigninManagerObserverNotificationOfSignOut) { base::RunLoop run_loop; TestSigninManagerObserver signin_manager_observer(signin_manager()); signin_manager_observer.set_on_google_signed_out_callback( @@ -1333,11 +1498,8 @@ TEST_F(IdentityManagerTest, CallbackSentOnPrimaryAccountRefreshTokenRemoval) { RemoveRefreshTokenForPrimaryAccount(token_service(), identity_manager()); - AccountInfo account_info = - identity_manager_observer() - ->account_from_refresh_token_removed_callback(); - EXPECT_EQ(kTestGaiaId, account_info.gaia); - EXPECT_EQ(kTestEmail, account_info.email); + EXPECT_EQ(account_id, identity_manager_observer() + ->account_from_refresh_token_removed_callback()); } TEST_F(IdentityManagerTest, @@ -1385,12 +1547,9 @@ TEST_F(IdentityManagerTest, CallbackSentOnSecondaryAccountRefreshTokenRemoval) { RemoveRefreshTokenForAccount(token_service(), identity_manager(), expected_account_info.account_id); - AccountInfo account_info = - identity_manager_observer() - ->account_from_refresh_token_removed_callback(); - EXPECT_EQ(expected_account_info.account_id, account_info.account_id); - EXPECT_EQ(expected_account_info.gaia, account_info.gaia); - EXPECT_EQ(expected_account_info.email, account_info.email); + EXPECT_EQ(expected_account_info.account_id, + identity_manager_observer() + ->account_from_refresh_token_removed_callback()); } #if !defined(OS_CHROMEOS) @@ -1460,34 +1619,34 @@ TEST_F(IdentityManagerTest, RemoveRefreshTokenForAccount(token_service(), identity_manager(), expected_account_info.account_id); - AccountInfo account_info = - identity_manager_observer() - ->account_from_refresh_token_removed_callback(); - EXPECT_EQ(expected_account_info.account_id, account_info.account_id); - EXPECT_EQ(expected_account_info.gaia, account_info.gaia); - EXPECT_EQ(expected_account_info.email, account_info.email); + EXPECT_EQ(expected_account_info.account_id, + identity_manager_observer() + ->account_from_refresh_token_removed_callback()); } #endif -TEST_F(IdentityManagerTest, - CallbackNotSentOnRefreshTokenRemovalOfUnknownAccount) { - // RemoveCredentials expects (and DCHECKS) that either the caller passes a - // known account ID, or the account is unknown because the token service is - // still loading credentials. Our common test setup actually completes this - // loading, so use the *for_testing() method below to simulate the race - // condition. +TEST_F(IdentityManagerTest, CallbackSentOnRefreshTokenRemovalOfUnknownAccount) { + // When the token service is still loading credentials, it may send token + // revoked callbacks for accounts that it has never sent a token available + // callback. Our common test setup actually completes this loading, so use the + // *for_testing() method below to simulate the race condition and ensure that + // IdentityManager passes on the callback in this case. token_service()->set_all_credentials_loaded_for_testing(false); - base::RunLoop run_loop; - identity_manager_observer()->set_on_refresh_token_removed_callback( - base::BindRepeating([](const AccountInfo&) { EXPECT_TRUE(false); })); - token_service()->RevokeCredentials("dummy_account"); + std::string dummy_account_id = "dummy_account"; + base::RunLoop run_loop; + token_service()->RevokeCredentials(dummy_account_id); run_loop.RunUntilIdle(); + + EXPECT_EQ(dummy_account_id, + identity_manager_observer() + ->account_from_refresh_token_removed_callback()); } -TEST_F(IdentityManagerTest, - IdentityManagerGetsTokenUpdateEventBeforeTokenServiceObserver) { +TEST_F( + IdentityManagerTest, + IdentityManagerGivesConsistentValuesFromTokenServiceObserverNotificationOfTokenUpdate) { std::string account_id = signin_manager()->GetAuthenticatedAccountId(); base::RunLoop run_loop; @@ -1512,8 +1671,9 @@ TEST_F(IdentityManagerTest, run_loop.Run(); } -TEST_F(IdentityManagerTest, - IdentityManagerGetsTokenRemovalEventBeforeTokenServiceObserver) { +TEST_F( + IdentityManagerTest, + IdentityManagerGivesConsistentValuesFromTokenServiceObserverNotificationOfTokenRemoval) { std::string account_id = signin_manager()->GetAuthenticatedAccountId(); base::RunLoop run_loop; @@ -1543,6 +1703,22 @@ TEST_F(IdentityManagerTest, run_loop2.Run(); } +TEST_F(IdentityManagerTest, IdentityManagerGetsTokensLoadedEvent) { + std::string account_id = signin_manager()->GetAuthenticatedAccountId(); + + base::RunLoop run_loop; + identity_manager_observer()->set_on_refresh_tokens_loaded_callback( + run_loop.QuitClosure()); + + // Credentials are already loaded in SigninManager::Initialize() + // which runs even before the IdentityManager is created. That's why + // we fake the credentials loaded state and force another load in + // order to be able to capture the TokensLoaded event. + token_service()->set_all_credentials_loaded_for_testing(false); + token_service()->LoadCredentials(""); + run_loop.Run(); +} + TEST_F(IdentityManagerTest, CallbackSentOnUpdateToAccountsInCookieWithNoAccounts) { base::RunLoop run_loop; diff --git a/chromium/services/identity/public/cpp/identity_test_environment.cc b/chromium/services/identity/public/cpp/identity_test_environment.cc index bcb239010aa..686229f0de5 100644 --- a/chromium/services/identity/public/cpp/identity_test_environment.cc +++ b/chromium/services/identity/public/cpp/identity_test_environment.cc @@ -12,6 +12,7 @@ #include "components/signin/core/browser/profile_management_switches.h" #include "components/signin/core/browser/test_signin_client.h" #include "components/sync_preferences/testing_pref_service_syncable.h" +#include "google_apis/gaia/oauth2_access_token_consumer.h" #include "services/identity/public/cpp/identity_test_utils.h" #if defined(OS_CHROMEOS) @@ -25,7 +26,8 @@ namespace identity { // Internal class that abstracts the dependencies out of the public interface. class IdentityTestEnvironmentInternal { public: - IdentityTestEnvironmentInternal(); + IdentityTestEnvironmentInternal( + bool use_fake_url_loader_for_gaia_cookie_manager); ~IdentityTestEnvironmentInternal(); // The IdentityManager instance created and owned by this instance. @@ -37,6 +39,8 @@ class IdentityTestEnvironmentInternal { FakeProfileOAuth2TokenService* token_service(); + FakeGaiaCookieManagerService* gaia_cookie_manager_service(); + private: sync_preferences::TestingPrefServiceSyncable pref_service_; AccountTrackerService account_tracker_; @@ -49,7 +53,8 @@ class IdentityTestEnvironmentInternal { DISALLOW_COPY_AND_ASSIGN(IdentityTestEnvironmentInternal); }; -IdentityTestEnvironmentInternal::IdentityTestEnvironmentInternal() +IdentityTestEnvironmentInternal::IdentityTestEnvironmentInternal( + bool use_fake_url_loader_for_gaia_cookie_manager) : signin_client_(&pref_service_), token_service_(&pref_service_), #if defined(OS_CHROMEOS) @@ -63,19 +68,19 @@ IdentityTestEnvironmentInternal::IdentityTestEnvironmentInternal() // NOTE: Some unittests set up their own TestURLFetcherFactory. In these // contexts FakeGaiaCookieManagerService can't set up its own // FakeURLFetcherFactory, as {Test, Fake}URLFetcherFactory allow only one - // instance to be alive at a time. If some users of - // IdentityTestEnvironment require that GaiaCookieManagerService have a - // FakeURLFetcherFactory, we'll need to pass a config param in to - // IdentityTestEnvironment to specify this. If some users want that - // behavior while *also* having their own FakeURLFetcherFactory, we'll - // need to pass the actual object in and have GaiaCookieManagerService - // have a reference to the object (or figure out the sharing some other - // way). Contact blundell@chromium.org if you come up against this issue. - gaia_cookie_manager_service_(&token_service_, - "identity_test_environment", - &signin_client_, - /*use_fake_url_fetcher=*/false) { + // instance to be alive at a time. If some users require that + // GaiaCookieManagerService have a FakeURLFetcherFactory while *also* + // having their own FakeURLFetcherFactory, we'll need to pass the actual + // object in and have GaiaCookieManagerService have a reference to the + // object (or figure out the sharing some other way). Contact + // blundell@chromium.org if you come up against this issue. + gaia_cookie_manager_service_( + &token_service_, + "identity_test_environment", + &signin_client_, + use_fake_url_loader_for_gaia_cookie_manager) { AccountTrackerService::RegisterPrefs(pref_service_.registry()); + ProfileOAuth2TokenService::RegisterProfilePrefs(pref_service_.registry()); SigninManagerBase::RegisterProfilePrefs(pref_service_.registry()); SigninManagerBase::RegisterPrefs(pref_service_.registry()); @@ -106,8 +111,15 @@ IdentityTestEnvironmentInternal::token_service() { return &token_service_; } -IdentityTestEnvironment::IdentityTestEnvironment() - : internals_(std::make_unique<IdentityTestEnvironmentInternal>()) { +FakeGaiaCookieManagerService* +IdentityTestEnvironmentInternal::gaia_cookie_manager_service() { + return &gaia_cookie_manager_service_; +} + +IdentityTestEnvironment::IdentityTestEnvironment( + bool use_fake_url_loader_for_gaia_cookie_manager) + : internals_(std::make_unique<IdentityTestEnvironmentInternal>( + use_fake_url_loader_for_gaia_cookie_manager)) { internals_->identity_manager()->AddDiagnosticsObserver(this); } @@ -178,6 +190,12 @@ void IdentityTestEnvironment::RemoveRefreshTokenForAccount( internals_->token_service(), internals_->identity_manager(), account_id); } +void IdentityTestEnvironment::SetCookieAccounts( + const std::vector<CookieParams>& cookie_accounts) { + identity::SetCookieAccounts(internals_->gaia_cookie_manager_service(), + internals_->identity_manager(), cookie_accounts); +} + void IdentityTestEnvironment::SetAutomaticIssueOfAccessTokens(bool grant) { internals_->token_service()->set_auto_post_fetch_response_on_message_loop( grant); @@ -194,6 +212,16 @@ void IdentityTestEnvironment:: void IdentityTestEnvironment:: WaitForAccessTokenRequestIfNecessaryAndRespondWithToken( + const std::string& token, + const base::Time& expiration, + const std::string& id_token) { + WaitForAccessTokenRequestIfNecessary(base::nullopt); + internals_->token_service()->IssueTokenForAllPendingRequests( + OAuth2AccessTokenConsumer::TokenResponse(token, expiration, id_token)); +} + +void IdentityTestEnvironment:: + WaitForAccessTokenRequestIfNecessaryAndRespondWithToken( const std::string& account_id, const std::string& token, const base::Time& expiration) { diff --git a/chromium/services/identity/public/cpp/identity_test_environment.h b/chromium/services/identity/public/cpp/identity_test_environment.h index 4ad038b7dc2..e473eece75d 100644 --- a/chromium/services/identity/public/cpp/identity_test_environment.h +++ b/chromium/services/identity/public/cpp/identity_test_environment.h @@ -19,7 +19,8 @@ class IdentityTestEnvironmentInternal; // not available; call MakePrimaryAccountAvailable() as needed. class IdentityTestEnvironment : public IdentityManager::DiagnosticsObserver { public: - IdentityTestEnvironment(); + IdentityTestEnvironment( + bool use_fake_url_loader_for_gaia_cookie_manager = false); ~IdentityTestEnvironment() override; // The IdentityManager instance created and owned by this instance. @@ -80,6 +81,10 @@ class IdentityTestEnvironment : public IdentityManager::DiagnosticsObserver { // NOTE: See disclaimer at top of file re: direct usage. void RemoveRefreshTokenForAccount(const std::string& account_id); + // Puts the given accounts into the Gaia cookie, replacing any previous + // accounts. Blocks until the accounts have been set. + void SetCookieAccounts(const std::vector<CookieParams>& cookie_accounts); + // When this is set, access token requests will be automatically granted with // an access token value of "access_token". void SetAutomaticIssueOfAccessTokens(bool grant); @@ -98,6 +103,23 @@ class IdentityTestEnvironment : public IdentityManager::DiagnosticsObserver { const std::string& token, const base::Time& expiration); + // Issues |token| in response to any access token request that either has (a) + // already occurred and has not been matched by a previous call to this or + // other WaitFor... method, or (b) will occur in the future. In the latter + // case, waits until the access token request occurs. + // NOTE: This method behaves this way to allow IdentityTestEnvironment to be + // agnostic with respect to whether access token requests are handled + // synchronously or asynchronously in the production code. + // NOTE: This version is suitable for use in the common context where access + // token requests are only being made for one account. If you need to + // disambiguate requests coming for different accounts, see the version below. + // NOTE: This version allows passing the uncommon id_token parameter which is + // needed to test some cases where checking for that extra info is required. + void WaitForAccessTokenRequestIfNecessaryAndRespondWithToken( + const std::string& token, + const base::Time& expiration, + const std::string& id_token); + // Issues |token| in response to an access token request for |account_id| that // either already occurred and has not been matched by a previous call to this // or other WaitFor... method , or (b) will occur in the future. In the latter diff --git a/chromium/services/identity/public/cpp/identity_test_utils.cc b/chromium/services/identity/public/cpp/identity_test_utils.cc index e4156d32a4e..a76d42dcedc 100644 --- a/chromium/services/identity/public/cpp/identity_test_utils.cc +++ b/chromium/services/identity/public/cpp/identity_test_utils.cc @@ -7,6 +7,7 @@ #include "base/run_loop.h" #include "base/strings/string_split.h" #include "components/signin/core/browser/account_tracker_service.h" +#include "components/signin/core/browser/fake_gaia_cookie_manager_service.h" #include "components/signin/core/browser/fake_signin_manager.h" #include "components/signin/core/browser/profile_oauth2_token_service.h" #include "services/identity/public/cpp/identity_manager.h" @@ -19,7 +20,8 @@ enum class IdentityManagerEvent { PRIMARY_ACCOUNT_SET, PRIMARY_ACCOUNT_CLEARED, REFRESH_TOKEN_UPDATED, - REFRESH_TOKEN_REMOVED + REFRESH_TOKEN_REMOVED, + ACCOUNTS_IN_COOKIE_UPDATED, }; class OneShotIdentityManagerObserver : public IdentityManager::Observer { @@ -36,8 +38,9 @@ class OneShotIdentityManagerObserver : public IdentityManager::Observer { const AccountInfo& previous_primary_account_info) override; void OnRefreshTokenUpdatedForAccount(const AccountInfo& account_info, bool is_valid) override; - void OnRefreshTokenRemovedForAccount( - const AccountInfo& account_info) override; + void OnRefreshTokenRemovedForAccount(const std::string& account_id) override; + void OnAccountsInCookieUpdated( + const std::vector<AccountInfo>& accounts) override; IdentityManager* identity_manager_; base::OnceClosure done_closure_; @@ -89,7 +92,7 @@ void OneShotIdentityManagerObserver::OnRefreshTokenUpdatedForAccount( } void OneShotIdentityManagerObserver::OnRefreshTokenRemovedForAccount( - const AccountInfo& account_info) { + const std::string& account_id) { if (event_to_wait_on_ != IdentityManagerEvent::REFRESH_TOKEN_REMOVED) return; @@ -97,6 +100,15 @@ void OneShotIdentityManagerObserver::OnRefreshTokenRemovedForAccount( std::move(done_closure_).Run(); } +void OneShotIdentityManagerObserver::OnAccountsInCookieUpdated( + const std::vector<AccountInfo>& accounts) { + if (event_to_wait_on_ != IdentityManagerEvent::ACCOUNTS_IN_COOKIE_UPDATED) + return; + + DCHECK(done_closure_); + std::move(done_closure_).Run(); +} + // Helper function that updates the refresh token for |account_id| to // |new_token|. Blocks until the update is processed by |identity_manager|. void UpdateRefreshTokenForAccount(ProfileOAuth2TokenService* token_service, @@ -159,7 +171,7 @@ AccountInfo SetPrimaryAccount(SigninManagerBase* signin_manager, void SetRefreshTokenForPrimaryAccount(ProfileOAuth2TokenService* token_service, IdentityManager* identity_manager) { DCHECK(identity_manager->HasPrimaryAccount()); - std::string account_id = identity_manager->GetPrimaryAccountInfo().account_id; + std::string account_id = identity_manager->GetPrimaryAccountId(); std::string refresh_token = "refresh_token_for_" + account_id; SetRefreshTokenForAccount(token_service, identity_manager, account_id); @@ -169,7 +181,7 @@ void SetInvalidRefreshTokenForPrimaryAccount( ProfileOAuth2TokenService* token_service, IdentityManager* identity_manager) { DCHECK(identity_manager->HasPrimaryAccount()); - std::string account_id = identity_manager->GetPrimaryAccountInfo().account_id; + std::string account_id = identity_manager->GetPrimaryAccountId(); SetInvalidRefreshTokenForAccount(token_service, identity_manager, account_id); } @@ -180,7 +192,7 @@ void RemoveRefreshTokenForPrimaryAccount( if (!identity_manager->HasPrimaryAccount()) return; - std::string account_id = identity_manager->GetPrimaryAccountInfo().account_id; + std::string account_id = identity_manager->GetPrimaryAccountId(); RemoveRefreshTokenForAccount(token_service, identity_manager, account_id); } @@ -290,4 +302,28 @@ void RemoveRefreshTokenForAccount(ProfileOAuth2TokenService* token_service, run_loop.Run(); } +void SetCookieAccounts(FakeGaiaCookieManagerService* cookie_manager, + IdentityManager* identity_manager, + const std::vector<CookieParams>& cookie_accounts) { + // Convert |cookie_accounts| to the format FakeGaiaCookieManagerService wants. + std::vector<FakeGaiaCookieManagerService::CookieParams> gaia_cookie_accounts; + for (const CookieParams& params : cookie_accounts) { + gaia_cookie_accounts.push_back({params.email, params.gaia_id, + /*valid=*/true, /*signed_out=*/false, + /*verified=*/true}); + } + + base::RunLoop run_loop; + OneShotIdentityManagerObserver cookie_observer( + identity_manager, run_loop.QuitClosure(), + IdentityManagerEvent::ACCOUNTS_IN_COOKIE_UPDATED); + + cookie_manager->SetListAccountsResponseWithParams(gaia_cookie_accounts); + + cookie_manager->set_list_accounts_stale_for_testing(true); + cookie_manager->ListAccounts(nullptr, nullptr, "test"); + + run_loop.Run(); +} + } // namespace identity diff --git a/chromium/services/identity/public/cpp/identity_test_utils.h b/chromium/services/identity/public/cpp/identity_test_utils.h index 88e6735bf3a..11619bf8a14 100644 --- a/chromium/services/identity/public/cpp/identity_test_utils.h +++ b/chromium/services/identity/public/cpp/identity_test_utils.h @@ -11,6 +11,7 @@ #include "components/signin/core/browser/account_info.h" class AccountTrackerService; +class FakeGaiaCookieManagerService; class FakeSigninManagerBase; class FakeSigninManager; class ProfileOAuth2TokenService; @@ -43,6 +44,11 @@ enum class ClearPrimaryAccountPolicy { REMOVE_ALL_ACCOUNTS }; +struct CookieParams { + std::string email; + std::string gaia_id; +}; + class IdentityManager; // Sets the primary account (which must not already be set) to the given email @@ -130,6 +136,13 @@ void RemoveRefreshTokenForAccount(ProfileOAuth2TokenService* token_service, IdentityManager* identity_manager, const std::string& account_id); +// Puts the given accounts into the Gaia cookie, replacing any previous +// accounts. Blocks until the accounts have been set. +// NOTE: See disclaimer at top of file re: direct usage. +void SetCookieAccounts(FakeGaiaCookieManagerService* cookie_manager, + IdentityManager* identity_manager, + const std::vector<CookieParams>& cookie_accounts); + } // namespace identity #endif // SERVICES_IDENTITY_PUBLIC_CPP_IDENTITY_TEST_UTILS_H_ diff --git a/chromium/services/identity/public/cpp/primary_account_access_token_fetcher.cc b/chromium/services/identity/public/cpp/primary_account_access_token_fetcher.cc index 627a233aaa9..08ddb0d12a8 100644 --- a/chromium/services/identity/public/cpp/primary_account_access_token_fetcher.cc +++ b/chromium/services/identity/public/cpp/primary_account_access_token_fetcher.cc @@ -62,8 +62,7 @@ void PrimaryAccountAccessTokenFetcher::StartAccessTokenRequest() { // token available. AccessTokenFetcher used in // |kWaitUntilRefreshTokenAvailable| mode would guarantee only the latter. access_token_fetcher_ = identity_manager_->CreateAccessTokenFetcherForAccount( - identity_manager_->GetPrimaryAccountInfo().account_id, - oauth_consumer_name_, scopes_, + identity_manager_->GetPrimaryAccountId(), oauth_consumer_name_, scopes_, base::BindOnce( &PrimaryAccountAccessTokenFetcher::OnAccessTokenFetchComplete, base::Unretained(this)), diff --git a/chromium/services/identity/public/cpp/primary_account_access_token_fetcher_unittest.cc b/chromium/services/identity/public/cpp/primary_account_access_token_fetcher_unittest.cc index d58fb784832..a79464dc1cf 100644 --- a/chromium/services/identity/public/cpp/primary_account_access_token_fetcher_unittest.cc +++ b/chromium/services/identity/public/cpp/primary_account_access_token_fetcher_unittest.cc @@ -49,7 +49,8 @@ class PrimaryAccountAccessTokenFetcherTest : public testing::Test, PrimaryAccountAccessTokenFetcherTest() : access_token_info_("access token", - base::Time::Now() + base::TimeDelta::FromHours(1)) {} + base::Time::Now() + base::TimeDelta::FromHours(1), + "id_token") {} ~PrimaryAccountAccessTokenFetcherTest() override { } @@ -98,7 +99,8 @@ TEST_F(PrimaryAccountAccessTokenFetcherTest, OneShotShouldReturnAccessToken) { EXPECT_CALL(callback, Run(GoogleServiceAuthError::AuthErrorNone(), access_token_info())); identity_test_env()->WaitForAccessTokenRequestIfNecessaryAndRespondWithToken( - access_token_info().token, access_token_info().expiration_time); + access_token_info().token, access_token_info().expiration_time, + access_token_info().id_token); } TEST_F(PrimaryAccountAccessTokenFetcherTest, @@ -118,7 +120,8 @@ TEST_F(PrimaryAccountAccessTokenFetcherTest, EXPECT_CALL(callback, Run(GoogleServiceAuthError::AuthErrorNone(), access_token_info())); identity_test_env()->WaitForAccessTokenRequestIfNecessaryAndRespondWithToken( - access_token_info().token, access_token_info().expiration_time); + access_token_info().token, access_token_info().expiration_time, + access_token_info().id_token); } TEST_F(PrimaryAccountAccessTokenFetcherTest, ShouldNotReplyIfDestroyed) { @@ -136,7 +139,8 @@ TEST_F(PrimaryAccountAccessTokenFetcherTest, ShouldNotReplyIfDestroyed) { // Fulfilling the request now should have no effect. identity_test_env()->WaitForAccessTokenRequestIfNecessaryAndRespondWithToken( - access_token_info().token, access_token_info().expiration_time); + access_token_info().token, access_token_info().expiration_time, + access_token_info().id_token); } TEST_F(PrimaryAccountAccessTokenFetcherTest, OneShotCallsBackWhenSignedOut) { @@ -200,7 +204,8 @@ TEST_F(PrimaryAccountAccessTokenFetcherTest, ShouldWaitForSignIn) { EXPECT_CALL(callback, Run(GoogleServiceAuthError::AuthErrorNone(), access_token_info())); identity_test_env()->WaitForAccessTokenRequestIfNecessaryAndRespondWithToken( - access_token_info().token, access_token_info().expiration_time); + access_token_info().token, access_token_info().expiration_time, + access_token_info().id_token); // The request should not have to have been retried. EXPECT_FALSE(fetcher->access_token_request_retried()); @@ -228,7 +233,8 @@ TEST_F(PrimaryAccountAccessTokenFetcherTest, ShouldWaitForRefreshToken) { EXPECT_CALL(callback, Run(GoogleServiceAuthError::AuthErrorNone(), access_token_info())); identity_test_env()->WaitForAccessTokenRequestIfNecessaryAndRespondWithToken( - access_token_info().token, access_token_info().expiration_time); + access_token_info().token, access_token_info().expiration_time, + access_token_info().id_token); // The request should not have to have been retried. EXPECT_FALSE(fetcher->access_token_request_retried()); @@ -294,7 +300,8 @@ TEST_F(PrimaryAccountAccessTokenFetcherTest, EXPECT_CALL(callback, Run(GoogleServiceAuthError::AuthErrorNone(), access_token_info())); identity_test_env()->WaitForAccessTokenRequestIfNecessaryAndRespondWithToken( - access_token_info().token, access_token_info().expiration_time); + access_token_info().token, access_token_info().expiration_time, + access_token_info().id_token); } TEST_F(PrimaryAccountAccessTokenFetcherTest, diff --git a/chromium/services/identity/public/objc/BUILD.gn b/chromium/services/identity/public/objc/BUILD.gn new file mode 100644 index 00000000000..0e9cf2dd08d --- /dev/null +++ b/chromium/services/identity/public/objc/BUILD.gn @@ -0,0 +1,15 @@ +# Copyright 2018 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +source_set("objc") { + configs += [ "//build/config/compiler:enable_arc" ] + sources = [ + "identity_manager_observer_bridge.h", + "identity_manager_observer_bridge.mm", + ] + + public_deps = [ + "//services/identity/public/cpp", + ] +} diff --git a/chromium/services/identity/public/objc/identity_manager_observer_bridge.h b/chromium/services/identity/public/objc/identity_manager_observer_bridge.h new file mode 100644 index 00000000000..9285e4f694e --- /dev/null +++ b/chromium/services/identity/public/objc/identity_manager_observer_bridge.h @@ -0,0 +1,67 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef SERVICES_IDENTITY_PUBLIC_OBJC_IDENTITY_MANAGER_OBSERVER_BRIDGE_H_ +#define SERVICES_IDENTITY_PUBLIC_OBJC_IDENTITY_MANAGER_OBSERVER_BRIDGE_H_ + +#import <Foundation/Foundation.h> +#include <vector> + +#include "services/identity/public/cpp/identity_manager.h" + +// Implement this protocol and pass your implementation into an +// IdentityManagerObserverBridge object to receive IdentityManager observer +// callbacks in Objective-C. +@protocol IdentityManagerObserverBridgeDelegate<NSObject> + +@optional + +// These callbacks follow the semantics of the corresponding +// IdentityManager::Observer callbacks. See the comments on +// IdentityManager::Observer in identity_manager.h for the specification of +// these semantics. + +- (void)onPrimaryAccountSet:(const AccountInfo&)primaryAccountInfo; +- (void)onPrimaryAccountCleared:(const AccountInfo&)previousPrimaryAccountInfo; +- (void)onRefreshTokenUpdatedForAccount:(const AccountInfo&)accountInfo + valid:(BOOL)isValid; +- (void)onRefreshTokenRemovedForAccount:(const AccountInfo&)accountInfo; +- (void)onAccountsInCookieUpdated:(const std::vector<AccountInfo>&)accounts; + +@end + +namespace identity { + +// Bridge class that listens for |IdentityManager| notifications and +// passes them to its Objective-C delegate. +class IdentityManagerObserverBridge : public IdentityManager::Observer { + public: + IdentityManagerObserverBridge( + IdentityManager* identity_manager, + id<IdentityManagerObserverBridgeDelegate> delegate); + ~IdentityManagerObserverBridge() override; + + // IdentityManager::Observer. + void OnPrimaryAccountSet(const AccountInfo& primary_account_info) override; + void OnPrimaryAccountCleared( + const AccountInfo& previous_primary_account_info) override; + void OnRefreshTokenUpdatedForAccount(const AccountInfo& account_info, + bool is_valid) override; + void OnRefreshTokenRemovedForAccount( + const AccountInfo& account_info) override; + void OnAccountsInCookieUpdated( + const std::vector<AccountInfo>& accounts) override; + + private: + // Identity manager to observe. + IdentityManager* identity_manager_; + // Delegate to call. + __weak id<IdentityManagerObserverBridgeDelegate> delegate_; + + DISALLOW_COPY_AND_ASSIGN(IdentityManagerObserverBridge); +}; + +} // namespace identity + +#endif // SERVICES_IDENTITY_PUBLIC_OBJC_IDENTITY_MANAGER_OBSERVER_BRIDGE_H_ diff --git a/chromium/services/identity/public/objc/identity_manager_observer_bridge.mm b/chromium/services/identity/public/objc/identity_manager_observer_bridge.mm new file mode 100644 index 00000000000..3880b14db7b --- /dev/null +++ b/chromium/services/identity/public/objc/identity_manager_observer_bridge.mm @@ -0,0 +1,62 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#import "services/identity/public/objc/identity_manager_observer_bridge.h" + +#if !defined(__has_feature) || !__has_feature(objc_arc) +#error "This file requires ARC support." +#endif + +namespace identity { + +IdentityManagerObserverBridge::IdentityManagerObserverBridge( + IdentityManager* identity_manager, + id<IdentityManagerObserverBridgeDelegate> delegate) + : identity_manager_(identity_manager), delegate_(delegate) { + identity_manager_->AddObserver(this); +} + +IdentityManagerObserverBridge::~IdentityManagerObserverBridge() { + identity_manager_->RemoveObserver(this); +} + +void IdentityManagerObserverBridge::OnPrimaryAccountSet( + const AccountInfo& primary_account_info) { + if ([delegate_ respondsToSelector:@selector(onPrimaryAccountSet:)]) { + [delegate_ onPrimaryAccountSet:primary_account_info]; + } +} + +void IdentityManagerObserverBridge::OnPrimaryAccountCleared( + const AccountInfo& previous_primary_account_info) { + if ([delegate_ respondsToSelector:@selector(onPrimaryAccountCleared:)]) { + [delegate_ onPrimaryAccountCleared:previous_primary_account_info]; + } +} + +void IdentityManagerObserverBridge::OnRefreshTokenUpdatedForAccount( + const AccountInfo& account_info, + bool is_valid) { + if ([delegate_ respondsToSelector:@selector + (onRefreshTokenUpdatedForAccount:valid:)]) { + [delegate_ onRefreshTokenUpdatedForAccount:account_info valid:is_valid]; + } +} + +void IdentityManagerObserverBridge::OnRefreshTokenRemovedForAccount( + const AccountInfo& account_info) { + if ([delegate_ + respondsToSelector:@selector(onRefreshTokenRemovedForAccount:)]) { + [delegate_ onRefreshTokenRemovedForAccount:account_info]; + } +} + +void IdentityManagerObserverBridge::OnAccountsInCookieUpdated( + const std::vector<AccountInfo>& accounts) { + if ([delegate_ respondsToSelector:@selector(onAccountsInCookieUpdated:)]) { + [delegate_ onAccountsInCookieUpdated:accounts]; + } +} + +} // namespace identity diff --git a/chromium/services/media_session/BUILD.gn b/chromium/services/media_session/BUILD.gn index c2d97c5aced..71936902c6b 100644 --- a/chromium/services/media_session/BUILD.gn +++ b/chromium/services/media_session/BUILD.gn @@ -12,12 +12,20 @@ import("//testing/test.gni") source_set("lib") { sources = [ + "audio_focus_manager.cc", + "audio_focus_manager.h", + "audio_focus_manager_metrics_helper.cc", + "audio_focus_manager_metrics_helper.h", "media_session_service.cc", "media_session_service.h", ] + configs += [ "//build/config/compiler:wexit_time_destructors" ] + deps = [ "//mojo/public/cpp/bindings", + "//services/media_session/public/cpp", + "//services/media_session/public/mojom", ] public_deps = [ @@ -34,13 +42,19 @@ service_manifest("manifest") { source_set("tests") { testonly = true sources = [ + "audio_focus_manager_unittest.cc", "media_session_service_unittest.cc", + "mock_media_session.cc", + "mock_media_session.h", ] deps = [ ":lib", "//base", "//base/test:test_support", + "//services/media_session/public/cpp", + "//services/media_session/public/cpp/test:test_support", + "//services/media_session/public/mojom", "//services/service_manager/public/cpp/test:test_support", "//testing/gtest", ] diff --git a/chromium/services/media_session/audio_focus_manager.cc b/chromium/services/media_session/audio_focus_manager.cc new file mode 100644 index 00000000000..1e15c00b885 --- /dev/null +++ b/chromium/services/media_session/audio_focus_manager.cc @@ -0,0 +1,381 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "services/media_session/audio_focus_manager.h" + +#include <iterator> +#include <utility> + +#include "base/containers/adapters.h" +#include "base/threading/thread_task_runner_handle.h" +#include "base/unguessable_token.h" +#include "mojo/public/cpp/bindings/interface_request.h" +#include "services/media_session/audio_focus_manager_metrics_helper.h" +#include "services/media_session/public/cpp/switches.h" +#include "services/media_session/public/mojom/audio_focus.mojom.h" + +namespace media_session { + +class AudioFocusManager::StackRow : public mojom::AudioFocusRequestClient { + public: + StackRow(AudioFocusManager* owner, + mojom::AudioFocusRequestClientRequest request, + mojom::MediaSessionPtr session, + mojom::MediaSessionInfoPtr session_info, + mojom::AudioFocusType audio_focus_type, + RequestId id, + const std::string& source_name) + : id_(id), + source_name_(source_name), + metrics_helper_(source_name), + session_(std::move(session)), + session_info_(std::move(session_info)), + audio_focus_type_(audio_focus_type), + binding_(this, std::move(request)), + owner_(owner) { + // Listen for mojo errors. + binding_.set_connection_error_handler( + base::BindOnce(&AudioFocusManager::StackRow::OnConnectionError, + base::Unretained(this))); + session_.set_connection_error_handler( + base::BindOnce(&AudioFocusManager::StackRow::OnConnectionError, + base::Unretained(this))); + + metrics_helper_.OnRequestAudioFocus( + AudioFocusManagerMetricsHelper::AudioFocusRequestSource::kInitial, + audio_focus_type); + } + + ~StackRow() override = default; + + // mojom::AudioFocusRequestClient. + void RequestAudioFocus(mojom::MediaSessionInfoPtr session_info, + mojom::AudioFocusType type, + RequestAudioFocusCallback callback) override { + session_info_ = std::move(session_info); + + if (IsActive() && owner_->IsSessionOnTopOfAudioFocusStack(id(), type)) { + // Early returning if |media_session| is already on top (has focus) and is + // active. + std::move(callback).Run(); + return; + } + + // Remove this StackRow for the audio focus stack. + std::unique_ptr<StackRow> row = owner_->RemoveFocusEntryIfPresent(id()); + DCHECK(row); + + owner_->RequestAudioFocusInternal(std::move(row), type, + std::move(callback)); + + metrics_helper_.OnRequestAudioFocus( + AudioFocusManagerMetricsHelper::AudioFocusRequestSource::kUpdate, + audio_focus_type_); + } + + void AbandonAudioFocus() override { + metrics_helper_.OnAbandonAudioFocus( + AudioFocusManagerMetricsHelper::AudioFocusAbandonSource::kAPI); + + owner_->AbandonAudioFocusInternal(id_); + } + + void MediaSessionInfoChanged(mojom::MediaSessionInfoPtr info) override { + session_info_ = std::move(info); + } + + void GetRequestId(GetRequestIdCallback callback) override { + std::move(callback).Run(id()); + } + + mojom::MediaSession* session() { return session_.get(); } + const mojom::MediaSessionInfoPtr& info() const { return session_info_; } + mojom::AudioFocusType audio_focus_type() const { return audio_focus_type_; } + + void SetAudioFocusType(mojom::AudioFocusType type) { + audio_focus_type_ = type; + } + + bool IsActive() const { + return session_info_->state == + mojom::MediaSessionInfo::SessionState::kActive; + } + + RequestId id() const { return id_; } + + const std::string& source_name() const { return source_name_; } + + private: + void OnConnectionError() { + // Since we have multiple pathways that can call |OnConnectionError| we + // should use the |encountered_error_| bit to make sure we abandon focus + // just the first time. + if (encountered_error_) + return; + encountered_error_ = true; + + metrics_helper_.OnAbandonAudioFocus( + AudioFocusManagerMetricsHelper::AudioFocusAbandonSource:: + kConnectionError); + + base::ThreadTaskRunnerHandle::Get()->PostTask( + FROM_HERE, base::BindOnce(&AudioFocusManager::AbandonAudioFocusInternal, + base::Unretained(owner_), id_)); + } + + const RequestId id_; + const std::string source_name_; + + AudioFocusManagerMetricsHelper metrics_helper_; + bool encountered_error_ = false; + + mojom::MediaSessionPtr session_; + mojom::MediaSessionInfoPtr session_info_; + mojom::AudioFocusType audio_focus_type_; + mojo::Binding<mojom::AudioFocusRequestClient> binding_; + + // Weak pointer to the owning |AudioFocusManager| instance. + AudioFocusManager* owner_; + + DISALLOW_COPY_AND_ASSIGN(StackRow); +}; + +void AudioFocusManager::RequestAudioFocus( + mojom::AudioFocusRequestClientRequest request, + mojom::MediaSessionPtr media_session, + mojom::MediaSessionInfoPtr session_info, + mojom::AudioFocusType type, + RequestAudioFocusCallback callback) { + RequestAudioFocusInternal( + std::make_unique<StackRow>( + this, std::move(request), std::move(media_session), + std::move(session_info), type, base::UnguessableToken::Create(), + GetBindingSourceName()), + type, std::move(callback)); +} + +void AudioFocusManager::GetFocusRequests(GetFocusRequestsCallback callback) { + std::vector<mojom::AudioFocusRequestStatePtr> requests; + + for (const auto& row : audio_focus_stack_) { + auto request = mojom::AudioFocusRequestState::New(); + request->session_info = row->info().Clone(); + request->audio_focus_type = row->audio_focus_type(); + request->request_id = row->id(); + request->source_name = row->source_name(); + requests.push_back(std::move(request)); + } + + std::move(callback).Run(std::move(requests)); +} + +void AudioFocusManager::GetDebugInfoForRequest( + const RequestId& request_id, + GetDebugInfoForRequestCallback callback) { + for (auto& row : audio_focus_stack_) { + if (row->id() != request_id) + continue; + + row->session()->GetDebugInfo(std::move(callback)); + return; + } + + std::move(callback).Run(mojom::MediaSessionDebugInfo::New()); +} + +void AudioFocusManager::AbandonAudioFocusInternal(RequestId id) { + if (audio_focus_stack_.empty()) + return; + + if (audio_focus_stack_.back()->id() != id) { + RemoveFocusEntryIfPresent(id); + return; + } + + auto row = std::move(audio_focus_stack_.back()); + audio_focus_stack_.pop_back(); + + if (audio_focus_stack_.empty()) { + // Notify observers that we lost audio focus. + observers_.ForAllPtrs([&row](mojom::AudioFocusObserver* observer) { + observer->OnFocusLost(row->info().Clone()); + }); + return; + } + + if (IsAudioFocusEnforcementEnabled()) + EnforceAudioFocusAbandon(row->audio_focus_type()); + + // Notify observers that we lost audio focus. + observers_.ForAllPtrs([&row](mojom::AudioFocusObserver* observer) { + observer->OnFocusLost(row->info().Clone()); + }); +} + +void AudioFocusManager::AddObserver(mojom::AudioFocusObserverPtr observer) { + DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); + observers_.AddPtr(std::move(observer)); +} + +void AudioFocusManager::SetSourceName(const std::string& name) { + DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); + bindings_.dispatch_context()->source_name = name; +} + +void AudioFocusManager::BindToInterface( + mojom::AudioFocusManagerRequest request) { + DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); + bindings_.AddBinding(this, std::move(request), + std::make_unique<BindingContext>()); +} + +void AudioFocusManager::BindToDebugInterface( + mojom::AudioFocusManagerDebugRequest request) { + DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); + debug_bindings_.AddBinding(this, std::move(request)); +} + +void AudioFocusManager::CloseAllMojoObjects() { + DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); + observers_.CloseAll(); + bindings_.CloseAllBindings(); + debug_bindings_.CloseAllBindings(); +} + +void AudioFocusManager::RequestAudioFocusInternal( + std::unique_ptr<StackRow> row, + mojom::AudioFocusType type, + base::OnceCallback<void()> callback) { + // If audio focus is enabled then we should enforce this request and make sure + // the new active session is not ducking. + if (IsAudioFocusEnforcementEnabled()) { + EnforceAudioFocusRequest(type); + row->session()->StopDucking(); + } + + row->SetAudioFocusType(type); + audio_focus_stack_.push_back(std::move(row)); + + // Notify observers that we were gained audio focus. + mojom::MediaSessionInfoPtr session_info = + audio_focus_stack_.back()->info().Clone(); + observers_.ForAllPtrs( + [&session_info, type](mojom::AudioFocusObserver* observer) { + observer->OnFocusGained(session_info.Clone(), type); + }); + + // We always grant the audio focus request but this may not always be the case + // in the future. + std::move(callback).Run(); +} + +void AudioFocusManager::EnforceAudioFocusRequest(mojom::AudioFocusType type) { + DCHECK(IsAudioFocusEnforcementEnabled()); + + for (auto& old_session : audio_focus_stack_) { + // If the session has the force duck flag set then we should always duck it. + if (old_session->info()->force_duck) { + old_session->session()->StartDucking(); + continue; + } + + switch (type) { + case mojom::AudioFocusType::kGain: + case mojom::AudioFocusType::kGainTransient: + old_session->session()->Suspend( + mojom::MediaSession::SuspendType::kSystem); + break; + case mojom::AudioFocusType::kGainTransientMayDuck: + old_session->session()->StartDucking(); + break; + } + } +} + +void AudioFocusManager::EnforceAudioFocusAbandon(mojom::AudioFocusType type) { + DCHECK(IsAudioFocusEnforcementEnabled()); + + // Allow the top-most MediaSession having force duck to unduck even if + // it is not active. + for (auto iter = audio_focus_stack_.rbegin(); + iter != audio_focus_stack_.rend(); ++iter) { + if (!(*iter)->info()->force_duck) + continue; + + // TODO(beccahughes): Replace with std::rotate. + auto duck_row = std::move(*iter); + duck_row->session()->StopDucking(); + audio_focus_stack_.erase(std::next(iter).base()); + audio_focus_stack_.push_back(std::move(duck_row)); + return; + } + + DCHECK(!audio_focus_stack_.empty()); + StackRow* top = audio_focus_stack_.back().get(); + + switch (type) { + case mojom::AudioFocusType::kGain: + // Do nothing. The abandoned session suspended all the media sessions and + // they should stay suspended to avoid surprising the user. + break; + case mojom::AudioFocusType::kGainTransient: + // The abandoned session suspended all the media sessions but we should + // start playing the top one again as the abandoned media was transient. + top->session()->Resume(mojom::MediaSession::SuspendType::kSystem); + break; + case mojom::AudioFocusType::kGainTransientMayDuck: + // The abandoned session ducked all the media sessions so we should unduck + // them. If they are not playing then they will not resume. + for (auto& session : base::Reversed(audio_focus_stack_)) { + session->session()->StopDucking(); + + // If the new session is ducking then we should continue ducking all but + // the new session. + if (top->audio_focus_type() == + mojom::AudioFocusType::kGainTransientMayDuck) + break; + } + break; + } +} + +AudioFocusManager::AudioFocusManager() { + DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); +} + +AudioFocusManager::~AudioFocusManager() { + DCHECK(observers_.empty()); + DCHECK(bindings_.empty()); + DCHECK(debug_bindings_.empty()); +} + +std::unique_ptr<AudioFocusManager::StackRow> +AudioFocusManager::RemoveFocusEntryIfPresent(RequestId id) { + std::unique_ptr<StackRow> row; + + for (auto iter = audio_focus_stack_.begin(); iter != audio_focus_stack_.end(); + ++iter) { + if ((*iter)->id() == id) { + row.swap((*iter)); + audio_focus_stack_.erase(iter); + break; + } + } + + return row; +} + +const std::string& AudioFocusManager::GetBindingSourceName() const { + DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); + return bindings_.dispatch_context()->source_name; +} + +bool AudioFocusManager::IsSessionOnTopOfAudioFocusStack( + RequestId id, + mojom::AudioFocusType type) const { + return !audio_focus_stack_.empty() && audio_focus_stack_.back()->id() == id && + audio_focus_stack_.back()->audio_focus_type() == type; +} + +} // namespace media_session diff --git a/chromium/services/media_session/audio_focus_manager.h b/chromium/services/media_session/audio_focus_manager.h new file mode 100644 index 00000000000..4ac825ef121 --- /dev/null +++ b/chromium/services/media_session/audio_focus_manager.h @@ -0,0 +1,120 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef SERVICES_MEDIA_SESSION_AUDIO_FOCUS_MANAGER_H_ +#define SERVICES_MEDIA_SESSION_AUDIO_FOCUS_MANAGER_H_ + +#include <list> +#include <string> +#include <unordered_map> + +#include "base/threading/thread_checker.h" +#include "mojo/public/cpp/bindings/binding.h" +#include "mojo/public/cpp/bindings/binding_set.h" +#include "mojo/public/cpp/bindings/interface_ptr.h" +#include "mojo/public/cpp/bindings/interface_ptr_set.h" +#include "services/media_session/public/mojom/audio_focus.mojom.h" + +namespace base { +class UnguessableToken; +} // namespace base + +namespace media_session { + +namespace test { +class MockMediaSession; +} // namespace test + +class AudioFocusManager : public mojom::AudioFocusManager, + public mojom::AudioFocusManagerDebug { + public: + AudioFocusManager(); + ~AudioFocusManager() override; + + // TODO(beccahughes): Remove this. + using RequestId = base::UnguessableToken; + + // mojom::AudioFocusManager. + void RequestAudioFocus(mojom::AudioFocusRequestClientRequest request, + mojom::MediaSessionPtr media_session, + mojom::MediaSessionInfoPtr session_info, + mojom::AudioFocusType type, + RequestAudioFocusCallback callback) override; + void GetFocusRequests(GetFocusRequestsCallback callback) override; + void AddObserver(mojom::AudioFocusObserverPtr observer) override; + void SetSourceName(const std::string& name) override; + + // mojom::AudioFocusManagerDebug. + void GetDebugInfoForRequest(const RequestId& request_id, + GetDebugInfoForRequestCallback callback) override; + + // Bind to a mojom::AudioFocusManagerRequest. + void BindToInterface(mojom::AudioFocusManagerRequest request); + + // Bind to a mojom::AudioFocusManagerDebugRequest. + void BindToDebugInterface(mojom::AudioFocusManagerDebugRequest request); + + // This will close all Mojo bindings and interface pointers. This should be + // called by the MediaSession service before it is destroyed. + void CloseAllMojoObjects(); + + private: + friend class AudioFocusManagerTest; + friend class test::MockMediaSession; + + // StackRow is an AudioFocusRequestClient and allows a media session to + // control its audio focus. + class StackRow; + + // BindingContext stores associated metadata for mojo binding. + struct BindingContext { + // The source name is associated with a binding when a client calls + // |SetSourceName|. It is used to provide more granularity than a + // service_manager::Identity for metrics and for identifying where an audio + // focus request originated from. + std::string source_name; + }; + + void RequestAudioFocusInternal(std::unique_ptr<StackRow>, + mojom::AudioFocusType, + base::OnceCallback<void()>); + void EnforceAudioFocusRequest(mojom::AudioFocusType); + + void AbandonAudioFocusInternal(RequestId); + void EnforceAudioFocusAbandon(mojom::AudioFocusType); + + std::unique_ptr<StackRow> RemoveFocusEntryIfPresent(RequestId id); + + // Returns the source name of the binding currently accessing the Audio + // Focus Manager API over mojo. + const std::string& GetBindingSourceName() const; + + bool IsSessionOnTopOfAudioFocusStack(RequestId id, + mojom::AudioFocusType type) const; + + // Holds mojo bindings for the Audio Focus Manager API. + mojo::BindingSet<mojom::AudioFocusManager, std::unique_ptr<BindingContext>> + bindings_; + + // Holds mojo bindings for the Audio Focus Manager Debug API. + mojo::BindingSet<mojom::AudioFocusManagerDebug> debug_bindings_; + + // Weak reference of managed observers. Observers are expected to remove + // themselves before being destroyed. + mojo::InterfacePtrSet<mojom::AudioFocusObserver> observers_; + + // A stack of Mojo interface pointers and their requested audio focus type. + // A MediaSession must abandon audio focus before its destruction. + std::list<std::unique_ptr<StackRow>> audio_focus_stack_; + + // Adding observers should happen on the same thread that the service is + // running on. + THREAD_CHECKER(thread_checker_); + + DISALLOW_COPY_AND_ASSIGN(AudioFocusManager); +}; + +} // namespace media_session + +#endif // SERVICES_MEDIA_SESSION_AUDIO_FOCUS_MANAGER_H_ diff --git a/chromium/services/media_session/audio_focus_manager_metrics_helper.cc b/chromium/services/media_session/audio_focus_manager_metrics_helper.cc new file mode 100644 index 00000000000..bbf5e442a6c --- /dev/null +++ b/chromium/services/media_session/audio_focus_manager_metrics_helper.cc @@ -0,0 +1,101 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "services/media_session/audio_focus_manager_metrics_helper.h" + +#include "base/metrics/histogram.h" +#include "base/metrics/histogram_base.h" +#include "base/strings/string_util.h" +#include "services/media_session/public/mojom/audio_focus.mojom.h" + +namespace media_session { + +namespace { + +static const char kHistogramPrefix[] = "Media.Session.AudioFocus."; +static const char kHistogramSeparator[] = "."; + +static const char kRequestAudioFocusName[] = "Request"; +static const char kAudioFocusTypeName[] = "Type"; +static const char kAbandonAudioFocusName[] = "Abandon"; + +static constexpr base::HistogramBase::Sample kHistogramMinimum = 1; + +} // namespace + +AudioFocusManagerMetricsHelper::AudioFocusManagerMetricsHelper( + const std::string& source_name) + : source_name_(source_name), + request_source_histogram_(GetHistogram( + kRequestAudioFocusName, + static_cast<Sample>(AudioFocusRequestSource::kMaxValue))), + focus_type_histogram_( + GetHistogram(kAudioFocusTypeName, + static_cast<Sample>(AudioFocusType::kMaxValue))), + abandon_source_histogram_(GetHistogram( + kAbandonAudioFocusName, + static_cast<Sample>(AudioFocusAbandonSource::kMaxValue))) {} + +AudioFocusManagerMetricsHelper::~AudioFocusManagerMetricsHelper() = default; + +void AudioFocusManagerMetricsHelper::OnRequestAudioFocus( + AudioFocusManagerMetricsHelper::AudioFocusRequestSource source, + mojom::AudioFocusType type) { + if (!ShouldRecordMetrics()) + return; + + request_source_histogram_->Add(static_cast<Sample>(source)); + focus_type_histogram_->Add(static_cast<Sample>(FromMojoFocusType(type))); +} + +void AudioFocusManagerMetricsHelper::OnAbandonAudioFocus( + AudioFocusManagerMetricsHelper::AudioFocusAbandonSource source) { + if (!ShouldRecordMetrics()) + return; + + abandon_source_histogram_->Add(static_cast<Sample>(source)); +} + +base::HistogramBase* AudioFocusManagerMetricsHelper::GetHistogram( + const char* name, + Sample max) const { + std::string histogram_name; + histogram_name.append(kHistogramPrefix); + histogram_name.append(name); + histogram_name.append(kHistogramSeparator); + + // This will ensure that |source_name| starts with an upper case letter. + for (auto it = source_name_.begin(); it < source_name_.end(); ++it) { + if (it == source_name_.begin()) + histogram_name.push_back(base::ToUpperASCII(*it)); + else + histogram_name.push_back(*it); + } + + return base::LinearHistogram::FactoryGet(histogram_name, kHistogramMinimum, + max, max + 1, + base::HistogramBase::kNoFlags); +} + +// static +AudioFocusManagerMetricsHelper::AudioFocusType +AudioFocusManagerMetricsHelper::FromMojoFocusType(mojom::AudioFocusType type) { + switch (type) { + case mojom::AudioFocusType::kGain: + return AudioFocusType::kGain; + case mojom::AudioFocusType::kGainTransientMayDuck: + return AudioFocusType::kGainTransientMayDuck; + case mojom::AudioFocusType::kGainTransient: + return AudioFocusType::kGainTransient; + } + + NOTREACHED(); + return AudioFocusType::kUnknown; +} + +bool AudioFocusManagerMetricsHelper::ShouldRecordMetrics() const { + return !source_name_.empty(); +} + +} // namespace media_session diff --git a/chromium/services/media_session/audio_focus_manager_metrics_helper.h b/chromium/services/media_session/audio_focus_manager_metrics_helper.h new file mode 100644 index 00000000000..28b8f9d1308 --- /dev/null +++ b/chromium/services/media_session/audio_focus_manager_metrics_helper.h @@ -0,0 +1,79 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef SERVICES_MEDIA_SESSION_AUDIO_FOCUS_MANAGER_METRICS_HELPER_H_ +#define SERVICES_MEDIA_SESSION_AUDIO_FOCUS_MANAGER_METRICS_HELPER_H_ + +#include <string> + +#include "base/macros.h" +#include "base/metrics/histogram_base.h" + +namespace media_session { + +namespace mojom { +enum class AudioFocusType; +} // namespace mojom + +class AudioFocusManagerMetricsHelper { + public: + using Sample = base::HistogramBase::Sample; + + AudioFocusManagerMetricsHelper(const std::string& source_name); + ~AudioFocusManagerMetricsHelper(); + + // This is used for UMA histogram + // (Media.Session.AudioFocus.*.RequestAudioFocus). New values should be + // appended only and update |kMaxValue|. + enum class AudioFocusRequestSource { + kUnknown = 0, + kInitial = 1, + kUpdate = 2, + kMaxValue = kUpdate // Leave at the end. + }; + + // This is used for UMA histogram + // (Media.Session.AudioFocus.*.AbandonAudioFocus). New values should be + // appended only and update |kMaxValue|. + enum class AudioFocusAbandonSource { + kUnknown = 0, + kAPI = 1, + kConnectionError = 2, + kMaxValue = kConnectionError // Leave at the end. + }; + + // This is used for UMA histogram + // (Media.Session.AudioFocus.*.AudioFocusType). New values should be + // appended only and update |kMaxValue|. It should mirror the + // media_session::mojom::AudioFocusType enum. + enum class AudioFocusType { + kUnknown = 0, + kGain = 1, + kGainTransientMayDuck = 2, + kGainTransient = 3, + kMaxValue = kGainTransient // Leave at the end. + }; + + void OnRequestAudioFocus(AudioFocusRequestSource, mojom::AudioFocusType); + void OnAbandonAudioFocus(AudioFocusAbandonSource); + + private: + static AudioFocusType FromMojoFocusType(mojom::AudioFocusType); + + base::HistogramBase* GetHistogram(const char* name, Sample max) const; + + bool ShouldRecordMetrics() const; + + const std::string& source_name_; + + base::HistogramBase* const request_source_histogram_; + base::HistogramBase* const focus_type_histogram_; + base::HistogramBase* const abandon_source_histogram_; + + DISALLOW_COPY_AND_ASSIGN(AudioFocusManagerMetricsHelper); +}; + +} // namespace media_session + +#endif // SERVICES_MEDIA_SESSION_AUDIO_FOCUS_MANAGER_METRICS_HELPER_H_ diff --git a/chromium/services/media_session/audio_focus_manager_unittest.cc b/chromium/services/media_session/audio_focus_manager_unittest.cc new file mode 100644 index 00000000000..e5225fde109 --- /dev/null +++ b/chromium/services/media_session/audio_focus_manager_unittest.cc @@ -0,0 +1,923 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "services/media_session/audio_focus_manager.h" + +#include <memory> +#include <utility> +#include <vector> + +#include "base/callback.h" +#include "base/command_line.h" +#include "base/run_loop.h" +#include "base/test/metrics/histogram_tester.h" +#include "base/test/scoped_command_line.h" +#include "base/test/scoped_task_environment.h" +#include "mojo/public/cpp/bindings/binding_set.h" +#include "mojo/public/cpp/bindings/interface_request.h" +#include "services/media_session/audio_focus_manager_metrics_helper.h" +#include "services/media_session/media_session_service.h" +#include "services/media_session/mock_media_session.h" +#include "services/media_session/public/cpp/switches.h" +#include "services/media_session/public/cpp/test/audio_focus_test_util.h" +#include "services/media_session/public/mojom/audio_focus.mojom.h" +#include "services/service_manager/public/cpp/test/test_connector_factory.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace media_session { + +namespace { + +const char kExampleSourceName[] = "test"; +const char kExampleSourceName2[] = "test2"; + +} // anonymous namespace + +// This tests the Audio Focus Manager API. The parameter determines whether +// audio focus is enabled or not. If it is not enabled it should track the media +// sessions but not enforce single session focus. +class AudioFocusManagerTest : public testing::TestWithParam<bool> { + public: + AudioFocusManagerTest() = default; + + void SetUp() override { + if (!GetParam()) { + command_line_.GetProcessCommandLine()->AppendSwitchASCII( + switches::kEnableAudioFocus, switches::kEnableAudioFocusNoEnforce); + } + + ASSERT_EQ(GetParam(), IsAudioFocusEnforcementEnabled()); + + // Create an instance of the MediaSessionService. + connector_factory_ = + service_manager::TestConnectorFactory::CreateForUniqueService( + MediaSessionService::Create()); + connector_ = connector_factory_->CreateConnector(); + + // Bind |audio_focus_ptr_| to AudioFocusManager. + connector_->BindInterface("test", mojo::MakeRequest(&audio_focus_ptr_)); + + // Bind |audio_focus_debug_ptr_| to AudioFocusManagerDebug. + connector_->BindInterface("test", + mojo::MakeRequest(&audio_focus_debug_ptr_)); + } + + void TearDown() override { + // Run pending tasks. + base::RunLoop().RunUntilIdle(); + } + + AudioFocusManager::RequestId GetAudioFocusedSession() { + const auto audio_focus_requests = GetRequests(); + for (auto iter = audio_focus_requests.rbegin(); + iter != audio_focus_requests.rend(); ++iter) { + if ((*iter)->audio_focus_type == mojom::AudioFocusType::kGain) + return (*iter)->request_id.value(); + } + return base::UnguessableToken::Null(); + } + + int GetTransientCount() { + return GetCountForType(mojom::AudioFocusType::kGainTransient); + } + + int GetTransientMaybeDuckCount() { + return GetCountForType(mojom::AudioFocusType::kGainTransientMayDuck); + } + + void AbandonAudioFocusNoReset(test::MockMediaSession* session) { + session->audio_focus_request()->AbandonAudioFocus(); + session->FlushForTesting(); + audio_focus_ptr_.FlushForTesting(); + } + + AudioFocusManager::RequestId RequestAudioFocus( + test::MockMediaSession* session, + mojom::AudioFocusType audio_focus_type) { + return session->RequestAudioFocusFromService(audio_focus_ptr_, + audio_focus_type); + } + + mojom::MediaSessionDebugInfoPtr GetDebugInfo( + AudioFocusManager::RequestId request_id) { + mojom::MediaSessionDebugInfoPtr result; + base::OnceCallback<void(mojom::MediaSessionDebugInfoPtr)> callback = + base::BindOnce( + [](mojom::MediaSessionDebugInfoPtr* out_result, + mojom::MediaSessionDebugInfoPtr result) { + *out_result = std::move(result); + }, + &result); + + GetDebugService()->GetDebugInfoForRequest(request_id, std::move(callback)); + + audio_focus_ptr_.FlushForTesting(); + audio_focus_debug_ptr_.FlushForTesting(); + + return result; + } + + mojom::MediaSessionInfo::SessionState GetState( + test::MockMediaSession* session) { + mojom::MediaSessionInfo::SessionState state = session->GetState(); + + if (!GetParam()) { + // If audio focus enforcement is disabled then we should never see these + // states in the tests. + EXPECT_NE(mojom::MediaSessionInfo::SessionState::kSuspended, state); + EXPECT_NE(mojom::MediaSessionInfo::SessionState::kDucking, state); + } + + return state; + } + + std::unique_ptr<test::TestAudioFocusObserver> CreateObserver() { + std::unique_ptr<test::TestAudioFocusObserver> observer = + std::make_unique<test::TestAudioFocusObserver>(); + + mojom::AudioFocusObserverPtr observer_ptr; + observer->BindToMojoRequest(mojo::MakeRequest(&observer_ptr)); + GetService()->AddObserver(std::move(observer_ptr)); + + audio_focus_ptr_.FlushForTesting(); + return observer; + } + + mojom::MediaSessionInfo::SessionState GetStateFromParam( + mojom::MediaSessionInfo::SessionState state) { + // If enforcement is enabled then returns the provided state, otherwise + // returns kActive because without enforcement we did not change state. + if (GetParam()) + return state; + return mojom::MediaSessionInfo::SessionState::kActive; + } + + void SetSourceName(const std::string& name) { + GetService()->SetSourceName(name); + audio_focus_ptr_.FlushForTesting(); + } + + mojom::AudioFocusManagerPtr CreateAudioFocusManagerPtr() { + mojom::AudioFocusManagerPtr ptr; + connector_->BindInterface("test", mojo::MakeRequest(&ptr)); + return ptr; + } + + const std::string GetSourceNameForLastRequest() { + std::vector<mojom::AudioFocusRequestStatePtr> requests = GetRequests(); + EXPECT_TRUE(requests.back()); + return requests.back()->source_name.value(); + } + + std::unique_ptr<base::HistogramSamples> GetHistogramSamplesSinceTestStart( + const std::string& name) { + return histogram_tester_.GetHistogramSamplesSinceCreation(name); + } + + int GetAudioFocusHistogramCount() { + return histogram_tester_ + .GetTotalCountsForPrefix("Media.Session.AudioFocus.") + .size(); + } + + private: + int GetCountForType(mojom::AudioFocusType type) { + const auto audio_focus_requests = GetRequests(); + return std::count_if(audio_focus_requests.begin(), + audio_focus_requests.end(), + [type](const auto& session) { + return session->audio_focus_type == type; + }); + } + + std::vector<mojom::AudioFocusRequestStatePtr> GetRequests() { + std::vector<mojom::AudioFocusRequestStatePtr> result; + + GetService()->GetFocusRequests(base::BindOnce( + [](std::vector<mojom::AudioFocusRequestStatePtr>* out, + std::vector<mojom::AudioFocusRequestStatePtr> requests) { + for (auto& request : requests) + out->push_back(request.Clone()); + }, + &result)); + + audio_focus_ptr_.FlushForTesting(); + return result; + } + + mojom::AudioFocusManager* GetService() const { + return audio_focus_ptr_.get(); + } + + mojom::AudioFocusManagerDebug* GetDebugService() const { + return audio_focus_debug_ptr_.get(); + } + + void FlushForTestingIfEnabled() { + if (!GetParam()) + return; + + audio_focus_ptr_.FlushForTesting(); + } + + base::test::ScopedCommandLine command_line_; + base::test::ScopedTaskEnvironment task_environment_; + base::HistogramTester histogram_tester_; + + std::unique_ptr<service_manager::TestConnectorFactory> connector_factory_; + std::unique_ptr<service_manager::Connector> connector_; + + mojom::AudioFocusManagerPtr audio_focus_ptr_; + mojom::AudioFocusManagerDebugPtr audio_focus_debug_ptr_; + + DISALLOW_COPY_AND_ASSIGN(AudioFocusManagerTest); +}; + +INSTANTIATE_TEST_CASE_P(, AudioFocusManagerTest, testing::Bool()); + +TEST_P(AudioFocusManagerTest, RequestAudioFocusGain_ReplaceFocusedEntry) { + test::MockMediaSession media_session_1; + test::MockMediaSession media_session_2; + test::MockMediaSession media_session_3; + + EXPECT_EQ(base::UnguessableToken::Null(), GetAudioFocusedSession()); + EXPECT_EQ(mojom::MediaSessionInfo::SessionState::kInactive, + GetState(&media_session_1)); + EXPECT_EQ(mojom::MediaSessionInfo::SessionState::kInactive, + GetState(&media_session_2)); + EXPECT_EQ(mojom::MediaSessionInfo::SessionState::kInactive, + GetState(&media_session_3)); + + AudioFocusManager::RequestId request_id_1 = + RequestAudioFocus(&media_session_1, mojom::AudioFocusType::kGain); + EXPECT_EQ(request_id_1, GetAudioFocusedSession()); + EXPECT_EQ(mojom::MediaSessionInfo::SessionState::kActive, + GetState(&media_session_1)); + + AudioFocusManager::RequestId request_id_2 = + RequestAudioFocus(&media_session_2, mojom::AudioFocusType::kGain); + EXPECT_EQ(request_id_2, GetAudioFocusedSession()); + EXPECT_EQ( + GetStateFromParam(mojom::MediaSessionInfo::SessionState::kSuspended), + GetState(&media_session_1)); + + AudioFocusManager::RequestId request_id_3 = + RequestAudioFocus(&media_session_3, mojom::AudioFocusType::kGain); + EXPECT_EQ(request_id_3, GetAudioFocusedSession()); + EXPECT_EQ( + GetStateFromParam(mojom::MediaSessionInfo::SessionState::kSuspended), + GetState(&media_session_2)); +} + +TEST_P(AudioFocusManagerTest, RequestAudioFocusGain_Duplicate) { + test::MockMediaSession media_session; + + EXPECT_EQ(base::UnguessableToken::Null(), GetAudioFocusedSession()); + + AudioFocusManager::RequestId request_id = + RequestAudioFocus(&media_session, mojom::AudioFocusType::kGain); + EXPECT_EQ(request_id, GetAudioFocusedSession()); + + RequestAudioFocus(&media_session, mojom::AudioFocusType::kGain); + EXPECT_EQ(request_id, GetAudioFocusedSession()); +} + +TEST_P(AudioFocusManagerTest, RequestAudioFocusGain_FromTransient) { + test::MockMediaSession media_session; + + AudioFocusManager::RequestId request_id = + RequestAudioFocus(&media_session, mojom::AudioFocusType::kGainTransient); + EXPECT_EQ(base::UnguessableToken::Null(), GetAudioFocusedSession()); + EXPECT_EQ(1, GetTransientCount()); + + RequestAudioFocus(&media_session, mojom::AudioFocusType::kGain); + EXPECT_EQ(request_id, GetAudioFocusedSession()); + EXPECT_EQ(0, GetTransientCount()); +} + +TEST_P(AudioFocusManagerTest, RequestAudioFocusGain_FromTransientMayDuck) { + test::MockMediaSession media_session; + + AudioFocusManager::RequestId request_id = RequestAudioFocus( + &media_session, mojom::AudioFocusType::kGainTransientMayDuck); + EXPECT_EQ(base::UnguessableToken::Null(), GetAudioFocusedSession()); + EXPECT_EQ(1, GetTransientMaybeDuckCount()); + + RequestAudioFocus(&media_session, mojom::AudioFocusType::kGain); + EXPECT_EQ(request_id, GetAudioFocusedSession()); + EXPECT_EQ(0, GetTransientMaybeDuckCount()); +} + +TEST_P(AudioFocusManagerTest, RequestAudioFocusTransient_FromGain) { + test::MockMediaSession media_session; + + AudioFocusManager::RequestId request_id = + RequestAudioFocus(&media_session, mojom::AudioFocusType::kGain); + + EXPECT_EQ(request_id, GetAudioFocusedSession()); + EXPECT_EQ(0, GetTransientCount()); + + RequestAudioFocus(&media_session, mojom::AudioFocusType::kGainTransient); + EXPECT_EQ(base::UnguessableToken::Null(), GetAudioFocusedSession()); + EXPECT_EQ(1, GetTransientCount()); + EXPECT_NE(mojom::MediaSessionInfo::SessionState::kSuspended, + GetState(&media_session)); +} + +TEST_P(AudioFocusManagerTest, RequestAudioFocusTransientMayDuck_FromGain) { + test::MockMediaSession media_session; + + AudioFocusManager::RequestId request_id = + RequestAudioFocus(&media_session, mojom::AudioFocusType::kGain); + + EXPECT_EQ(request_id, GetAudioFocusedSession()); + EXPECT_EQ(0, GetTransientMaybeDuckCount()); + + RequestAudioFocus(&media_session, + mojom::AudioFocusType::kGainTransientMayDuck); + EXPECT_EQ(base::UnguessableToken::Null(), GetAudioFocusedSession()); + EXPECT_EQ(1, GetTransientMaybeDuckCount()); + EXPECT_NE(mojom::MediaSessionInfo::SessionState::kDucking, + GetState(&media_session)); +} + +TEST_P(AudioFocusManagerTest, RequestAudioFocusTransient_FromGainWhileDucking) { + test::MockMediaSession media_session_1; + test::MockMediaSession media_session_2; + + RequestAudioFocus(&media_session_1, mojom::AudioFocusType::kGain); + EXPECT_EQ(0, GetTransientMaybeDuckCount()); + EXPECT_EQ(mojom::MediaSessionInfo::SessionState::kActive, + GetState(&media_session_1)); + + RequestAudioFocus(&media_session_2, + mojom::AudioFocusType::kGainTransientMayDuck); + EXPECT_EQ(0, GetTransientCount()); + EXPECT_EQ(1, GetTransientMaybeDuckCount()); + EXPECT_EQ(GetStateFromParam(mojom::MediaSessionInfo::SessionState::kDucking), + GetState(&media_session_1)); + + RequestAudioFocus(&media_session_1, mojom::AudioFocusType::kGainTransient); + EXPECT_EQ(1, GetTransientCount()); + EXPECT_EQ(1, GetTransientMaybeDuckCount()); + EXPECT_EQ(mojom::MediaSessionInfo::SessionState::kActive, + GetState(&media_session_1)); +} + +TEST_P(AudioFocusManagerTest, + RequestAudioFocusTransientMayDuck_FromGainWhileDucking) { + test::MockMediaSession media_session_1; + test::MockMediaSession media_session_2; + + RequestAudioFocus(&media_session_1, mojom::AudioFocusType::kGain); + EXPECT_EQ(0, GetTransientMaybeDuckCount()); + EXPECT_EQ(mojom::MediaSessionInfo::SessionState::kActive, + GetState(&media_session_1)); + + RequestAudioFocus(&media_session_2, + mojom::AudioFocusType::kGainTransientMayDuck); + EXPECT_EQ(1, GetTransientMaybeDuckCount()); + EXPECT_EQ(GetStateFromParam(mojom::MediaSessionInfo::SessionState::kDucking), + GetState(&media_session_1)); + + RequestAudioFocus(&media_session_1, + mojom::AudioFocusType::kGainTransientMayDuck); + EXPECT_EQ(2, GetTransientMaybeDuckCount()); + EXPECT_EQ(mojom::MediaSessionInfo::SessionState::kActive, + GetState(&media_session_1)); +} + +TEST_P(AudioFocusManagerTest, AbandonAudioFocus_RemovesFocusedEntry) { + test::MockMediaSession media_session; + + AudioFocusManager::RequestId request_id = + RequestAudioFocus(&media_session, mojom::AudioFocusType::kGain); + EXPECT_EQ(request_id, GetAudioFocusedSession()); + + media_session.AbandonAudioFocusFromClient(); + EXPECT_EQ(base::UnguessableToken::Null(), GetAudioFocusedSession()); +} + +TEST_P(AudioFocusManagerTest, AbandonAudioFocus_MultipleCalls) { + test::MockMediaSession media_session; + + AudioFocusManager::RequestId request_id = + RequestAudioFocus(&media_session, mojom::AudioFocusType::kGain); + EXPECT_EQ(request_id, GetAudioFocusedSession()); + + AbandonAudioFocusNoReset(&media_session); + + std::unique_ptr<test::TestAudioFocusObserver> observer = CreateObserver(); + media_session.AbandonAudioFocusFromClient(); + + EXPECT_EQ(base::UnguessableToken::Null(), GetAudioFocusedSession()); + EXPECT_TRUE(observer->focus_lost_session_.is_null()); +} + +TEST_P(AudioFocusManagerTest, AbandonAudioFocus_RemovesTransientMayDuckEntry) { + test::MockMediaSession media_session; + + RequestAudioFocus(&media_session, + mojom::AudioFocusType::kGainTransientMayDuck); + EXPECT_EQ(1, GetTransientMaybeDuckCount()); + + { + std::unique_ptr<test::TestAudioFocusObserver> observer = CreateObserver(); + media_session.AbandonAudioFocusFromClient(); + + EXPECT_EQ(0, GetTransientMaybeDuckCount()); + EXPECT_TRUE(observer->focus_lost_session_.Equals( + test::GetMediaSessionInfoSync(&media_session))); + } +} + +TEST_P(AudioFocusManagerTest, AbandonAudioFocus_RemovesTransientEntry) { + test::MockMediaSession media_session; + + RequestAudioFocus(&media_session, mojom::AudioFocusType::kGainTransient); + EXPECT_EQ(1, GetTransientCount()); + + { + std::unique_ptr<test::TestAudioFocusObserver> observer = CreateObserver(); + media_session.AbandonAudioFocusFromClient(); + + EXPECT_EQ(0, GetTransientCount()); + EXPECT_TRUE(observer->focus_lost_session_.Equals( + test::GetMediaSessionInfoSync(&media_session))); + } +} + +TEST_P(AudioFocusManagerTest, AbandonAudioFocus_WhileDuckingThenResume) { + test::MockMediaSession media_session_1; + test::MockMediaSession media_session_2; + + RequestAudioFocus(&media_session_1, mojom::AudioFocusType::kGain); + EXPECT_EQ(0, GetTransientMaybeDuckCount()); + EXPECT_NE(mojom::MediaSessionInfo::SessionState::kDucking, + GetState(&media_session_1)); + + RequestAudioFocus(&media_session_2, + mojom::AudioFocusType::kGainTransientMayDuck); + EXPECT_EQ(1, GetTransientMaybeDuckCount()); + EXPECT_EQ(GetStateFromParam(mojom::MediaSessionInfo::SessionState::kDucking), + GetState(&media_session_1)); + + media_session_1.AbandonAudioFocusFromClient(); + EXPECT_EQ(1, GetTransientMaybeDuckCount()); + + media_session_2.AbandonAudioFocusFromClient(); + EXPECT_EQ(0, GetTransientMaybeDuckCount()); + + RequestAudioFocus(&media_session_1, mojom::AudioFocusType::kGain); + EXPECT_NE(mojom::MediaSessionInfo::SessionState::kDucking, + GetState(&media_session_1)); +} + +TEST_P(AudioFocusManagerTest, AbandonAudioFocus_StopsDucking) { + test::MockMediaSession media_session_1; + test::MockMediaSession media_session_2; + + RequestAudioFocus(&media_session_1, mojom::AudioFocusType::kGain); + EXPECT_EQ(0, GetTransientMaybeDuckCount()); + EXPECT_NE(mojom::MediaSessionInfo::SessionState::kDucking, + GetState(&media_session_1)); + + RequestAudioFocus(&media_session_2, + mojom::AudioFocusType::kGainTransientMayDuck); + EXPECT_EQ(1, GetTransientMaybeDuckCount()); + EXPECT_EQ(GetStateFromParam(mojom::MediaSessionInfo::SessionState::kDucking), + GetState(&media_session_1)); + + media_session_2.AbandonAudioFocusFromClient(); + EXPECT_EQ(0, GetTransientMaybeDuckCount()); + EXPECT_NE(mojom::MediaSessionInfo::SessionState::kDucking, + GetState(&media_session_1)); +} + +TEST_P(AudioFocusManagerTest, AbandonAudioFocus_ResumesPlayback) { + test::MockMediaSession media_session_1; + test::MockMediaSession media_session_2; + + RequestAudioFocus(&media_session_1, mojom::AudioFocusType::kGain); + EXPECT_EQ(0, GetTransientCount()); + EXPECT_EQ(mojom::MediaSessionInfo::SessionState::kActive, + GetState(&media_session_1)); + + RequestAudioFocus(&media_session_2, mojom::AudioFocusType::kGainTransient); + EXPECT_EQ(1, GetTransientCount()); + EXPECT_EQ( + GetStateFromParam(mojom::MediaSessionInfo::SessionState::kSuspended), + GetState(&media_session_1)); + + media_session_2.AbandonAudioFocusFromClient(); + EXPECT_EQ(0, GetTransientCount()); + EXPECT_EQ(mojom::MediaSessionInfo::SessionState::kActive, + GetState(&media_session_1)); +} + +TEST_P(AudioFocusManagerTest, DuckWhilePlaying) { + test::MockMediaSession media_session_1; + test::MockMediaSession media_session_2; + + RequestAudioFocus(&media_session_1, mojom::AudioFocusType::kGain); + EXPECT_NE(mojom::MediaSessionInfo::SessionState::kDucking, + GetState(&media_session_1)); + + RequestAudioFocus(&media_session_2, + mojom::AudioFocusType::kGainTransientMayDuck); + EXPECT_EQ(GetStateFromParam(mojom::MediaSessionInfo::SessionState::kDucking), + GetState(&media_session_1)); +} + +TEST_P(AudioFocusManagerTest, GainSuspendsTransient) { + test::MockMediaSession media_session_1; + test::MockMediaSession media_session_2; + + RequestAudioFocus(&media_session_2, mojom::AudioFocusType::kGainTransient); + + RequestAudioFocus(&media_session_1, mojom::AudioFocusType::kGain); + EXPECT_EQ( + GetStateFromParam(mojom::MediaSessionInfo::SessionState::kSuspended), + GetState(&media_session_2)); +} + +TEST_P(AudioFocusManagerTest, GainSuspendsTransientMayDuck) { + test::MockMediaSession media_session_1; + test::MockMediaSession media_session_2; + + RequestAudioFocus(&media_session_2, + mojom::AudioFocusType::kGainTransientMayDuck); + + RequestAudioFocus(&media_session_1, mojom::AudioFocusType::kGain); + EXPECT_EQ( + GetStateFromParam(mojom::MediaSessionInfo::SessionState::kSuspended), + GetState(&media_session_2)); +} + +TEST_P(AudioFocusManagerTest, DuckWithMultipleTransientMayDucks) { + test::MockMediaSession media_session_1; + test::MockMediaSession media_session_2; + test::MockMediaSession media_session_3; + test::MockMediaSession media_session_4; + + RequestAudioFocus(&media_session_1, mojom::AudioFocusType::kGain); + EXPECT_NE(mojom::MediaSessionInfo::SessionState::kDucking, + GetState(&media_session_1)); + + RequestAudioFocus(&media_session_2, mojom::AudioFocusType::kGainTransient); + EXPECT_NE(mojom::MediaSessionInfo::SessionState::kDucking, + GetState(&media_session_2)); + + RequestAudioFocus(&media_session_3, + mojom::AudioFocusType::kGainTransientMayDuck); + EXPECT_EQ(GetStateFromParam(mojom::MediaSessionInfo::SessionState::kDucking), + GetState(&media_session_1)); + EXPECT_EQ(GetStateFromParam(mojom::MediaSessionInfo::SessionState::kDucking), + GetState(&media_session_2)); + + RequestAudioFocus(&media_session_4, + mojom::AudioFocusType::kGainTransientMayDuck); + EXPECT_EQ(GetStateFromParam(mojom::MediaSessionInfo::SessionState::kDucking), + GetState(&media_session_1)); + EXPECT_EQ(GetStateFromParam(mojom::MediaSessionInfo::SessionState::kDucking), + GetState(&media_session_2)); + + media_session_3.AbandonAudioFocusFromClient(); + EXPECT_EQ(GetStateFromParam(mojom::MediaSessionInfo::SessionState::kDucking), + GetState(&media_session_1)); + EXPECT_EQ(GetStateFromParam(mojom::MediaSessionInfo::SessionState::kDucking), + GetState(&media_session_2)); + + media_session_4.AbandonAudioFocusFromClient(); + EXPECT_NE(mojom::MediaSessionInfo::SessionState::kDucking, + GetState(&media_session_1)); + EXPECT_NE(mojom::MediaSessionInfo::SessionState::kDucking, + GetState(&media_session_2)); +} + +TEST_P(AudioFocusManagerTest, MediaSessionDestroyed_ReleasesFocus) { + { + test::MockMediaSession media_session; + + AudioFocusManager::RequestId request_id = + RequestAudioFocus(&media_session, mojom::AudioFocusType::kGain); + EXPECT_EQ(request_id, GetAudioFocusedSession()); + } + + // If the media session is destroyed without abandoning audio focus we do not + // know until we next interact with the manager. + test::MockMediaSession media_session; + RequestAudioFocus(&media_session, + mojom::AudioFocusType::kGainTransientMayDuck); + EXPECT_EQ(base::UnguessableToken::Null(), GetAudioFocusedSession()); +} + +TEST_P(AudioFocusManagerTest, MediaSessionDestroyed_ReleasesTransient) { + { + test::MockMediaSession media_session; + RequestAudioFocus(&media_session, mojom::AudioFocusType::kGainTransient); + EXPECT_EQ(1, GetTransientCount()); + } + + // If the media session is destroyed without abandoning audio focus we do not + // know until we next interact with the manager. + test::MockMediaSession media_session; + RequestAudioFocus(&media_session, mojom::AudioFocusType::kGain); + EXPECT_EQ(0, GetTransientCount()); +} + +TEST_P(AudioFocusManagerTest, MediaSessionDestroyed_ReleasesTransientMayDucks) { + { + test::MockMediaSession media_session; + RequestAudioFocus(&media_session, + mojom::AudioFocusType::kGainTransientMayDuck); + EXPECT_EQ(1, GetTransientMaybeDuckCount()); + } + + // If the media session is destroyed without abandoning audio focus we do not + // know until we next interact with the manager. + test::MockMediaSession media_session; + RequestAudioFocus(&media_session, mojom::AudioFocusType::kGain); + EXPECT_EQ(0, GetTransientMaybeDuckCount()); +} + +TEST_P(AudioFocusManagerTest, GainDucksForceDuck) { + test::MockMediaSession media_session_1(true /* force_duck */); + test::MockMediaSession media_session_2; + + RequestAudioFocus(&media_session_1, mojom::AudioFocusType::kGain); + + AudioFocusManager::RequestId request_id_2 = + RequestAudioFocus(&media_session_2, mojom::AudioFocusType::kGain); + + EXPECT_EQ(request_id_2, GetAudioFocusedSession()); + EXPECT_EQ(GetStateFromParam(mojom::MediaSessionInfo::SessionState::kDucking), + GetState(&media_session_1)); +} + +TEST_P(AudioFocusManagerTest, + AbandoningGainFocusRevokesTopMostForceDuckSession) { + test::MockMediaSession media_session_1(true /* force_duck */); + test::MockMediaSession media_session_2; + test::MockMediaSession media_session_3; + + AudioFocusManager::RequestId request_id_1 = + RequestAudioFocus(&media_session_1, mojom::AudioFocusType::kGain); + AudioFocusManager::RequestId request_id_2 = + RequestAudioFocus(&media_session_2, mojom::AudioFocusType::kGain); + + AudioFocusManager::RequestId request_id_3 = + RequestAudioFocus(&media_session_3, mojom::AudioFocusType::kGain); + EXPECT_EQ(request_id_3, GetAudioFocusedSession()); + + EXPECT_EQ( + GetStateFromParam(mojom::MediaSessionInfo::SessionState::kSuspended), + GetState(&media_session_2)); + EXPECT_EQ(GetStateFromParam(mojom::MediaSessionInfo::SessionState::kDucking), + GetState(&media_session_1)); + + media_session_3.AbandonAudioFocusFromClient(); + EXPECT_EQ(GetParam() ? request_id_1 : request_id_2, GetAudioFocusedSession()); +} + +TEST_P(AudioFocusManagerTest, AudioFocusObserver_RequestNoop) { + test::MockMediaSession media_session; + AudioFocusManager::RequestId request_id; + + { + std::unique_ptr<test::TestAudioFocusObserver> observer = CreateObserver(); + request_id = + RequestAudioFocus(&media_session, mojom::AudioFocusType::kGain); + + EXPECT_EQ(request_id, GetAudioFocusedSession()); + EXPECT_EQ(mojom::AudioFocusType::kGain, observer->focus_gained_type()); + EXPECT_FALSE(observer->focus_gained_session_.is_null()); + } + + { + std::unique_ptr<test::TestAudioFocusObserver> observer = CreateObserver(); + RequestAudioFocus(&media_session, mojom::AudioFocusType::kGain); + + EXPECT_EQ(request_id, GetAudioFocusedSession()); + EXPECT_TRUE(observer->focus_gained_session_.is_null()); + } +} + +TEST_P(AudioFocusManagerTest, AudioFocusObserver_TransientMayDuck) { + test::MockMediaSession media_session; + + { + std::unique_ptr<test::TestAudioFocusObserver> observer = CreateObserver(); + RequestAudioFocus(&media_session, + mojom::AudioFocusType::kGainTransientMayDuck); + + EXPECT_EQ(1, GetTransientMaybeDuckCount()); + EXPECT_EQ(mojom::AudioFocusType::kGainTransientMayDuck, + observer->focus_gained_type()); + EXPECT_FALSE(observer->focus_gained_session_.is_null()); + } + + { + std::unique_ptr<test::TestAudioFocusObserver> observer = CreateObserver(); + media_session.AbandonAudioFocusFromClient(); + + EXPECT_EQ(0, GetTransientMaybeDuckCount()); + EXPECT_TRUE(observer->focus_lost_session_.Equals( + test::GetMediaSessionInfoSync(&media_session))); + } +} + +TEST_P(AudioFocusManagerTest, GetDebugInfo) { + test::MockMediaSession media_session; + AudioFocusManager::RequestId request_id = + RequestAudioFocus(&media_session, mojom::AudioFocusType::kGain); + + mojom::MediaSessionDebugInfoPtr debug_info = GetDebugInfo(request_id); + EXPECT_FALSE(debug_info->name.empty()); + EXPECT_FALSE(debug_info->owner.empty()); + EXPECT_FALSE(debug_info->state.empty()); +} + +TEST_P(AudioFocusManagerTest, GetDebugInfo_BadRequestId) { + mojom::MediaSessionDebugInfoPtr debug_info = + GetDebugInfo(base::UnguessableToken::Create()); + EXPECT_TRUE(debug_info->name.empty()); +} + +TEST_P(AudioFocusManagerTest, + RequestAudioFocusTransient_FromGainWhileSuspended) { + test::MockMediaSession media_session_1; + test::MockMediaSession media_session_2; + + RequestAudioFocus(&media_session_1, mojom::AudioFocusType::kGain); + EXPECT_EQ(0, GetTransientCount()); + EXPECT_EQ(mojom::MediaSessionInfo::SessionState::kActive, + GetState(&media_session_1)); + + RequestAudioFocus(&media_session_2, mojom::AudioFocusType::kGainTransient); + EXPECT_EQ(1, GetTransientCount()); + EXPECT_EQ( + GetStateFromParam(mojom::MediaSessionInfo::SessionState::kSuspended), + GetState(&media_session_1)); + + RequestAudioFocus(&media_session_1, mojom::AudioFocusType::kGainTransient); + EXPECT_EQ(2, GetTransientCount()); + EXPECT_EQ(mojom::MediaSessionInfo::SessionState::kActive, + GetState(&media_session_1)); +} + +TEST_P(AudioFocusManagerTest, + RequestAudioFocusTransientMayDuck_FromGainWhileSuspended) { + test::MockMediaSession media_session_1; + test::MockMediaSession media_session_2; + + RequestAudioFocus(&media_session_1, mojom::AudioFocusType::kGain); + EXPECT_EQ(0, GetTransientCount()); + EXPECT_EQ(0, GetTransientMaybeDuckCount()); + EXPECT_EQ(mojom::MediaSessionInfo::SessionState::kActive, + GetState(&media_session_1)); + + RequestAudioFocus(&media_session_2, mojom::AudioFocusType::kGainTransient); + EXPECT_EQ(1, GetTransientCount()); + EXPECT_EQ(0, GetTransientMaybeDuckCount()); + EXPECT_EQ( + GetStateFromParam(mojom::MediaSessionInfo::SessionState::kSuspended), + GetState(&media_session_1)); + + RequestAudioFocus(&media_session_1, + mojom::AudioFocusType::kGainTransientMayDuck); + EXPECT_EQ(1, GetTransientCount()); + EXPECT_EQ(1, GetTransientMaybeDuckCount()); + EXPECT_EQ(mojom::MediaSessionInfo::SessionState::kActive, + GetState(&media_session_1)); +} + +TEST_P(AudioFocusManagerTest, SourceName_AssociatedWithBinding) { + SetSourceName(kExampleSourceName); + + mojom::AudioFocusManagerPtr new_ptr = CreateAudioFocusManagerPtr(); + new_ptr->SetSourceName(kExampleSourceName2); + new_ptr.FlushForTesting(); + + test::MockMediaSession media_session; + RequestAudioFocus(&media_session, mojom::AudioFocusType::kGain); + EXPECT_EQ(kExampleSourceName, GetSourceNameForLastRequest()); +} + +TEST_P(AudioFocusManagerTest, SourceName_Empty) { + test::MockMediaSession media_session; + RequestAudioFocus(&media_session, mojom::AudioFocusType::kGain); + EXPECT_TRUE(GetSourceNameForLastRequest().empty()); +} + +TEST_P(AudioFocusManagerTest, SourceName_Updated) { + SetSourceName(kExampleSourceName); + + test::MockMediaSession media_session; + RequestAudioFocus(&media_session, mojom::AudioFocusType::kGain); + EXPECT_EQ(kExampleSourceName, GetSourceNameForLastRequest()); + + SetSourceName(kExampleSourceName2); + EXPECT_EQ(kExampleSourceName, GetSourceNameForLastRequest()); +} + +TEST_P(AudioFocusManagerTest, RecordUmaMetrics) { + EXPECT_EQ(0, GetAudioFocusHistogramCount()); + + SetSourceName(kExampleSourceName); + test::MockMediaSession media_session; + RequestAudioFocus(&media_session, mojom::AudioFocusType::kGainTransient); + + { + std::unique_ptr<base::HistogramSamples> samples( + GetHistogramSamplesSinceTestStart( + "Media.Session.AudioFocus.Request.Test")); + EXPECT_EQ(1, samples->TotalCount()); + EXPECT_EQ(1, samples->GetCount(static_cast<base::HistogramBase::Sample>( + AudioFocusManagerMetricsHelper::AudioFocusRequestSource:: + kInitial))); + } + + { + std::unique_ptr<base::HistogramSamples> samples( + GetHistogramSamplesSinceTestStart( + "Media.Session.AudioFocus.Type.Test")); + EXPECT_EQ(1, samples->TotalCount()); + EXPECT_EQ( + 1, + samples->GetCount(static_cast<base::HistogramBase::Sample>( + AudioFocusManagerMetricsHelper::AudioFocusType::kGainTransient))); + } + + EXPECT_EQ(2, GetAudioFocusHistogramCount()); + + RequestAudioFocus(&media_session, mojom::AudioFocusType::kGain); + + { + std::unique_ptr<base::HistogramSamples> samples( + GetHistogramSamplesSinceTestStart( + "Media.Session.AudioFocus.Request.Test")); + EXPECT_EQ(2, samples->TotalCount()); + EXPECT_EQ( + 1, + samples->GetCount(static_cast<base::HistogramBase::Sample>( + AudioFocusManagerMetricsHelper::AudioFocusRequestSource::kUpdate))); + } + + { + std::unique_ptr<base::HistogramSamples> samples( + GetHistogramSamplesSinceTestStart( + "Media.Session.AudioFocus.Type.Test")); + EXPECT_EQ(2, samples->TotalCount()); + EXPECT_EQ(1, samples->GetCount(static_cast<base::HistogramBase::Sample>( + AudioFocusManagerMetricsHelper::AudioFocusType::kGain))); + } + + EXPECT_EQ(2, GetAudioFocusHistogramCount()); + + media_session.AbandonAudioFocusFromClient(); + + { + std::unique_ptr<base::HistogramSamples> samples( + GetHistogramSamplesSinceTestStart( + "Media.Session.AudioFocus.Abandon.Test")); + EXPECT_EQ(1, samples->TotalCount()); + EXPECT_EQ( + 1, samples->GetCount(static_cast<base::HistogramBase::Sample>( + AudioFocusManagerMetricsHelper::AudioFocusAbandonSource::kAPI))); + } + + EXPECT_EQ(3, GetAudioFocusHistogramCount()); +} + +TEST_P(AudioFocusManagerTest, RecordUmaMetrics_ConnectionError) { + SetSourceName(kExampleSourceName); + + { + test::MockMediaSession media_session; + RequestAudioFocus(&media_session, mojom::AudioFocusType::kGain); + } + + test::MockMediaSession media_session; + RequestAudioFocus(&media_session, mojom::AudioFocusType::kGain); + + { + std::unique_ptr<base::HistogramSamples> samples( + GetHistogramSamplesSinceTestStart( + "Media.Session.AudioFocus.Abandon.Test")); + EXPECT_EQ(1, samples->TotalCount()); + EXPECT_EQ(1, samples->GetCount(static_cast<base::HistogramBase::Sample>( + AudioFocusManagerMetricsHelper::AudioFocusAbandonSource:: + kConnectionError))); + } +} + +TEST_P(AudioFocusManagerTest, RecordUmaMetrics_NoSourceName) { + test::MockMediaSession media_session; + RequestAudioFocus(&media_session, mojom::AudioFocusType::kGain); + + EXPECT_EQ(0, GetAudioFocusHistogramCount()); +} + +} // namespace media_session diff --git a/chromium/services/media_session/manifest.json b/chromium/services/media_session/manifest.json index 6ee22022d27..ced68944766 100644 --- a/chromium/services/media_session/manifest.json +++ b/chromium/services/media_session/manifest.json @@ -8,7 +8,8 @@ "service_manager:connector": { "provides": { "app": [ - "media_session.mojom.AudioFocus" + "media_session.mojom.AudioFocusManager", + "media_session.mojom.AudioFocusManagerDebug" ], "tests": [ "*" ] } diff --git a/chromium/services/media_session/media_session_service.cc b/chromium/services/media_session/media_session_service.cc index 4e0ebe0f186..23aa1bb111d 100644 --- a/chromium/services/media_session/media_session_service.cc +++ b/chromium/services/media_session/media_session_service.cc @@ -4,6 +4,8 @@ #include "services/media_session/media_session_service.h" +#include "base/bind.h" +#include "services/media_session/audio_focus_manager.h" #include "services/service_manager/public/cpp/service_context.h" namespace media_session { @@ -12,14 +14,24 @@ std::unique_ptr<service_manager::Service> MediaSessionService::Create() { return std::make_unique<MediaSessionService>(); } -MediaSessionService::MediaSessionService() = default; +MediaSessionService::MediaSessionService() + : audio_focus_manager_(std::make_unique<AudioFocusManager>()) {} -MediaSessionService::~MediaSessionService() = default; +MediaSessionService::~MediaSessionService() { + audio_focus_manager_->CloseAllMojoObjects(); +} void MediaSessionService::OnStart() { - DLOG(ERROR) << "start"; ref_factory_.reset(new service_manager::ServiceContextRefFactory( context()->CreateQuitClosure())); + + registry_.AddInterface( + base::BindRepeating(&AudioFocusManager::BindToInterface, + base::Unretained(audio_focus_manager_.get()))); + + registry_.AddInterface( + base::BindRepeating(&AudioFocusManager::BindToDebugInterface, + base::Unretained(audio_focus_manager_.get()))); } void MediaSessionService::OnBindInterface( diff --git a/chromium/services/media_session/media_session_service.h b/chromium/services/media_session/media_session_service.h index 6603cfde0a2..86575c999c5 100644 --- a/chromium/services/media_session/media_session_service.h +++ b/chromium/services/media_session/media_session_service.h @@ -16,6 +16,8 @@ namespace media_session { +class AudioFocusManager; + class MediaSessionService : public service_manager::Service { public: MediaSessionService(); @@ -36,6 +38,8 @@ class MediaSessionService : public service_manager::Service { } private: + std::unique_ptr<AudioFocusManager> audio_focus_manager_; + service_manager::BinderRegistry registry_; std::unique_ptr<service_manager::ServiceContextRefFactory> ref_factory_; base::WeakPtrFactory<MediaSessionService> weak_factory_{this}; diff --git a/chromium/services/media_session/mock_media_session.cc b/chromium/services/media_session/mock_media_session.cc new file mode 100644 index 00000000000..d96cf3fea96 --- /dev/null +++ b/chromium/services/media_session/mock_media_session.cc @@ -0,0 +1,140 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "services/media_session/mock_media_session.h" + +#include <utility> + +#include "services/media_session/public/cpp/switches.h" + +namespace media_session { +namespace test { + +MockMediaSession::MockMediaSession() = default; + +MockMediaSession::MockMediaSession(bool force_duck) : force_duck_(force_duck) {} + +MockMediaSession::~MockMediaSession() {} + +void MockMediaSession::Suspend(SuspendType suspend_type) { + DCHECK_EQ(SuspendType::kSystem, suspend_type); + SetState(mojom::MediaSessionInfo::SessionState::kSuspended); +} + +void MockMediaSession::Resume(SuspendType suspend_type) { + DCHECK_EQ(SuspendType::kSystem, suspend_type); + SetState(mojom::MediaSessionInfo::SessionState::kActive); +} + +void MockMediaSession::StartDucking() { + is_ducking_ = true; + NotifyObservers(); +} + +void MockMediaSession::StopDucking() { + is_ducking_ = false; + NotifyObservers(); +} + +void MockMediaSession::GetMediaSessionInfo( + GetMediaSessionInfoCallback callback) { + std::move(callback).Run(GetMediaSessionInfoSync()); +} + +void MockMediaSession::AddObserver(mojom::MediaSessionObserverPtr observer) {} + +void MockMediaSession::GetDebugInfo(GetDebugInfoCallback callback) { + mojom::MediaSessionDebugInfoPtr debug_info( + mojom::MediaSessionDebugInfo::New()); + + debug_info->name = "name"; + debug_info->owner = "owner"; + debug_info->state = "state"; + + std::move(callback).Run(std::move(debug_info)); +} + +void MockMediaSession::AbandonAudioFocusFromClient() { + DCHECK(afr_client_.is_bound()); + afr_client_->AbandonAudioFocus(); + afr_client_.FlushForTesting(); + afr_client_.reset(); +} + +base::UnguessableToken MockMediaSession::GetRequestIdFromClient() { + DCHECK(afr_client_.is_bound()); + base::UnguessableToken id = base::UnguessableToken::Null(); + + afr_client_->GetRequestId(base::BindOnce( + [](base::UnguessableToken* id, + const base::UnguessableToken& received_id) { *id = received_id; }, + &id)); + + afr_client_.FlushForTesting(); + DCHECK_NE(base::UnguessableToken::Null(), id); + return id; +} + +base::UnguessableToken MockMediaSession::RequestAudioFocusFromService( + mojom::AudioFocusManagerPtr& service, + mojom::AudioFocusType audio_focus_type) { + bool result; + base::OnceClosure callback = + base::BindOnce([](bool* out_result) { *out_result = true; }, &result); + + if (afr_client_.is_bound()) { + // Request audio focus through the existing request. + afr_client_->RequestAudioFocus(GetMediaSessionInfoSync(), audio_focus_type, + std::move(callback)); + + afr_client_.FlushForTesting(); + } else { + // Build a new audio focus request. + mojom::MediaSessionPtr media_session; + bindings_.AddBinding(this, mojo::MakeRequest(&media_session)); + + service->RequestAudioFocus( + mojo::MakeRequest(&afr_client_), std::move(media_session), + GetMediaSessionInfoSync(), audio_focus_type, std::move(callback)); + + service.FlushForTesting(); + } + + // If the audio focus was granted then we should set the session state to + // active. + if (result) + SetState(mojom::MediaSessionInfo::SessionState::kActive); + + return GetRequestIdFromClient(); +} + +mojom::MediaSessionInfo::SessionState MockMediaSession::GetState() const { + return GetMediaSessionInfoSync()->state; +} + +void MockMediaSession::FlushForTesting() { + afr_client_.FlushForTesting(); +} + +void MockMediaSession::SetState(mojom::MediaSessionInfo::SessionState state) { + state_ = state; + NotifyObservers(); +} + +void MockMediaSession::NotifyObservers() { + if (afr_client_.is_bound()) + afr_client_->MediaSessionInfoChanged(GetMediaSessionInfoSync()); +} + +mojom::MediaSessionInfoPtr MockMediaSession::GetMediaSessionInfoSync() const { + mojom::MediaSessionInfoPtr info(mojom::MediaSessionInfo::New()); + info->force_duck = force_duck_; + info->state = state_; + if (is_ducking_) + info->state = mojom::MediaSessionInfo::SessionState::kDucking; + return info; +} + +} // namespace test +} // namespace media_session diff --git a/chromium/services/media_session/mock_media_session.h b/chromium/services/media_session/mock_media_session.h new file mode 100644 index 00000000000..6b6a6072d9f --- /dev/null +++ b/chromium/services/media_session/mock_media_session.h @@ -0,0 +1,72 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef SERVICES_MEDIA_SESSION_MOCK_MEDIA_SESSION_H_ +#define SERVICES_MEDIA_SESSION_MOCK_MEDIA_SESSION_H_ + +#include "mojo/public/cpp/bindings/binding_set.h" +#include "services/media_session/public/mojom/audio_focus.mojom.h" +#include "services/media_session/public/mojom/media_session.mojom.h" + +namespace base { +class UnguessableToken; +} // namespace base + +namespace media_session { +namespace test { + +// A mock MediaSession that can be used for interacting with the Media Session +// service during tests. +class MockMediaSession : public mojom::MediaSession { + public: + MockMediaSession(); + explicit MockMediaSession(bool force_duck); + + ~MockMediaSession() override; + + // mojom::MediaSession overrides. + void Suspend(SuspendType) override; + void Resume(SuspendType) override; + void StartDucking() override; + void StopDucking() override; + void GetMediaSessionInfo(GetMediaSessionInfoCallback) override; + void AddObserver(mojom::MediaSessionObserverPtr) override; + void GetDebugInfo(GetDebugInfoCallback) override; + + void AbandonAudioFocusFromClient(); + base::UnguessableToken GetRequestIdFromClient(); + + base::UnguessableToken RequestAudioFocusFromService( + mojom::AudioFocusManagerPtr&, + mojom::AudioFocusType); + + mojom::MediaSessionInfo::SessionState GetState() const; + + mojom::AudioFocusRequestClient* audio_focus_request() const { + return afr_client_.get(); + } + void FlushForTesting(); + + private: + void SetState(mojom::MediaSessionInfo::SessionState); + void NotifyObservers(); + mojom::MediaSessionInfoPtr GetMediaSessionInfoSync() const; + + mojom::AudioFocusRequestClientPtr afr_client_; + + const bool force_duck_ = false; + bool is_ducking_ = false; + + mojom::MediaSessionInfo::SessionState state_ = + mojom::MediaSessionInfo::SessionState::kInactive; + + mojo::BindingSet<mojom::MediaSession> bindings_; + + DISALLOW_COPY_AND_ASSIGN(MockMediaSession); +}; + +} // namespace test +} // namespace media_session + +#endif // SERVICES_MEDIA_SESSION_MOCK_MEDIA_SESSION_H_ diff --git a/chromium/services/media_session/public/cpp/BUILD.gn b/chromium/services/media_session/public/cpp/BUILD.gn index 21b04ac5f34..0f6b88d789b 100644 --- a/chromium/services/media_session/public/cpp/BUILD.gn +++ b/chromium/services/media_session/public/cpp/BUILD.gn @@ -14,5 +14,7 @@ component("cpp") { "//base", ] + configs += [ "//build/config/compiler:wexit_time_destructors" ] + defines = [ "IS_MEDIA_SESSION_CPP_IMPL" ] } diff --git a/chromium/services/media_session/public/cpp/switches.cc b/chromium/services/media_session/public/cpp/switches.cc index 17474818350..9905fe10b88 100644 --- a/chromium/services/media_session/public/cpp/switches.cc +++ b/chromium/services/media_session/public/cpp/switches.cc @@ -12,13 +12,19 @@ namespace switches { // Enable a internal audio focus management between tabs in such a way that two // tabs can't play on top of each other. -// The allowed values are: "" (empty) or |switches::kEnableAudioFocusDuckFlash|. +// The allowed values are: "" (empty) or |kEnableAudioFocusDuckFlash| +// or |kEnableAudioFocusNoEnforce|. const char kEnableAudioFocus[] = "enable-audio-focus"; // This value is used as an option for |kEnableAudioFocus|. Flash will // be ducked when losing audio focus. const char kEnableAudioFocusDuckFlash[] = "duck-flash"; +// This value is used as an option for |kEnableAudioFocus|. If enabled then +// single media session audio focus will not be enforced. This should be used by +// embedders that wish to track audio focus but without the enforcement. +const char kEnableAudioFocusNoEnforce[] = "no-enforce"; + #if !defined(OS_ANDROID) // Turns on the internal media session backend. This should be used by embedders // that want to control the media playback with the media session interfaces. @@ -38,6 +44,12 @@ bool IsAudioFocusDuckFlashEnabled() { switches::kEnableAudioFocusDuckFlash; } +bool IsAudioFocusEnforcementEnabled() { + return base::CommandLine::ForCurrentProcess()->GetSwitchValueASCII( + switches::kEnableAudioFocus) != + switches::kEnableAudioFocusNoEnforce; +} + bool IsMediaSessionEnabled() { // Media session is enabled on Android and Chrome OS to allow control of media // players as needed. diff --git a/chromium/services/media_session/public/cpp/switches.h b/chromium/services/media_session/public/cpp/switches.h index 7d531beca33..7fadac8f393 100644 --- a/chromium/services/media_session/public/cpp/switches.h +++ b/chromium/services/media_session/public/cpp/switches.h @@ -14,6 +14,8 @@ namespace switches { COMPONENT_EXPORT(MEDIA_SESSION_CPP) extern const char kEnableAudioFocus[]; COMPONENT_EXPORT(MEDIA_SESSION_CPP) extern const char kEnableAudioFocusDuckFlash[]; +COMPONENT_EXPORT(MEDIA_SESSION_CPP) +extern const char kEnableAudioFocusNoEnforce[]; #if !defined(OS_ANDROID) COMPONENT_EXPORT(MEDIA_SESSION_CPP) @@ -28,6 +30,10 @@ COMPONENT_EXPORT(MEDIA_SESSION_CPP) bool IsAudioFocusEnabled(); // audio focus duck flash should be enabled. COMPONENT_EXPORT(MEDIA_SESSION_CPP) bool IsAudioFocusDuckFlashEnabled(); +// Based on the command line of the current process, determine if +// audio focus enforcement should be enabled. +COMPONENT_EXPORT(MEDIA_SESSION_CPP) bool IsAudioFocusEnforcementEnabled(); + COMPONENT_EXPORT(MEDIA_SESSION_CPP) bool IsMediaSessionEnabled(); } // namespace media_session diff --git a/chromium/services/media_session/public/cpp/test/BUILD.gn b/chromium/services/media_session/public/cpp/test/BUILD.gn new file mode 100644 index 00000000000..2b515f04dbf --- /dev/null +++ b/chromium/services/media_session/public/cpp/test/BUILD.gn @@ -0,0 +1,19 @@ +# Copyright 2018 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +component("test_support") { + output_name = "media_session_test_support_cpp" + + sources = [ + "audio_focus_test_util.cc", + "audio_focus_test_util.h", + ] + + deps = [ + "//base", + "//services/media_session/public/mojom", + ] + + defines = [ "IS_MEDIA_SESSION_TEST_SUPPORT_CPP_IMPL" ] +} diff --git a/chromium/services/media_session/public/mojom/BUILD.gn b/chromium/services/media_session/public/mojom/BUILD.gn index 0a3741eae4d..17ed45bb7ba 100644 --- a/chromium/services/media_session/public/mojom/BUILD.gn +++ b/chromium/services/media_session/public/mojom/BUILD.gn @@ -10,6 +10,7 @@ mojom_component("mojom") { sources = [ "audio_focus.mojom", + "constants.mojom", "media_session.mojom", ] diff --git a/chromium/services/media_session/public/mojom/audio_focus.mojom b/chromium/services/media_session/public/mojom/audio_focus.mojom index d6d88c7245d..9ad4b27aae2 100644 --- a/chromium/services/media_session/public/mojom/audio_focus.mojom +++ b/chromium/services/media_session/public/mojom/audio_focus.mojom @@ -4,25 +4,97 @@ module media_session.mojom; +import "mojo/public/mojom/base/unguessable_token.mojom"; import "services/media_session/public/mojom/media_session.mojom"; +// Next MinVersion: 4 + // These are the different types of audio focus that can be requested. +[Extensible] enum AudioFocusType { // Request permanent audio focus when you plan to play audio for the - // foreseeable future (for example, when playing music) and you expect - // the previous holder of audio focus to stop playing. + // foreseeable future (for example, when playing music) and you expect the + // previous holder of audio focus to stop playing. kGain, - // Request transient focus when you expect to play audio for only a - // short time and you expect the previous holder to pause playing. + // Request transient focus with ducking to indicate that you expect to play + // audio for only a short time and that it's OK for the previous focus owner + // to keep playing if it "ducks" (lowers) its audio output. kGainTransientMayDuck, + + // Request transient focus when you expect to play audio for only a short + // time and you expect the previous holder to pause playing. + kGainTransient, +}; + +// Contains information about |MediaSessions| that have requested audio focus +// and their current requested type. +struct AudioFocusRequestState { + MediaSessionInfo session_info; + AudioFocusType audio_focus_type; + + [MinVersion=2] string? source_name; + [MinVersion=3] mojo_base.mojom.UnguessableToken? request_id; }; // The observer for audio focus events. +// Next Method ID: 2 interface AudioFocusObserver { // The given |session| gained audio focus with the specified |type|. - OnFocusGained(MediaSession session, AudioFocusType type); + OnFocusGained@0(MediaSessionInfo session, AudioFocusType type); // The given |session| lost audio focus. - OnFocusLost(MediaSession session); + OnFocusLost@1(MediaSessionInfo session); +}; + +// Controls audio focus for an associated request. +// Next Method ID: 5 +// Deprecated method IDs: 3 +interface AudioFocusRequestClient { + // Requests updated audio focus for this request. If the request was granted + // then the callback will resolve. + RequestAudioFocus@0(MediaSessionInfo session_info, AudioFocusType type) => (); + + // Abandons audio focus for this request. + AbandonAudioFocus@1(); + + // Notifies the audio focus backend when the associated session info changes. + MediaSessionInfoChanged@2(MediaSessionInfo session_info); + + // Retrieve a unique ID for this request. + [MinVersion=3] GetRequestId@4() + => (mojo_base.mojom.UnguessableToken request_id); +}; + +// Controls audio focus across the entire system. +// Next Method ID: 4 +interface AudioFocusManager { + // Requests audio focus with |type| for the |media_session| with + // |session_info|. Media sessions should provide a |request| that will + // provide an AudioFocusRequestClient that can be used to control this + // request. The callback will resolve when audio focus has been granted. + RequestAudioFocus@0(AudioFocusRequestClient& client, + MediaSession media_session, + MediaSessionInfo session_info, + AudioFocusType type) => (); + + // Gets all the information about all |MediaSessions| that have requested + // audio focus and their current requested type. + GetFocusRequests@1() => (array<AudioFocusRequestState> requests); + + // Adds observers that receive audio focus events. + AddObserver@2(AudioFocusObserver observer); + + // Associates a name with this binding. This will be associated with all + // audio focus requests made with this binding. It will also be used for + // associating metrics to a source. If the source name is updated then + // the audio focus requests will retain the previous source name. + [MinVersion=2] SetSourceName@3(string name); +}; + +// Provides debug information about audio focus requests. +interface AudioFocusManagerDebug { + // Gets debugging information for a |MediaSession| with |request_id|. + GetDebugInfoForRequest(mojo_base.mojom.UnguessableToken request_id) + => (MediaSessionDebugInfo debug_info); }; diff --git a/chromium/services/media_session/public/mojom/constants.mojom b/chromium/services/media_session/public/mojom/constants.mojom new file mode 100644 index 00000000000..a2f262ded5e --- /dev/null +++ b/chromium/services/media_session/public/mojom/constants.mojom @@ -0,0 +1,7 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +module media_session.mojom; + +const string kServiceName = "media_session"; diff --git a/chromium/services/media_session/public/mojom/media_session.mojom b/chromium/services/media_session/public/mojom/media_session.mojom index 5624599df2c..25901a221a1 100644 --- a/chromium/services/media_session/public/mojom/media_session.mojom +++ b/chromium/services/media_session/public/mojom/media_session.mojom @@ -4,8 +4,89 @@ module media_session.mojom; +// Next MinVersion: 1 + +// Contains state information about a MediaSession. +struct MediaSessionInfo { + [Extensible] + enum SessionState { + // The MediaSession is currently playing media. + kActive, + + // The MediaSession is currently playing at a reduced volume (ducking). + kDucking, + + // The MediaSession is currently paused. + kSuspended, + + // The MediaSession is not currently playing media. + kInactive, + }; + + // The current state of the MediaSession. + SessionState state; + + // If true then we will always duck this MediaSession instead of suspending. + bool force_duck; +}; + +// Contains debugging information about a MediaSession. This will be displayed +// on the Media Internals WebUI. +struct MediaSessionDebugInfo { + // A unique name for the MediaSession. + string name; + + // The owner of the MediaSession. + string owner; + + // State information stored in a string e.g. Ducked. + string state; +}; + +// The observer for observing media session events. +// Next Method ID: 1 +interface MediaSessionObserver { + // The info associated with the session changed. + MediaSessionInfoChanged@0(MediaSessionInfo info); +}; + // A MediaSession manages the media session and audio focus for a given // WebContents or ARC app. // TODO(https://crbug.com/875004): migrate media session from content/public // to mojo. -interface MediaSession {}; +// Next Method ID: 6 +interface MediaSession { + [Extensible] + enum SuspendType { + // Suspended by the system because a transient sound needs to be played. + kSystem, + // Suspended by the UI. + kUI, + // Suspended by the page via script or user interaction. + kContent, + }; + + // Returns information about the MediaSession. + GetMediaSessionInfo@0() => (MediaSessionInfo info); + + // Returns debug information about the MediaSession. + GetDebugInfo@1() => (MediaSessionDebugInfo info); + + // Let the media session start ducking such that the volume multiplier + // is reduced. + StartDucking@2(); + + // Let the media session stop ducking such that the volume multiplier is + // recovered. + StopDucking@3(); + + // Suspend the media session. + // |type| represents the origin of the request. + Suspend@4(SuspendType suspend_type); + + // Resume the media session. + // |type| represents the origin of the request. + Resume@5(SuspendType suspend_type); + + AddObserver@6(MediaSessionObserver observer); +}; diff --git a/chromium/services/metrics/BUILD.gn b/chromium/services/metrics/BUILD.gn index a3984985ee9..a6c81a755f9 100644 --- a/chromium/services/metrics/BUILD.gn +++ b/chromium/services/metrics/BUILD.gn @@ -17,6 +17,8 @@ source_set("metrics") { "ukm_recorder_interface.h", ] + configs += [ "//build/config/compiler:wexit_time_destructors" ] + deps = [ "//mojo/public/cpp/bindings", "//services/metrics/public/cpp:metrics_cpp", diff --git a/chromium/services/metrics/public/cpp/BUILD.gn b/chromium/services/metrics/public/cpp/BUILD.gn index be5c845f2bf..154b9dcb1f9 100644 --- a/chromium/services/metrics/public/cpp/BUILD.gn +++ b/chromium/services/metrics/public/cpp/BUILD.gn @@ -25,6 +25,8 @@ component("metrics_cpp") { "ukm_source_id.h", ] + configs += [ "//build/config/compiler:wexit_time_destructors" ] + defines = [ "METRICS_IMPLEMENTATION" ] public_deps = [ diff --git a/chromium/services/network/BUILD.gn b/chromium/services/network/BUILD.gn index b6e3c13ad9e..89977bd7320 100644 --- a/chromium/services/network/BUILD.gn +++ b/chromium/services/network/BUILD.gn @@ -2,13 +2,14 @@ # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. +import("//build/config/jumbo.gni") import("//mojo/public/tools/bindings/mojom.gni") import("//services/catalog/public/tools/catalog.gni") import("//services/service_manager/public/cpp/service.gni") import("//services/service_manager/public/service_manifest.gni") import("//services/service_manager/public/tools/test/service_test.gni") -component("network_service") { +jumbo_component("network_service") { sources = [ "cert_verifier_config_type_converter.cc", "cert_verifier_config_type_converter.h", @@ -50,8 +51,10 @@ component("network_service") { "keepalive_statistics_recorder.h", "loader_util.cc", "loader_util.h", - "mojo_net_log.cc", - "mojo_net_log.h", + "net_log_capture_mode_type_converter.cc", + "net_log_capture_mode_type_converter.h", + "net_log_exporter.cc", + "net_log_exporter.h", "network_change_manager.cc", "network_change_manager.h", "network_context.cc", @@ -66,6 +69,8 @@ component("network_service") { "network_service.h", "network_service_network_delegate.cc", "network_service_network_delegate.h", + "network_service_proxy_delegate.cc", + "network_service_proxy_delegate.h", "network_usage_accumulator.cc", "network_usage_accumulator.h", "p2p/socket.cc", @@ -114,6 +119,8 @@ component("network_service") { "ssl_config_service_mojo.h", "ssl_config_type_converter.cc", "ssl_config_type_converter.h", + "tcp_bound_socket.cc", + "tcp_bound_socket.h", "tcp_connected_socket.cc", "tcp_connected_socket.h", "tcp_server_socket.cc", @@ -134,6 +141,8 @@ component("network_service") { "throttling/throttling_upload_data_stream.h", "tls_client_socket.cc", "tls_client_socket.h", + "tls_socket_factory.cc", + "tls_socket_factory.h", "transitional_url_loader_factory_owner.cc", "transitional_url_loader_factory_owner.h", "udp_socket.cc", @@ -161,6 +170,8 @@ component("network_service") { ] } + configs += [ "//build/config/compiler:wexit_time_destructors" ] + deps = [ "//base", "//components/certificate_transparency", @@ -182,6 +193,7 @@ component("network_service") { "//services/service_manager/sandbox:sandbox", "//third_party/webrtc/media:rtc_media_base", "//third_party/webrtc/rtc_base", + "//third_party/webrtc/rtc_base:timeutils", "//third_party/webrtc_overrides", "//third_party/webrtc_overrides:init_webrtc", "//url", @@ -205,17 +217,13 @@ component("network_service") { deps += [ "//sandbox/win:sandbox" ] } - # TODO(sdefresne): This depends on net's enable_net_mojo getting turned on for - # iOS, which depends on net_with_v8 as well. http://crbug.com/803149 - if (!is_ios) { - sources += [ - "proxy_resolver_factory_mojo.cc", - "proxy_resolver_factory_mojo.h", - "proxy_service_mojo.cc", - "proxy_service_mojo.h", - ] - deps += [ "//net/dns:mojo_service" ] - } + sources += [ + "proxy_resolver_factory_mojo.cc", + "proxy_resolver_factory_mojo.h", + "proxy_service_mojo.cc", + "proxy_service_mojo.h", + ] + deps += [ "//net/dns:mojo_service" ] defines = [ "IS_NETWORK_SERVICE_IMPL" ] @@ -244,7 +252,7 @@ source_set("tests") { "network_change_manager_unittest.cc", "network_context_unittest.cc", "network_quality_estimator_manager_unittest.cc", - "network_service_network_delegate_unittest.cc", + "network_service_proxy_delegate_unittest.cc", "network_service_unittest.cc", "network_usage_accumulator_unittest.cc", "p2p/socket_tcp_server_unittest.cc", @@ -262,6 +270,7 @@ source_set("tests") { "session_cleanup_cookie_store_unittest.cc", "socket_data_pump_unittest.cc", "ssl_config_service_mojo_unittest.cc", + "tcp_bound_socket_unittest.cc", "tcp_socket_unittest.cc", "test/test_url_loader_factory_unittest.cc", "test_chunked_data_pipe_getter.cc", @@ -311,7 +320,7 @@ source_set("tests") { ] } -source_set("test_support") { +jumbo_source_set("test_support") { testonly = true sources = [ diff --git a/chromium/services/network/OWNERS b/chromium/services/network/OWNERS index fe217b595ed..418278801fa 100644 --- a/chromium/services/network/OWNERS +++ b/chromium/services/network/OWNERS @@ -8,7 +8,6 @@ tsepez@chromium.org yhirano@chromium.org per-file cross_origin_read_blocking*=creis@chromium.org -per-file cross_origin_read_blocking*=nick@chromium.org per-file cross_origin_read_blocking*=lukasza@chromium.org per-file expect_ct_reporter*=estark@chromium.org diff --git a/chromium/services/network/chunked_data_pipe_upload_data_stream_unittest.cc b/chromium/services/network/chunked_data_pipe_upload_data_stream_unittest.cc index 3c2cd638478..77aaf02c07b 100644 --- a/chromium/services/network/chunked_data_pipe_upload_data_stream_unittest.cc +++ b/chromium/services/network/chunked_data_pipe_upload_data_stream_unittest.cc @@ -76,8 +76,8 @@ TEST_F(ChunkedDataPipeUploadDataStreamTest, ReadBeforeDataReady) { std::string read_data; while (read_data.size() < kData.size()) { net::TestCompletionCallback callback; - scoped_refptr<net::IOBufferWithSize> io_buffer( - new net::IOBufferWithSize(consumer_read_size)); + auto io_buffer = + base::MakeRefCounted<net::IOBufferWithSize>(consumer_read_size); int result = chunked_upload_stream_->Read( io_buffer.get(), io_buffer->size(), callback.callback()); if (read_data.size() == 0) @@ -91,8 +91,8 @@ TEST_F(ChunkedDataPipeUploadDataStreamTest, ReadBeforeDataReady) { EXPECT_EQ(read_data, kData); } - scoped_refptr<net::IOBufferWithSize> io_buffer( - new net::IOBufferWithSize(consumer_read_size)); + auto io_buffer = + base::MakeRefCounted<net::IOBufferWithSize>(consumer_read_size); net::TestCompletionCallback callback; int result = chunked_upload_stream_->Read( io_buffer.get(), io_buffer->size(), callback.callback()); @@ -121,8 +121,8 @@ TEST_F(ChunkedDataPipeUploadDataStreamTest, ReadAfterDataReady) { std::string read_data; while (read_data.size() < kData.size()) { net::TestCompletionCallback callback; - scoped_refptr<net::IOBufferWithSize> io_buffer( - new net::IOBufferWithSize(consumer_read_size)); + auto io_buffer = + base::MakeRefCounted<net::IOBufferWithSize>(consumer_read_size); int result = chunked_upload_stream_->Read( io_buffer.get(), io_buffer->size(), callback.callback()); ASSERT_LT(0, result); @@ -137,8 +137,8 @@ TEST_F(ChunkedDataPipeUploadDataStreamTest, ReadAfterDataReady) { base::RunLoop().RunUntilIdle(); net::TestCompletionCallback callback; - scoped_refptr<net::IOBufferWithSize> io_buffer( - new net::IOBufferWithSize(consumer_read_size)); + auto io_buffer = + base::MakeRefCounted<net::IOBufferWithSize>(consumer_read_size); EXPECT_EQ(net::OK, chunked_upload_stream_->Read(io_buffer.get(), io_buffer->size(), callback.callback())); @@ -174,8 +174,8 @@ TEST_F(ChunkedDataPipeUploadDataStreamTest, MultipleReadThrough) { std::string read_data; while (read_data.size() < kData.size()) { net::TestCompletionCallback callback; - scoped_refptr<net::IOBufferWithSize> io_buffer( - new net::IOBufferWithSize(kData.size())); + auto io_buffer = + base::MakeRefCounted<net::IOBufferWithSize>(kData.size()); int result = chunked_upload_stream_->Read( io_buffer.get(), io_buffer->size(), callback.callback()); result = callback.GetResult(result); @@ -191,8 +191,7 @@ TEST_F(ChunkedDataPipeUploadDataStreamTest, MultipleReadThrough) { EXPECT_EQ(kData, read_data); net::TestCompletionCallback callback; - scoped_refptr<net::IOBufferWithSize> io_buffer( - new net::IOBufferWithSize(1)); + auto io_buffer = base::MakeRefCounted<net::IOBufferWithSize>(1); int result = chunked_upload_stream_->Read(io_buffer.get(), 1, callback.callback()); EXPECT_EQ(net::OK, callback.GetResult(result)); @@ -232,8 +231,8 @@ TEST_F(ChunkedDataPipeUploadDataStreamTest, std::string read_data; while (read_data.size() < num_bytes_to_read) { net::TestCompletionCallback callback; - scoped_refptr<net::IOBufferWithSize> io_buffer( - new net::IOBufferWithSize(kData.size())); + auto io_buffer = + base::MakeRefCounted<net::IOBufferWithSize>(kData.size()); int result = chunked_upload_stream_->Read( io_buffer.get(), io_buffer->size(), callback.callback()); result = callback.GetResult(result); @@ -253,7 +252,7 @@ TEST_F(ChunkedDataPipeUploadDataStreamTest, } net::TestCompletionCallback callback; - scoped_refptr<net::IOBufferWithSize> io_buffer(new net::IOBufferWithSize(1)); + auto io_buffer = base::MakeRefCounted<net::IOBufferWithSize>(1); int result = chunked_upload_stream_->Read(io_buffer.get(), 1, callback.callback()); EXPECT_EQ(net::OK, callback.GetResult(result)); @@ -286,8 +285,8 @@ TEST_F(ChunkedDataPipeUploadDataStreamTest, std::string read_data; while (read_data.size() < num_bytes_to_read) { net::TestCompletionCallback callback; - scoped_refptr<net::IOBufferWithSize> io_buffer( - new net::IOBufferWithSize(kData.size())); + auto io_buffer = + base::MakeRefCounted<net::IOBufferWithSize>(kData.size()); int result = chunked_upload_stream_->Read( io_buffer.get(), io_buffer->size(), callback.callback()); result = callback.GetResult(result); @@ -305,7 +304,7 @@ TEST_F(ChunkedDataPipeUploadDataStreamTest, base::RunLoop().RunUntilIdle(); net::TestCompletionCallback callback; - scoped_refptr<net::IOBufferWithSize> io_buffer(new net::IOBufferWithSize(1)); + auto io_buffer = base::MakeRefCounted<net::IOBufferWithSize>(1); int result = chunked_upload_stream_->Read(io_buffer.get(), 1, callback.callback()); EXPECT_EQ(net::OK, callback.GetResult(result)); @@ -339,8 +338,7 @@ TEST_F(ChunkedDataPipeUploadDataStreamTest, GetSizeSucceedsBeforeInit) { while (read_data.size() < kData.size()) { net::TestCompletionCallback callback; int read_size = kData.size() - read_data.size(); - scoped_refptr<net::IOBufferWithSize> io_buffer( - new net::IOBufferWithSize(read_size)); + auto io_buffer = base::MakeRefCounted<net::IOBufferWithSize>(read_size); int result = chunked_upload_stream_->Read( io_buffer.get(), io_buffer->size(), callback.callback()); result = callback.GetResult(result); @@ -362,8 +360,7 @@ TEST_F(ChunkedDataPipeUploadDataStreamTest, GetSizeSucceedsAfterReset) { while (read_data.size() < kData.size()) { net::TestCompletionCallback callback; int read_size = kData.size() - read_data.size(); - scoped_refptr<net::IOBufferWithSize> io_buffer( - new net::IOBufferWithSize(read_size)); + auto io_buffer = base::MakeRefCounted<net::IOBufferWithSize>(read_size); int result = chunked_upload_stream_->Read( io_buffer.get(), io_buffer->size(), callback.callback()); result = callback.GetResult(result); @@ -390,8 +387,7 @@ TEST_F(ChunkedDataPipeUploadDataStreamTest, GetSizeSucceedsAfterReset) { while (read_data.size() < kData.size()) { net::TestCompletionCallback callback; int read_size = kData.size() - read_data.size(); - scoped_refptr<net::IOBufferWithSize> io_buffer( - new net::IOBufferWithSize(read_size)); + auto io_buffer = base::MakeRefCounted<net::IOBufferWithSize>(read_size); int result = chunked_upload_stream_->Read( io_buffer.get(), io_buffer->size(), callback.callback()); result = callback.GetResult(result); @@ -433,8 +429,7 @@ TEST_F(ChunkedDataPipeUploadDataStreamTest, GetSizeFailsAfterReset) { while (read_data.size() < kData.size()) { net::TestCompletionCallback callback; int read_size = kData.size() - read_data.size(); - scoped_refptr<net::IOBufferWithSize> io_buffer( - new net::IOBufferWithSize(read_size)); + auto io_buffer = base::MakeRefCounted<net::IOBufferWithSize>(read_size); int result = chunked_upload_stream_->Read( io_buffer.get(), io_buffer->size(), callback.callback()); result = callback.GetResult(result); @@ -467,7 +462,7 @@ TEST_F(ChunkedDataPipeUploadDataStreamTest, CloseBodyPipeBeforeSuccess1) { base::RunLoop().RunUntilIdle(); net::TestCompletionCallback callback; - scoped_refptr<net::IOBufferWithSize> io_buffer(new net::IOBufferWithSize(1)); + auto io_buffer = base::MakeRefCounted<net::IOBufferWithSize>(1); int result = chunked_upload_stream_->Read(io_buffer.get(), 1, callback.callback()); EXPECT_EQ(result, net::ERR_IO_PENDING); @@ -483,7 +478,7 @@ TEST_F(ChunkedDataPipeUploadDataStreamTest, CloseBodyPipeBeforeSuccess1) { // GetSizeCallback is invoked. TEST_F(ChunkedDataPipeUploadDataStreamTest, CloseBodyPipeBeforeSuccess2) { net::TestCompletionCallback callback; - scoped_refptr<net::IOBufferWithSize> io_buffer(new net::IOBufferWithSize(1)); + auto io_buffer = base::MakeRefCounted<net::IOBufferWithSize>(1); int result = chunked_upload_stream_->Read(io_buffer.get(), 1, callback.callback()); EXPECT_EQ(result, net::ERR_IO_PENDING); @@ -510,7 +505,7 @@ TEST_F(ChunkedDataPipeUploadDataStreamTest, CloseBodyPipeBeforeSuccess3) { EXPECT_FALSE(chunked_upload_stream_->IsEOF()); net::TestCompletionCallback callback; - scoped_refptr<net::IOBufferWithSize> io_buffer(new net::IOBufferWithSize(1)); + auto io_buffer = base::MakeRefCounted<net::IOBufferWithSize>(1); int result = chunked_upload_stream_->Read(io_buffer.get(), 1, callback.callback()); EXPECT_EQ(net::OK, callback.GetResult(result)); @@ -525,7 +520,7 @@ TEST_F(ChunkedDataPipeUploadDataStreamTest, CloseBodyPipeBeforeTruncation1) { base::RunLoop().RunUntilIdle(); net::TestCompletionCallback callback; - scoped_refptr<net::IOBufferWithSize> io_buffer(new net::IOBufferWithSize(1)); + auto io_buffer = base::MakeRefCounted<net::IOBufferWithSize>(1); int result = chunked_upload_stream_->Read(io_buffer.get(), 1, callback.callback()); EXPECT_EQ(result, net::ERR_IO_PENDING); @@ -536,7 +531,7 @@ TEST_F(ChunkedDataPipeUploadDataStreamTest, CloseBodyPipeBeforeTruncation1) { TEST_F(ChunkedDataPipeUploadDataStreamTest, CloseBodyPipeBeforeTruncation2) { net::TestCompletionCallback callback; - scoped_refptr<net::IOBufferWithSize> io_buffer(new net::IOBufferWithSize(1)); + auto io_buffer = base::MakeRefCounted<net::IOBufferWithSize>(1); int result = chunked_upload_stream_->Read(io_buffer.get(), 1, callback.callback()); EXPECT_EQ(result, net::ERR_IO_PENDING); @@ -556,7 +551,7 @@ TEST_F(ChunkedDataPipeUploadDataStreamTest, CloseBodyPipeBeforeTruncation3) { std::move(get_size_callback_).Run(net::OK, 1); net::TestCompletionCallback callback; - scoped_refptr<net::IOBufferWithSize> io_buffer(new net::IOBufferWithSize(1)); + auto io_buffer = base::MakeRefCounted<net::IOBufferWithSize>(1); int result = chunked_upload_stream_->Read(io_buffer.get(), 1, callback.callback()); EXPECT_EQ(net::ERR_FAILED, callback.GetResult(result)); @@ -569,7 +564,7 @@ TEST_F(ChunkedDataPipeUploadDataStreamTest, CloseBodyPipeBeforeFailure1) { base::RunLoop().RunUntilIdle(); net::TestCompletionCallback callback; - scoped_refptr<net::IOBufferWithSize> io_buffer(new net::IOBufferWithSize(1)); + auto io_buffer = base::MakeRefCounted<net::IOBufferWithSize>(1); int result = chunked_upload_stream_->Read(io_buffer.get(), 1, callback.callback()); EXPECT_EQ(result, net::ERR_IO_PENDING); @@ -580,7 +575,7 @@ TEST_F(ChunkedDataPipeUploadDataStreamTest, CloseBodyPipeBeforeFailure1) { TEST_F(ChunkedDataPipeUploadDataStreamTest, CloseBodyPipeBeforeFailure2) { net::TestCompletionCallback callback; - scoped_refptr<net::IOBufferWithSize> io_buffer(new net::IOBufferWithSize(1)); + auto io_buffer = base::MakeRefCounted<net::IOBufferWithSize>(1); int result = chunked_upload_stream_->Read(io_buffer.get(), 1, callback.callback()); EXPECT_EQ(result, net::ERR_IO_PENDING); @@ -600,7 +595,7 @@ TEST_F(ChunkedDataPipeUploadDataStreamTest, CloseBodyPipeBeforeFailure3) { std::move(get_size_callback_).Run(net::ERR_ACCESS_DENIED, 0); net::TestCompletionCallback callback; - scoped_refptr<net::IOBufferWithSize> io_buffer(new net::IOBufferWithSize(1)); + auto io_buffer = base::MakeRefCounted<net::IOBufferWithSize>(1); int result = chunked_upload_stream_->Read(io_buffer.get(), 1, callback.callback()); EXPECT_EQ(net::ERR_ACCESS_DENIED, callback.GetResult(result)); @@ -614,7 +609,7 @@ TEST_F(ChunkedDataPipeUploadDataStreamTest, CloseBodyPipeBeforeCloseGetter1) { base::RunLoop().RunUntilIdle(); net::TestCompletionCallback callback; - scoped_refptr<net::IOBufferWithSize> io_buffer(new net::IOBufferWithSize(1)); + auto io_buffer = base::MakeRefCounted<net::IOBufferWithSize>(1); int result = chunked_upload_stream_->Read(io_buffer.get(), 1, callback.callback()); EXPECT_EQ(result, net::ERR_IO_PENDING); @@ -625,7 +620,7 @@ TEST_F(ChunkedDataPipeUploadDataStreamTest, CloseBodyPipeBeforeCloseGetter1) { TEST_F(ChunkedDataPipeUploadDataStreamTest, CloseBodyPipeBeforeCloseGetter2) { net::TestCompletionCallback callback; - scoped_refptr<net::IOBufferWithSize> io_buffer(new net::IOBufferWithSize(1)); + auto io_buffer = base::MakeRefCounted<net::IOBufferWithSize>(1); int result = chunked_upload_stream_->Read(io_buffer.get(), 1, callback.callback()); EXPECT_EQ(result, net::ERR_IO_PENDING); @@ -645,7 +640,7 @@ TEST_F(ChunkedDataPipeUploadDataStreamTest, CloseBodyPipeBeforeCloseGetter3) { chunked_data_pipe_getter_.reset(); net::TestCompletionCallback callback; - scoped_refptr<net::IOBufferWithSize> io_buffer(new net::IOBufferWithSize(1)); + auto io_buffer = base::MakeRefCounted<net::IOBufferWithSize>(1); int result = chunked_upload_stream_->Read(io_buffer.get(), 1, callback.callback()); EXPECT_EQ(net::ERR_FAILED, callback.GetResult(result)); @@ -660,8 +655,7 @@ TEST_F(ChunkedDataPipeUploadDataStreamTest, ExtraBytes1) { std::string read_data; while (read_data.size() < kData.size()) { net::TestCompletionCallback callback; - scoped_refptr<net::IOBufferWithSize> io_buffer( - new net::IOBufferWithSize(kData.size())); + auto io_buffer = base::MakeRefCounted<net::IOBufferWithSize>(kData.size()); int result = chunked_upload_stream_->Read( io_buffer.get(), io_buffer->size(), callback.callback()); result = callback.GetResult(result); @@ -672,7 +666,7 @@ TEST_F(ChunkedDataPipeUploadDataStreamTest, ExtraBytes1) { } net::TestCompletionCallback callback; - scoped_refptr<net::IOBufferWithSize> io_buffer(new net::IOBufferWithSize(1)); + auto io_buffer = base::MakeRefCounted<net::IOBufferWithSize>(1); int result = chunked_upload_stream_->Read(io_buffer.get(), 1, callback.callback()); EXPECT_EQ(result, net::ERR_IO_PENDING); @@ -689,7 +683,7 @@ TEST_F(ChunkedDataPipeUploadDataStreamTest, ExtraBytes2) { // Read first byte. mojo::BlockingCopyFromString(kData.substr(0, 1), write_pipe_); net::TestCompletionCallback callback; - scoped_refptr<net::IOBufferWithSize> io_buffer(new net::IOBufferWithSize(1)); + auto io_buffer = base::MakeRefCounted<net::IOBufferWithSize>(1); int result = chunked_upload_stream_->Read(io_buffer.get(), io_buffer->size(), callback.callback()); result = callback.GetResult(result); @@ -737,7 +731,7 @@ TEST_F(ChunkedDataPipeUploadDataStreamTest, base::RunLoop().RunUntilIdle(); net::TestCompletionCallback callback; - scoped_refptr<net::IOBufferWithSize> io_buffer(new net::IOBufferWithSize(1)); + auto io_buffer = base::MakeRefCounted<net::IOBufferWithSize>(1); int result = chunked_upload_stream_->Read(io_buffer.get(), 1, callback.callback()); EXPECT_EQ(net::ERR_FAILED, callback.GetResult(result)); @@ -746,7 +740,7 @@ TEST_F(ChunkedDataPipeUploadDataStreamTest, TEST_F(ChunkedDataPipeUploadDataStreamTest, ClosePipeGetterWithoutCallingGetSizeCallbackPendingRead) { net::TestCompletionCallback callback; - scoped_refptr<net::IOBufferWithSize> io_buffer(new net::IOBufferWithSize(1)); + auto io_buffer = base::MakeRefCounted<net::IOBufferWithSize>(1); int result = chunked_upload_stream_->Read(io_buffer.get(), 1, callback.callback()); EXPECT_EQ(net::ERR_IO_PENDING, result); @@ -762,8 +756,7 @@ TEST_F(ChunkedDataPipeUploadDataStreamTest, const char kData[] = "1"; const int kDataLen = strlen(kData); net::TestCompletionCallback callback; - scoped_refptr<net::IOBufferWithSize> io_buffer( - new net::IOBufferWithSize(kDataLen)); + auto io_buffer = base::MakeRefCounted<net::IOBufferWithSize>(kDataLen); std::move(get_size_callback_).Run(net::OK, kDataLen); // Destroy the DataPipeGetter pipe, which is the pipe used for // GetSizeCallback. diff --git a/chromium/services/network/cookie_manager.cc b/chromium/services/network/cookie_manager.cc index 256b9eec57c..3e66d18b3ba 100644 --- a/chromium/services/network/cookie_manager.cc +++ b/chromium/services/network/cookie_manager.cc @@ -72,6 +72,12 @@ CookieManager::CookieManager( cookie_settings_.set_block_third_party_cookies( params->block_third_party_cookies); cookie_settings_.set_content_settings(params->settings); + cookie_settings_.set_secure_origin_cookies_allowed_schemes( + params->secure_origin_cookies_allowed_schemes); + cookie_settings_.set_matching_scheme_cookies_allowed_schemes( + params->matching_scheme_cookies_allowed_schemes); + cookie_settings_.set_third_party_cookies_allowed_schemes( + params->third_party_cookies_allowed_schemes); } } diff --git a/chromium/services/network/cookie_settings.cc b/chromium/services/network/cookie_settings.cc index 06eb302c43e..82fc2693afb 100644 --- a/chromium/services/network/cookie_settings.cc +++ b/chromium/services/network/cookie_settings.cc @@ -31,9 +31,25 @@ void CookieSettings::GetCookieSetting(const GURL& url, const GURL& first_party_url, content_settings::SettingSource* source, ContentSetting* cookie_setting) const { + if (base::ContainsKey(secure_origin_cookies_allowed_schemes_, + first_party_url.scheme()) && + url.SchemeIsCryptographic()) { + *cookie_setting = CONTENT_SETTING_ALLOW; + return; + } + + if (base::ContainsKey(matching_scheme_cookies_allowed_schemes_, + url.scheme()) && + url.SchemeIs(first_party_url.scheme_piece())) { + *cookie_setting = CONTENT_SETTING_ALLOW; + return; + } + // Default to allowing cookies. *cookie_setting = CONTENT_SETTING_ALLOW; - bool block_third = block_third_party_cookies_; + bool block_third = block_third_party_cookies_ && + !base::ContainsKey(third_party_cookies_allowed_schemes_, + first_party_url.scheme()); for (const auto& entry : content_settings_) { if (entry.primary_pattern.Matches(url) && entry.secondary_pattern.Matches(first_party_url)) { diff --git a/chromium/services/network/cookie_settings.h b/chromium/services/network/cookie_settings.h index 9f63fcc54af..578af6b3722 100644 --- a/chromium/services/network/cookie_settings.h +++ b/chromium/services/network/cookie_settings.h @@ -29,6 +29,30 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) CookieSettings block_third_party_cookies_ = block_third_party_cookies; } + void set_secure_origin_cookies_allowed_schemes( + const std::vector<std::string>& secure_origin_cookies_allowed_schemes) { + secure_origin_cookies_allowed_schemes_.clear(); + secure_origin_cookies_allowed_schemes_.insert( + secure_origin_cookies_allowed_schemes.begin(), + secure_origin_cookies_allowed_schemes.end()); + } + + void set_matching_scheme_cookies_allowed_schemes( + const std::vector<std::string>& matching_scheme_cookies_allowed_schemes) { + matching_scheme_cookies_allowed_schemes_.clear(); + matching_scheme_cookies_allowed_schemes_.insert( + matching_scheme_cookies_allowed_schemes.begin(), + matching_scheme_cookies_allowed_schemes.end()); + } + + void set_third_party_cookies_allowed_schemes( + const std::vector<std::string>& third_party_cookies_allowed_schemes) { + third_party_cookies_allowed_schemes_.clear(); + third_party_cookies_allowed_schemes_.insert( + third_party_cookies_allowed_schemes.begin(), + third_party_cookies_allowed_schemes.end()); + } + // Returns a predicate that takes the domain of a cookie and a bool whether // the cookie is secure and returns true if the cookie should be deleted on // exit. @@ -47,6 +71,9 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) CookieSettings ContentSettingsForOneType content_settings_; bool block_third_party_cookies_ = false; + std::set<std::string> secure_origin_cookies_allowed_schemes_; + std::set<std::string> matching_scheme_cookies_allowed_schemes_; + std::set<std::string> third_party_cookies_allowed_schemes_; DISALLOW_COPY_AND_ASSIGN(CookieSettings); }; diff --git a/chromium/services/network/cookie_settings_unittest.cc b/chromium/services/network/cookie_settings_unittest.cc index 20040d2fa20..d2e131eab81 100644 --- a/chromium/services/network/cookie_settings_unittest.cc +++ b/chromium/services/network/cookie_settings_unittest.cc @@ -117,5 +117,74 @@ TEST(CookieSettingsTest, CreateDeleteCookieOnExitPredicateAllow) { EXPECT_FALSE(settings.CreateDeleteCookieOnExitPredicate().Run(kURL, false)); } +TEST(CookieSettingsTest, GetCookieSettingSecureOriginCookiesAllowed) { + CookieSettings settings; + settings.set_secure_origin_cookies_allowed_schemes({"chrome"}); + settings.set_block_third_party_cookies(true); + + ContentSetting setting; + settings.GetCookieSetting(GURL("https://foo.com") /* url */, + GURL("chrome://foo") /* first_party_url */, + nullptr /* source */, &setting); + EXPECT_EQ(setting, CONTENT_SETTING_ALLOW); + + settings.GetCookieSetting(GURL("chrome://foo") /* url */, + GURL("https://foo.com") /* first_party_url */, + nullptr /* source */, &setting); + EXPECT_EQ(setting, CONTENT_SETTING_BLOCK); + + settings.GetCookieSetting(GURL("http://foo.com") /* url */, + GURL("chrome://foo") /* first_party_url */, + nullptr /* source */, &setting); + EXPECT_EQ(setting, CONTENT_SETTING_BLOCK); +} + +TEST(CookieSettingsTest, GetCookieSettingWithThirdPartyCookiesAllowedScheme) { + CookieSettings settings; + settings.set_third_party_cookies_allowed_schemes({"chrome-extension"}); + settings.set_block_third_party_cookies(true); + + ContentSetting setting; + settings.GetCookieSetting( + GURL("http://foo.com") /* url */, + GURL("chrome-extension://foo") /* first_party_url */, + nullptr /* source */, &setting); + EXPECT_EQ(setting, CONTENT_SETTING_ALLOW); + + settings.GetCookieSetting(GURL("http://foo.com") /* url */, + GURL("other-scheme://foo") /* first_party_url */, + nullptr /* source */, &setting); + EXPECT_EQ(setting, CONTENT_SETTING_BLOCK); + + settings.GetCookieSetting(GURL("chrome-extension://foo") /* url */, + GURL("http://foo.com") /* first_party_url */, + nullptr /* source */, &setting); + EXPECT_EQ(setting, CONTENT_SETTING_BLOCK); +} + +TEST(CookieSettingsTest, GetCookieSettingMatchingSchemeCookiesAllowed) { + CookieSettings settings; + settings.set_matching_scheme_cookies_allowed_schemes({"chrome-extension"}); + settings.set_block_third_party_cookies(true); + + ContentSetting setting; + settings.GetCookieSetting( + GURL("chrome-extension://bar") /* url */, + GURL("chrome-extension://foo") /* first_party_url */, + nullptr /* source */, &setting); + EXPECT_EQ(setting, CONTENT_SETTING_ALLOW); + + settings.GetCookieSetting( + GURL("http://foo.com") /* url */, + GURL("chrome-extension://foo") /* first_party_url */, + nullptr /* source */, &setting); + EXPECT_EQ(setting, CONTENT_SETTING_BLOCK); + + settings.GetCookieSetting(GURL("chrome-extension://foo") /* url */, + GURL("http://foo.com") /* first_party_url */, + nullptr /* source */, &setting); + EXPECT_EQ(setting, CONTENT_SETTING_BLOCK); +} + } // namespace } // namespace network diff --git a/chromium/services/network/cors/cors_url_loader.cc b/chromium/services/network/cors/cors_url_loader.cc index 4df63705d65..b7b56a5e4f9 100644 --- a/chromium/services/network/cors/cors_url_loader.cc +++ b/chromium/services/network/cors/cors_url_loader.cc @@ -8,6 +8,7 @@ #include "net/base/load_flags.h" #include "services/network/cors/preflight_controller.h" #include "services/network/public/cpp/cors/cors.h" +#include "services/network/public/cpp/cors/origin_access_list.h" #include "url/url_util.h" namespace network { @@ -16,30 +17,8 @@ namespace cors { namespace { -bool CalculateCORSFlag(const ResourceRequest& request) { - if (request.fetch_request_mode == mojom::FetchRequestMode::kNavigate || - request.fetch_request_mode == mojom::FetchRequestMode::kNoCORS) { - return false; - } - // CORS needs a proper origin (including a unique opaque origin). If the - // request doesn't have one, CORS should not work. - DCHECK(request.request_initiator); - url::Origin url_origin = url::Origin::Create(request.url); - url::Origin security_origin(request.request_initiator.value()); - return !security_origin.IsSameOriginWith(url_origin); -} - -base::Optional<std::string> GetHeaderString( - const scoped_refptr<net::HttpResponseHeaders>& headers, - const std::string& header_name) { - std::string header_value; - if (!headers->GetNormalizedHeader(header_name, &header_value)) - return base::nullopt; - return header_value; -} - bool NeedsPreflight(const ResourceRequest& request) { - if (!cors::IsCORSEnabledRequestMode(request.fetch_request_mode)) + if (!IsCORSEnabledRequestMode(request.fetch_request_mode)) return false; if (request.is_external_request) @@ -58,13 +37,9 @@ bool NeedsPreflight(const ResourceRequest& request) { if (!IsCORSSafelistedMethod(request.method)) return true; - for (const auto& header : request.headers.GetHeaderVector()) { - if (!IsCORSSafelistedHeader(header.key, header.value) && - !IsForbiddenHeader(header.key)) { - return true; - } - } - return false; + return !CORSUnsafeNotForbiddenRequestHeaderNames( + request.headers.GetHeaderVector(), request.is_revalidating) + .empty(); } } // namespace @@ -79,7 +54,8 @@ CORSURLLoader::CORSURLLoader( mojom::URLLoaderClientPtr client, const net::MutableNetworkTrafficAnnotationTag& traffic_annotation, mojom::URLLoaderFactory* network_loader_factory, - const base::RepeatingCallback<void(int)>& request_finalizer) + const base::RepeatingCallback<void(int)>& request_finalizer, + const OriginAccessList* origin_access_list) : binding_(this, std::move(loader_request)), routing_id_(routing_id), request_id_(request_id), @@ -89,27 +65,26 @@ CORSURLLoader::CORSURLLoader( network_client_binding_(this), request_(resource_request), forwarding_client_(std::move(client)), - fetch_cors_flag_(CalculateCORSFlag(resource_request)), request_finalizer_(request_finalizer), traffic_annotation_(traffic_annotation), + origin_access_list_(origin_access_list), weak_factory_(this) { binding_.set_connection_error_handler(base::BindOnce( &CORSURLLoader::OnConnectionError, base::Unretained(this))); DCHECK(network_loader_factory_); + DCHECK(origin_access_list_); + SetCORSFlagIfNeeded(); } -CORSURLLoader::~CORSURLLoader() = default; +CORSURLLoader::~CORSURLLoader() { + // Close pipes first to ignore possible subsequent callback invocations + // cased by |network_loader_| + network_client_binding_.Close(); +} void CORSURLLoader::Start() { if (fetch_cors_flag_ && - request_.fetch_request_mode == mojom::FetchRequestMode::kSameOrigin) { - HandleComplete(URLLoaderCompletionStatus( - CORSErrorStatus(mojom::CORSError::kDisallowedByMode))); - return; - } - - if (fetch_cors_flag_ && - cors::IsCORSEnabledRequestMode(request_.fetch_request_mode)) { + IsCORSEnabledRequestMode(request_.fetch_request_mode)) { // Username and password should be stripped in a CORS-enabled request. if (request_.url.has_username() || request_.url.has_password()) { GURL::Replacements replacements; @@ -149,8 +124,14 @@ void CORSURLLoader::FollowRedirect( request_.method = redirect_info_.new_method; request_.referrer = GURL(redirect_info_.new_referrer); request_.referrer_policy = redirect_info_.new_referrer_policy; + + // The request method can be changed to "GET". In this case we need to + // reset the request body manually. + if (request_.method == net::HttpRequestHeaders::kGetMethod) + request_.request_body = nullptr; + const bool original_fetch_cors_flag = fetch_cors_flag_; - fetch_cors_flag_ = fetch_cors_flag_ || CalculateCORSFlag(request_); + SetCORSFlagIfNeeded(); // We cannot use FollowRedirect for a request with preflight (i.e., when both // |fetch_cors_flag_| and |NeedsPreflight(request_)| are true). @@ -164,6 +145,9 @@ void CORSURLLoader::FollowRedirect( // in net/url_request/redirect_util.cc). if ((original_fetch_cors_flag && !NeedsPreflight(request_)) || !fetch_cors_flag_) { + response_tainting_ = + CalculateResponseTainting(request_.url, request_.fetch_request_mode, + request_.request_initiator, fetch_cors_flag_); network_loader_->FollowRedirect(to_be_removed_request_headers, modified_request_headers); return; @@ -174,15 +158,6 @@ void CORSURLLoader::FollowRedirect( request_finalizer_.Run(request_id_); network_client_binding_.Unbind(); - if (request_.fetch_credentials_mode == - mojom::FetchCredentialsMode::kSameOrigin) { - // If the credentials mode is "same-origin" and CORS flag is set, we must - // not send credentials. As network::URLLoaderImpl doesn't understand - // |fetch_credentials_mode|, we need to set the load flags here. - request_.load_flags |= net::LOAD_DO_NOT_SAVE_COOKIES; - request_.load_flags |= net::LOAD_DO_NOT_SEND_COOKIES; - request_.load_flags |= net::LOAD_DO_NOT_SEND_AUTH_DATA; - } StartRequest(); } @@ -197,13 +172,11 @@ void CORSURLLoader::SetPriority(net::RequestPriority priority, } void CORSURLLoader::PauseReadingBodyFromNet() { - DCHECK(!is_waiting_follow_redirect_call_); if (network_loader_) network_loader_->PauseReadingBodyFromNet(); } void CORSURLLoader::ResumeReadingBodyFromNet() { - DCHECK(!is_waiting_follow_redirect_call_); if (network_loader_) network_loader_->ResumeReadingBodyFromNet(); } @@ -213,14 +186,14 @@ void CORSURLLoader::OnReceiveResponse( DCHECK(network_loader_); DCHECK(forwarding_client_); DCHECK(!is_waiting_follow_redirect_call_); - if (fetch_cors_flag_ && - IsCORSEnabledRequestMode(request_.fetch_request_mode)) { - // TODO(toyoshim): Reflect --allow-file-access-from-files flag. + + const bool is_304_for_revalidation = + request_.is_revalidating && response_head.headers->response_code() == 304; + if (fetch_cors_flag_ && !is_304_for_revalidation) { const auto error_status = CheckAccess( request_.url, response_head.headers->response_code(), - GetHeaderString(response_head.headers, - header_names::kAccessControlAllowOrigin), - GetHeaderString(response_head.headers, + GetHeaderString(response_head, header_names::kAccessControlAllowOrigin), + GetHeaderString(response_head, header_names::kAccessControlAllowCredentials), request_.fetch_credentials_mode, tainted_ ? url::Origin() : *request_.request_initiator); @@ -229,7 +202,10 @@ void CORSURLLoader::OnReceiveResponse( return; } } - forwarding_client_->OnReceiveResponse(response_head); + + ResourceResponseHead response_head_to_pass = response_head; + response_head_to_pass.response_type = response_tainting_; + forwarding_client_->OnReceiveResponse(response_head_to_pass); } void CORSURLLoader::OnReceiveRedirect( @@ -249,12 +225,10 @@ void CORSURLLoader::OnReceiveRedirect( // failure, then return a network error. if (fetch_cors_flag_ && IsCORSEnabledRequestMode(request_.fetch_request_mode)) { - // TODO(toyoshim): Reflect --allow-file-access-from-files flag. const auto error_status = CheckAccess( request_.url, response_head.headers->response_code(), - GetHeaderString(response_head.headers, - header_names::kAccessControlAllowOrigin), - GetHeaderString(response_head.headers, + GetHeaderString(response_head, header_names::kAccessControlAllowOrigin), + GetHeaderString(response_head, header_names::kAccessControlAllowCredentials), request_.fetch_credentials_mode, tainted_ ? url::Origin() : *request_.request_initiator); @@ -313,7 +287,15 @@ void CORSURLLoader::OnReceiveRedirect( redirect_info_ = redirect_info; is_waiting_follow_redirect_call_ = true; - forwarding_client_->OnReceiveRedirect(redirect_info, response_head); + + auto response_head_to_pass = response_head; + if (request_.fetch_redirect_mode == mojom::FetchRedirectMode::kManual) { + response_head_to_pass.response_type = + mojom::FetchResponseType::kOpaqueRedirect; + } else { + response_head_to_pass.response_type = response_tainting_; + } + forwarding_client_->OnReceiveRedirect(redirect_info, response_head_to_pass); } void CORSURLLoader::OnUploadProgress(int64_t current_position, @@ -352,7 +334,11 @@ void CORSURLLoader::OnComplete(const URLLoaderCompletionStatus& status) { DCHECK(network_loader_); DCHECK(forwarding_client_); DCHECK(!is_waiting_follow_redirect_call_); - HandleComplete(status); + + URLLoaderCompletionStatus modified_status(status); + if (status.error_code == net::OK) + modified_status.cors_preflight_timing_info.swap(preflight_timing_info_); + HandleComplete(modified_status); } void CORSURLLoader::StartRequest() { @@ -367,30 +353,42 @@ void CORSURLLoader::StartRequest() { // `HEAD`, or |httpRequest|’s mode is "websocket", then append // `Origin`/the result of serializing a request origin with |httpRequest|, to // |httpRequest|’s header list. - if (fetch_cors_flag_ || - (request_.method != "GET" && request_.method != "HEAD")) { - if (request_.request_initiator) { - request_.headers.SetHeader( - net::HttpRequestHeaders::kOrigin, - (tainted_ ? url::Origin() : *request_.request_initiator).Serialize()); - } + // + // We exclude navigation requests to keep the existing behavior. + // TODO(yhirano): Reconsider this. + if (request_.fetch_request_mode != mojom::FetchRequestMode::kNavigate && + request_.request_initiator && + (fetch_cors_flag_ || + (request_.method != "GET" && request_.method != "HEAD"))) { + request_.headers.SetHeader( + net::HttpRequestHeaders::kOrigin, + (tainted_ ? url::Origin() : *request_.request_initiator).Serialize()); } - if (request_.fetch_request_mode == mojom::FetchRequestMode::kSameOrigin) { + if (fetch_cors_flag_ && + request_.fetch_request_mode == mojom::FetchRequestMode::kSameOrigin) { DCHECK(request_.request_initiator); - if (!request_.request_initiator->IsSameOriginWith( - url::Origin::Create(request_.url))) { - HandleComplete(URLLoaderCompletionStatus( - CORSErrorStatus(mojom::CORSError::kDisallowedByMode))); - return; - } + HandleComplete(URLLoaderCompletionStatus( + CORSErrorStatus(mojom::CORSError::kDisallowedByMode))); + return; + } + + response_tainting_ = + CalculateResponseTainting(request_.url, request_.fetch_request_mode, + request_.request_initiator, fetch_cors_flag_); + + if (!CalculateCredentialsFlag(request_.fetch_credentials_mode, + response_tainting_)) { + request_.load_flags |= net::LOAD_DO_NOT_SAVE_COOKIES; + request_.load_flags |= net::LOAD_DO_NOT_SEND_COOKIES; + request_.load_flags |= net::LOAD_DO_NOT_SEND_AUTH_DATA; } // Note that even when |NeedsPreflight(request_)| holds we don't make a // preflight request when |fetch_cors_flag_| is false (e.g., when the origin // of the url is equal to the origin of the request. if (!fetch_cors_flag_ || !NeedsPreflight(request_)) { - StartNetworkRequest(net::OK, base::nullopt); + StartNetworkRequest(net::OK, base::nullopt, base::nullopt); return; } @@ -408,14 +406,18 @@ void CORSURLLoader::StartRequest() { void CORSURLLoader::StartNetworkRequest( int error_code, - base::Optional<CORSErrorStatus> status) { + base::Optional<CORSErrorStatus> status, + base::Optional<PreflightTimingInfo> preflight_timing_info) { if (error_code != net::OK) { HandleComplete(status ? URLLoaderCompletionStatus(*status) : URLLoaderCompletionStatus(error_code)); return; } - DCHECK(!status); + + if (preflight_timing_info) + preflight_timing_info_.push_back(*preflight_timing_info); + mojom::URLLoaderClientPtr network_client; network_client_binding_.Bind(mojo::MakeRequest(&network_client)); // Binding |this| as an unretained pointer is safe because @@ -429,19 +431,69 @@ void CORSURLLoader::StartNetworkRequest( void CORSURLLoader::HandleComplete(const URLLoaderCompletionStatus& status) { forwarding_client_->OnComplete(status); - - // Close pipes to ignore possible subsequent callback invocations. - network_client_binding_.Close(); - - forwarding_client_.reset(); - network_loader_.reset(); - std::move(delete_callback_).Run(this); // |this| is deleted here. } void CORSURLLoader::OnConnectionError() { - HandleComplete(URLLoaderCompletionStatus(net::ERR_FAILED)); + HandleComplete(URLLoaderCompletionStatus(net::ERR_ABORTED)); +} + +// This should be identical to CalculateCORSFlag defined in +// //third_party/blink/renderer/platform/loader/cors/cors.cc. +void CORSURLLoader::SetCORSFlagIfNeeded() { + if (fetch_cors_flag_) + return; + + if (request_.fetch_request_mode == mojom::FetchRequestMode::kNavigate || + request_.fetch_request_mode == mojom::FetchRequestMode::kNoCORS) { + return; + } + + if (request_.url.SchemeIs(url::kDataScheme)) + return; + + // CORS needs a proper origin (including a unique opaque origin). If the + // request doesn't have one, CORS should not work. + DCHECK(request_.request_initiator); + + // The source origin and destination URL pair may be in the allow list. + if (origin_access_list_->IsAllowed(*request_.request_initiator, + request_.url)) { + return; + } + + // When a request is initiated in a unique opaque origin (e.g., in a sandboxed + // iframe) and the blob is also created in the context, |request_initiator| + // is a unique opaque origin and url::Origin::Create(request_.url) is another + // unique opaque origin. url::Origin::IsSameOriginWith(p, q) returns false + // when both |p| and |q| are opaque, but in this case we want to say that the + // request is a same-origin request. Hence we don't set |fetch_cors_flag_|, + // assuming the request comes from a renderer and the origin is checked there + // (in BaseFetchContext::CanRequest). + // In the future blob URLs will not come here because there will be a + // separate URLLoaderFactory for blobs. + // TODO(yhirano): Remove this logic at the time. + if (request_.url.SchemeIsBlob() && request_.request_initiator->opaque() && + url::Origin::Create(request_.url).opaque()) { + return; + } + + if (request_.request_initiator->IsSameOriginWith( + url::Origin::Create(request_.url))) { + return; + } + + fetch_cors_flag_ = true; +} + +base::Optional<std::string> CORSURLLoader::GetHeaderString( + const ResourceResponseHead& response, + const std::string& header_name) { + std::string header_value; + if (!response.headers->GetNormalizedHeader(header_name, &header_value)) + return base::nullopt; + return header_value; } } // namespace cors diff --git a/chromium/services/network/cors/cors_url_loader.h b/chromium/services/network/cors/cors_url_loader.h index 835cec52c2d..5d67f8401a8 100644 --- a/chromium/services/network/cors/cors_url_loader.h +++ b/chromium/services/network/cors/cors_url_loader.h @@ -10,6 +10,7 @@ #include "mojo/public/cpp/bindings/binding.h" #include "net/traffic_annotation/network_traffic_annotation.h" #include "services/network/public/cpp/cors/cors_error_status.h" +#include "services/network/public/cpp/cors/preflight_timing_info.h" #include "services/network/public/mojom/fetch_api.mojom.h" #include "services/network/public/mojom/url_loader_factory.mojom.h" #include "url/gurl.h" @@ -19,6 +20,8 @@ namespace network { namespace cors { +class OriginAccessList; + // Wrapper class that adds cross-origin resource sharing capabilities // (https://fetch.spec.whatwg.org/#http-cors-protocol), delegating requests as // well as potential preflight requests to the supplied @@ -43,7 +46,8 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) CORSURLLoader mojom::URLLoaderClientPtr client, const net::MutableNetworkTrafficAnnotationTag& traffic_annotation, mojom::URLLoaderFactory* network_loader_factory, - const base::RepeatingCallback<void(int)>& request_finalizer); + const base::RepeatingCallback<void(int)>& request_finalizer, + const OriginAccessList* origin_access_list); ~CORSURLLoader() override; @@ -77,8 +81,10 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) CORSURLLoader private: void StartRequest(); - void StartNetworkRequest(int net_error, - base::Optional<CORSErrorStatus> status); + void StartNetworkRequest( + int net_error, + base::Optional<CORSErrorStatus> status, + base::Optional<PreflightTimingInfo> preflight_timing_info); // Called when there is a connection error on the upstream pipe used for the // actual request. @@ -89,6 +95,12 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) CORSURLLoader void OnConnectionError(); + void SetCORSFlagIfNeeded(); + + static base::Optional<std::string> GetHeaderString( + const ResourceResponseHead& response, + const std::string& header_name); + mojo::Binding<mojom::URLLoader> binding_; // We need to save these for redirect. @@ -114,12 +126,18 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) CORSURLLoader // different if redirects happen. GURL last_response_url_; + // https://fetch.spec.whatwg.org/#concept-request-response-tainting + // As "response tainting" is subset of "response type", we use + // mojom::FetchResponseType for convenience. + mojom::FetchResponseType response_tainting_ = + mojom::FetchResponseType::kBasic; + // A flag to indicate that the instance is waiting for that forwarding_client_ // calls FollowRedirect. bool is_waiting_follow_redirect_call_ = false; // Corresponds to the Fetch spec, https://fetch.spec.whatwg.org/. - bool fetch_cors_flag_; + bool fetch_cors_flag_ = false; net::RedirectInfo redirect_info_; @@ -136,6 +154,12 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) CORSURLLoader // We need to save this for redirect. net::MutableNetworkTrafficAnnotationTag traffic_annotation_; + // Holds timing info if a preflight was made. + std::vector<PreflightTimingInfo> preflight_timing_info_; + + // Outlives |this|. + const OriginAccessList* const origin_access_list_; + // Used to run asynchronous class instance bound callbacks safely. base::WeakPtrFactory<CORSURLLoader> weak_factory_; diff --git a/chromium/services/network/cors/cors_url_loader_factory.cc b/chromium/services/network/cors/cors_url_loader_factory.cc index 47e085480a3..5786ab2c56e 100644 --- a/chromium/services/network/cors/cors_url_loader_factory.cc +++ b/chromium/services/network/cors/cors_url_loader_factory.cc @@ -23,15 +23,18 @@ CORSURLLoaderFactory::CORSURLLoaderFactory( NetworkContext* context, mojom::URLLoaderFactoryParamsPtr params, scoped_refptr<ResourceSchedulerClient> resource_scheduler_client, - mojom::URLLoaderFactoryRequest request) + mojom::URLLoaderFactoryRequest request, + const OriginAccessList* origin_access_list) : context_(context), disable_web_security_(params && params->disable_web_security), network_loader_factory_(std::make_unique<network::URLLoaderFactory>( context, std::move(params), std::move(resource_scheduler_client), - this)) { + this)), + origin_access_list_(origin_access_list) { DCHECK(context_); + DCHECK(origin_access_list_); bindings_.AddBinding(this, std::move(request)); bindings_.set_connection_error_handler(base::BindRepeating( &CORSURLLoaderFactory::DeleteIfNeeded, base::Unretained(this))); @@ -40,10 +43,14 @@ CORSURLLoaderFactory::CORSURLLoaderFactory( CORSURLLoaderFactory::CORSURLLoaderFactory( bool disable_web_security, std::unique_ptr<mojom::URLLoaderFactory> network_loader_factory, - const base::RepeatingCallback<void(int)>& preflight_finalizer) + const base::RepeatingCallback<void(int)>& preflight_finalizer, + const OriginAccessList* origin_access_list) : disable_web_security_(disable_web_security), network_loader_factory_(std::move(network_loader_factory)), - preflight_finalizer_(preflight_finalizer) {} + preflight_finalizer_(preflight_finalizer), + origin_access_list_(origin_access_list) { + DCHECK(origin_access_list_); +} CORSURLLoaderFactory::~CORSURLLoaderFactory() = default; @@ -80,7 +87,8 @@ void CORSURLLoaderFactory::CreateLoaderAndStart( base::BindOnce(&CORSURLLoaderFactory::DestroyURLLoader, base::Unretained(this)), resource_request, std::move(client), traffic_annotation, - network_loader_factory_.get(), preflight_finalizer_); + network_loader_factory_.get(), preflight_finalizer_, + origin_access_list_); auto* raw_loader = loader.get(); OnLoaderCreated(std::move(loader)); raw_loader->Start(); diff --git a/chromium/services/network/cors/cors_url_loader_factory.h b/chromium/services/network/cors/cors_url_loader_factory.h index 61db987e15d..0e5d71d4f64 100644 --- a/chromium/services/network/cors/cors_url_loader_factory.h +++ b/chromium/services/network/cors/cors_url_loader_factory.h @@ -12,6 +12,7 @@ #include "base/macros.h" #include "mojo/public/cpp/bindings/strong_binding_set.h" #include "net/traffic_annotation/network_traffic_annotation.h" +#include "services/network/public/cpp/cors/origin_access_list.h" #include "services/network/public/mojom/network_context.mojom.h" #include "services/network/public/mojom/url_loader_factory.mojom.h" @@ -31,18 +32,21 @@ namespace cors { class COMPONENT_EXPORT(NETWORK_SERVICE) CORSURLLoaderFactory final : public mojom::URLLoaderFactory { public: + // |origin_access_list| should always outlive this factory instance. // Used by network::NetworkContext. CORSURLLoaderFactory( NetworkContext* context, mojom::URLLoaderFactoryParamsPtr params, scoped_refptr<ResourceSchedulerClient> resource_scheduler_client, - mojom::URLLoaderFactoryRequest request); + mojom::URLLoaderFactoryRequest request, + const OriginAccessList* origin_access_list); // Used by content::ResourceMessageFilter. // TODO(yhirano): Remove this once when the network service is fully enabled. CORSURLLoaderFactory( bool disable_web_security, std::unique_ptr<mojom::URLLoaderFactory> network_loader_factory, - const base::RepeatingCallback<void(int)>& preflight_finalizer); + const base::RepeatingCallback<void(int)>& preflight_finalizer, + const OriginAccessList* origin_access_list); ~CORSURLLoaderFactory() override; void OnLoaderCreated(std::unique_ptr<mojom::URLLoader> loader); @@ -83,6 +87,10 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) CORSURLLoaderFactory final // Used when constructed by ResourceMessageFilter. base::RepeatingCallback<void(int)> preflight_finalizer_; + // Accessed by instances in |loaders_| too. Since the factory outlives them, + // it's safe. + const OriginAccessList* const origin_access_list_; + DISALLOW_COPY_AND_ASSIGN(CORSURLLoaderFactory); }; diff --git a/chromium/services/network/cors/cors_url_loader_unittest.cc b/chromium/services/network/cors/cors_url_loader_unittest.cc index 16ef0cabc8e..bc0164b312e 100644 --- a/chromium/services/network/cors/cors_url_loader_unittest.cc +++ b/chromium/services/network/cors/cors_url_loader_unittest.cc @@ -18,6 +18,7 @@ #include "net/url_request/url_request.h" #include "services/network/cors/cors_url_loader_factory.h" #include "services/network/public/cpp/features.h" +#include "services/network/public/mojom/cors.mojom.h" #include "services/network/public/mojom/url_loader.mojom.h" #include "services/network/public/mojom/url_loader_factory.mojom.h" #include "services/network/test/test_url_loader_client.h" @@ -39,12 +40,13 @@ class TestURLLoaderFactory : public mojom::URLLoaderFactory { } void NotifyClientOnReceiveResponse( + int status_code, const std::vector<std::string>& extra_headers) { - DCHECK(client_ptr_); ResourceResponseHead response; response.headers = new net::HttpResponseHeaders( - "HTTP/1.1 200 OK\n" - "Content-Type: image/png\n"); + base::StringPrintf("HTTP/1.1 %d OK\n" + "Content-Type: image/png\n", + status_code)); for (const auto& header : extra_headers) response.headers->AddHeader(header); @@ -121,7 +123,8 @@ class CORSURLLoaderTest : public testing::Test { std::make_unique<TestURLLoaderFactory>(); test_url_loader_factory_ = factory->GetWeakPtr(); cors_url_loader_factory_ = std::make_unique<CORSURLLoaderFactory>( - false, std::move(factory), base::RepeatingCallback<void(int)>()); + false, std::move(factory), base::RepeatingCallback<void(int)>(), + &origin_access_list_); } protected: @@ -161,7 +164,15 @@ class CORSURLLoaderTest : public testing::Test { void NotifyLoaderClientOnReceiveResponse( const std::vector<std::string>& extra_headers = {}) { DCHECK(test_url_loader_factory_); - test_url_loader_factory_->NotifyClientOnReceiveResponse(extra_headers); + test_url_loader_factory_->NotifyClientOnReceiveResponse(200, extra_headers); + } + + void NotifyLoaderClientOnReceiveResponse( + int status_code, + const std::vector<std::string>& extra_headers = {}) { + DCHECK(test_url_loader_factory_); + test_url_loader_factory_->NotifyClientOnReceiveResponse(status_code, + extra_headers); } void NotifyLoaderClientOnReceiveRedirect( @@ -214,6 +225,15 @@ class CORSURLLoaderTest : public testing::Test { test_cors_loader_client_.RunUntilRedirectReceived(); } + void AddAllowListEntryForOrigin(const url::Origin& source_origin, + const std::string& protocol, + const std::string& domain, + bool allow_subdomains) { + origin_access_list_.AddAllowListEntryForOrigin( + source_origin, protocol, domain, allow_subdomains, + network::mojom::CORSOriginAccessMatchPriority::kDefaultPriority); + } + static net::RedirectInfo CreateRedirectInfo( int status_code, base::StringPiece method, @@ -248,6 +268,9 @@ class CORSURLLoaderTest : public testing::Test { // TestURLLoaderClient that records callback activities. TestURLLoaderClient test_cors_loader_client_; + // Holds for allowed origin access lists. + OriginAccessList origin_access_list_; + DISALLOW_COPY_AND_ASSIGN(CORSURLLoaderTest); }; @@ -980,6 +1003,175 @@ TEST_F(CORSURLLoaderTest, FollowErrorRedirect) { EXPECT_EQ(net::ERR_FAILED, client().completion_status().error_code); } +// Tests if OriginAccessList is actually used to decide the cors flag. +// Does not verify detailed functionalities that are verified in +// OriginAccessListTest. +TEST_F(CORSURLLoaderTest, OriginAccessList) { + const GURL origin("http://example.com"); + const GURL url("http://other.com/foo.png"); + + // Adds an entry to allow the cross origin request beyond the CORS + // rules. + AddAllowListEntryForOrigin(url::Origin::Create(origin), url.scheme(), + url.host(), false); + + CreateLoaderAndStart(origin, url, mojom::FetchRequestMode::kCORS); + + NotifyLoaderClientOnReceiveResponse(); + NotifyLoaderClientOnComplete(net::OK); + + RunUntilComplete(); + + EXPECT_TRUE(IsNetworkLoaderStarted()); + EXPECT_FALSE(client().has_received_redirect()); + EXPECT_TRUE(client().has_received_response()); + EXPECT_TRUE(client().has_received_completion()); + EXPECT_EQ(net::OK, client().completion_status().error_code); +} + +TEST_F(CORSURLLoaderTest, 304ForSimpleRevalidation) { + const GURL origin("https://example.com"); + const GURL url("https://other.example.com/foo.png"); + const GURL new_url("https://other2.example.com/bar.png"); + + ResourceRequest request; + request.fetch_request_mode = mojom::FetchRequestMode::kCORS; + request.fetch_credentials_mode = mojom::FetchCredentialsMode::kOmit; + request.load_flags |= net::LOAD_DO_NOT_SAVE_COOKIES; + request.load_flags |= net::LOAD_DO_NOT_SEND_COOKIES; + request.load_flags |= net::LOAD_DO_NOT_SEND_AUTH_DATA; + request.method = "GET"; + request.url = url; + request.request_initiator = url::Origin::Create(origin); + request.headers.SetHeader("If-Modified-Since", "x"); + request.headers.SetHeader("If-None-Match", "y"); + request.headers.SetHeader("Cache-Control", "z"); + request.is_revalidating = true; + CreateLoaderAndStart(request); + + // No preflight, no CORS response headers. + NotifyLoaderClientOnReceiveResponse(304, {}); + NotifyLoaderClientOnComplete(net::OK); + RunUntilComplete(); + + EXPECT_TRUE(IsNetworkLoaderStarted()); + EXPECT_FALSE(client().has_received_redirect()); + EXPECT_TRUE(client().has_received_response()); + EXPECT_TRUE(client().has_received_completion()); + EXPECT_EQ(net::OK, client().completion_status().error_code); +} + +TEST_F(CORSURLLoaderTest, 304ForSimpleGet) { + const GURL origin("https://example.com"); + const GURL url("https://other.example.com/foo.png"); + const GURL new_url("https://other2.example.com/bar.png"); + + ResourceRequest request; + request.fetch_request_mode = mojom::FetchRequestMode::kCORS; + request.fetch_credentials_mode = mojom::FetchCredentialsMode::kOmit; + request.load_flags |= net::LOAD_DO_NOT_SAVE_COOKIES; + request.load_flags |= net::LOAD_DO_NOT_SEND_COOKIES; + request.load_flags |= net::LOAD_DO_NOT_SEND_AUTH_DATA; + request.method = "GET"; + request.url = url; + request.request_initiator = url::Origin::Create(origin); + CreateLoaderAndStart(request); + + // No preflight, no CORS response headers. + NotifyLoaderClientOnReceiveResponse(304, {}); + NotifyLoaderClientOnComplete(net::OK); + RunUntilComplete(); + + EXPECT_TRUE(IsNetworkLoaderStarted()); + EXPECT_FALSE(client().has_received_redirect()); + EXPECT_FALSE(client().has_received_response()); + EXPECT_TRUE(client().has_received_completion()); + EXPECT_EQ(net::ERR_FAILED, client().completion_status().error_code); +} + +TEST_F(CORSURLLoaderTest, 200ForSimpleRevalidation) { + const GURL origin("https://example.com"); + const GURL url("https://other.example.com/foo.png"); + const GURL new_url("https://other2.example.com/bar.png"); + + ResourceRequest request; + request.fetch_request_mode = mojom::FetchRequestMode::kCORS; + request.fetch_credentials_mode = mojom::FetchCredentialsMode::kOmit; + request.load_flags |= net::LOAD_DO_NOT_SAVE_COOKIES; + request.load_flags |= net::LOAD_DO_NOT_SEND_COOKIES; + request.load_flags |= net::LOAD_DO_NOT_SEND_AUTH_DATA; + request.method = "GET"; + request.url = url; + request.request_initiator = url::Origin::Create(origin); + request.headers.SetHeader("If-Modified-Since", "x"); + request.headers.SetHeader("If-None-Match", "y"); + request.headers.SetHeader("Cache-Control", "z"); + request.is_revalidating = true; + CreateLoaderAndStart(request); + + // No preflight, no CORS response headers. + NotifyLoaderClientOnReceiveResponse(200, {}); + NotifyLoaderClientOnComplete(net::OK); + RunUntilComplete(); + + EXPECT_TRUE(IsNetworkLoaderStarted()); + EXPECT_FALSE(client().has_received_redirect()); + EXPECT_FALSE(client().has_received_response()); + EXPECT_TRUE(client().has_received_completion()); + EXPECT_EQ(net::ERR_FAILED, client().completion_status().error_code); +} + +TEST_F(CORSURLLoaderTest, RevalidationAndPreflight) { + const GURL origin("https://example.com"); + const GURL url("https://other.example.com/foo.png"); + const GURL new_url("https://other2.example.com/bar.png"); + + ResourceRequest original_request; + original_request.fetch_request_mode = mojom::FetchRequestMode::kCORS; + original_request.fetch_credentials_mode = mojom::FetchCredentialsMode::kOmit; + original_request.load_flags |= net::LOAD_DO_NOT_SAVE_COOKIES; + original_request.load_flags |= net::LOAD_DO_NOT_SEND_COOKIES; + original_request.load_flags |= net::LOAD_DO_NOT_SEND_AUTH_DATA; + original_request.method = "GET"; + original_request.url = url; + original_request.request_initiator = url::Origin::Create(origin); + original_request.headers.SetHeader("If-Modified-Since", "x"); + original_request.headers.SetHeader("If-None-Match", "y"); + original_request.headers.SetHeader("Cache-Control", "z"); + original_request.headers.SetHeader("foo", "bar"); + original_request.is_revalidating = true; + CreateLoaderAndStart(original_request); + + // preflight request + EXPECT_EQ(1, num_created_loaders()); + EXPECT_EQ(GetRequest().url, url); + EXPECT_EQ(GetRequest().method, "OPTIONS"); + std::string preflight_request_headers; + EXPECT_TRUE(GetRequest().headers.GetHeader("access-control-request-headers", + &preflight_request_headers)); + EXPECT_EQ(preflight_request_headers, "foo"); + + NotifyLoaderClientOnReceiveResponse( + {"Access-Control-Allow-Origin: https://example.com", + "Access-Control-Allow-Headers: foo"}); + RunUntilCreateLoaderAndStartCalled(); + + // the actual request + EXPECT_EQ(2, num_created_loaders()); + EXPECT_EQ(GetRequest().url, url); + EXPECT_EQ(GetRequest().method, "GET"); + + NotifyLoaderClientOnReceiveResponse( + {"Access-Control-Allow-Origin: https://example.com"}); + NotifyLoaderClientOnComplete(net::OK); + RunUntilComplete(); + + EXPECT_FALSE(client().has_received_redirect()); + EXPECT_TRUE(client().has_received_response()); + ASSERT_TRUE(client().has_received_completion()); + EXPECT_EQ(net::OK, client().completion_status().error_code); +} + } // namespace } // namespace cors diff --git a/chromium/services/network/cors/preflight_controller.cc b/chromium/services/network/cors/preflight_controller.cc index 629fd22e6aa..37712042364 100644 --- a/chromium/services/network/cors/preflight_controller.cc +++ b/chromium/services/network/cors/preflight_controller.cc @@ -8,8 +8,10 @@ #include <vector> #include "base/bind.h" +#include "base/no_destructor.h" #include "base/strings/string_util.h" #include "base/strings/stringprintf.h" +#include "base/time/time.h" #include "net/base/load_flags.h" #include "net/http/http_request_headers.h" #include "services/network/public/cpp/cors/cors.h" @@ -40,19 +42,14 @@ base::Optional<std::string> GetHeaderString( // - sorted lexicographically // - byte-lowercased std::string CreateAccessControlRequestHeadersHeader( - const net::HttpRequestHeaders& headers) { - std::vector<std::string> filtered_headers; - for (const auto& header : headers.GetHeaderVector()) { - // Exclude CORS-safelisted headers. - if (cors::IsCORSSafelistedHeader(header.key, header.value)) - continue; - // Exclude the forbidden headers because they may be added by the user - // agent. They must be checked separately and rejected for - // JavaScript-initiated requests. - if (cors::IsForbiddenHeader(header.key)) - continue; - filtered_headers.push_back(base::ToLowerASCII(header.key)); - } + const net::HttpRequestHeaders& headers, + bool is_revalidating) { + // Exclude the forbidden headers because they may be added by the user + // agent. They must be checked separately and rejected for + // JavaScript-initiated requests. + std::vector<std::string> filtered_headers = + CORSUnsafeNotForbiddenRequestHeaderNames(headers.GetHeaderVector(), + is_revalidating); if (filtered_headers.empty()) return std::string(); @@ -88,18 +85,18 @@ std::unique_ptr<ResourceRequest> CreatePreflightRequest( preflight_request->load_flags |= net::LOAD_DO_NOT_SEND_AUTH_DATA; preflight_request->headers.SetHeader( - cors::header_names::kAccessControlRequestMethod, request.method); + header_names::kAccessControlRequestMethod, request.method); - std::string request_headers = - CreateAccessControlRequestHeadersHeader(request.headers); + std::string request_headers = CreateAccessControlRequestHeadersHeader( + request.headers, request.is_revalidating); if (!request_headers.empty()) { preflight_request->headers.SetHeader( - cors::header_names::kAccessControlRequestHeaders, request_headers); + header_names::kAccessControlRequestHeaders, request_headers); } if (request.is_external_request) { preflight_request->headers.SetHeader( - cors::header_names::kAccessControlRequestExternal, "true"); + header_names::kAccessControlRequestExternal, "true"); } DCHECK(request.request_initiator); @@ -127,16 +124,13 @@ std::unique_ptr<PreflightResult> CreatePreflightResult( base::Optional<CORSErrorStatus>* detected_error_status) { DCHECK(detected_error_status); - // TODO(toyoshim): Reflect --allow-file-access-from-files flag. *detected_error_status = CheckPreflightAccess( final_url, head.headers->response_code(), + GetHeaderString(head.headers, header_names::kAccessControlAllowOrigin), GetHeaderString(head.headers, - cors::header_names::kAccessControlAllowOrigin), - GetHeaderString(head.headers, - cors::header_names::kAccessControlAllowCredentials), + header_names::kAccessControlAllowCredentials), original_request.fetch_credentials_mode, - tainted ? url::Origin() : *original_request.request_initiator, - false /* allow_file_origin */); + tainted ? url::Origin() : *original_request.request_initiator); if (*detected_error_status) return nullptr; @@ -174,7 +168,8 @@ base::Optional<CORSErrorStatus> CheckPreflightResult( if (status) return status; - return result->EnsureAllowedCrossOriginHeaders(original_request.headers); + return result->EnsureAllowedCrossOriginHeaders( + original_request.headers, original_request.is_revalidating); } // TODO(toyoshim): Remove this class once the Network Service is enabled. @@ -184,8 +179,8 @@ base::Optional<CORSErrorStatus> CheckPreflightResult( class WrappedLegacyURLLoaderFactory final : public mojom::URLLoaderFactory { public: static WrappedLegacyURLLoaderFactory* GetSharedInstance() { - static WrappedLegacyURLLoaderFactory factory; - return &factory; + static base::NoDestructor<WrappedLegacyURLLoaderFactory> factory; + return &*factory; } ~WrappedLegacyURLLoaderFactory() override = default; @@ -277,7 +272,8 @@ class PreflightController::PreflightLoader final { std::move(completion_callback_) .Run(net::ERR_FAILED, - CORSErrorStatus(mojom::CORSError::kPreflightDisallowedRedirect)); + CORSErrorStatus(mojom::CORSError::kPreflightDisallowedRedirect), + base::nullopt); RemoveFromController(); // |this| is deleted here. @@ -287,6 +283,14 @@ class PreflightController::PreflightLoader final { const ResourceResponseHead& head) { FinalizeLoader(); + timing_info_.start_time = head.request_start; + timing_info_.finish_time = base::TimeTicks::Now(); + timing_info_.alpn_negotiated_protocol = head.alpn_negotiated_protocol; + timing_info_.connection_info = head.connection_info; + head.headers->GetNormalizedHeader("Timing-Allow-Origin", + &timing_info_.timing_allow_origin); + timing_info_.transfer_size = head.encoded_data_length; + base::Optional<CORSErrorStatus> detected_error_status; std::unique_ptr<PreflightResult> result = CreatePreflightResult( final_url, head, original_request_, tainted_, &detected_error_status); @@ -305,9 +309,12 @@ class PreflightController::PreflightLoader final { original_request_.url, std::move(result)); } + base::Optional<PreflightTimingInfo> timing_info; + if (!detected_error_status) + timing_info = std::move(timing_info_); std::move(completion_callback_) .Run(detected_error_status ? net::ERR_FAILED : net::OK, - detected_error_status); + detected_error_status, std::move(timing_info)); RemoveFromController(); // |this| is deleted here. @@ -321,7 +328,7 @@ class PreflightController::PreflightLoader final { const int error = loader_->NetError(); DCHECK_NE(error, net::OK); FinalizeLoader(); - std::move(completion_callback_).Run(error, base::nullopt); + std::move(completion_callback_).Run(error, base::nullopt, base::nullopt); RemoveFromController(); // |this| is deleted here. } @@ -343,6 +350,8 @@ class PreflightController::PreflightLoader final { // Holds SimpleURLLoader instance for the CORS-preflight request. std::unique_ptr<SimpleURLLoader> loader_; + PreflightTimingInfo timing_info_; + // Holds caller's information. PreflightController::CompletionCallback completion_callback_; const ResourceRequest original_request_; @@ -367,8 +376,8 @@ PreflightController::CreatePreflightRequestForTesting( // static PreflightController* PreflightController::GetDefaultController() { - static PreflightController controller; - return &controller; + static base::NoDestructor<PreflightController> controller; + return &*controller; } PreflightController::PreflightController() = default; @@ -388,8 +397,9 @@ void PreflightController::PerformPreflightCheck( if (!request.is_external_request && cache_.CheckIfRequestCanSkipPreflight( request.request_initiator->Serialize(), request.url, - request.fetch_credentials_mode, request.method, request.headers)) { - std::move(callback).Run(net::OK, base::nullopt); + request.fetch_credentials_mode, request.method, request.headers, + request.is_revalidating)) { + std::move(callback).Run(net::OK, base::nullopt, base::nullopt); return; } diff --git a/chromium/services/network/cors/preflight_controller.h b/chromium/services/network/cors/preflight_controller.h index 9aaa178d44e..779c8f11342 100644 --- a/chromium/services/network/cors/preflight_controller.h +++ b/chromium/services/network/cors/preflight_controller.h @@ -17,6 +17,7 @@ #include "services/network/public/cpp/cors/cors_error_status.h" #include "services/network/public/cpp/cors/preflight_cache.h" #include "services/network/public/cpp/cors/preflight_result.h" +#include "services/network/public/cpp/cors/preflight_timing_info.h" #include "services/network/public/cpp/resource_request.h" #include "services/network/public/mojom/fetch_api.mojom.h" #include "services/network/public/mojom/url_loader_factory.mojom.h" @@ -32,8 +33,11 @@ namespace cors { // See also https://crbug.com/803766 to check a design doc. class COMPONENT_EXPORT(NETWORK_SERVICE) PreflightController final { public: + // PreflightTimingInfo is provided only when a preflight request was made. using CompletionCallback = - base::OnceCallback<void(int, base::Optional<CORSErrorStatus>)>; + base::OnceCallback<void(int net_error, + base::Optional<CORSErrorStatus>, + base::Optional<PreflightTimingInfo>)>; // Creates a CORS-preflight ResourceRequest for a specified |request| for a // URL that is originally requested. static std::unique_ptr<ResourceRequest> CreatePreflightRequestForTesting( diff --git a/chromium/services/network/cors/preflight_controller_unittest.cc b/chromium/services/network/cors/preflight_controller_unittest.cc index a5cb266f80d..a3002222863 100644 --- a/chromium/services/network/cors/preflight_controller_unittest.cc +++ b/chromium/services/network/cors/preflight_controller_unittest.cc @@ -18,6 +18,7 @@ #include "net/traffic_annotation/network_traffic_annotation_test_helper.h" #include "services/network/network_service.h" #include "services/network/public/cpp/cors/cors.h" +#include "services/network/public/cpp/cors/preflight_timing_info.h" #include "services/network/public/mojom/network_service.mojom.h" #include "services/network/public/mojom/url_loader_factory.mojom.h" #include "testing/gtest/include/gtest/gtest.h" @@ -50,7 +51,7 @@ TEST(PreflightControllerCreatePreflightRequestTest, LexicographicalOrder) { EXPECT_EQ("null", header); EXPECT_TRUE(preflight->headers.GetHeader( - cors::header_names::kAccessControlRequestHeaders, &header)); + header_names::kAccessControlRequestHeaders, &header)); EXPECT_EQ("apple,content-type,kiwifruit,orange,strawberry", header); } @@ -73,7 +74,7 @@ TEST(PreflightControllerCreatePreflightRequestTest, ExcludeSimpleHeaders) { // left out in the preflight request. std::string header; EXPECT_FALSE(preflight->headers.GetHeader( - cors::header_names::kAccessControlRequestHeaders, &header)); + header_names::kAccessControlRequestHeaders, &header)); } TEST(PreflightControllerCreatePreflightRequestTest, Credentials) { @@ -108,7 +109,7 @@ TEST(PreflightControllerCreatePreflightRequestTest, // Empty list also; see comment in test above. std::string header; EXPECT_FALSE(preflight->headers.GetHeader( - cors::header_names::kAccessControlRequestHeaders, &header)); + header_names::kAccessControlRequestHeaders, &header)); } TEST(PreflightControllerCreatePreflightRequestTest, IncludeNonSimpleHeader) { @@ -123,7 +124,7 @@ TEST(PreflightControllerCreatePreflightRequestTest, IncludeNonSimpleHeader) { std::string header; EXPECT_TRUE(preflight->headers.GetHeader( - cors::header_names::kAccessControlRequestHeaders, &header)); + header_names::kAccessControlRequestHeaders, &header)); EXPECT_EQ("x-custom-header", header); } @@ -141,7 +142,7 @@ TEST(PreflightControllerCreatePreflightRequestTest, std::string header; EXPECT_TRUE(preflight->headers.GetHeader( - cors::header_names::kAccessControlRequestHeaders, &header)); + header_names::kAccessControlRequestHeaders, &header)); EXPECT_EQ("content-type", header); } @@ -157,7 +158,7 @@ TEST(PreflightControllerCreatePreflightRequestTest, ExcludeForbiddenHeaders) { std::string header; EXPECT_FALSE(preflight->headers.GetHeader( - cors::header_names::kAccessControlRequestHeaders, &header)); + header_names::kAccessControlRequestHeaders, &header)); } TEST(PreflightControllerCreatePreflightRequestTest, Tainted) { @@ -199,8 +200,10 @@ class PreflightControllerTest : public testing::Test { } protected: - void HandleRequestCompletion(int net_error, - base::Optional<CORSErrorStatus> status) { + void HandleRequestCompletion( + int net_error, + base::Optional<CORSErrorStatus> status, + base::Optional<PreflightTimingInfo> timing_info) { net_error_ = net_error; status_ = status; run_loop_->Quit(); @@ -256,7 +259,7 @@ class PreflightControllerTest : public testing::Test { net::test_server::ShouldHandle(request, "/tainted") ? url::Origin() : url::Origin::Create(test_server_.base_url()); - response->AddCustomHeader(cors::header_names::kAccessControlAllowOrigin, + response->AddCustomHeader(header_names::kAccessControlAllowOrigin, origin.Serialize()); response->AddCustomHeader(header_names::kAccessControlAllowMethods, "GET, OPTIONS"); diff --git a/chromium/services/network/cross_origin_read_blocking.cc b/chromium/services/network/cross_origin_read_blocking.cc index 084b6c95688..02b7e5f6ecc 100644 --- a/chromium/services/network/cross_origin_read_blocking.cc +++ b/chromium/services/network/cross_origin_read_blocking.cc @@ -819,8 +819,8 @@ bool CrossOriginReadBlocking::ResponseAnalyzer::ShouldReportBlockedResponse() void CrossOriginReadBlocking::ResponseAnalyzer::LogBytesReadForSniffing() { if (bytes_read_for_sniffing_ >= 0) { - UMA_HISTOGRAM_COUNTS("SiteIsolation.XSD.Browser.BytesReadForSniffing", - bytes_read_for_sniffing_); + UMA_HISTOGRAM_COUNTS_1M("SiteIsolation.XSD.Browser.BytesReadForSniffing", + bytes_read_for_sniffing_); } } diff --git a/chromium/services/network/cross_origin_read_blocking.h b/chromium/services/network/cross_origin_read_blocking.h index ea227c38c51..6344ecab626 100644 --- a/chromium/services/network/cross_origin_read_blocking.h +++ b/chromium/services/network/cross_origin_read_blocking.h @@ -234,6 +234,8 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) CrossOriginReadBlocking { // not allowed by actual CORS rules by ignoring 1) credentials and 2) // methods. Preflight requests don't matter here since they are not used to // decide whether to block a response or not on the client side. + // TODO(crbug.com/736308) Remove this check once the kOutOfBlinkCORS feature + // is shipped. static bool IsValidCorsHeaderSet(const url::Origin& frame_origin, const std::string& access_control_origin); FRIEND_TEST_ALL_PREFIXES(CrossOriginReadBlockingTest, IsValidCorsHeaderSet); diff --git a/chromium/services/network/data_pipe_element_reader_unittest.cc b/chromium/services/network/data_pipe_element_reader_unittest.cc index eae5fc4178e..cc13e67b7b7 100644 --- a/chromium/services/network/data_pipe_element_reader_unittest.cc +++ b/chromium/services/network/data_pipe_element_reader_unittest.cc @@ -143,7 +143,7 @@ TEST_F(DataPipeElementReaderTest, InitInterruptsInit) { EXPECT_FALSE(element_reader_.IsInMemory()); // Try to read from the body. - scoped_refptr<net::IOBufferWithSize> io_buffer(new net::IOBufferWithSize(10)); + auto io_buffer = base::MakeRefCounted<net::IOBufferWithSize>(10); net::TestCompletionCallback read_callback; EXPECT_EQ(net::ERR_IO_PENDING, element_reader_.Read(io_buffer.get(), io_buffer->size(), @@ -177,8 +177,7 @@ TEST_F(DataPipeElementReaderTest, InitInterruptsRead) { ASSERT_EQ(net::OK, first_init_callback.WaitForResult()); - scoped_refptr<net::IOBufferWithSize> first_io_buffer( - new net::IOBufferWithSize(10)); + auto first_io_buffer = base::MakeRefCounted<net::IOBufferWithSize>(10); net::TestCompletionCallback first_read_callback; EXPECT_EQ(net::ERR_IO_PENDING, element_reader_.Read(first_io_buffer.get(), first_io_buffer->size(), @@ -209,7 +208,7 @@ TEST_F(DataPipeElementReaderTest, InitInterruptsRead) { EXPECT_FALSE(element_reader_.IsInMemory()); // Try to read from the body. - scoped_refptr<net::IOBufferWithSize> io_buffer(new net::IOBufferWithSize(10)); + auto io_buffer = base::MakeRefCounted<net::IOBufferWithSize>(10); net::TestCompletionCallback second_read_callback; EXPECT_EQ(net::ERR_IO_PENDING, element_reader_.Read(io_buffer.get(), io_buffer->size(), diff --git a/chromium/services/network/expect_ct_reporter.h b/chromium/services/network/expect_ct_reporter.h index e8bb0fcb24b..e8220a3c5b7 100644 --- a/chromium/services/network/expect_ct_reporter.h +++ b/chromium/services/network/expect_ct_reporter.h @@ -41,7 +41,7 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) ExpectCTReporter const base::Closure& failure_callback); ~ExpectCTReporter() override; - // net::ExpectCTReporter: + // net::TransportSecurityState::ExpectCTReporter: void OnExpectCTFailed(const net::HostPortPair& host_port_pair, const GURL& report_uri, base::Time expiration, diff --git a/chromium/services/network/host_resolver.cc b/chromium/services/network/host_resolver.cc index ddc12783456..cf52398226a 100644 --- a/chromium/services/network/host_resolver.cc +++ b/chromium/services/network/host_resolver.cc @@ -12,6 +12,7 @@ #include "net/base/host_port_pair.h" #include "net/base/net_errors.h" #include "net/dns/host_resolver.h" +#include "net/dns/host_resolver_source.h" #include "net/log/net_log.h" #include "services/network/resolve_host_request.h" @@ -31,6 +32,8 @@ ConvertOptionalParameters( net::HostResolver::ResolveHostParameters parameters; parameters.dns_query_type = mojo_parameters->dns_query_type; parameters.initial_priority = mojo_parameters->initial_priority; + parameters.source = mojo_parameters->source; + parameters.allow_cached_response = mojo_parameters->allow_cached_response; parameters.include_canonical_name = mojo_parameters->include_canonical_name; parameters.loopback_only = mojo_parameters->loopback_only; parameters.is_speculative = mojo_parameters->is_speculative; @@ -66,6 +69,13 @@ void HostResolver::ResolveHost( const net::HostPortPair& host, mojom::ResolveHostParametersPtr optional_parameters, mojom::ResolveHostClientPtr response_client) { +#if !BUILDFLAG(ENABLE_MDNS) + // TODO(crbug.com/821021): Handle without crashing if we create restricted + // HostResolvers for passing to untrusted processes. + DCHECK(!optional_parameters || + optional_parameters->source != net::HostResolverSource::MULTICAST_DNS); +#endif // !BUILDFLAG(ENABLE_MDNS) + if (resolve_host_callback.Get()) resolve_host_callback.Get().Run(host.host()); diff --git a/chromium/services/network/host_resolver_unittest.cc b/chromium/services/network/host_resolver_unittest.cc index 11f9965ab0c..501a0429dd8 100644 --- a/chromium/services/network/host_resolver_unittest.cc +++ b/chromium/services/network/host_resolver_unittest.cc @@ -11,6 +11,7 @@ #include "base/run_loop.h" #include "base/test/bind_test_util.h" #include "base/test/scoped_task_environment.h" +#include "base/test/simple_test_tick_clock.h" #include "mojo/public/cpp/bindings/binding.h" #include "mojo/public/cpp/bindings/interface_request.h" #include "net/base/address_list.h" @@ -198,6 +199,280 @@ TEST_F(HostResolverTest, InitialPriority) { EXPECT_EQ(net::HIGHEST, inner_resolver->last_request_priority()); } +// Make requests specifying a source for host resolution and ensure the correct +// source is requested from the inner resolver. +TEST_F(HostResolverTest, Source) { + constexpr char kDomain[] = "example.com"; + constexpr char kAnyResult[] = "1.2.3.4"; + constexpr char kSystemResult[] = "127.0.0.1"; + constexpr char kDnsResult[] = "168.100.12.23"; + constexpr char kMdnsResult[] = "200.1.2.3"; + auto inner_resolver = std::make_unique<net::MockHostResolver>(); + inner_resolver->rules_map()[net::HostResolverSource::ANY]->AddRule( + kDomain, kAnyResult); + inner_resolver->rules_map()[net::HostResolverSource::SYSTEM]->AddRule( + kDomain, kSystemResult); + inner_resolver->rules_map()[net::HostResolverSource::DNS]->AddRule( + kDomain, kDnsResult); + inner_resolver->rules_map()[net::HostResolverSource::MULTICAST_DNS]->AddRule( + kDomain, kMdnsResult); + + net::NetLog net_log; + HostResolver resolver(inner_resolver.get(), &net_log); + + base::RunLoop any_run_loop; + mojom::ResolveHostClientPtr any_client_ptr; + TestResolveHostClient any_client(&any_client_ptr, &any_run_loop); + mojom::ResolveHostParametersPtr any_parameters = + mojom::ResolveHostParameters::New(); + any_parameters->source = net::HostResolverSource::ANY; + resolver.ResolveHost(net::HostPortPair(kDomain, 80), + std::move(any_parameters), std::move(any_client_ptr)); + + base::RunLoop system_run_loop; + mojom::ResolveHostClientPtr system_client_ptr; + TestResolveHostClient system_client(&system_client_ptr, &system_run_loop); + mojom::ResolveHostParametersPtr system_parameters = + mojom::ResolveHostParameters::New(); + system_parameters->source = net::HostResolverSource::SYSTEM; + resolver.ResolveHost(net::HostPortPair(kDomain, 80), + std::move(system_parameters), + std::move(system_client_ptr)); + + base::RunLoop dns_run_loop; + mojom::ResolveHostClientPtr dns_client_ptr; + TestResolveHostClient dns_client(&dns_client_ptr, &dns_run_loop); + mojom::ResolveHostParametersPtr dns_parameters = + mojom::ResolveHostParameters::New(); + dns_parameters->source = net::HostResolverSource::DNS; + resolver.ResolveHost(net::HostPortPair(kDomain, 80), + std::move(dns_parameters), std::move(dns_client_ptr)); + + any_run_loop.Run(); + system_run_loop.Run(); + dns_run_loop.Run(); + + EXPECT_EQ(net::OK, any_client.result_error()); + EXPECT_THAT(any_client.result_addresses().value().endpoints(), + testing::ElementsAre(CreateExpectedEndPoint(kAnyResult, 80))); + EXPECT_EQ(net::OK, system_client.result_error()); + EXPECT_THAT(system_client.result_addresses().value().endpoints(), + testing::ElementsAre(CreateExpectedEndPoint(kSystemResult, 80))); + EXPECT_EQ(net::OK, dns_client.result_error()); + EXPECT_THAT(dns_client.result_addresses().value().endpoints(), + testing::ElementsAre(CreateExpectedEndPoint(kDnsResult, 80))); + +#if BUILDFLAG(ENABLE_MDNS) + base::RunLoop mdns_run_loop; + mojom::ResolveHostClientPtr mdns_client_ptr; + TestResolveHostClient mdns_client(&mdns_client_ptr, &mdns_run_loop); + mojom::ResolveHostParametersPtr mdns_parameters = + mojom::ResolveHostParameters::New(); + mdns_parameters->source = net::HostResolverSource::MULTICAST_DNS; + resolver.ResolveHost(net::HostPortPair(kDomain, 80), + std::move(mdns_parameters), std::move(mdns_client_ptr)); + + mdns_run_loop.Run(); + + EXPECT_EQ(net::OK, mdns_client.result_error()); + EXPECT_THAT(mdns_client.result_addresses().value().endpoints(), + testing::ElementsAre(CreateExpectedEndPoint(kMdnsResult, 80))); +#endif // BUILDFLAG(ENABLE_MDNS) +} + +// Test that cached results are properly keyed by requested source. +TEST_F(HostResolverTest, SeparateCacheBySource) { + constexpr char kDomain[] = "example.com"; + constexpr char kAnyResultOriginal[] = "1.2.3.4"; + constexpr char kSystemResultOriginal[] = "127.0.0.1"; + auto inner_resolver = std::make_unique<net::MockCachingHostResolver>(); + inner_resolver->rules_map()[net::HostResolverSource::ANY]->AddRule( + kDomain, kAnyResultOriginal); + inner_resolver->rules_map()[net::HostResolverSource::SYSTEM]->AddRule( + kDomain, kSystemResultOriginal); + base::SimpleTestTickClock test_clock; + inner_resolver->set_tick_clock(&test_clock); + + net::NetLog net_log; + HostResolver resolver(inner_resolver.get(), &net_log); + + // Load SYSTEM result into cache. + base::RunLoop system_run_loop; + mojom::ResolveHostClientPtr system_client_ptr; + TestResolveHostClient system_client(&system_client_ptr, &system_run_loop); + mojom::ResolveHostParametersPtr system_parameters = + mojom::ResolveHostParameters::New(); + system_parameters->source = net::HostResolverSource::SYSTEM; + resolver.ResolveHost(net::HostPortPair(kDomain, 80), + std::move(system_parameters), + std::move(system_client_ptr)); + system_run_loop.Run(); + ASSERT_EQ(net::OK, system_client.result_error()); + EXPECT_THAT( + system_client.result_addresses().value().endpoints(), + testing::ElementsAre(CreateExpectedEndPoint(kSystemResultOriginal, 80))); + + // Change |inner_resolver| rules to ensure results are coming from cache or + // not based on whether they resolve to the old or new value. + constexpr char kAnyResultFresh[] = "111.222.1.1"; + constexpr char kSystemResultFresh[] = "111.222.1.2"; + inner_resolver->rules_map()[net::HostResolverSource::ANY]->ClearRules(); + inner_resolver->rules_map()[net::HostResolverSource::ANY]->AddRule( + kDomain, kAnyResultFresh); + inner_resolver->rules_map()[net::HostResolverSource::SYSTEM]->ClearRules(); + inner_resolver->rules_map()[net::HostResolverSource::SYSTEM]->AddRule( + kDomain, kSystemResultFresh); + + base::RunLoop cached_run_loop; + mojom::ResolveHostClientPtr cached_client_ptr; + TestResolveHostClient cached_client(&cached_client_ptr, &cached_run_loop); + mojom::ResolveHostParametersPtr cached_parameters = + mojom::ResolveHostParameters::New(); + cached_parameters->source = net::HostResolverSource::SYSTEM; + resolver.ResolveHost(net::HostPortPair(kDomain, 80), + std::move(cached_parameters), + std::move(cached_client_ptr)); + + base::RunLoop uncached_run_loop; + mojom::ResolveHostClientPtr uncached_client_ptr; + TestResolveHostClient uncached_client(&uncached_client_ptr, + &uncached_run_loop); + mojom::ResolveHostParametersPtr uncached_parameters = + mojom::ResolveHostParameters::New(); + uncached_parameters->source = net::HostResolverSource::ANY; + resolver.ResolveHost(net::HostPortPair(kDomain, 80), + std::move(uncached_parameters), + std::move(uncached_client_ptr)); + + cached_run_loop.Run(); + uncached_run_loop.Run(); + + EXPECT_EQ(net::OK, cached_client.result_error()); + EXPECT_THAT( + cached_client.result_addresses().value().endpoints(), + testing::ElementsAre(CreateExpectedEndPoint(kSystemResultOriginal, 80))); + EXPECT_EQ(net::OK, uncached_client.result_error()); + EXPECT_THAT( + uncached_client.result_addresses().value().endpoints(), + testing::ElementsAre(CreateExpectedEndPoint(kAnyResultFresh, 80))); +} + +TEST_F(HostResolverTest, CacheDisabled) { + constexpr char kDomain[] = "example.com"; + constexpr char kResultOriginal[] = "1.2.3.4"; + auto inner_resolver = std::make_unique<net::MockCachingHostResolver>(); + inner_resolver->rules()->AddRule(kDomain, kResultOriginal); + base::SimpleTestTickClock test_clock; + inner_resolver->set_tick_clock(&test_clock); + + net::NetLog net_log; + HostResolver resolver(inner_resolver.get(), &net_log); + + // Load result into cache. + base::RunLoop run_loop; + mojom::ResolveHostClientPtr client_ptr; + TestResolveHostClient client(&client_ptr, &run_loop); + resolver.ResolveHost(net::HostPortPair(kDomain, 80), nullptr, + std::move(client_ptr)); + run_loop.Run(); + ASSERT_EQ(net::OK, client.result_error()); + EXPECT_THAT( + client.result_addresses().value().endpoints(), + testing::ElementsAre(CreateExpectedEndPoint(kResultOriginal, 80))); + + // Change |inner_resolver| rules to ensure results are coming from cache or + // not based on whether they resolve to the old or new value. + constexpr char kResultFresh[] = "111.222.1.1"; + inner_resolver->rules()->ClearRules(); + inner_resolver->rules()->AddRule(kDomain, kResultFresh); + + base::RunLoop cached_run_loop; + mojom::ResolveHostClientPtr cached_client_ptr; + TestResolveHostClient cached_client(&cached_client_ptr, &cached_run_loop); + mojom::ResolveHostParametersPtr cached_parameters = + mojom::ResolveHostParameters::New(); + cached_parameters->allow_cached_response = true; + resolver.ResolveHost(net::HostPortPair(kDomain, 80), + std::move(cached_parameters), + std::move(cached_client_ptr)); + cached_run_loop.Run(); + + EXPECT_EQ(net::OK, cached_client.result_error()); + EXPECT_THAT( + cached_client.result_addresses().value().endpoints(), + testing::ElementsAre(CreateExpectedEndPoint(kResultOriginal, 80))); + + base::RunLoop uncached_run_loop; + mojom::ResolveHostClientPtr uncached_client_ptr; + TestResolveHostClient uncached_client(&uncached_client_ptr, + &uncached_run_loop); + mojom::ResolveHostParametersPtr uncached_parameters = + mojom::ResolveHostParameters::New(); + uncached_parameters->allow_cached_response = false; + resolver.ResolveHost(net::HostPortPair(kDomain, 80), + std::move(uncached_parameters), + std::move(uncached_client_ptr)); + uncached_run_loop.Run(); + + EXPECT_EQ(net::OK, uncached_client.result_error()); + EXPECT_THAT(uncached_client.result_addresses().value().endpoints(), + testing::ElementsAre(CreateExpectedEndPoint(kResultFresh, 80))); +} + +// Test for a resolve with a result only in the cache and error if the cache is +// disabled. +TEST_F(HostResolverTest, CacheDisabled_ErrorResults) { + constexpr char kDomain[] = "example.com"; + constexpr char kResult[] = "1.2.3.4"; + auto inner_resolver = std::make_unique<net::MockCachingHostResolver>(); + inner_resolver->rules()->AddRule(kDomain, kResult); + base::SimpleTestTickClock test_clock; + inner_resolver->set_tick_clock(&test_clock); + + net::NetLog net_log; + HostResolver resolver(inner_resolver.get(), &net_log); + + // Load initial result into cache. + base::RunLoop run_loop; + mojom::ResolveHostClientPtr client_ptr; + TestResolveHostClient client(&client_ptr, &run_loop); + resolver.ResolveHost(net::HostPortPair(kDomain, 80), nullptr, + std::move(client_ptr)); + run_loop.Run(); + ASSERT_EQ(net::OK, client.result_error()); + + // Change |inner_resolver| rules to an error. + inner_resolver->rules()->ClearRules(); + inner_resolver->rules()->AddSimulatedFailure(kDomain); + + // Resolves for |kFreshErrorDomain| should result in error only when cache is + // disabled because success was cached. + base::RunLoop cached_run_loop; + mojom::ResolveHostClientPtr cached_client_ptr; + TestResolveHostClient cached_client(&cached_client_ptr, &cached_run_loop); + mojom::ResolveHostParametersPtr cached_parameters = + mojom::ResolveHostParameters::New(); + cached_parameters->allow_cached_response = true; + resolver.ResolveHost(net::HostPortPair(kDomain, 80), + std::move(cached_parameters), + std::move(cached_client_ptr)); + cached_run_loop.Run(); + EXPECT_EQ(net::OK, cached_client.result_error()); + + base::RunLoop uncached_run_loop; + mojom::ResolveHostClientPtr uncached_client_ptr; + TestResolveHostClient uncached_client(&uncached_client_ptr, + &uncached_run_loop); + mojom::ResolveHostParametersPtr uncached_parameters = + mojom::ResolveHostParameters::New(); + uncached_parameters->allow_cached_response = false; + resolver.ResolveHost(net::HostPortPair(kDomain, 80), + std::move(uncached_parameters), + std::move(uncached_client_ptr)); + uncached_run_loop.Run(); + EXPECT_EQ(net::ERR_NAME_NOT_RESOLVED, uncached_client.result_error()); +} + TEST_F(HostResolverTest, IncludeCanonicalName) { auto inner_resolver = std::make_unique<net::MockHostResolver>(); inner_resolver->rules()->AddRuleWithFlags("example.com", "123.0.12.24", diff --git a/chromium/services/network/http_cache_data_counter.cc b/chromium/services/network/http_cache_data_counter.cc index 1300b639927..0d2d558441c 100644 --- a/chromium/services/network/http_cache_data_counter.cc +++ b/chromium/services/network/http_cache_data_counter.cc @@ -78,7 +78,7 @@ void HttpCacheDataCounter::GotBackend( return; } - int rv; + int64_t rv; disk_cache::Backend* cache = *backend; // Handle this here since some backends would DCHECK on this. @@ -106,7 +106,7 @@ void HttpCacheDataCounter::GotBackend( } void HttpCacheDataCounter::PostResult(bool is_upper_limit, - int result_or_error) { + int64_t result_or_error) { base::SequencedTaskRunnerHandle::Get()->PostTask( FROM_HERE, base::BindOnce(std::move(callback_), this, is_upper_limit, result_or_error)); diff --git a/chromium/services/network/http_cache_data_counter.h b/chromium/services/network/http_cache_data_counter.h index 686435d4667..96cd201d972 100644 --- a/chromium/services/network/http_cache_data_counter.h +++ b/chromium/services/network/http_cache_data_counter.h @@ -57,7 +57,7 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) HttpCacheDataCounter { void GotBackend(std::unique_ptr<disk_cache::Backend*> backend, int error_code); - void PostResult(bool is_upper_limit, int result_or_error); + void PostResult(bool is_upper_limit, int64_t result_or_error); base::WeakPtr<HttpCacheDataCounter> GetWeakPtr() { return weak_factory_.GetWeakPtr(); diff --git a/chromium/services/network/net_log_capture_mode_type_converter.cc b/chromium/services/network/net_log_capture_mode_type_converter.cc new file mode 100644 index 00000000000..b9447f40bdc --- /dev/null +++ b/chromium/services/network/net_log_capture_mode_type_converter.cc @@ -0,0 +1,24 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "services/network/net_log_capture_mode_type_converter.h" + +namespace mojo { + +net::NetLogCaptureMode +TypeConverter<net::NetLogCaptureMode, network::mojom::NetLogCaptureMode>:: + Convert(const network::mojom::NetLogCaptureMode capture_mode) { + switch (capture_mode) { + case network::mojom::NetLogCaptureMode::DEFAULT: + return net::NetLogCaptureMode::Default(); + case network::mojom::NetLogCaptureMode::INCLUDE_COOKIES_AND_CREDENTIALS: + return net::NetLogCaptureMode::IncludeCookiesAndCredentials(); + case network::mojom::NetLogCaptureMode::INCLUDE_SOCKET_BYTES: + return net::NetLogCaptureMode::IncludeSocketBytes(); + } + NOTREACHED(); + return net::NetLogCaptureMode::Default(); +} + +} // namespace mojo diff --git a/chromium/services/network/net_log_capture_mode_type_converter.h b/chromium/services/network/net_log_capture_mode_type_converter.h new file mode 100644 index 00000000000..3a10c80e468 --- /dev/null +++ b/chromium/services/network/net_log_capture_mode_type_converter.h @@ -0,0 +1,24 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef SERVICES_NETWORK_NET_LOG_CAPTURE_MODE_TYPE_CONVERTER_H_ +#define SERVICES_NETWORK_NET_LOG_CAPTURE_MODE_TYPE_CONVERTER_H_ + +#include "mojo/public/cpp/bindings/type_converter.h" +#include "net/log/net_log_capture_mode.h" +#include "services/network/public/mojom/net_log.mojom.h" + +namespace mojo { + +// Converts a network::mojom::NetLogCaptureMode to a net::NetLogCaptureMode. +template <> +struct TypeConverter<net::NetLogCaptureMode, + network::mojom::NetLogCaptureMode> { + static net::NetLogCaptureMode Convert( + network::mojom::NetLogCaptureMode capture_mode); +}; + +} // namespace mojo + +#endif // SERVICES_NETWORK_NET_LOG_CAPTURE_MODE_TYPE_CONVERTER_H_ diff --git a/chromium/services/network/mojo_net_log.cc b/chromium/services/network/net_log_exporter.cc index 42f6962bcac..7f993e23d26 100644 --- a/chromium/services/network/mojo_net_log.cc +++ b/chromium/services/network/net_log_exporter.cc @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "services/network/mojo_net_log.h" +#include "services/network/net_log_exporter.h" #include "base/callback.h" #include "base/command_line.h" @@ -12,36 +12,17 @@ #include "base/task/post_task.h" #include "base/task/task_traits.h" #include "base/values.h" +#include "mojo/public/cpp/bindings/type_converter.h" #include "net/log/file_net_log_observer.h" #include "net/log/net_log_util.h" #include "net/url_request/url_request_context.h" +#include "services/network/net_log_capture_mode_type_converter.h" #include "services/network/network_context.h" #include "services/network/network_service.h" #include "services/network/public/cpp/network_switches.h" namespace network { -MojoNetLog::MojoNetLog() = default; -MojoNetLog::~MojoNetLog() = default; - -void MojoNetLog::ShutDown() { - if (file_net_log_observer_) { - file_net_log_observer_->StopObserving(nullptr /*polled_data*/, - base::OnceClosure()); - } -} - -void MojoNetLog::ObserveFileWithConstants(base::File file, - base::Value constants) { - // TODO(eroman): Should get capture mode from the command line. - net::NetLogCaptureMode capture_mode = - net::NetLogCaptureMode::IncludeCookiesAndCredentials(); - - file_net_log_observer_ = net::FileNetLogObserver::CreateUnboundedPreExisting( - std::move(file), std::make_unique<base::Value>(std::move(constants))); - file_net_log_observer_->StartObserving(this, capture_mode); -} - NetLogExporter::NetLogExporter(NetworkContext* network_context) : network_context_(network_context), state_(STATE_IDLE) {} @@ -60,7 +41,7 @@ NetLogExporter::~NetLogExporter() { void NetLogExporter::Start(base::File destination, base::Value extra_constants, - NetLogExporter::CaptureMode capture_mode, + mojom::NetLogCaptureMode capture_mode, uint64_t max_file_size, StartCallback callback) { DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); @@ -76,18 +57,8 @@ void NetLogExporter::Start(base::File destination, // be carefully controlled. destination_ = std::move(destination); - net::NetLogCaptureMode net_capture_mode; - switch (capture_mode) { - case NetLogExporter::CaptureMode::DEFAULT: - net_capture_mode = net::NetLogCaptureMode::Default(); - break; - case NetLogExporter::CaptureMode::INCLUDE_COOKIES_AND_CREDENTIALS: - net_capture_mode = net::NetLogCaptureMode::IncludeCookiesAndCredentials(); - break; - case NetLogExporter::CaptureMode::INCLUDE_SOCKET_BYTES: - net_capture_mode = net::NetLogCaptureMode::IncludeSocketBytes(); - break; - } + net::NetLogCaptureMode net_capture_mode = + mojo::ConvertTo<net::NetLogCaptureMode>(capture_mode); state_ = STATE_WAITING_DIR; static_assert(kUnlimitedFileSize == net::FileNetLogObserver::kNoLimit, diff --git a/chromium/services/network/mojo_net_log.h b/chromium/services/network/net_log_exporter.h index 2753457f339..ce73afa1a5c 100644 --- a/chromium/services/network/mojo_net_log.h +++ b/chromium/services/network/net_log_exporter.h @@ -2,14 +2,17 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#ifndef SERVICES_NETWORK_MOJO_NET_LOG_H_ -#define SERVICES_NETWORK_MOJO_NET_LOG_H_ +#ifndef SERVICES_NETWORK_NET_LOG_EXPORTER_H_ +#define SERVICES_NETWORK_NET_LOG_EXPORTER_H_ #include <memory> +#include "base/files/file.h" #include "base/macros.h" #include "base/threading/thread_checker.h" +#include "base/values.h" #include "net/log/net_log.h" +#include "services/network/public/mojom/net_log.mojom.h" #include "services/network/public/mojom/network_service.mojom.h" namespace net { @@ -20,28 +23,6 @@ namespace network { class NetworkContext; -// NetLog used by NetworkService when it owns the NetLog, rather than when a -// pre-existing one is passed in to its constructor. -// -// Currently only provides --log-net-log support. -class MojoNetLog : public net::NetLog { - public: - MojoNetLog(); - ~MojoNetLog() override; - - // Finalizes the logfile created by any call to ObserveFileWithConstants(). - void ShutDown(); - - // If specified by the command line, stream network events (NetLog) to a - // file on disk. This will last for the duration of the process. - void ObserveFileWithConstants(base::File file, base::Value constants); - - private: - std::unique_ptr<net::FileNetLogObserver> file_net_log_observer_; - - DISALLOW_COPY_AND_ASSIGN(MojoNetLog); -}; - // API implementation for exporting ongoing netlogs. class COMPONENT_EXPORT(NETWORK_SERVICE) NetLogExporter : public mojom::NetLogExporter, @@ -54,7 +35,7 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) NetLogExporter void Start(base::File destination, base::Value extra_constants, - NetLogExporter::CaptureMode capture_mode, + network::mojom::NetLogCaptureMode capture_mode, uint64_t max_file_size, StartCallback callback) override; void Stop(base::Value polled_data, StopCallback callback) override; @@ -105,4 +86,4 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) NetLogExporter } // namespace network -#endif // SERVICES_NETWORK_MOJO_NET_LOG_H_ +#endif // SERVICES_NETWORK_NET_LOG_EXPORTER_H_ diff --git a/chromium/services/network/network_context.cc b/chromium/services/network/network_context.cc index 91540a3a6b1..8934556a91d 100644 --- a/chromium/services/network/network_context.cc +++ b/chromium/services/network/network_context.cc @@ -7,6 +7,7 @@ #include <memory> #include <utility> +#include "base/base64.h" #include "base/command_line.h" #include "base/containers/unique_ptr_adapters.h" #include "base/debug/dump_without_crashing.h" @@ -18,6 +19,7 @@ #include "base/sequenced_task_runner.h" #include "base/stl_util.h" #include "base/strings/string_number_conversions.h" +#include "base/strings/utf_string_conversions.h" #include "base/task/post_task.h" #include "base/task/task_traits.h" #include "build/build_config.h" @@ -70,9 +72,10 @@ #include "services/network/host_resolver.h" #include "services/network/http_server_properties_pref_delegate.h" #include "services/network/ignore_errors_cert_verifier.h" -#include "services/network/mojo_net_log.h" +#include "services/network/net_log_exporter.h" #include "services/network/network_service.h" #include "services/network/network_service_network_delegate.h" +#include "services/network/network_service_proxy_delegate.h" #include "services/network/p2p/socket_manager.h" #include "services/network/proxy_config_service_mojo.h" #include "services/network/proxy_lookup_request.h" @@ -112,10 +115,39 @@ #include "net/cert_net/cert_net_fetcher_impl.h" #endif +#if defined(OS_ANDROID) +#include "base/android/application_status_listener.h" +#endif + namespace network { namespace { +// A Base-64 encoded DER certificate for use in test Expect-CT reports. The +// contents of the certificate don't matter. +const char kTestReportCert[] = + "MIIDvzCCAqegAwIBAgIBAzANBgkqhkiG9w0BAQsFADBjMQswCQYDVQQGEwJVUzET" + "MBEGA1UECAwKQ2FsaWZvcm5pYTEWMBQGA1UEBwwNTW91bnRhaW4gVmlldzEQMA4G" + "A1UECgwHVGVzdCBDQTEVMBMGA1UEAwwMVGVzdCBSb290IENBMB4XDTE3MDYwNTE3" + "MTA0NloXDTI3MDYwMzE3MTA0NlowYDELMAkGA1UEBhMCVVMxEzARBgNVBAgMCkNh" + "bGlmb3JuaWExFjAUBgNVBAcMDU1vdW50YWluIFZpZXcxEDAOBgNVBAoMB1Rlc3Qg" + "Q0ExEjAQBgNVBAMMCTEyNy4wLjAuMTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCC" + "AQoCggEBALS/0pcz5RNbd2W9cxp1KJtHWea3MOhGM21YW9ofCv/k5C3yHfiJ6GQu" + "9sPN16OO1/fN59gOEMPnVtL85ebTTuL/gk0YY4ewo97a7wo3e6y1t0PO8gc53xTp" + "w6RBPn5oRzSbe2HEGOYTzrO0puC6A+7k6+eq9G2+l1uqBpdQAdB4uNaSsOTiuUOI" + "ta4UZH1ScNQFHAkl1eJPyaiC20Exw75EbwvU/b/B7tlivzuPtQDI0d9dShOtceRL" + "X9HZckyD2JNAv2zNL2YOBNa5QygkySX9WXD+PfKpCk7Cm8TenldeXRYl5ni2REkp" + "nfa/dPuF1g3xZVjyK9aPEEnIAC2I4i0CAwEAAaOBgDB+MAwGA1UdEwEB/wQCMAAw" + "HQYDVR0OBBYEFODc4C8HiHQ6n9Mwo3GK+dal5aZTMB8GA1UdIwQYMBaAFJsmC4qY" + "qbsduR8c4xpAM+2OF4irMB0GA1UdJQQWMBQGCCsGAQUFBwMBBggrBgEFBQcDAjAP" + "BgNVHREECDAGhwR/AAABMA0GCSqGSIb3DQEBCwUAA4IBAQB6FEQuUDRcC5jkX3aZ" + "uuTeZEqMVL7JXgvgFqzXsPb8zIdmxr/tEDfwXx2qDf2Dpxts7Fq4vqUwimK4qV3K" + "7heLnWV2+FBvV1eeSfZ7AQj+SURkdlyo42r41+t13QUf+Z0ftR9266LSWLKrukeI" + "Mxk73hOkm/u8enhTd00dy/FN9dOFBFHseVMspWNxIkdRILgOmiyfQNRgxNYdOf0e" + "EfELR8Hn6WjZ8wAbvO4p7RTrzu1c/RZ0M+NLkID56Brbl70GC2h5681LPwAOaZ7/" + "mWQ5kekSyJjmLfF12b+h9RVAt5MrXZgk2vNujssgGf4nbWh4KZyQ6qrs778ZdDLm" + "yfUn"; + net::CertVerifier* g_cert_verifier_for_testing = nullptr; // A CertVerifier that forwards all requests to |g_cert_verifier_for_testing|. @@ -224,6 +256,37 @@ void OnClearedChannelIds(net::SSLConfigService* ssl_config_service, std::move(callback).Run(); } +#if defined(OS_ANDROID) +class NetworkContextApplicationStatusListener + : public base::android::ApplicationStatusListener { + public: + // base::android::ApplicationStatusListener implementation: + void SetCallback(const ApplicationStateChangeCallback& callback) override { + DCHECK(!callback_); + DCHECK(callback); + callback_ = callback; + } + + void Notify(base::android::ApplicationState state) override { + if (callback_) + callback_.Run(state); + } + + private: + ApplicationStateChangeCallback callback_; +}; +#endif + +std::string HashesToBase64String(const net::HashValueVector& hashes) { + std::string str; + for (size_t i = 0; i != hashes.size(); ++i) { + if (i != 0) + str += ","; + str += hashes[i].ToString(); + } + return str; +} + } // namespace constexpr bool NetworkContext::enable_resource_scheduler_; @@ -242,12 +305,16 @@ class NetworkContext::ContextNetworkDelegate std::unique_ptr<net::NetworkDelegate> nested_network_delegate, bool enable_referrers, bool validate_referrer_policy_on_initial_request, + mojom::ProxyErrorClientPtrInfo proxy_error_client_info, NetworkContext* network_context) : LayeredNetworkDelegate(std::move(nested_network_delegate)), enable_referrers_(enable_referrers), validate_referrer_policy_on_initial_request_( validate_referrer_policy_on_initial_request), - network_context_(network_context) {} + network_context_(network_context) { + if (proxy_error_client_info) + proxy_error_client_.Bind(std::move(proxy_error_client_info)); + } ~ContextNetworkDelegate() override {} @@ -284,6 +351,8 @@ class NetworkContext::ContextNetworkDelegate "Net.HttpRequestCompletionErrorCodes.MainFrame", -net_error); } } + + ForwardProxyErrors(net_error); } bool OnCancelURLRequestWithPolicyViolatingReferrerHeaderInternal( @@ -309,13 +378,72 @@ class NetworkContext::ContextNetworkDelegate return true; } + bool OnCanGetCookiesInternal(const net::URLRequest& request, + const net::CookieList& cookie_list, + bool allowed_from_caller) override { + return allowed_from_caller && + network_context_->cookie_manager() + ->cookie_settings() + .IsCookieAccessAllowed(request.url(), + request.site_for_cookies()); + } + + bool OnCanSetCookieInternal(const net::URLRequest& request, + const net::CanonicalCookie& cookie, + net::CookieOptions* options, + bool allowed_from_caller) override { + return allowed_from_caller && + network_context_->cookie_manager() + ->cookie_settings() + .IsCookieAccessAllowed(request.url(), + request.site_for_cookies()); + } + + bool OnCanEnablePrivacyModeInternal( + const GURL& url, + const GURL& site_for_cookies) const override { + return !network_context_->cookie_manager() + ->cookie_settings() + .IsCookieAccessAllowed(url, site_for_cookies); + } + + void OnResponseStartedInternal(net::URLRequest* request, + int net_error) override { + ForwardProxyErrors(net_error); + } + + void OnPACScriptErrorInternal(int line_number, + const base::string16& error) override { + if (!proxy_error_client_) + return; + + proxy_error_client_->OnPACScriptError(line_number, + base::UTF16ToUTF8(error)); + } + void set_enable_referrers(bool enable_referrers) { enable_referrers_ = enable_referrers; } private: + void ForwardProxyErrors(int net_error) { + if (!proxy_error_client_) + return; + + // TODO(https://crbug.com/876848): Provide justification for the currently + // enumerated errors. + switch (net_error) { + case net::ERR_PROXY_AUTH_UNSUPPORTED: + case net::ERR_PROXY_CONNECTION_FAILED: + case net::ERR_TUNNEL_CONNECTION_FAILED: + proxy_error_client_->OnRequestMaybeFailedDueToProxySettings(net_error); + break; + } + } + bool enable_referrers_; bool validate_referrer_policy_on_initial_request_; + mojom::ProxyErrorClientPtr proxy_error_client_; NetworkContext* network_context_; DISALLOW_COPY_AND_ASSIGN(ContextNetworkDelegate); @@ -327,8 +455,13 @@ NetworkContext::NetworkContext( mojom::NetworkContextParamsPtr params, OnConnectionCloseCallback on_connection_close_callback) : network_service_(network_service), + url_request_context_(nullptr), params_(std::move(params)), on_connection_close_callback_(std::move(on_connection_close_callback)), +#if defined(OS_ANDROID) + app_status_listener_( + std::make_unique<NetworkContextApplicationStatusListener>()), +#endif binding_(this, std::move(request)) { url_request_context_owner_ = MakeURLRequestContext(); url_request_context_ = url_request_context_owner_.url_request_context.get(); @@ -357,7 +490,12 @@ NetworkContext::NetworkContext( mojom::NetworkContextParamsPtr params, std::unique_ptr<URLRequestContextBuilderMojo> builder) : network_service_(network_service), + url_request_context_(nullptr), params_(std::move(params)), +#if defined(OS_ANDROID) + app_status_listener_( + std::make_unique<NetworkContextApplicationStatusListener>()), +#endif binding_(this, std::move(request)) { url_request_context_owner_ = ApplyContextParamsToBuilder(builder.get()); url_request_context_ = url_request_context_owner_.url_request_context.get(); @@ -374,6 +512,10 @@ NetworkContext::NetworkContext(NetworkService* network_service, net::URLRequestContext* url_request_context) : network_service_(network_service), url_request_context_(url_request_context), +#if defined(OS_ANDROID) + app_status_listener_( + std::make_unique<NetworkContextApplicationStatusListener>()), +#endif binding_(this, std::move(request)), cookie_manager_( std::make_unique<CookieManager>(url_request_context->cookie_store(), @@ -455,7 +597,7 @@ void NetworkContext::CreateURLLoaderFactory( scoped_refptr<ResourceSchedulerClient> resource_scheduler_client) { url_loader_factories_.emplace(std::make_unique<cors::CORSURLLoaderFactory>( this, std::move(params), std::move(resource_scheduler_client), - std::move(request))); + std::move(request), &cors_origin_access_list_)); } void NetworkContext::SetClient(mojom::NetworkContextClientPtr client) { @@ -686,6 +828,17 @@ void NetworkContext::CloseAllConnections(CloseAllConnectionsCallback callback) { std::move(callback).Run(); } +void NetworkContext::CloseIdleConnections( + CloseIdleConnectionsCallback callback) { + net::HttpNetworkSession* http_session = + url_request_context_->http_transaction_factory()->GetSession(); + DCHECK(http_session); + + http_session->CloseIdleConnections(); + + std::move(callback).Run(); +} + void NetworkContext::SetNetworkConditions( const base::UnguessableToken& throttling_profile_id, mojom::NetworkConditionsPtr conditions) { @@ -725,6 +878,111 @@ void NetworkContext::SetCTPolicy( excluded_spkis, excluded_legacy_spkis); } +void NetworkContext::AddExpectCT(const std::string& domain, + base::Time expiry, + bool enforce, + const GURL& report_uri, + AddExpectCTCallback callback) { + net::TransportSecurityState* transport_security_state = + url_request_context()->transport_security_state(); + if (!transport_security_state) { + std::move(callback).Run(false); + return; + } + + transport_security_state->AddExpectCT(domain, expiry, enforce, report_uri); + std::move(callback).Run(true); +} + +void NetworkContext::SetExpectCTTestReport( + const GURL& report_uri, + SetExpectCTTestReportCallback callback) { + std::string decoded_dummy_cert; + DCHECK(base::Base64Decode(kTestReportCert, &decoded_dummy_cert)); + scoped_refptr<net::X509Certificate> dummy_cert = + net::X509Certificate::CreateFromBytes(decoded_dummy_cert.data(), + decoded_dummy_cert.size()); + + LazyCreateExpectCTReporter(url_request_context()); + + // We need to save |callback| into a queue because this implementation is + // relying on the success/failed observer methods of network::ExpectCTReporter + // which can be called at any time, and for other reasons. It's unlikely + // but it is possible that |callback| could be called for some other event + // other than the one initiated below when calling OnExpectCTFailed. + outstanding_set_expect_ct_callbacks_.push(std::move(callback)); + + // Send a test report with dummy data. + net::SignedCertificateTimestampAndStatusList dummy_sct_list; + expect_ct_reporter_->OnExpectCTFailed( + net::HostPortPair("expect-ct-report.test", 443), report_uri, + base::Time::Now(), dummy_cert.get(), dummy_cert.get(), dummy_sct_list); +} + +void NetworkContext::LazyCreateExpectCTReporter( + net::URLRequestContext* url_request_context) { + if (expect_ct_reporter_) + return; + + // This instance owns owns and outlives expect_ct_reporter_, so safe to + // pass |this|. + expect_ct_reporter_ = std::make_unique<network::ExpectCTReporter>( + url_request_context, + base::BindRepeating(&NetworkContext::OnSetExpectCTTestReportSuccess, + base::Unretained(this)), + base::BindRepeating(&NetworkContext::OnSetExpectCTTestReportFailure, + base::Unretained(this))); +} + +void NetworkContext::OnSetExpectCTTestReportSuccess() { + if (outstanding_set_expect_ct_callbacks_.empty()) + return; + std::move(outstanding_set_expect_ct_callbacks_.front()).Run(true); + outstanding_set_expect_ct_callbacks_.pop(); +} + +void NetworkContext::OnSetExpectCTTestReportFailure() { + if (outstanding_set_expect_ct_callbacks_.empty()) + return; + std::move(outstanding_set_expect_ct_callbacks_.front()).Run(false); + outstanding_set_expect_ct_callbacks_.pop(); +} + +void NetworkContext::GetExpectCTState(const std::string& domain, + GetExpectCTStateCallback callback) { + base::DictionaryValue result; + if (base::IsStringASCII(domain)) { + net::TransportSecurityState* transport_security_state = + url_request_context()->transport_security_state(); + if (transport_security_state) { + net::TransportSecurityState::ExpectCTState dynamic_expect_ct_state; + bool found = transport_security_state->GetDynamicExpectCTState( + domain, &dynamic_expect_ct_state); + + // TODO(estark): query static Expect-CT state as well. + if (found) { + result.SetString("dynamic_expect_ct_domain", domain); + result.SetDouble("dynamic_expect_ct_observed", + dynamic_expect_ct_state.last_observed.ToDoubleT()); + result.SetDouble("dynamic_expect_ct_expiry", + dynamic_expect_ct_state.expiry.ToDoubleT()); + result.SetBoolean("dynamic_expect_ct_enforce", + dynamic_expect_ct_state.enforce); + result.SetString("dynamic_expect_ct_report_uri", + dynamic_expect_ct_state.report_uri.spec()); + } + + result.SetBoolean("result", found); + } else { + result.SetString("error", "no Expect-CT state active"); + } + } else { + result.SetString("error", "non-ASCII domain name"); + } + + std::move(callback).Run(std::move(result)); +} + void NetworkContext::CreateUDPSocket(mojom::UDPSocketRequest request, mojom::UDPSocketReceiverPtr receiver) { socket_factory_->CreateUDPSocket(std::move(request), std::move(receiver)); @@ -745,16 +1003,28 @@ void NetworkContext::CreateTCPServerSocket( void NetworkContext::CreateTCPConnectedSocket( const base::Optional<net::IPEndPoint>& local_addr, const net::AddressList& remote_addr_list, + mojom::TCPConnectedSocketOptionsPtr tcp_connected_socket_options, const net::MutableNetworkTrafficAnnotationTag& traffic_annotation, mojom::TCPConnectedSocketRequest request, mojom::SocketObserverPtr observer, CreateTCPConnectedSocketCallback callback) { socket_factory_->CreateTCPConnectedSocket( - local_addr, remote_addr_list, + local_addr, remote_addr_list, std::move(tcp_connected_socket_options), static_cast<net::NetworkTrafficAnnotationTag>(traffic_annotation), std::move(request), std::move(observer), std::move(callback)); } +void NetworkContext::CreateTCPBoundSocket( + const net::IPEndPoint& local_addr, + const net::MutableNetworkTrafficAnnotationTag& traffic_annotation, + mojom::TCPBoundSocketRequest request, + CreateTCPBoundSocketCallback callback) { + socket_factory_->CreateTCPBoundSocket( + local_addr, + static_cast<net::NetworkTrafficAnnotationTag>(traffic_annotation), + std::move(request), std::move(callback)); +} + void NetworkContext::CreateProxyResolvingSocketFactory( mojom::ProxyResolvingSocketFactoryRequest request) { proxy_resolving_socket_factories_.AddBinding( @@ -876,16 +1146,116 @@ void NetworkContext::IsHSTSActiveForHost(const std::string& host, std::move(callback).Run(security_state->ShouldUpgradeToSSL(host)); } -void NetworkContext::AddHSTSForTesting(const std::string& host, - base::Time expiry, - bool include_subdomains, - AddHSTSForTestingCallback callback) { +void NetworkContext::SetCorsOriginAccessListsForOrigin( + const url::Origin& source_origin, + std::vector<mojom::CorsOriginPatternPtr> allow_patterns, + std::vector<mojom::CorsOriginPatternPtr> block_patterns, + SetCorsOriginAccessListsForOriginCallback callback) { + cors_origin_access_list_.SetAllowListForOrigin(source_origin, allow_patterns); + cors_origin_access_list_.SetBlockListForOrigin(source_origin, block_patterns); + std::move(callback).Run(); +} + +void NetworkContext::AddHSTS(const std::string& host, + base::Time expiry, + bool include_subdomains, + AddHSTSCallback callback) { net::TransportSecurityState* state = url_request_context_->transport_security_state(); state->AddHSTS(host, expiry, include_subdomains); std::move(callback).Run(); } +void NetworkContext::GetHSTSState(const std::string& domain, + GetHSTSStateCallback callback) { + base::DictionaryValue result; + + if (base::IsStringASCII(domain)) { + net::TransportSecurityState* transport_security_state = + url_request_context()->transport_security_state(); + if (transport_security_state) { + net::TransportSecurityState::STSState static_sts_state; + net::TransportSecurityState::PKPState static_pkp_state; + bool found_static = transport_security_state->GetStaticDomainState( + domain, &static_sts_state, &static_pkp_state); + if (found_static) { + result.SetInteger("static_upgrade_mode", + static_cast<int>(static_sts_state.upgrade_mode)); + result.SetBoolean("static_sts_include_subdomains", + static_sts_state.include_subdomains); + result.SetDouble("static_sts_observed", + static_sts_state.last_observed.ToDoubleT()); + result.SetDouble("static_sts_expiry", + static_sts_state.expiry.ToDoubleT()); + result.SetBoolean("static_pkp_include_subdomains", + static_pkp_state.include_subdomains); + result.SetDouble("static_pkp_observed", + static_pkp_state.last_observed.ToDoubleT()); + result.SetDouble("static_pkp_expiry", + static_pkp_state.expiry.ToDoubleT()); + result.SetString("static_spki_hashes", + HashesToBase64String(static_pkp_state.spki_hashes)); + result.SetString("static_sts_domain", static_sts_state.domain); + result.SetString("static_pkp_domain", static_pkp_state.domain); + } + + net::TransportSecurityState::STSState dynamic_sts_state; + net::TransportSecurityState::PKPState dynamic_pkp_state; + bool found_sts_dynamic = transport_security_state->GetDynamicSTSState( + domain, &dynamic_sts_state); + + bool found_pkp_dynamic = transport_security_state->GetDynamicPKPState( + domain, &dynamic_pkp_state); + if (found_sts_dynamic) { + result.SetInteger("dynamic_upgrade_mode", + static_cast<int>(dynamic_sts_state.upgrade_mode)); + result.SetBoolean("dynamic_sts_include_subdomains", + dynamic_sts_state.include_subdomains); + result.SetDouble("dynamic_sts_observed", + dynamic_sts_state.last_observed.ToDoubleT()); + result.SetDouble("dynamic_sts_expiry", + dynamic_sts_state.expiry.ToDoubleT()); + result.SetString("dynamic_sts_domain", dynamic_sts_state.domain); + } + + if (found_pkp_dynamic) { + result.SetBoolean("dynamic_pkp_include_subdomains", + dynamic_pkp_state.include_subdomains); + result.SetDouble("dynamic_pkp_observed", + dynamic_pkp_state.last_observed.ToDoubleT()); + result.SetDouble("dynamic_pkp_expiry", + dynamic_pkp_state.expiry.ToDoubleT()); + result.SetString("dynamic_spki_hashes", + HashesToBase64String(dynamic_pkp_state.spki_hashes)); + result.SetString("dynamic_pkp_domain", dynamic_pkp_state.domain); + } + + result.SetBoolean("result", + found_static || found_sts_dynamic || found_pkp_dynamic); + } else { + result.SetString("error", "no TransportSecurityState active"); + } + } else { + result.SetString("error", "non-ASCII domain name"); + } + + std::move(callback).Run(std::move(result)); +} + +void NetworkContext::DeleteDynamicDataForHost( + const std::string& host, + DeleteDynamicDataForHostCallback callback) { + net::TransportSecurityState* transport_security_state = + url_request_context()->transport_security_state(); + if (!transport_security_state) { + std::move(callback).Run(false); + return; + } + + std::move(callback).Run( + transport_security_state->DeleteDynamicDataForHost(host)); +} + void NetworkContext::SetFailingHttpTransactionForTesting( int32_t error_code, SetFailingHttpTransactionForTestingCallback callback) { @@ -975,15 +1345,14 @@ URLRequestContextOwner NetworkContext::ApplyContextParamsToBuilder( network_service_->network_quality_estimator()); } - scoped_refptr<network::SessionCleanupCookieStore> - session_cleanup_cookie_store; + scoped_refptr<SessionCleanupCookieStore> session_cleanup_cookie_store; scoped_refptr<SessionCleanupChannelIDStore> session_cleanup_channel_id_store; if (params_->cookie_path) { scoped_refptr<base::SequencedTaskRunner> client_task_runner = base::MessageLoopCurrent::Get()->task_runner(); scoped_refptr<base::SequencedTaskRunner> background_task_runner = base::CreateSequencedTaskRunnerWithTraits( - {base::MayBlock(), base::TaskPriority::BEST_EFFORT, + {base::MayBlock(), net::GetCookieStoreBackgroundSequencePriority(), base::TaskShutdownBehavior::BLOCK_SHUTDOWN}); std::unique_ptr<net::ChannelIDService> channel_id_service; @@ -1012,7 +1381,7 @@ URLRequestContextOwner NetworkContext::ApplyContextParamsToBuilder( crypto_delegate)); session_cleanup_cookie_store = - base::MakeRefCounted<network::SessionCleanupCookieStore>(sqlite_store); + base::MakeRefCounted<SessionCleanupCookieStore>(sqlite_store); std::unique_ptr<net::CookieMonster> cookie_store = std::make_unique<net::CookieMonster>(session_cleanup_cookie_store.get(), @@ -1061,6 +1430,9 @@ URLRequestContextOwner NetworkContext::ApplyContextParamsToBuilder( *base::CommandLine::ForCurrentProcess()); } +#if defined(OS_ANDROID) + cache_params.app_status_listener = app_status_listener(); +#endif builder->EnableHttpCache(cache_params); } @@ -1177,6 +1549,7 @@ URLRequestContextOwner NetworkContext::ApplyContextParamsToBuilder( network_context_params->enable_referrers, network_context_params ->validate_referrer_policy_on_initial_request, + std::move(network_context_params->proxy_error_client), network_context); if (out_context_network_delegate) *out_context_network_delegate = context_network_delegate.get(); @@ -1250,8 +1623,7 @@ URLRequestContextOwner NetworkContext::ApplyContextParamsToBuilder( } if (params_->enable_expect_ct_reporting) { - expect_ct_reporter_ = std::make_unique<ExpectCTReporter>( - result.url_request_context.get(), base::Closure(), base::Closure()); + LazyCreateExpectCTReporter(result.url_request_context.get()); result.url_request_context->transport_security_state()->SetExpectCTReporter( expect_ct_reporter_.get()); } @@ -1351,6 +1723,12 @@ URLRequestContextOwner NetworkContext::MakeURLRequestContext() { std::make_unique<NetworkServiceNetworkDelegate>(this); builder.set_network_delegate(std::move(network_delegate)); + if (params_->custom_proxy_config_client_request) { + proxy_delegate_ = std::make_unique<NetworkServiceProxyDelegate>( + std::move(params_->custom_proxy_config_client_request)); + builder.set_shared_proxy_delegate(proxy_delegate_.get()); + } + // |network_service_| may be nullptr in tests. auto result = ApplyContextParamsToBuilder(&builder); @@ -1467,4 +1845,16 @@ void NetworkContext::OnCertVerifyForSignedExchangeComplete(int cert_verify_id, .Run(result, *pending_cert_verify->result.get(), ct_verify_result); } +void NetworkContext::ForceReloadProxyConfig( + ForceReloadProxyConfigCallback callback) { + url_request_context()->proxy_resolution_service()->ForceReloadProxyConfig(); + std::move(callback).Run(); +} + +void NetworkContext::ClearBadProxiesCache( + ClearBadProxiesCacheCallback callback) { + url_request_context()->proxy_resolution_service()->ClearBadProxiesCache(); + std::move(callback).Run(); +} + } // namespace network diff --git a/chromium/services/network/network_context.h b/chromium/services/network/network_context.h index dd10a642afa..e25f3045425 100644 --- a/chromium/services/network/network_context.h +++ b/chromium/services/network/network_context.h @@ -26,6 +26,7 @@ #include "net/cert/cert_verify_result.h" #include "services/network/http_cache_data_counter.h" #include "services/network/http_cache_data_remover.h" +#include "services/network/public/cpp/cors/origin_access_list.h" #include "services/network/public/mojom/host_resolver.mojom.h" #include "services/network/public/mojom/network_context.mojom.h" #include "services/network/public/mojom/proxy_lookup_client.mojom.h" @@ -60,6 +61,7 @@ class CookieManager; class ExpectCTReporter; class HostResolver; class NetworkService; +class NetworkServiceProxyDelegate; class P2PSocketManager; class ProxyLookupRequest; class ResourceScheduler; @@ -134,6 +136,12 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) NetworkContext CookieManager* cookie_manager() { return cookie_manager_.get(); } +#if defined(OS_ANDROID) + base::android::ApplicationStatusListener* app_status_listener() const { + return app_status_listener_.get(); + } +#endif + // Creates a URLLoaderFactory with a ResourceSchedulerClient specified. This // is used to reuse the existing ResourceSchedulerClient for cloned // URLLoaderFactory. @@ -177,6 +185,7 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) NetworkContext mojom::ClearDataFilterPtr filter, ClearNetworkErrorLoggingCallback callback) override; void CloseAllConnections(CloseAllConnectionsCallback callback) override; + void CloseIdleConnections(CloseIdleConnectionsCallback callback) override; void SetNetworkConditions(const base::UnguessableToken& throttling_profile_id, mojom::NetworkConditionsPtr conditions) override; void SetAcceptLanguage(const std::string& new_accept_language) override; @@ -186,6 +195,15 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) NetworkContext const std::vector<std::string>& excluded_hosts, const std::vector<std::string>& excluded_spkis, const std::vector<std::string>& excluded_legacy_spkis) override; + void AddExpectCT(const std::string& domain, + base::Time expiry, + bool enforce, + const GURL& report_uri, + AddExpectCTCallback callback) override; + void SetExpectCTTestReport(const GURL& report_uri, + SetExpectCTTestReportCallback callback) override; + void GetExpectCTState(const std::string& domain, + GetExpectCTStateCallback callback) override; void CreateUDPSocket(mojom::UDPSocketRequest request, mojom::UDPSocketReceiverPtr receiver) override; void CreateTCPServerSocket( @@ -197,10 +215,16 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) NetworkContext void CreateTCPConnectedSocket( const base::Optional<net::IPEndPoint>& local_addr, const net::AddressList& remote_addr_list, + mojom::TCPConnectedSocketOptionsPtr tcp_connected_socket_options, const net::MutableNetworkTrafficAnnotationTag& traffic_annotation, mojom::TCPConnectedSocketRequest request, mojom::SocketObserverPtr observer, CreateTCPConnectedSocketCallback callback) override; + void CreateTCPBoundSocket( + const net::IPEndPoint& local_addr, + const net::MutableNetworkTrafficAnnotationTag& traffic_annotation, + mojom::TCPBoundSocketRequest request, + CreateTCPBoundSocketCallback callback) override; void CreateProxyResolvingSocketFactory( mojom::ProxyResolvingSocketFactoryRequest request) override; void CreateWebSocket(mojom::WebSocketRequest request, @@ -211,6 +235,8 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) NetworkContext void LookUpProxyForURL( const GURL& url, mojom::ProxyLookupClientPtr proxy_lookup_client) override; + void ForceReloadProxyConfig(ForceReloadProxyConfigCallback callback) override; + void ClearBadProxiesCache(ClearBadProxiesCacheCallback callback) override; void CreateNetLogExporter(mojom::NetLogExporterRequest request) override; void ResolveHost(const net::HostPortPair& host, mojom::ResolveHostParametersPtr optional_parameters, @@ -228,10 +254,20 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) NetworkContext VerifyCertForSignedExchangeCallback callback) override; void IsHSTSActiveForHost(const std::string& host, IsHSTSActiveForHostCallback callback) override; - void AddHSTSForTesting(const std::string& host, - base::Time expiry, - bool include_subdomains, - AddHSTSForTestingCallback callback) override; + void AddHSTS(const std::string& host, + base::Time expiry, + bool include_subdomains, + AddHSTSCallback callback) override; + void GetHSTSState(const std::string& domain, + GetHSTSStateCallback callback) override; + void DeleteDynamicDataForHost( + const std::string& host, + DeleteDynamicDataForHostCallback callback) override; + void SetCorsOriginAccessListsForOrigin( + const url::Origin& source_origin, + std::vector<mojom::CorsOriginPatternPtr> allow_patterns, + std::vector<mojom::CorsOriginPatternPtr> block_patterns, + SetCorsOriginAccessListsForOriginCallback callback) override; void SetFailingHttpTransactionForTesting( int32_t rv, SetFailingHttpTransactionForTestingCallback callback) override; @@ -261,6 +297,10 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) NetworkContext return proxy_lookup_requests_.size(); } + NetworkServiceProxyDelegate* proxy_delegate() const { + return proxy_delegate_.get(); + } + private: class ContextNetworkDelegate; @@ -293,6 +333,12 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) NetworkContext void OnCertVerifyForSignedExchangeComplete(int cert_verify_id, int result); + void OnSetExpectCTTestReportSuccess(); + + void LazyCreateExpectCTReporter(net::URLRequestContext* url_request_context); + + void OnSetExpectCTTestReportFailure(); + NetworkService* const network_service_; mojom::NetworkContextClientPtr client_; @@ -314,6 +360,11 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) NetworkContext // If non-null, called when the mojo pipe for the NetworkContext is closed. OnConnectionCloseCallback on_connection_close_callback_; +#if defined(OS_ANDROID) + std::unique_ptr<base::android::ApplicationStatusListener> + app_status_listener_; +#endif + mojo::Binding<mojom::NetworkContext> binding_; std::unique_ptr<CookieManager> cookie_manager_; @@ -371,6 +422,8 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) NetworkContext std::set<std::unique_ptr<HostResolver>, base::UniquePtrComparator> host_resolvers_; + std::unique_ptr<NetworkServiceProxyDelegate> proxy_delegate_; + // Used for Signed Exchange certificate verification. int next_cert_verify_id_ = 0; struct PendingCertVerify { @@ -388,6 +441,12 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) NetworkContext }; std::map<int, std::unique_ptr<PendingCertVerify>> cert_verifier_requests_; + // Manages allowed origin access lists. + cors::OriginAccessList cors_origin_access_list_; + + std::queue<SetExpectCTTestReportCallback> + outstanding_set_expect_ct_callbacks_; + DISALLOW_COPY_AND_ASSIGN(NetworkContext); }; diff --git a/chromium/services/network/network_context_unittest.cc b/chromium/services/network/network_context_unittest.cc index fb9b4d04b13..d22064b2fec 100644 --- a/chromium/services/network/network_context_unittest.cc +++ b/chromium/services/network/network_context_unittest.cc @@ -20,6 +20,7 @@ #include "base/optional.h" #include "base/run_loop.h" #include "base/stl_util.h" +#include "base/strings/strcat.h" #include "base/strings/string_split.h" #include "base/strings/utf_string_conversions.h" #include "base/synchronization/waitable_event.h" @@ -38,6 +39,7 @@ #include "components/network_session_configurator/browser/network_session_configurator.h" #include "components/network_session_configurator/common/network_switches.h" #include "mojo/public/cpp/bindings/interface_request.h" +#include "mojo/public/cpp/bindings/strong_binding.h" #include "mojo/public/cpp/system/data_pipe_utils.h" #include "net/base/cache_type.h" #include "net/base/hash_value.h" @@ -58,7 +60,6 @@ #include "net/http/http_server_properties_manager.h" #include "net/http/http_transaction_factory.h" #include "net/http/http_transaction_test_util.h" -#include "net/log/net_log_with_source.h" #include "net/proxy_resolution/proxy_config.h" #include "net/proxy_resolution/proxy_info.h" #include "net/proxy_resolution/proxy_resolution_service.h" @@ -68,20 +69,24 @@ #include "net/ssl/channel_id_store.h" #include "net/test/cert_test_util.h" #include "net/test/embedded_test_server/controllable_http_response.h" +#include "net/test/embedded_test_server/default_handlers.h" #include "net/test/embedded_test_server/embedded_test_server.h" #include "net/test/embedded_test_server/embedded_test_server_connection_listener.h" +#include "net/test/gtest_util.h" #include "net/test/test_data_directory.h" #include "net/traffic_annotation/network_traffic_annotation_test_helper.h" #include "net/url_request/http_user_agent_settings.h" #include "net/url_request/url_request_context.h" #include "net/url_request/url_request_context_builder.h" #include "net/url_request/url_request_job_factory.h" +#include "net/url_request/url_request_test_util.h" #include "services/network/cookie_manager.h" -#include "services/network/mojo_net_log.h" +#include "services/network/net_log_exporter.h" #include "services/network/network_context.h" #include "services/network/network_service.h" #include "services/network/public/cpp/features.h" #include "services/network/public/mojom/host_resolver.mojom.h" +#include "services/network/public/mojom/net_log.mojom.h" #include "services/network/public/mojom/network_service.mojom.h" #include "services/network/public/mojom/proxy_config.mojom.h" #include "services/network/test/test_url_loader_client.h" @@ -104,6 +109,10 @@ namespace network { namespace { +const GURL kURL("http://foo.com"); +const GURL kOtherURL("http://other.com"); +constexpr char kMockHost[] = "mock.host"; + // Sends an HttpResponse for requests for "/" that result in sending an HPKP // report. Ignores other paths to avoid catching the subsequent favicon // request. @@ -126,6 +135,18 @@ std::unique_ptr<net::test_server::HttpResponse> SendReportHttpResponse( return nullptr; } +void StoreBool(bool* result, const base::Closure& callback, bool value) { + *result = value; + callback.Run(); +} + +void StoreValue(base::Value* result, + const base::Closure& callback, + base::Value value) { + *result = std::move(value); + callback.Run(); +} + mojom::NetworkContextParamsPtr CreateContextParams() { mojom::NetworkContextParamsPtr params = mojom::NetworkContextParams::New(); // Use a fixed proxy config, to avoid dependencies on local network @@ -134,6 +155,47 @@ mojom::NetworkContextParamsPtr CreateContextParams() { return params; } +void SetContentSetting(const GURL& primary_pattern, + const GURL& secondary_pattern, + ContentSetting setting, + NetworkContext* network_context) { + network_context->cookie_manager()->SetContentSettings( + {ContentSettingPatternSource( + ContentSettingsPattern::FromURL(primary_pattern), + ContentSettingsPattern::FromURL(secondary_pattern), + base::Value(setting), std::string(), false)}); +} + +void SetDefaultContentSetting(ContentSetting setting, + NetworkContext* network_context) { + network_context->cookie_manager()->SetContentSettings( + {ContentSettingPatternSource(ContentSettingsPattern::Wildcard(), + ContentSettingsPattern::Wildcard(), + base::Value(setting), std::string(), + false)}); +} + +std::unique_ptr<TestURLLoaderClient> FetchRequest( + const ResourceRequest& request, + NetworkContext* network_context) { + mojom::URLLoaderFactoryPtr loader_factory; + auto params = mojom::URLLoaderFactoryParams::New(); + params->process_id = mojom::kBrowserProcessId; + params->is_corb_enabled = false; + network_context->CreateURLLoaderFactory(mojo::MakeRequest(&loader_factory), + std::move(params)); + + auto client = std::make_unique<TestURLLoaderClient>(); + mojom::URLLoaderPtr loader; + loader_factory->CreateLoaderAndStart( + mojo::MakeRequest(&loader), 0 /* routing_id */, 0 /* request_id */, + 0 /* options */, request, client->CreateInterfacePtr(), + net::MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS)); + + client->RunUntilComplete(); + return client; +} + // ProxyLookupClient that drives proxy lookups and can wait for the responses to // be received. class TestProxyLookupClient : public mojom::ProxyLookupClient { @@ -227,7 +289,7 @@ class NetworkContextTest : public testing::Test, // Looks up a value with the given name from the NetworkContext's // TransportSocketPool info dictionary. int GetSocketPoolInfo(NetworkContext* context, base::StringPiece name) { - int value; + int value = -1; context->url_request_context() ->http_transaction_factory() ->GetSession() @@ -1732,6 +1794,8 @@ TEST_F(NetworkContextTest, ClearEmptyReportingCacheReports) { } TEST_F(NetworkContextTest, ClearReportingCacheReportsWithNoService) { + base::test::ScopedFeatureList scoped_feature_list_; + scoped_feature_list_.InitAndDisableFeature(features::kReporting); std::unique_ptr<NetworkContext> network_context = CreateContextWithParams(CreateContextParams()); @@ -1842,6 +1906,8 @@ TEST_F(NetworkContextTest, ClearEmptyReportingCacheClients) { } TEST_F(NetworkContextTest, ClearReportingCacheClientsWithNoService) { + base::test::ScopedFeatureList scoped_feature_list_; + scoped_feature_list_.InitAndDisableFeature(features::kReporting); std::unique_ptr<NetworkContext> network_context = CreateContextWithParams(CreateContextParams()); @@ -1937,6 +2003,8 @@ TEST_F(NetworkContextTest, ClearEmptyNetworkErrorLogging) { } TEST_F(NetworkContextTest, ClearEmptyNetworkErrorLoggingWithNoService) { + base::test::ScopedFeatureList scoped_feature_list_; + scoped_feature_list_.InitAndDisableFeature(features::kNetworkErrorLogging); std::unique_ptr<NetworkContext> network_context = CreateContextWithParams(CreateContextParams()); @@ -2359,8 +2427,8 @@ TEST_F(NetworkContextTest, CreateNetLogExporter) { net::TestCompletionCallback cb; net_log_exporter->Start(std::move(out_file), std::move(dict_start), - mojom::NetLogExporter_CaptureMode::DEFAULT, - 100 * 1024, cb.callback()); + mojom::NetLogCaptureMode::DEFAULT, 100 * 1024, + cb.callback()); EXPECT_EQ(net::OK, cb.WaitForResult()); base::Value dict_late(base::Value::Type::DICTIONARY); @@ -2403,7 +2471,7 @@ TEST_F(NetworkContextTest, CreateNetLogExporterUnbounded) { net::TestCompletionCallback cb; net_log_exporter->Start( std::move(out_file), base::Value(base::Value::Type::DICTIONARY), - mojom::NetLogExporter::CaptureMode::DEFAULT, + mojom::NetLogCaptureMode::DEFAULT, mojom::NetLogExporter::kUnlimitedFileSize, cb.callback()); EXPECT_EQ(net::OK, cb.WaitForResult()); @@ -2443,7 +2511,7 @@ TEST_F(NetworkContextTest, CreateNetLogExporterErrors) { net_log_exporter->Start( std::move(temp_file), base::Value(base::Value::Type::DICTIONARY), - mojom::NetLogExporter_CaptureMode::DEFAULT, 100 * 1024, cb.callback()); + mojom::NetLogCaptureMode::DEFAULT, 100 * 1024, cb.callback()); EXPECT_EQ(net::OK, cb.WaitForResult()); // Can't start twice. @@ -2455,7 +2523,7 @@ TEST_F(NetworkContextTest, CreateNetLogExporterErrors) { net_log_exporter->Start( std::move(temp_file2), base::Value(base::Value::Type::DICTIONARY), - mojom::NetLogExporter_CaptureMode::DEFAULT, 100 * 1024, cb.callback()); + mojom::NetLogCaptureMode::DEFAULT, 100 * 1024, cb.callback()); EXPECT_EQ(net::ERR_UNEXPECTED, cb.WaitForResult()); base::DeleteFile(temp_path, false); @@ -2497,10 +2565,9 @@ TEST_F(NetworkContextTest, DestroyNetLogExporterWhileCreatingScratchDir) { base::File::FLAG_CREATE_ALWAYS | base::File::FLAG_WRITE); ASSERT_TRUE(temp_file.IsValid()); - net_log_exporter->Start(std::move(temp_file), - base::Value(base::Value::Type::DICTIONARY), - mojom::NetLogExporter_CaptureMode::DEFAULT, 100, - base::BindOnce([](int) {})); + net_log_exporter->Start( + std::move(temp_file), base::Value(base::Value::Type::DICTIONARY), + mojom::NetLogCaptureMode::DEFAULT, 100, base::BindOnce([](int) {})); net_log_exporter = nullptr; block_mktemp.Signal(); @@ -2973,6 +3040,133 @@ TEST_F(NetworkContextTest, CreateHostResolver_CloseContext) { EXPECT_TRUE(resolver_closed); } +TEST_F(NetworkContextTest, PrivacyModeDisabledByDefault) { + std::unique_ptr<NetworkContext> network_context = + CreateContextWithParams(CreateContextParams()); + + EXPECT_FALSE(network_context->url_request_context() + ->network_delegate() + ->CanEnablePrivacyMode(kURL, kOtherURL)); +} + +TEST_F(NetworkContextTest, PrivacyModeEnabledIfCookiesBlocked) { + std::unique_ptr<NetworkContext> network_context = + CreateContextWithParams(CreateContextParams()); + + SetContentSetting(kURL, kOtherURL, CONTENT_SETTING_BLOCK, + network_context.get()); + EXPECT_TRUE(network_context->url_request_context() + ->network_delegate() + ->CanEnablePrivacyMode(kURL, kOtherURL)); + EXPECT_FALSE(network_context->url_request_context() + ->network_delegate() + ->CanEnablePrivacyMode(kOtherURL, kURL)); +} + +TEST_F(NetworkContextTest, PrivacyModeDisabledIfCookiesAllowed) { + std::unique_ptr<NetworkContext> network_context = + CreateContextWithParams(CreateContextParams()); + + SetContentSetting(kURL, kOtherURL, CONTENT_SETTING_ALLOW, + network_context.get()); + EXPECT_FALSE(network_context->url_request_context() + ->network_delegate() + ->CanEnablePrivacyMode(kURL, kOtherURL)); +} + +TEST_F(NetworkContextTest, PrivacyModeDisabledIfCookiesSettingForOtherURL) { + std::unique_ptr<NetworkContext> network_context = + CreateContextWithParams(CreateContextParams()); + + // URLs are switched so setting should not apply. + SetContentSetting(kOtherURL, kURL, CONTENT_SETTING_BLOCK, + network_context.get()); + EXPECT_FALSE(network_context->url_request_context() + ->network_delegate() + ->CanEnablePrivacyMode(kURL, kOtherURL)); +} + +TEST_F(NetworkContextTest, PrivacyModeEnabledIfThirdPartyCookiesBlocked) { + std::unique_ptr<NetworkContext> network_context = + CreateContextWithParams(CreateContextParams()); + net::NetworkDelegate* delegate = + network_context->url_request_context()->network_delegate(); + + network_context->cookie_manager()->BlockThirdPartyCookies(true); + EXPECT_TRUE(delegate->CanEnablePrivacyMode(kURL, kOtherURL)); + EXPECT_FALSE(delegate->CanEnablePrivacyMode(kURL, kURL)); + + network_context->cookie_manager()->BlockThirdPartyCookies(false); + EXPECT_FALSE(delegate->CanEnablePrivacyMode(kURL, kOtherURL)); + EXPECT_FALSE(delegate->CanEnablePrivacyMode(kURL, kURL)); +} + +TEST_F(NetworkContextTest, CanSetCookieFalseIfCookiesBlocked) { + std::unique_ptr<NetworkContext> network_context = + CreateContextWithParams(CreateContextParams()); + net::URLRequestContext context; + std::unique_ptr<net::URLRequest> request = context.CreateRequest( + kURL, net::DEFAULT_PRIORITY, nullptr, TRAFFIC_ANNOTATION_FOR_TESTS); + net::CanonicalCookie cookie("TestCookie", "1", "www.test.com", "/", + base::Time(), base::Time(), base::Time(), false, + false, net::CookieSameSite::NO_RESTRICTION, + net::COOKIE_PRIORITY_LOW); + + EXPECT_TRUE( + network_context->url_request_context()->network_delegate()->CanSetCookie( + *request, cookie, nullptr, true)); + SetDefaultContentSetting(CONTENT_SETTING_BLOCK, network_context.get()); + EXPECT_FALSE( + network_context->url_request_context()->network_delegate()->CanSetCookie( + *request, cookie, nullptr, true)); +} + +TEST_F(NetworkContextTest, CanSetCookieTrueIfCookiesAllowed) { + std::unique_ptr<NetworkContext> network_context = + CreateContextWithParams(CreateContextParams()); + net::URLRequestContext context; + std::unique_ptr<net::URLRequest> request = context.CreateRequest( + kURL, net::DEFAULT_PRIORITY, nullptr, TRAFFIC_ANNOTATION_FOR_TESTS); + net::CanonicalCookie cookie("TestCookie", "1", "www.test.com", "/", + base::Time(), base::Time(), base::Time(), false, + false, net::CookieSameSite::NO_RESTRICTION, + net::COOKIE_PRIORITY_LOW); + + SetDefaultContentSetting(CONTENT_SETTING_ALLOW, network_context.get()); + EXPECT_TRUE( + network_context->url_request_context()->network_delegate()->CanSetCookie( + *request, cookie, nullptr, true)); +} + +TEST_F(NetworkContextTest, CanGetCookiesFalseIfCookiesBlocked) { + std::unique_ptr<NetworkContext> network_context = + CreateContextWithParams(CreateContextParams()); + net::URLRequestContext context; + std::unique_ptr<net::URLRequest> request = context.CreateRequest( + kURL, net::DEFAULT_PRIORITY, nullptr, TRAFFIC_ANNOTATION_FOR_TESTS); + + EXPECT_TRUE( + network_context->url_request_context()->network_delegate()->CanGetCookies( + *request, {}, true)); + SetDefaultContentSetting(CONTENT_SETTING_BLOCK, network_context.get()); + EXPECT_FALSE( + network_context->url_request_context()->network_delegate()->CanGetCookies( + *request, {}, true)); +} + +TEST_F(NetworkContextTest, CanGetCookiesTrueIfCookiesAllowed) { + std::unique_ptr<NetworkContext> network_context = + CreateContextWithParams(CreateContextParams()); + net::URLRequestContext context; + std::unique_ptr<net::URLRequest> request = context.CreateRequest( + kURL, net::DEFAULT_PRIORITY, nullptr, TRAFFIC_ANNOTATION_FOR_TESTS); + + SetDefaultContentSetting(CONTENT_SETTING_ALLOW, network_context.get()); + EXPECT_TRUE( + network_context->url_request_context()->network_delegate()->CanGetCookies( + *request, {}, true)); +} + // Gets notified by the EmbeddedTestServer on incoming connections being // accepted or read from, keeps track of them and exposes that info to // the tests. @@ -3006,7 +3200,7 @@ class ConnectionListener // Get called from the EmbeddedTestServer thread to be notified that // a connection was read from. void ReadFromSocket(const net::StreamSocket& connection, int rv) override { - EXPECT_EQ(net::OK, rv); + EXPECT_GE(rv, net::OK); } // Wait for exactly |n| items in |sockets_|. |n| must be greater than 0. @@ -3244,6 +3438,172 @@ TEST_F(NetworkContextTest, CloseAllConnections) { EXPECT_EQ(num_sockets, 0); } +TEST_F(NetworkContextTest, CloseIdleConnections) { + std::unique_ptr<NetworkContext> network_context = + CreateContextWithParams(CreateContextParams()); + + ConnectionListener connection_listener; + net::EmbeddedTestServer test_server; + test_server.AddDefaultHandlers( + base::FilePath(FILE_PATH_LITERAL("services/test/data"))); + test_server.SetConnectionListener(&connection_listener); + ASSERT_TRUE(test_server.Start()); + + // Create a hung (i.e. non-idle) socket. + net::TestDelegate delegate; + std::unique_ptr<net::URLRequest> request = + network_context->url_request_context()->CreateRequest( + test_server.GetURL("/hung"), net::DEFAULT_PRIORITY, &delegate, + TRAFFIC_ANNOTATION_FOR_TESTS); + request->Start(); + connection_listener.WaitForAcceptedConnections(1u); + EXPECT_EQ(0, GetSocketPoolInfo(network_context.get(), "idle_socket_count")); + EXPECT_EQ( + 0, GetSocketPoolInfo(network_context.get(), "connecting_socket_count")); + EXPECT_EQ( + 1, GetSocketPoolInfo(network_context.get(), "handed_out_socket_count")); + + // Create an idle socket. + network_context->PreconnectSockets(2, test_server.base_url(), + net::LOAD_NORMAL, true); + connection_listener.WaitForAcceptedConnections(2u); + EXPECT_EQ(2, GetSocketPoolInfo(network_context.get(), "idle_socket_count")); + EXPECT_EQ( + 0, GetSocketPoolInfo(network_context.get(), "connecting_socket_count")); + EXPECT_EQ( + 1, GetSocketPoolInfo(network_context.get(), "handed_out_socket_count")); + + base::RunLoop run_loop; + network_context->CloseIdleConnections(run_loop.QuitClosure()); + run_loop.Run(); + + EXPECT_EQ(0, GetSocketPoolInfo(network_context.get(), "idle_socket_count")); + EXPECT_EQ( + 0, GetSocketPoolInfo(network_context.get(), "connecting_socket_count")); + EXPECT_EQ( + 1, GetSocketPoolInfo(network_context.get(), "handed_out_socket_count")); +} + +TEST_F(NetworkContextTest, ExpectCT) { + std::unique_ptr<NetworkContext> network_context = + CreateContextWithParams(CreateContextParams()); + + const char kTestDomain[] = "example.com"; + const base::Time expiry = + base::Time::Now() + base::TimeDelta::FromSeconds(1000); + const bool enforce = true; + const GURL report_uri = GURL("https://example.com/foo/bar"); + + // Assert we start with no data for the test host. + { + base::Value state; + base::RunLoop run_loop; + network_context->GetExpectCTState( + kTestDomain, + base::BindOnce(&StoreValue, &state, run_loop.QuitClosure())); + run_loop.Run(); + EXPECT_TRUE(state.is_dict()); + + const base::Value* result = + state.FindKeyOfType("result", base::Value::Type::BOOLEAN); + ASSERT_TRUE(result != nullptr); + EXPECT_FALSE(result->GetBool()); + } + + // Add the host data. + { + base::RunLoop run_loop; + bool result = false; + network_context->AddExpectCT( + kTestDomain, expiry, enforce, report_uri, + base::BindOnce(&StoreBool, &result, run_loop.QuitClosure())); + run_loop.Run(); + EXPECT_TRUE(result); + } + + // Assert added host data is returned. + { + base::Value state; + base::RunLoop run_loop; + network_context->GetExpectCTState( + kTestDomain, + base::BindOnce(&StoreValue, &state, run_loop.QuitClosure())); + run_loop.Run(); + EXPECT_TRUE(state.is_dict()); + + const base::Value* value = state.FindKeyOfType("dynamic_expect_ct_domain", + base::Value::Type::STRING); + ASSERT_TRUE(value != nullptr); + EXPECT_EQ(kTestDomain, value->GetString()); + + value = state.FindKeyOfType("dynamic_expect_ct_expiry", + base::Value::Type::DOUBLE); + ASSERT_TRUE(value != nullptr); + EXPECT_EQ(expiry.ToDoubleT(), value->GetDouble()); + + value = state.FindKeyOfType("dynamic_expect_ct_enforce", + base::Value::Type::BOOLEAN); + ASSERT_TRUE(value != nullptr); + EXPECT_EQ(enforce, value->GetBool()); + + value = state.FindKeyOfType("dynamic_expect_ct_report_uri", + base::Value::Type::STRING); + ASSERT_TRUE(value != nullptr); + EXPECT_EQ(report_uri, value->GetString()); + } + + // Delete host data. + { + bool result; + base::RunLoop run_loop; + network_context->DeleteDynamicDataForHost( + kTestDomain, + base::BindOnce(&StoreBool, &result, run_loop.QuitClosure())); + run_loop.Run(); + EXPECT_TRUE(result); + } + + // Assert data is removed. + { + base::Value state; + base::RunLoop run_loop; + network_context->GetExpectCTState( + kTestDomain, + base::BindOnce(&StoreValue, &state, run_loop.QuitClosure())); + run_loop.Run(); + EXPECT_TRUE(state.is_dict()); + + const base::Value* result = + state.FindKeyOfType("result", base::Value::Type::BOOLEAN); + ASSERT_TRUE(result != nullptr); + EXPECT_FALSE(result->GetBool()); + } +} + +TEST_F(NetworkContextTest, SetExpectCTTestReport) { + std::unique_ptr<NetworkContext> network_context = + CreateContextWithParams(CreateContextParams()); + net::EmbeddedTestServer test_server; + + std::set<GURL> requested_urls; + auto monitor_callback = base::BindLambdaForTesting( + [&](const net::test_server::HttpRequest& request) { + requested_urls.insert(request.GetURL()); + }); + test_server.RegisterRequestMonitor(monitor_callback); + ASSERT_TRUE(test_server.Start()); + const GURL kReportURL = test_server.base_url().Resolve("/report/path"); + + base::RunLoop run_loop; + bool result = false; + network_context->SetExpectCTTestReport( + kReportURL, base::BindOnce(&StoreBool, &result, run_loop.QuitClosure())); + run_loop.Run(); + EXPECT_FALSE(result); + + EXPECT_TRUE(base::ContainsKey(requested_urls, kReportURL)); +} + TEST_F(NetworkContextTest, QueryHSTS) { const char kTestDomain[] = "example.com"; @@ -3259,9 +3619,11 @@ TEST_F(NetworkContextTest, QueryHSTS) { EXPECT_TRUE(got_result); EXPECT_FALSE(result); - network_context->AddHSTSForTesting( + base::RunLoop run_loop; + network_context->AddHSTS( kTestDomain, base::Time::Now() + base::TimeDelta::FromDays(1000), - false /*include_subdomains*/, base::DoNothing()); + false /*include_subdomains*/, run_loop.QuitClosure()); + run_loop.Run(); bool result2 = false, got_result2 = false; network_context->IsHSTSActiveForHost( @@ -3273,6 +3635,859 @@ TEST_F(NetworkContextTest, QueryHSTS) { EXPECT_TRUE(result2); } +TEST_F(NetworkContextTest, GetHSTSState) { + const char kTestDomain[] = "example.com"; + const base::Time expiry = + base::Time::Now() + base::TimeDelta::FromSeconds(1000); + const GURL report_uri = GURL("https://example.com/foo/bar"); + + std::unique_ptr<NetworkContext> network_context = + CreateContextWithParams(CreateContextParams()); + + base::Value state; + { + base::RunLoop run_loop; + network_context->GetHSTSState( + kTestDomain, + base::BindOnce(&StoreValue, &state, run_loop.QuitClosure())); + run_loop.Run(); + } + EXPECT_TRUE(state.is_dict()); + + const base::Value* result = + state.FindKeyOfType("result", base::Value::Type::BOOLEAN); + ASSERT_TRUE(result != nullptr); + EXPECT_FALSE(result->GetBool()); + + { + base::RunLoop run_loop; + network_context->AddHSTS(kTestDomain, expiry, false /*include_subdomains*/, + run_loop.QuitClosure()); + run_loop.Run(); + } + + { + base::RunLoop run_loop; + network_context->GetHSTSState( + kTestDomain, + base::BindOnce(&StoreValue, &state, run_loop.QuitClosure())); + run_loop.Run(); + } + EXPECT_TRUE(state.is_dict()); + + result = state.FindKeyOfType("result", base::Value::Type::BOOLEAN); + ASSERT_TRUE(result != nullptr); + EXPECT_TRUE(result->GetBool()); + + // Not checking all values - only enough to ensure the underlying call + // was made. + const base::Value* value = + state.FindKeyOfType("dynamic_sts_domain", base::Value::Type::STRING); + ASSERT_TRUE(value != nullptr); + EXPECT_EQ(kTestDomain, value->GetString()); + + value = state.FindKeyOfType("dynamic_sts_expiry", base::Value::Type::DOUBLE); + ASSERT_TRUE(value != nullptr); + EXPECT_EQ(expiry.ToDoubleT(), value->GetDouble()); +} + +TEST_F(NetworkContextTest, ForceReloadProxyConfig) { + std::unique_ptr<NetworkContext> network_context = + CreateContextWithParams(CreateContextParams()); + + auto net_log_exporter = + std::make_unique<network::NetLogExporter>(network_context.get()); + base::FilePath net_log_path; + ASSERT_TRUE(base::CreateTemporaryFile(&net_log_path)); + + { + base::File net_log_file( + net_log_path, base::File::FLAG_CREATE_ALWAYS | base::File::FLAG_WRITE); + EXPECT_TRUE(net_log_file.IsValid()); + base::RunLoop run_loop; + int32_t start_param = 0; + auto start_callback = base::BindLambdaForTesting([&](int32_t result) { + start_param = result; + run_loop.Quit(); + }); + net_log_exporter->Start( + std::move(net_log_file), + /*extra_constants=*/base::Value(base::Value::Type::DICTIONARY), + network::mojom::NetLogCaptureMode::DEFAULT, + network::mojom::NetLogExporter::kUnlimitedFileSize, start_callback); + run_loop.Run(); + EXPECT_EQ(net::OK, start_param); + } + + { + base::RunLoop run_loop; + network_context->ForceReloadProxyConfig(run_loop.QuitClosure()); + run_loop.Run(); + } + + { + base::RunLoop run_loop; + int32_t stop_param = 0; + auto stop_callback = base::BindLambdaForTesting([&](int32_t result) { + stop_param = result; + run_loop.Quit(); + }); + net_log_exporter->Stop( + /*polled_data=*/base::Value(base::Value::Type::DICTIONARY), + stop_callback); + run_loop.Run(); + EXPECT_EQ(net::OK, stop_param); + } + + std::string log_contents; + EXPECT_TRUE(base::ReadFileToString(net_log_path, &log_contents)); + + EXPECT_NE(std::string::npos, log_contents.find("\"new_config\"")) + << log_contents; + base::DeleteFile(net_log_path, false); +} + +TEST_F(NetworkContextTest, ClearBadProxiesCache) { + std::unique_ptr<NetworkContext> network_context = + CreateContextWithParams(CreateContextParams()); + + net::ProxyResolutionService* proxy_resolution_service = + network_context->url_request_context()->proxy_resolution_service(); + + // Very starting conditions: zero bad proxies. + EXPECT_EQ(0UL, proxy_resolution_service->proxy_retry_info().size()); + + // Simulate network error to add one proxy to the bad proxy list. + net::ProxyInfo proxy_info; + proxy_info.UseNamedProxy("http://foo1.com"); + proxy_resolution_service->ReportSuccess(proxy_info); + std::vector<net::ProxyServer> proxies; + proxies.push_back(net::ProxyServer::FromURI("http://foo1.com", + net::ProxyServer::SCHEME_HTTP)); + proxy_resolution_service->MarkProxiesAsBadUntil( + proxy_info, base::TimeDelta::FromDays(1), proxies, + net::NetLogWithSource()); + base::RunLoop().RunUntilIdle(); + EXPECT_EQ(1UL, proxy_resolution_service->proxy_retry_info().size()); + + // Clear the bad proxies. + base::RunLoop run_loop; + network_context->ClearBadProxiesCache(run_loop.QuitClosure()); + run_loop.Run(); + + // Verify all cleared. + EXPECT_EQ(0UL, proxy_resolution_service->proxy_retry_info().size()); +} + +// This is a test ProxyErrorClient that records the sequence of calls made to +// OnPACScriptError() and OnRequestMaybeFailedDueToProxySettings(). +class TestProxyErrorClient final : public mojom::ProxyErrorClient { + public: + struct PacScriptError { + int line = -1; + std::string details; + }; + + TestProxyErrorClient() : binding_(this) {} + + ~TestProxyErrorClient() override {} + + void OnPACScriptError(int32_t line_number, + const std::string& details) override { + on_pac_script_error_calls_.push_back({line_number, details}); + } + + void OnRequestMaybeFailedDueToProxySettings(int32_t net_error) override { + on_request_maybe_failed_calls_.push_back(net_error); + } + + const std::vector<int>& on_request_maybe_failed_calls() const { + return on_request_maybe_failed_calls_; + } + + const std::vector<PacScriptError>& on_pac_script_error_calls() const { + return on_pac_script_error_calls_; + } + + // Creates an InterfacePtrInfo, binds it to |*this| and returns it. + mojom::ProxyErrorClientPtrInfo CreateInterfacePtrInfo() { + mojom::ProxyErrorClientPtrInfo client_ptr_info; + + binding_.Bind(mojo::MakeRequest(&client_ptr_info)); + binding_.set_connection_error_handler(base::BindOnce( + &TestProxyErrorClient::OnMojoPipeError, base::Unretained(this))); + return client_ptr_info; + } + + // Runs until the message pipe is closed due to an error. + void RunUntilMojoPipeError() { + if (has_received_mojo_pipe_error_) + return; + base::RunLoop run_loop; + quit_closure_for_on_mojo_pipe_error_ = run_loop.QuitClosure(); + run_loop.Run(); + } + + private: + void OnMojoPipeError() { + if (has_received_mojo_pipe_error_) + return; + has_received_mojo_pipe_error_ = true; + if (quit_closure_for_on_mojo_pipe_error_) + std::move(quit_closure_for_on_mojo_pipe_error_).Run(); + } + + mojo::Binding<mojom::ProxyErrorClient> binding_; + + base::OnceClosure quit_closure_for_on_mojo_pipe_error_; + bool has_received_mojo_pipe_error_ = false; + std::vector<int> on_request_maybe_failed_calls_; + std::vector<PacScriptError> on_pac_script_error_calls_; + + DISALLOW_COPY_AND_ASSIGN(TestProxyErrorClient); +}; + +// While in scope, all host resolutions will fail with ERR_NAME_NOT_RESOLVED, +// including localhost (so this precludes the use of embedded test server). +class ScopedFailAllHostResolutions { + public: + ScopedFailAllHostResolutions() + : mock_resolver_proc_(new net::RuleBasedHostResolverProc(nullptr)), + default_resolver_proc_(mock_resolver_proc_.get()) { + mock_resolver_proc_->AddSimulatedFailure("*"); + } + + private: + scoped_refptr<net::RuleBasedHostResolverProc> mock_resolver_proc_; + net::ScopedDefaultHostResolverProc default_resolver_proc_; +}; + +// Tests that when a ProxyErrorClient is provided to NetworkContextParams, this +// client's OnRequestMaybeFailedDueToProxySettings() method is called exactly +// once when a request fails due to a proxy server connectivity failure. +TEST_F(NetworkContextTest, ProxyErrorClientNotifiedOfProxyConnection) { + // Avoid the test having a network dependency on DNS. + ScopedFailAllHostResolutions fail_dns; + + // Set up the NetworkContext, such that it uses an unreachable proxy + // (proxy and is configured to send "proxy errors" to + // |proxy_error_client|. + TestProxyErrorClient proxy_error_client; + mojom::NetworkContextParamsPtr context_params = + mojom::NetworkContextParams::New(); + context_params->proxy_error_client = + proxy_error_client.CreateInterfacePtrInfo(); + net::ProxyConfig proxy_config; + // Set the proxy to an unreachable address (host resolution fails). + proxy_config.proxy_rules().ParseFromString("proxy.bad.dns"); + context_params->initial_proxy_config = net::ProxyConfigWithAnnotation( + proxy_config, TRAFFIC_ANNOTATION_FOR_TESTS); + std::unique_ptr<NetworkContext> network_context = + CreateContextWithParams(std::move(context_params)); + + // Issue an HTTP request. It doesn't matter exactly what the URL is, since it + // will be sent to the proxy. + ResourceRequest request; + request.url = GURL("http://example.test"); + + mojom::URLLoaderFactoryPtr loader_factory; + mojom::URLLoaderFactoryParamsPtr loader_params = + mojom::URLLoaderFactoryParams::New(); + loader_params->process_id = mojom::kBrowserProcessId; + network_context->CreateURLLoaderFactory(mojo::MakeRequest(&loader_factory), + std::move(loader_params)); + + mojom::URLLoaderPtr loader; + TestURLLoaderClient client; + loader_factory->CreateLoaderAndStart( + mojo::MakeRequest(&loader), 0 /* routing_id */, 0 /* request_id */, + 0 /* options */, request, client.CreateInterfacePtr(), + net::MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS)); + + // Confirm the the resource request failed due to an unreachable proxy. + client.RunUntilComplete(); + EXPECT_THAT(client.completion_status().error_code, + net::test::IsError(net::ERR_PROXY_CONNECTION_FAILED)); + + // Tear down the network context and wait for a pipe error to ensure + // that all queued messages on |proxy_error_client| have been processed. + network_context.reset(); + proxy_error_client.RunUntilMojoPipeError(); + + // Confirm that the ProxyErrorClient received the expected calls. + const auto& request_errors = + proxy_error_client.on_request_maybe_failed_calls(); + const auto& pac_errors = proxy_error_client.on_pac_script_error_calls(); + + ASSERT_EQ(1u, request_errors.size()); + EXPECT_THAT(request_errors[0], + net::test::IsError(net::ERR_PROXY_CONNECTION_FAILED)); + EXPECT_EQ(0u, pac_errors.size()); +} + +// Tests that when a ProxyErrorClient is provided to NetworkContextParams, this +// client's OnRequestMaybeFailedDueToProxySettings() method is +// NOT called when a request fails due to a non-proxy related error (in this +// case the target host is unreachable). +TEST_F(NetworkContextTest, ProxyErrorClientNotNotifiedOfUnreachableError) { + // Avoid the test having a network dependency on DNS. + ScopedFailAllHostResolutions fail_dns; + + // Set up the NetworkContext that uses the default DIRECT proxy + // configuration. + TestProxyErrorClient proxy_error_client; + mojom::NetworkContextParamsPtr context_params = CreateContextParams(); + context_params->proxy_error_client = + proxy_error_client.CreateInterfacePtrInfo(); + std::unique_ptr<NetworkContext> network_context = + CreateContextWithParams(std::move(context_params)); + + // Issue an HTTP request to an unreachable URL. + ResourceRequest request; + request.url = GURL("http://server.bad.dns/fail"); + + mojom::URLLoaderFactoryPtr loader_factory; + mojom::URLLoaderFactoryParamsPtr loader_params = + mojom::URLLoaderFactoryParams::New(); + loader_params->process_id = mojom::kBrowserProcessId; + network_context->CreateURLLoaderFactory(mojo::MakeRequest(&loader_factory), + std::move(loader_params)); + + mojom::URLLoaderPtr loader; + TestURLLoaderClient client; + loader_factory->CreateLoaderAndStart( + mojo::MakeRequest(&loader), 0 /* routing_id */, 0 /* request_id */, + 0 /* options */, request, client.CreateInterfacePtr(), + net::MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS)); + + // Confirm the the resource request failed. + client.RunUntilComplete(); + EXPECT_THAT(client.completion_status().error_code, + net::test::IsError(net::ERR_NAME_NOT_RESOLVED)); + + // Tear down the network context and wait for a pipe error to ensure + // that all queued messages on |proxy_error_client| have been processed. + network_context.reset(); + proxy_error_client.RunUntilMojoPipeError(); + + // Confirm that the ProxyErrorClient received no calls. + const auto& request_errors = + proxy_error_client.on_request_maybe_failed_calls(); + const auto& pac_errors = proxy_error_client.on_pac_script_error_calls(); + + EXPECT_EQ(0u, request_errors.size()); + EXPECT_EQ(0u, pac_errors.size()); +} + +// Test mojom::ProxyResolver that completes calls to GetProxyForUrl() with a +// DIRECT "proxy". It additionall emits a script error on line 42 for every call +// to GetProxyForUrl(). +class MockMojoProxyResolver : public proxy_resolver::mojom::ProxyResolver { + public: + MockMojoProxyResolver() {} + + private: + // Overridden from proxy_resolver::mojom::ProxyResolver: + void GetProxyForUrl( + const GURL& url, + proxy_resolver::mojom::ProxyResolverRequestClientPtr client) override { + // Report a Javascript error and then complete the request successfully, + // having chosen DIRECT connections. + client->OnError(42, "Failed: FindProxyForURL(url=" + url.spec() + ")"); + + net::ProxyInfo result; + result.UseDirect(); + + client->ReportResult(net::OK, result); + } + + DISALLOW_COPY_AND_ASSIGN(MockMojoProxyResolver); +}; + +// Test mojom::ProxyResolverFactory implementation that successfully completes +// any CreateResolver() requests, and binds the request to a new +// MockMojoProxyResolver. +class MockMojoProxyResolverFactory + : public proxy_resolver::mojom::ProxyResolverFactory { + public: + MockMojoProxyResolverFactory() {} + + // Binds and returns a mock ProxyResolverFactory whose lifetime is bound to + // the message pipe. + static proxy_resolver::mojom::ProxyResolverFactoryPtrInfo Create() { + proxy_resolver::mojom::ProxyResolverFactoryPtrInfo ptr_info; + mojo::MakeStrongBinding(std::make_unique<MockMojoProxyResolverFactory>(), + mojo::MakeRequest(&ptr_info)); + return ptr_info; + } + + private: + void CreateResolver( + const std::string& pac_url, + mojo::InterfaceRequest<proxy_resolver::mojom::ProxyResolver> request, + proxy_resolver::mojom::ProxyResolverFactoryRequestClientPtr client) + override { + // Bind |request| to a new MockMojoProxyResolver, and return success. + mojo::MakeStrongBinding(std::make_unique<MockMojoProxyResolver>(), + std::move(request)); + client->ReportResult(net::OK); + } + + DISALLOW_COPY_AND_ASSIGN(MockMojoProxyResolverFactory); +}; + +// Tests that when a ProxyErrorClient is provided to NetworkContextParams, this +// client's OnPACScriptError() method is called whenever the PAC script throws +// an error. +TEST_F(NetworkContextTest, ProxyErrorClientNotifiedOfPacError) { + // Avoid the test having a network dependency on DNS. + ScopedFailAllHostResolutions fail_dns; + + // Set up the NetworkContext so that it sends "proxy errors" to + // |proxy_error_client|, and uses a mock ProxyResolverFactory that emits + // script errors. + TestProxyErrorClient proxy_error_client; + mojom::NetworkContextParamsPtr context_params = + mojom::NetworkContextParams::New(); + context_params->proxy_error_client = + proxy_error_client.CreateInterfacePtrInfo(); + // The PAC URL doesn't matter, since the test is configured to use a + // mock ProxyResolverFactory which doesn't actually evaluate it. It just + // needs to be a data: URL to ensure the network fetch doesn't fail. + // + // That said, the mock PAC evalulator being used behaves similarly to the + // script embedded in the data URL below. + net::ProxyConfig proxy_config = net::ProxyConfig::CreateFromCustomPacURL( + GURL("data:,function FindProxyForURL(url,host){throw url}")); + context_params->initial_proxy_config = net::ProxyConfigWithAnnotation( + proxy_config, TRAFFIC_ANNOTATION_FOR_TESTS); + context_params->proxy_resolver_factory = + MockMojoProxyResolverFactory::Create(); + std::unique_ptr<NetworkContext> network_context = + CreateContextWithParams(std::move(context_params)); + + // Issue an HTTP request. This will end up being sent DIRECT since the PAC + // script is broken. + ResourceRequest request; + request.url = GURL("http://server.bad.dns"); + + mojom::URLLoaderFactoryPtr loader_factory; + mojom::URLLoaderFactoryParamsPtr loader_params = + mojom::URLLoaderFactoryParams::New(); + loader_params->process_id = mojom::kBrowserProcessId; + network_context->CreateURLLoaderFactory(mojo::MakeRequest(&loader_factory), + std::move(loader_params)); + + mojom::URLLoaderPtr loader; + TestURLLoaderClient client; + loader_factory->CreateLoaderAndStart( + mojo::MakeRequest(&loader), 0 /* routing_id */, 0 /* request_id */, + 0 /* options */, request, client.CreateInterfacePtr(), + net::MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS)); + + // Confirm the the resource request failed. + client.RunUntilComplete(); + EXPECT_THAT(client.completion_status().error_code, + net::test::IsError(net::ERR_NAME_NOT_RESOLVED)); + + // Tear down the network context and wait for a pipe error to ensure + // that all queued messages on |proxy_error_client| have been processed. + network_context.reset(); + proxy_error_client.RunUntilMojoPipeError(); + + // Confirm that the ProxyErrorClient received the expected calls. + const auto& request_errors = + proxy_error_client.on_request_maybe_failed_calls(); + const auto& pac_errors = proxy_error_client.on_pac_script_error_calls(); + + EXPECT_EQ(0u, request_errors.size()); + + ASSERT_EQ(1u, pac_errors.size()); + EXPECT_EQ(pac_errors[0].line, 42); + EXPECT_EQ(pac_errors[0].details, + "Failed: FindProxyForURL(url=http://server.bad.dns/)"); +} + +// Test ensures that ProxyServer data is populated correctly across Mojo calls. +// Basically it performs a set of URLLoader network requests, whose requests +// configure proxies. Then it checks whether the expected proxy scheme is +// respected. +TEST_F(NetworkContextTest, EnsureProperProxyServerIsUsed) { + net::test_server::EmbeddedTestServer test_server; + test_server.AddDefaultHandlers( + base::FilePath(FILE_PATH_LITERAL("services/test/data"))); + ASSERT_TRUE(test_server.Start()); + + struct ProxyConfigSet { + net::ProxyConfig proxy_config; + GURL url; + net::ProxyServer::Scheme expected_proxy_config_scheme; + } proxy_config_set[2]; + + proxy_config_set[0].proxy_config.proxy_rules().ParseFromString( + base::StringPrintf("http=%s", + test_server.host_port_pair().ToString().c_str())); + // The domain here is irrelevant, and it is the path that matters. + proxy_config_set[0].url = GURL("http://does.not.matter/echo"); + proxy_config_set[0].expected_proxy_config_scheme = + net::ProxyServer::SCHEME_HTTP; + + proxy_config_set[1].proxy_config.proxy_rules().ParseFromString( + "http=direct://"); + proxy_config_set[1].url = test_server.GetURL("/echo"); + proxy_config_set[1].expected_proxy_config_scheme = + net::ProxyServer::SCHEME_DIRECT; + + for (const auto& proxy_data : proxy_config_set) { + mojom::NetworkContextParamsPtr context_params = CreateContextParams(); + context_params->initial_proxy_config = net::ProxyConfigWithAnnotation( + proxy_data.proxy_config, TRAFFIC_ANNOTATION_FOR_TESTS); + mojom::ProxyConfigClientPtr config_client; + context_params->proxy_config_client_request = + mojo::MakeRequest(&config_client); + std::unique_ptr<NetworkContext> network_context = + CreateContextWithParams(std::move(context_params)); + + mojom::URLLoaderFactoryPtr loader_factory; + mojom::URLLoaderFactoryParamsPtr params = + mojom::URLLoaderFactoryParams::New(); + params->process_id = 0; + network_context->CreateURLLoaderFactory(mojo::MakeRequest(&loader_factory), + std::move(params)); + + ResourceRequest request; + // The domain here is irrelevant, and it is the path that matters. + request.url = proxy_data.url; // test_server.GetURL("/echo"); + + mojom::URLLoaderPtr loader; + TestURLLoaderClient client; + loader_factory->CreateLoaderAndStart( + mojo::MakeRequest(&loader), 0 /* routing_id */, 0 /* request_id */, + 0 /* options */, request, client.CreateInterfacePtr(), + net::MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS)); + + client.RunUntilComplete(); + + EXPECT_TRUE(client.has_received_completion()); + EXPECT_EQ(client.response_head().proxy_server.scheme(), + proxy_data.expected_proxy_config_scheme); + } +} + +// Custom proxy does not apply to localhost, so resolve kMockHost to localhost, +// and use that instead. +class NetworkContextMockHostTest : public NetworkContextTest { + public: + NetworkContextMockHostTest() { + auto host_resolver = std::make_unique<net::MockHostResolver>(); + host_resolver->rules()->AddRule(kMockHost, "127.0.0.1"); + network_service_->SetHostResolver(std::move(host_resolver)); + } + + protected: + GURL GetURLWithMockHost(const net::EmbeddedTestServer& server, + const std::string& relative_url) { + GURL server_base_url = server.base_url(); + GURL base_url = + GURL(base::StrCat({server_base_url.scheme(), "://", kMockHost, ":", + server_base_url.port()})); + EXPECT_TRUE(base_url.is_valid()) << base_url.possibly_invalid_spec(); + return base_url.Resolve(relative_url); + } +}; + +TEST_F(NetworkContextMockHostTest, CustomProxyAddsHeaders) { + net::EmbeddedTestServer test_server; + ASSERT_TRUE(test_server.Start()); + + net::EmbeddedTestServer proxy_test_server; + net::test_server::RegisterDefaultHandlers(&proxy_test_server); + ASSERT_TRUE(proxy_test_server.Start()); + + mojom::CustomProxyConfigClientPtr proxy_config_client; + mojom::NetworkContextParamsPtr context_params = CreateContextParams(); + context_params->custom_proxy_config_client_request = + mojo::MakeRequest(&proxy_config_client); + std::unique_ptr<NetworkContext> network_context = + CreateContextWithParams(std::move(context_params)); + + auto config = mojom::CustomProxyConfig::New(); + std::string base_url = proxy_test_server.base_url().spec(); + // Remove slash from URL. + base_url.pop_back(); + config->rules.ParseFromString("http=" + base_url); + config->pre_cache_headers.SetHeader("pre_foo", "pre_foo_value"); + config->post_cache_headers.SetHeader("post_foo", "post_foo_value"); + proxy_config_client->OnCustomProxyConfigUpdated(std::move(config)); + scoped_task_environment_.RunUntilIdle(); + + ResourceRequest request; + request.custom_proxy_pre_cache_headers.SetHeader("pre_bar", "pre_bar_value"); + request.custom_proxy_post_cache_headers.SetHeader("post_bar", + "post_bar_value"); + request.url = GetURLWithMockHost( + test_server, "/echoheader?pre_foo&post_foo&pre_bar&post_bar"); + std::unique_ptr<TestURLLoaderClient> client = + FetchRequest(request, network_context.get()); + std::string response; + EXPECT_TRUE( + mojo::BlockingCopyToString(client->response_body_release(), &response)); + + EXPECT_EQ(response, base::JoinString({"post_bar_value", "post_foo_value", + "pre_bar_value", "pre_foo_value"}, + "\n")); + EXPECT_EQ(client->response_head().proxy_server, + net::ProxyServer::FromURI(base_url, net::ProxyServer::SCHEME_HTTP)); +} + +TEST_F(NetworkContextMockHostTest, + CustomProxyRequestHeadersOverrideConfigHeaders) { + net::EmbeddedTestServer test_server; + ASSERT_TRUE(test_server.Start()); + + net::EmbeddedTestServer proxy_test_server; + net::test_server::RegisterDefaultHandlers(&proxy_test_server); + ASSERT_TRUE(proxy_test_server.Start()); + + mojom::CustomProxyConfigClientPtr proxy_config_client; + mojom::NetworkContextParamsPtr context_params = CreateContextParams(); + context_params->custom_proxy_config_client_request = + mojo::MakeRequest(&proxy_config_client); + std::unique_ptr<NetworkContext> network_context = + CreateContextWithParams(std::move(context_params)); + + auto config = mojom::CustomProxyConfig::New(); + std::string base_url = proxy_test_server.base_url().spec(); + // Remove slash from URL. + base_url.pop_back(); + config->rules.ParseFromString("http=" + base_url); + config->pre_cache_headers.SetHeader("foo", "bad"); + config->post_cache_headers.SetHeader("bar", "bad"); + proxy_config_client->OnCustomProxyConfigUpdated(std::move(config)); + scoped_task_environment_.RunUntilIdle(); + + ResourceRequest request; + request.custom_proxy_pre_cache_headers.SetHeader("foo", "foo_value"); + request.custom_proxy_post_cache_headers.SetHeader("bar", "bar_value"); + request.url = GetURLWithMockHost(test_server, "/echoheader?foo&bar"); + std::unique_ptr<TestURLLoaderClient> client = + FetchRequest(request, network_context.get()); + std::string response; + EXPECT_TRUE( + mojo::BlockingCopyToString(client->response_body_release(), &response)); + + EXPECT_EQ(response, base::JoinString({"bar_value", "foo_value"}, "\n")); + EXPECT_EQ(client->response_head().proxy_server, + net::ProxyServer::FromURI(base_url, net::ProxyServer::SCHEME_HTTP)); +} + +TEST_F(NetworkContextMockHostTest, CustomProxyConfigHeadersAddedBeforeCache) { + net::EmbeddedTestServer test_server; + ASSERT_TRUE(test_server.Start()); + + net::EmbeddedTestServer proxy_test_server; + net::test_server::RegisterDefaultHandlers(&proxy_test_server); + ASSERT_TRUE(proxy_test_server.Start()); + + mojom::CustomProxyConfigClientPtr proxy_config_client; + mojom::NetworkContextParamsPtr context_params = CreateContextParams(); + context_params->custom_proxy_config_client_request = + mojo::MakeRequest(&proxy_config_client); + std::unique_ptr<NetworkContext> network_context = + CreateContextWithParams(std::move(context_params)); + + auto config = mojom::CustomProxyConfig::New(); + std::string base_url = proxy_test_server.base_url().spec(); + // Remove slash from URL. + base_url.pop_back(); + config->rules.ParseFromString("http=" + base_url); + config->pre_cache_headers.SetHeader("foo", "foo_value"); + config->post_cache_headers.SetHeader("bar", "bar_value"); + proxy_config_client->OnCustomProxyConfigUpdated(config->Clone()); + scoped_task_environment_.RunUntilIdle(); + + ResourceRequest request; + request.url = GetURLWithMockHost(test_server, "/echoheadercache?foo&bar"); + std::unique_ptr<TestURLLoaderClient> client = + FetchRequest(request, network_context.get()); + std::string response; + EXPECT_TRUE( + mojo::BlockingCopyToString(client->response_body_release(), &response)); + + EXPECT_EQ(response, base::JoinString({"bar_value", "foo_value"}, "\n")); + EXPECT_EQ(client->response_head().proxy_server, + net::ProxyServer::FromURI(base_url, net::ProxyServer::SCHEME_HTTP)); + EXPECT_FALSE(client->response_head().was_fetched_via_cache); + + // post_cache_headers should not break caching. + config->post_cache_headers.SetHeader("bar", "new_bar"); + proxy_config_client->OnCustomProxyConfigUpdated(config->Clone()); + scoped_task_environment_.RunUntilIdle(); + + client = FetchRequest(request, network_context.get()); + EXPECT_TRUE( + mojo::BlockingCopyToString(client->response_body_release(), &response)); + + EXPECT_EQ(response, base::JoinString({"bar_value", "foo_value"}, "\n")); + EXPECT_TRUE(client->response_head().was_fetched_via_cache); + + // pre_cache_headers should invalidate cache. + config->pre_cache_headers.SetHeader("foo", "new_foo"); + proxy_config_client->OnCustomProxyConfigUpdated(config->Clone()); + scoped_task_environment_.RunUntilIdle(); + + client = FetchRequest(request, network_context.get()); + EXPECT_TRUE( + mojo::BlockingCopyToString(client->response_body_release(), &response)); + + EXPECT_EQ(response, base::JoinString({"new_bar", "new_foo"}, "\n")); + EXPECT_EQ(client->response_head().proxy_server, + net::ProxyServer::FromURI(base_url, net::ProxyServer::SCHEME_HTTP)); + EXPECT_FALSE(client->response_head().was_fetched_via_cache); +} + +TEST_F(NetworkContextMockHostTest, CustomProxyRequestHeadersAddedBeforeCache) { + net::EmbeddedTestServer test_server; + ASSERT_TRUE(test_server.Start()); + + net::EmbeddedTestServer proxy_test_server; + net::test_server::RegisterDefaultHandlers(&proxy_test_server); + ASSERT_TRUE(proxy_test_server.Start()); + + mojom::CustomProxyConfigClientPtr proxy_config_client; + mojom::NetworkContextParamsPtr context_params = CreateContextParams(); + context_params->custom_proxy_config_client_request = + mojo::MakeRequest(&proxy_config_client); + std::unique_ptr<NetworkContext> network_context = + CreateContextWithParams(std::move(context_params)); + + auto config = mojom::CustomProxyConfig::New(); + std::string base_url = proxy_test_server.base_url().spec(); + // Remove slash from URL. + base_url.pop_back(); + config->rules.ParseFromString("http=" + base_url); + proxy_config_client->OnCustomProxyConfigUpdated(std::move(config)); + scoped_task_environment_.RunUntilIdle(); + + ResourceRequest request; + request.url = GetURLWithMockHost(test_server, "/echoheadercache?foo&bar"); + request.custom_proxy_pre_cache_headers.SetHeader("foo", "foo_value"); + request.custom_proxy_post_cache_headers.SetHeader("bar", "bar_value"); + std::unique_ptr<TestURLLoaderClient> client = + FetchRequest(request, network_context.get()); + std::string response; + EXPECT_TRUE( + mojo::BlockingCopyToString(client->response_body_release(), &response)); + + EXPECT_EQ(response, base::JoinString({"bar_value", "foo_value"}, "\n")); + EXPECT_EQ(client->response_head().proxy_server, + net::ProxyServer::FromURI(base_url, net::ProxyServer::SCHEME_HTTP)); + EXPECT_FALSE(client->response_head().was_fetched_via_cache); + + // custom_proxy_post_cache_headers should not break caching. + request.custom_proxy_post_cache_headers.SetHeader("bar", "new_bar"); + + client = FetchRequest(request, network_context.get()); + EXPECT_TRUE( + mojo::BlockingCopyToString(client->response_body_release(), &response)); + + EXPECT_EQ(response, base::JoinString({"bar_value", "foo_value"}, "\n")); + EXPECT_TRUE(client->response_head().was_fetched_via_cache); + + // custom_proxy_pre_cache_headers should invalidate cache. + request.custom_proxy_pre_cache_headers.SetHeader("foo", "new_foo"); + + client = FetchRequest(request, network_context.get()); + EXPECT_TRUE( + mojo::BlockingCopyToString(client->response_body_release(), &response)); + + EXPECT_EQ(response, base::JoinString({"new_bar", "new_foo"}, "\n")); + EXPECT_EQ(client->response_head().proxy_server, + net::ProxyServer::FromURI(base_url, net::ProxyServer::SCHEME_HTTP)); + EXPECT_FALSE(client->response_head().was_fetched_via_cache); +} + +TEST_F(NetworkContextMockHostTest, + CustomProxyDoesNotAddHeadersWhenNoProxyUsed) { + net::EmbeddedTestServer test_server; + net::test_server::RegisterDefaultHandlers(&test_server); + ASSERT_TRUE(test_server.Start()); + + mojom::CustomProxyConfigClientPtr proxy_config_client; + mojom::NetworkContextParamsPtr context_params = CreateContextParams(); + context_params->custom_proxy_config_client_request = + mojo::MakeRequest(&proxy_config_client); + std::unique_ptr<NetworkContext> network_context = + CreateContextWithParams(std::move(context_params)); + + auto config = mojom::CustomProxyConfig::New(); + config->pre_cache_headers.SetHeader("pre_foo", "bad"); + config->post_cache_headers.SetHeader("post_foo", "bad"); + proxy_config_client->OnCustomProxyConfigUpdated(std::move(config)); + scoped_task_environment_.RunUntilIdle(); + + ResourceRequest request; + request.custom_proxy_pre_cache_headers.SetHeader("pre_bar", "bad"); + request.custom_proxy_post_cache_headers.SetHeader("post_bar", "bad"); + request.url = GetURLWithMockHost( + test_server, "/echoheader?pre_foo&post_foo&pre_bar&post_bar"); + std::unique_ptr<TestURLLoaderClient> client = + FetchRequest(request, network_context.get()); + std::string response; + EXPECT_TRUE( + mojo::BlockingCopyToString(client->response_body_release(), &response)); + + EXPECT_EQ(response, base::JoinString({"None", "None", "None", "None"}, "\n")); + EXPECT_TRUE(client->response_head().proxy_server.is_direct()); +} + +TEST_F(NetworkContextMockHostTest, + CustomProxyDoesNotAddHeadersWhenOtherProxyUsed) { + net::EmbeddedTestServer test_server; + ASSERT_TRUE(test_server.Start()); + + net::EmbeddedTestServer proxy_test_server; + net::test_server::RegisterDefaultHandlers(&proxy_test_server); + ASSERT_TRUE(proxy_test_server.Start()); + + mojom::NetworkContextParamsPtr context_params = CreateContextParams(); + // Set up a proxy to be used by the proxy config service. + net::ProxyConfig proxy_config; + std::string base_url = proxy_test_server.base_url().spec(); + // Remove slash from URL. + base_url.pop_back(); + proxy_config.proxy_rules().ParseFromString("http=" + base_url); + context_params->initial_proxy_config = net::ProxyConfigWithAnnotation( + proxy_config, TRAFFIC_ANNOTATION_FOR_TESTS); + + mojom::CustomProxyConfigClientPtr proxy_config_client; + context_params->custom_proxy_config_client_request = + mojo::MakeRequest(&proxy_config_client); + std::unique_ptr<NetworkContext> network_context = + CreateContextWithParams(std::move(context_params)); + + auto config = mojom::CustomProxyConfig::New(); + config->pre_cache_headers.SetHeader("pre_foo", "bad"); + config->post_cache_headers.SetHeader("post_foo", "bad"); + proxy_config_client->OnCustomProxyConfigUpdated(std::move(config)); + scoped_task_environment_.RunUntilIdle(); + + ResourceRequest request; + request.custom_proxy_pre_cache_headers.SetHeader("pre_bar", "bad"); + request.custom_proxy_post_cache_headers.SetHeader("post_bar", "bad"); + request.url = GetURLWithMockHost( + test_server, "/echoheader?pre_foo&post_foo&pre_bar&post_bar"); + std::unique_ptr<TestURLLoaderClient> client = + FetchRequest(request, network_context.get()); + std::string response; + EXPECT_TRUE( + mojo::BlockingCopyToString(client->response_body_release(), &response)); + + EXPECT_EQ(response, base::JoinString({"None", "None", "None", "None"}, "\n")); + EXPECT_EQ(client->response_head().proxy_server, + net::ProxyServer::FromURI(base_url, net::ProxyServer::SCHEME_HTTP)); +} + } // namespace } // namespace network diff --git a/chromium/services/network/network_sandbox_hook_linux.cc b/chromium/services/network/network_sandbox_hook_linux.cc index 983a6dac494..f20450d437e 100644 --- a/chromium/services/network/network_sandbox_hook_linux.cc +++ b/chromium/services/network/network_sandbox_hook_linux.cc @@ -31,7 +31,7 @@ bool NetworkPreSandboxHook(service_manager::SandboxLinux::Options options) { {BrokerFilePermission::ReadWriteCreateRecursive("/")}, service_manager::SandboxLinux::PreSandboxHook(), options); - instance->EngageNamespaceSandbox(false /* from_zygote */); + instance->EngageNamespaceSandboxIfPossible(); return true; } diff --git a/chromium/services/network/network_service.cc b/chromium/services/network/network_service.cc index d2f6d27b5e9..b353f8b21a6 100644 --- a/chromium/services/network/network_service.cc +++ b/chromium/services/network/network_service.cc @@ -21,21 +21,26 @@ #include "components/certificate_transparency/sth_observer.h" #include "components/os_crypt/os_crypt.h" #include "mojo/public/cpp/bindings/strong_binding.h" +#include "mojo/public/cpp/bindings/type_converter.h" #include "net/base/logging_network_change_observer.h" #include "net/base/network_change_notifier.h" #include "net/cert/ct_log_response_parser.h" #include "net/cert/signed_tree_head.h" +#include "net/dns/dns_config_overrides.h" #include "net/dns/host_resolver.h" #include "net/dns/mapped_host_resolver.h" #include "net/http/http_auth_handler_factory.h" #include "net/log/file_net_log_observer.h" #include "net/log/net_log.h" +#include "net/log/net_log_capture_mode.h" #include "net/log/net_log_util.h" +#include "net/ssl/ssl_key_logger_impl.h" #include "net/url_request/url_request_context.h" #include "net/url_request/url_request_context_builder.h" #include "services/network/crl_set_distributor.h" #include "services/network/cross_origin_read_blocking.h" -#include "services/network/mojo_net_log.h" +#include "services/network/net_log_capture_mode_type_converter.h" +#include "services/network/net_log_exporter.h" #include "services/network/network_context.h" #include "services/network/network_usage_accumulator.h" #include "services/network/public/cpp/features.h" @@ -52,14 +57,18 @@ #include "components/os_crypt/key_storage_config_linux.h" #endif +#if defined(OS_ANDROID) +#include "base/android/application_status_listener.h" +#endif + namespace network { namespace { NetworkService* g_network_service = nullptr; -MojoNetLog* GetMojoNetLog() { - static base::NoDestructor<MojoNetLog> instance; +net::NetLog* GetNetLog() { + static base::NoDestructor<net::NetLog> instance; return instance.get(); } @@ -163,8 +172,7 @@ NetworkService::NetworkService( // per-NetworkContext basis. UMA_HISTOGRAM_BOOLEAN( "Net.Certificate.IgnoreCertificateErrorsSPKIListPresent", - command_line->HasSwitch( - network::switches::kIgnoreCertificateErrorsSPKIList)); + command_line->HasSwitch(switches::kIgnoreCertificateErrorsSPKIList)); network_change_manager_ = std::make_unique<NetworkChangeManager>( CreateNetworkChangeNotifierIfNeeded()); @@ -172,18 +180,16 @@ NetworkService::NetworkService( if (net_log) { net_log_ = net_log; } else { - network_service_net_log_ = GetMojoNetLog(); - // Note: The command line switches are only checked when not using the - // embedder's NetLog, as it may already be writing to the destination log - // file. - net_log_ = network_service_net_log_; + net_log_ = GetNetLog(); } + trace_net_log_observer_.WatchForTraceStart(net_log_); + // Add an observer that will emit network change events to the ChromeNetLog. // Assuming NetworkChangeNotifier dispatches in FIFO order, we should be // logging the network change before other IO thread consumers respond to it. - network_change_observer_.reset( - new net::LoggingNetworkChangeObserver(net_log_)); + network_change_observer_ = + std::make_unique<net::LoggingNetworkChangeObserver>(net_log_); network_quality_estimator_manager_ = std::make_unique<NetworkQualityEstimatorManager>(net_log_); @@ -206,8 +212,11 @@ NetworkService::~NetworkService() { // point. DCHECK(network_contexts_.empty()); - if (network_service_net_log_) - network_service_net_log_->ShutDown(); + if (file_net_log_observer_) { + file_net_log_observer_->StopObserving(nullptr /*polled_data*/, + base::OnceClosure()); + } + trace_net_log_observer_.StopWatchForTraceStart(); } void NetworkService::set_os_crypt_is_configured() { @@ -279,13 +288,21 @@ void NetworkService::SetClient(mojom::NetworkServiceClientPtr client) { } void NetworkService::StartNetLog(base::File file, + mojom::NetLogCaptureMode capture_mode, base::Value client_constants) { DCHECK(client_constants.is_dict()); std::unique_ptr<base::DictionaryValue> constants = net::GetNetConstants(); constants->MergeDictionary(&client_constants); - network_service_net_log_->ObserveFileWithConstants(std::move(file), - std::move(*constants)); + file_net_log_observer_ = net::FileNetLogObserver::CreateUnboundedPreExisting( + std::move(file), std::move(constants)); + file_net_log_observer_->StartObserving( + net_log_, mojo::ConvertTo<net::NetLogCaptureMode>(capture_mode)); +} + +void NetworkService::SetSSLKeyLogFile(const base::FilePath& file) { + net::SSLClientSocket::SetSSLKeyLogger( + std::make_unique<net::SSLKeyLoggerImpl>(file)); } void NetworkService::CreateNetworkContext( @@ -303,7 +320,7 @@ void NetworkService::CreateNetworkContext( void NetworkService::ConfigureStubHostResolver( bool stub_resolver_enabled, - base::Optional<std::vector<network::mojom::DnsOverHttpsServerPtr>> + base::Optional<std::vector<mojom::DnsOverHttpsServerPtr>> dns_over_https_servers) { // If the stub resolver is not enabled, |dns_over_https_servers| has no // effect. @@ -315,19 +332,25 @@ void NetworkService::ConfigureStubHostResolver( host_resolver_->SetDnsClientEnabled(stub_resolver_enabled); // Configure DNS over HTTPS. - host_resolver_->ClearDnsOverHttpsServers(); - if (!dns_over_https_servers) + if (!dns_over_https_servers || dns_over_https_servers.value().empty()) { + host_resolver_->SetDnsConfigOverrides(net::DnsConfigOverrides()); return; + } for (auto* network_context : network_contexts_) { if (!network_context->IsPrimaryNetworkContext()) continue; host_resolver_->SetRequestContext(network_context->url_request_context()); + + net::DnsConfigOverrides overrides; + overrides.dns_over_https_servers.emplace(); for (const auto& doh_server : *dns_over_https_servers) { - host_resolver_->AddDnsOverHttpsServer(doh_server->server_template, - doh_server->use_post); + overrides.dns_over_https_servers.value().emplace_back( + doh_server->server_template, doh_server->use_post); } + host_resolver_->SetDnsConfigOverrides(overrides); + return; } @@ -458,6 +481,14 @@ void NetworkService::RemoveCorbExceptionForPlugin(uint32_t process_id) { CrossOriginReadBlocking::RemoveExceptionForPlugin(process_id); } +#if defined(OS_ANDROID) +void NetworkService::OnApplicationStateChange( + base::android::ApplicationState state) { + for (auto* network_context : network_contexts_) + network_context->app_status_listener()->Notify(state); +} +#endif + net::HttpAuthHandlerFactory* NetworkService::GetHttpAuthHandlerFactory() { if (!http_auth_handler_factory_) { http_auth_handler_factory_ = net::HttpAuthHandlerFactory::CreateDefault( @@ -496,9 +527,10 @@ void NetworkService::DestroyNetworkContexts() { // If DNS over HTTPS is enabled, the HostResolver is currently using the // primary NetworkContext to do DNS lookups, so need to tell the HostResolver // to stop using DNS over HTTPS before destroying the primary NetworkContext. - // The ClearDnsOverHttpsServers() call will will fail any in-progress DNS - // lookups, but only if DNS over HTTPS is currently enabled. - host_resolver_->ClearDnsOverHttpsServers(); + // The SetDnsConfigOverrides() call will will fail any in-progress DNS + // lookups, but only if there are current config overrides (which there will + // be if DNS over HTTPS is currently enabled). + host_resolver_->SetDnsConfigOverrides(net::DnsConfigOverrides()); host_resolver_->SetRequestContext(nullptr); DCHECK_LE(owned_network_contexts_.size(), 1u); diff --git a/chromium/services/network/network_service.h b/chromium/services/network/network_service.h index a275de05dc1..ed544fe474b 100644 --- a/chromium/services/network/network_service.h +++ b/chromium/services/network/network_service.h @@ -20,9 +20,11 @@ #include "mojo/public/cpp/bindings/binding.h" #include "net/http/http_auth_preferences.h" #include "net/log/net_log.h" +#include "net/log/trace_net_log_observer.h" #include "services/network/keepalive_statistics_recorder.h" #include "services/network/network_change_manager.h" #include "services/network/network_quality_estimator_manager.h" +#include "services/network/public/mojom/net_log.mojom.h" #include "services/network/public/mojom/network_change_manager.mojom.h" #include "services/network/public/mojom/network_quality_estimator_manager.mojom.h" #include "services/network/public/mojom/network_service.mojom.h" @@ -30,6 +32,7 @@ #include "services/service_manager/public/cpp/service.h" namespace net { +class FileNetLogObserver; class HostResolver; class HttpAuthHandlerFactory; class LoggingNetworkChangeObserver; @@ -47,7 +50,6 @@ namespace network { class CRLSetDistributor; class NetworkContext; class NetworkUsageAccumulator; -class MojoNetLog; class URLRequestContextBuilderMojo; class COMPONENT_EXPORT(NETWORK_SERVICE) NetworkService @@ -122,12 +124,15 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) NetworkService // mojom::NetworkService implementation: void SetClient(mojom::NetworkServiceClientPtr client) override; - void StartNetLog(base::File file, base::Value constants) override; + void StartNetLog(base::File file, + mojom::NetLogCaptureMode capture_mode, + base::Value constants) override; + void SetSSLKeyLogFile(const base::FilePath& file) override; void CreateNetworkContext(mojom::NetworkContextRequest request, mojom::NetworkContextParamsPtr params) override; void ConfigureStubHostResolver( bool stub_resolver_enabled, - base::Optional<std::vector<network::mojom::DnsOverHttpsServerPtr>> + base::Optional<std::vector<mojom::DnsOverHttpsServerPtr>> dns_over_https_servers) override; void DisableQuic() override; void SetUpHttpAuth( @@ -151,6 +156,9 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) NetworkService #endif void AddCorbExceptionForPlugin(uint32_t process_id) override; void RemoveCorbExceptionForPlugin(uint32_t process_id) override; +#if defined(OS_ANDROID) + void OnApplicationStateChange(base::android::ApplicationState state) override; +#endif // Returns the shared HttpAuthHandlerFactory for the NetworkService, creating // one if needed. @@ -209,10 +217,10 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) NetworkService // Starts timer call UpdateLoadInfo() again, if needed. void AckUpdateLoadInfo(); - MojoNetLog* network_service_net_log_ = nullptr; - // TODO(https://crbug.com/767450): Remove this, once Chrome no longer creates - // its own NetLog. - net::NetLog* net_log_; + net::NetLog* net_log_ = nullptr; + + std::unique_ptr<net::FileNetLogObserver> file_net_log_observer_; + net::TraceNetLogObserver trace_net_log_observer_; mojom::NetworkServiceClientPtr client_; diff --git a/chromium/services/network/network_service_network_delegate.cc b/chromium/services/network/network_service_network_delegate.cc index 1579d820de0..b9f1a1b1cfc 100644 --- a/chromium/services/network/network_service_network_delegate.cc +++ b/chromium/services/network/network_service_network_delegate.cc @@ -7,6 +7,7 @@ #include "services/network/cookie_manager.h" #include "services/network/network_context.h" #include "services/network/network_service.h" +#include "services/network/network_service_proxy_delegate.h" #include "services/network/public/cpp/features.h" #include "services/network/url_loader.h" @@ -24,6 +25,28 @@ NetworkServiceNetworkDelegate::NetworkServiceNetworkDelegate( NetworkServiceNetworkDelegate::~NetworkServiceNetworkDelegate() = default; +int NetworkServiceNetworkDelegate::OnBeforeStartTransaction( + net::URLRequest* request, + net::CompletionOnceCallback callback, + net::HttpRequestHeaders* headers) { + if (network_context_->proxy_delegate()) { + network_context_->proxy_delegate()->OnBeforeStartTransaction(request, + headers); + } + return net::OK; +} + +void NetworkServiceNetworkDelegate::OnBeforeSendHeaders( + net::URLRequest* request, + const net::ProxyInfo& proxy_info, + const net::ProxyRetryInfoMap& proxy_retry_info, + net::HttpRequestHeaders* headers) { + if (network_context_->proxy_delegate()) { + network_context_->proxy_delegate()->OnBeforeSendHeaders(request, proxy_info, + headers); + } +} + int NetworkServiceNetworkDelegate::OnHeadersReceived( net::URLRequest* request, net::CompletionOnceCallback callback, @@ -40,35 +63,31 @@ int NetworkServiceNetworkDelegate::OnHeadersReceived( bool NetworkServiceNetworkDelegate::OnCanGetCookies( const net::URLRequest& request, - const net::CookieList& cookie_list) { - bool allow = - network_context_->cookie_manager() - ->cookie_settings() - .IsCookieAccessAllowed(request.url(), request.site_for_cookies()); + const net::CookieList& cookie_list, + bool allowed_from_caller) { URLLoader* url_loader = URLLoader::ForRequest(request); if (url_loader && network_context_->network_service()->client()) { network_context_->network_service()->client()->OnCookiesRead( url_loader->GetProcessId(), url_loader->GetRenderFrameId(), - request.url(), request.site_for_cookies(), cookie_list, !allow); + request.url(), request.site_for_cookies(), cookie_list, + !allowed_from_caller); } - return allow; + return allowed_from_caller; } bool NetworkServiceNetworkDelegate::OnCanSetCookie( const net::URLRequest& request, const net::CanonicalCookie& cookie, - net::CookieOptions* options) { - bool allow = - network_context_->cookie_manager() - ->cookie_settings() - .IsCookieAccessAllowed(request.url(), request.site_for_cookies()); + net::CookieOptions* options, + bool allowed_from_caller) { URLLoader* url_loader = URLLoader::ForRequest(request); if (url_loader && network_context_->network_service()->client()) { network_context_->network_service()->client()->OnCookieChange( url_loader->GetProcessId(), url_loader->GetRenderFrameId(), - request.url(), request.site_for_cookies(), cookie, !allow); + request.url(), request.site_for_cookies(), cookie, + !allowed_from_caller); } - return allow; + return allowed_from_caller; } bool NetworkServiceNetworkDelegate::OnCanAccessFile( @@ -158,14 +177,6 @@ void NetworkServiceNetworkDelegate::FinishedClearSiteData( std::move(callback).Run(net::OK); } -bool NetworkServiceNetworkDelegate::OnCanEnablePrivacyMode( - const GURL& url, - const GURL& site_for_cookies) const { - return !network_context_->cookie_manager() - ->cookie_settings() - .IsCookieAccessAllowed(url, site_for_cookies); -} - void NetworkServiceNetworkDelegate::FinishedCanSendReportingReports( base::OnceCallback<void(std::set<url::Origin>)> result_callback, const std::vector<url::Origin>& origins) { diff --git a/chromium/services/network/network_service_network_delegate.h b/chromium/services/network/network_service_network_delegate.h index 1bfb085724b..b4b1bbc054d 100644 --- a/chromium/services/network/network_service_network_delegate.h +++ b/chromium/services/network/network_service_network_delegate.h @@ -22,6 +22,13 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) NetworkServiceNetworkDelegate private: // net::NetworkDelegateImpl implementation. + void OnBeforeSendHeaders(net::URLRequest* request, + const net::ProxyInfo& proxy_info, + const net::ProxyRetryInfoMap& proxy_retry_info, + net::HttpRequestHeaders* headers) override; + int OnBeforeStartTransaction(net::URLRequest* request, + net::CompletionOnceCallback callback, + net::HttpRequestHeaders* headers) override; int OnHeadersReceived( net::URLRequest* request, net::CompletionOnceCallback callback, @@ -29,15 +36,15 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) NetworkServiceNetworkDelegate scoped_refptr<net::HttpResponseHeaders>* override_response_headers, GURL* allowed_unsafe_redirect_url) override; bool OnCanGetCookies(const net::URLRequest& request, - const net::CookieList& cookie_list) override; + const net::CookieList& cookie_list, + bool allowed_from_caller) override; bool OnCanSetCookie(const net::URLRequest& request, const net::CanonicalCookie& cookie, - net::CookieOptions* options) override; + net::CookieOptions* options, + bool allowed_from_caller) override; bool OnCanAccessFile(const net::URLRequest& request, const base::FilePath& original_path, const base::FilePath& absolute_path) const override; - bool OnCanEnablePrivacyMode(const GURL& url, - const GURL& site_for_cookies) const override; bool OnCanQueueReportingReport(const url::Origin& origin) const override; void OnCanSendReportingReports(std::set<url::Origin> origins, diff --git a/chromium/services/network/network_service_network_delegate_unittest.cc b/chromium/services/network/network_service_network_delegate_unittest.cc deleted file mode 100644 index 86442d82413..00000000000 --- a/chromium/services/network/network_service_network_delegate_unittest.cc +++ /dev/null @@ -1,94 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "services/network/network_service_network_delegate.h" - -#include "base/test/scoped_task_environment.h" -#include "services/network/cookie_manager.h" -#include "services/network/network_context.h" -#include "services/network/network_service.h" -#include "testing/gtest/include/gtest/gtest.h" - -namespace network { -namespace { - -const GURL kURL("http://foo.com"); -const GURL kOtherURL("http://other.com"); - -class NetworkServiceNetworkDelegateTest : public testing::Test { - public: - NetworkServiceNetworkDelegateTest() - : network_service_(NetworkService::CreateForTesting()) { - mojom::NetworkContextPtr network_context_ptr; - network_context_ = std::make_unique<NetworkContext>( - network_service_.get(), mojo::MakeRequest(&network_context_ptr), - mojom::NetworkContextParams::New()); - } - - void SetContentSetting(const GURL& primary_pattern, - const GURL& secondary_pattern, - ContentSetting setting) { - network_context_->cookie_manager()->SetContentSettings( - {ContentSettingPatternSource( - ContentSettingsPattern::FromURL(primary_pattern), - ContentSettingsPattern::FromURL(secondary_pattern), - base::Value(setting), std::string(), false)}); - } - - void SetBlockThirdParty(bool block) { - network_context_->cookie_manager()->BlockThirdPartyCookies(block); - } - - NetworkContext* network_context() const { return network_context_.get(); } - - private: - base::test::ScopedTaskEnvironment scoped_task_environment_; - std::unique_ptr<NetworkService> network_service_; - std::unique_ptr<NetworkContext> network_context_; -}; - -TEST_F(NetworkServiceNetworkDelegateTest, PrivacyModeDisabledByDefault) { - NetworkServiceNetworkDelegate delegate(network_context()); - - EXPECT_FALSE(delegate.CanEnablePrivacyMode(kURL, kOtherURL)); -} - -TEST_F(NetworkServiceNetworkDelegateTest, PrivacyModeEnabledIfCookiesBlocked) { - NetworkServiceNetworkDelegate delegate(network_context()); - - SetContentSetting(kURL, kOtherURL, CONTENT_SETTING_BLOCK); - EXPECT_TRUE(delegate.CanEnablePrivacyMode(kURL, kOtherURL)); -} - -TEST_F(NetworkServiceNetworkDelegateTest, PrivacyModeDisabledIfCookiesAllowed) { - NetworkServiceNetworkDelegate delegate(network_context()); - - SetContentSetting(kURL, kOtherURL, CONTENT_SETTING_ALLOW); - EXPECT_FALSE(delegate.CanEnablePrivacyMode(kURL, kOtherURL)); -} - -TEST_F(NetworkServiceNetworkDelegateTest, - PrivacyModeDisabledIfCookiesSettingForOtherURL) { - NetworkServiceNetworkDelegate delegate(network_context()); - - // URLs are switched so setting should not apply. - SetContentSetting(kOtherURL, kURL, CONTENT_SETTING_BLOCK); - EXPECT_FALSE(delegate.CanEnablePrivacyMode(kURL, kOtherURL)); -} - -TEST_F(NetworkServiceNetworkDelegateTest, - PrivacyModeEnabledIfThirdPartyCookiesBlocked) { - NetworkServiceNetworkDelegate delegate(network_context()); - - SetBlockThirdParty(true); - EXPECT_TRUE(delegate.CanEnablePrivacyMode(kURL, kOtherURL)); - EXPECT_FALSE(delegate.CanEnablePrivacyMode(kURL, kURL)); - - SetBlockThirdParty(false); - EXPECT_FALSE(delegate.CanEnablePrivacyMode(kURL, kOtherURL)); - EXPECT_FALSE(delegate.CanEnablePrivacyMode(kURL, kURL)); -} - -} // namespace -} // namespace network diff --git a/chromium/services/network/network_service_proxy_delegate.cc b/chromium/services/network/network_service_proxy_delegate.cc new file mode 100644 index 00000000000..105a070a6b2 --- /dev/null +++ b/chromium/services/network/network_service_proxy_delegate.cc @@ -0,0 +1,157 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "services/network/network_service_proxy_delegate.h" +#include "net/base/url_util.h" +#include "net/http/http_request_headers.h" +#include "net/http/http_util.h" +#include "net/proxy_resolution/proxy_info.h" +#include "services/network/url_loader.h" + +namespace network { +namespace { + +void GetAlternativeProxy(const GURL& url, + const net::ProxyRetryInfoMap& proxy_retry_info, + net::ProxyInfo* result) { + net::ProxyServer resolved_proxy_server = result->proxy_server(); + DCHECK(resolved_proxy_server.is_valid()); + + // Right now, HTTPS proxies are assumed to support quic. If this needs to + // change, add a setting in CustomProxyConfig to control this behavior. + if (!resolved_proxy_server.is_https()) + return; + + net::ProxyInfo alternative_proxy_info; + alternative_proxy_info.UseProxyServer(net::ProxyServer( + net::ProxyServer::SCHEME_QUIC, resolved_proxy_server.host_port_pair())); + alternative_proxy_info.DeprioritizeBadProxies(proxy_retry_info); + + if (alternative_proxy_info.is_empty()) + return; + + result->SetAlternativeProxy(alternative_proxy_info.proxy_server()); +} + +bool ApplyProxyConfigToProxyInfo(const net::ProxyConfig::ProxyRules& rules, + const net::ProxyRetryInfoMap& proxy_retry_info, + const GURL& url, + net::ProxyInfo* proxy_info) { + DCHECK(proxy_info); + if (rules.empty()) + return false; + + rules.Apply(url, proxy_info); + proxy_info->DeprioritizeBadProxies(proxy_retry_info); + return !proxy_info->proxy_server().is_direct(); +} + +// Checks if |target_proxy| is in |proxy_list|. +bool CheckProxyList(const net::ProxyList& proxy_list, + const net::ProxyServer& target_proxy) { + for (const auto& proxy : proxy_list.GetAll()) { + if (proxy.host_port_pair().Equals(target_proxy.host_port_pair())) + return true; + } + return false; +} + +} // namespace + +NetworkServiceProxyDelegate::NetworkServiceProxyDelegate( + mojom::CustomProxyConfigClientRequest config_client_request) + : binding_(this, std::move(config_client_request)) {} + +void NetworkServiceProxyDelegate::OnBeforeStartTransaction( + net::URLRequest* request, + net::HttpRequestHeaders* headers) { + if (!MayProxyURL(request->url())) + return; + + headers->MergeFrom(proxy_config_->pre_cache_headers); + + auto* url_loader = URLLoader::ForRequest(*request); + if (url_loader) { + headers->MergeFrom(url_loader->custom_proxy_pre_cache_headers()); + } +} + +void NetworkServiceProxyDelegate::OnBeforeSendHeaders( + net::URLRequest* request, + const net::ProxyInfo& proxy_info, + net::HttpRequestHeaders* headers) { + auto* url_loader = URLLoader::ForRequest(*request); + if (IsInProxyConfig(proxy_info.proxy_server())) { + headers->MergeFrom(proxy_config_->post_cache_headers); + + if (url_loader) { + headers->MergeFrom(url_loader->custom_proxy_post_cache_headers()); + } + // TODO(crbug.com/721403): This check may be incorrect if a new proxy config + // is set between OnBeforeStartTransaction and here. + } else if (MayProxyURL(request->url())) { + for (const auto& kv : proxy_config_->pre_cache_headers.GetHeaderVector()) { + headers->RemoveHeader(kv.key); + } + + if (url_loader) { + for (const auto& kv : + url_loader->custom_proxy_pre_cache_headers().GetHeaderVector()) { + headers->RemoveHeader(kv.key); + } + } + } +} + +NetworkServiceProxyDelegate::~NetworkServiceProxyDelegate() {} + +void NetworkServiceProxyDelegate::OnResolveProxy( + const GURL& url, + const std::string& method, + const net::ProxyRetryInfoMap& proxy_retry_info, + net::ProxyInfo* result) { + if (!EligibleForProxy(*result, url, method)) + return; + + net::ProxyInfo proxy_info; + if (ApplyProxyConfigToProxyInfo(proxy_config_->rules, proxy_retry_info, url, + &proxy_info)) { + DCHECK(!proxy_info.is_empty() && !proxy_info.is_direct()); + result->OverrideProxyList(proxy_info.proxy_list()); + GetAlternativeProxy(url, proxy_retry_info, result); + } +} + +void NetworkServiceProxyDelegate::OnFallback(const net::ProxyServer& bad_proxy, + int net_error) {} + +void NetworkServiceProxyDelegate::OnCustomProxyConfigUpdated( + mojom::CustomProxyConfigPtr proxy_config) { + DCHECK(proxy_config->rules.empty() || + !proxy_config->rules.proxies_for_http.IsEmpty()); + proxy_config_ = std::move(proxy_config); +} + +bool NetworkServiceProxyDelegate::IsInProxyConfig( + const net::ProxyServer& proxy_server) const { + if (!proxy_server.is_valid() || proxy_server.is_direct()) + return false; + + return CheckProxyList(proxy_config_->rules.proxies_for_http, proxy_server); +} + +bool NetworkServiceProxyDelegate::MayProxyURL(const GURL& url) const { + return url.SchemeIs(url::kHttpScheme) && !proxy_config_->rules.empty() && + !net::IsLocalhost(url); +} + +bool NetworkServiceProxyDelegate::EligibleForProxy( + const net::ProxyInfo& proxy_info, + const GURL& url, + const std::string& method) const { + return proxy_info.is_direct() && proxy_info.proxy_list().size() == 1 && + MayProxyURL(url) && net::HttpUtil::IsMethodIdempotent(method); +} + +} // namespace network diff --git a/chromium/services/network/network_service_proxy_delegate.h b/chromium/services/network/network_service_proxy_delegate.h new file mode 100644 index 00000000000..1d8fc1c7f65 --- /dev/null +++ b/chromium/services/network/network_service_proxy_delegate.h @@ -0,0 +1,70 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef SERVICES_NETWORK_NETWORK_SERVICE_PROXY_DELEGATE_H_ +#define SERVICES_NETWORK_NETWORK_SERVICE_PROXY_DELEGATE_H_ + +#include "base/component_export.h" +#include "base/macros.h" +#include "mojo/public/cpp/bindings/binding.h" +#include "net/base/proxy_delegate.h" +#include "services/network/public/mojom/network_context.mojom.h" + +namespace net { +class HttpRequestHeaders; +class URLRequest; +} // namespace net + +namespace network { + +// NetworkServiceProxyDelegate is used to support the custom proxy +// configuration, which can be set in +// NetworkContextParams.custom_proxy_config_client_request. +class COMPONENT_EXPORT(NETWORK_SERVICE) NetworkServiceProxyDelegate + : public net::ProxyDelegate, + public mojom::CustomProxyConfigClient { + public: + explicit NetworkServiceProxyDelegate( + mojom::CustomProxyConfigClientRequest config_client_request); + ~NetworkServiceProxyDelegate() override; + + // These methods are forwarded from the NetworkDelegate. + void OnBeforeStartTransaction(net::URLRequest* request, + net::HttpRequestHeaders* headers); + void OnBeforeSendHeaders(net::URLRequest* request, + const net::ProxyInfo& proxy_info, + net::HttpRequestHeaders* headers); + + // net::ProxyDelegate implementation: + void OnResolveProxy(const GURL& url, + const std::string& method, + const net::ProxyRetryInfoMap& proxy_retry_info, + net::ProxyInfo* result) override; + void OnFallback(const net::ProxyServer& bad_proxy, int net_error) override; + + private: + // Checks whether |proxy_server| is present in the current proxy config. + bool IsInProxyConfig(const net::ProxyServer& proxy_server) const; + + // Whether the current config may proxy |url|. + bool MayProxyURL(const GURL& url) const; + + // Whether the |url| with current |proxy_info| is eligible to be proxied. + bool EligibleForProxy(const net::ProxyInfo& proxy_info, + const GURL& url, + const std::string& method) const; + + // mojom::CustomProxyConfigClient implementation: + void OnCustomProxyConfigUpdated( + mojom::CustomProxyConfigPtr proxy_config) override; + + mojom::CustomProxyConfigPtr proxy_config_; + mojo::Binding<mojom::CustomProxyConfigClient> binding_; + + DISALLOW_COPY_AND_ASSIGN(NetworkServiceProxyDelegate); +}; + +} // namespace network + +#endif // SERVICES_NETWORK_NETWORK_SERVICE_PROXY_DELEGATE_H_ diff --git a/chromium/services/network/network_service_proxy_delegate_unittest.cc b/chromium/services/network/network_service_proxy_delegate_unittest.cc new file mode 100644 index 00000000000..ddce4b2152a --- /dev/null +++ b/chromium/services/network/network_service_proxy_delegate_unittest.cc @@ -0,0 +1,319 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "services/network/network_service_proxy_delegate.h" +#include "base/test/scoped_task_environment.h" +#include "net/url_request/url_request_test_util.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace network { +namespace { + +constexpr char kHttpUrl[] = "http://example.com"; +constexpr char kLocalhost[] = "http://localhost"; +constexpr char kHttpsUrl[] = "https://example.com"; +constexpr char kWebsocketUrl[] = "ws://example.com"; + +} // namespace + +class NetworkServiceProxyDelegateTest : public testing::Test { + public: + NetworkServiceProxyDelegateTest() {} + + void SetUp() override { + context_ = std::make_unique<net::TestURLRequestContext>(true); + context_->Init(); + } + + protected: + std::unique_ptr<NetworkServiceProxyDelegate> CreateDelegate( + mojom::CustomProxyConfigPtr config) { + mojom::CustomProxyConfigClientPtr client; + auto delegate = std::make_unique<NetworkServiceProxyDelegate>( + mojo::MakeRequest(&client)); + client->OnCustomProxyConfigUpdated(std::move(config)); + scoped_task_environment_.RunUntilIdle(); + return delegate; + } + + std::unique_ptr<net::URLRequest> CreateRequest(const GURL& url) { + return context_->CreateRequest(url, net::DEFAULT_PRIORITY, nullptr); + } + + private: + std::unique_ptr<net::TestURLRequestContext> context_; + base::test::ScopedTaskEnvironment scoped_task_environment_; +}; + +TEST_F(NetworkServiceProxyDelegateTest, AddsHeadersBeforeCache) { + auto config = mojom::CustomProxyConfig::New(); + config->rules.ParseFromString("http=proxy"); + config->pre_cache_headers.SetHeader("foo", "bar"); + auto delegate = CreateDelegate(std::move(config)); + + net::HttpRequestHeaders headers; + auto request = CreateRequest(GURL(kHttpUrl)); + delegate->OnBeforeStartTransaction(request.get(), &headers); + + std::string value; + EXPECT_TRUE(headers.GetHeader("foo", &value)); + EXPECT_EQ(value, "bar"); +} + +TEST_F(NetworkServiceProxyDelegateTest, + DoesNotAddHeadersBeforeCacheForLocalhost) { + auto config = mojom::CustomProxyConfig::New(); + config->rules.ParseFromString("http=proxy"); + config->pre_cache_headers.SetHeader("foo", "bar"); + auto delegate = CreateDelegate(std::move(config)); + + net::HttpRequestHeaders headers; + auto request = CreateRequest(GURL(kLocalhost)); + delegate->OnBeforeStartTransaction(request.get(), &headers); + + EXPECT_TRUE(headers.IsEmpty()); +} + +TEST_F(NetworkServiceProxyDelegateTest, DoesNotAddHeadersBeforeCacheForHttps) { + auto config = mojom::CustomProxyConfig::New(); + config->rules.ParseFromString("http=proxy"); + config->pre_cache_headers.SetHeader("foo", "bar"); + auto delegate = CreateDelegate(std::move(config)); + + net::HttpRequestHeaders headers; + auto request = CreateRequest(GURL(kHttpsUrl)); + delegate->OnBeforeStartTransaction(request.get(), &headers); + + EXPECT_TRUE(headers.IsEmpty()); +} + +TEST_F(NetworkServiceProxyDelegateTest, AddsHeadersAfterCache) { + auto config = mojom::CustomProxyConfig::New(); + config->rules.ParseFromString("http=proxy"); + config->post_cache_headers.SetHeader("foo", "bar"); + auto delegate = CreateDelegate(std::move(config)); + + net::HttpRequestHeaders headers; + auto request = CreateRequest(GURL(kHttpUrl)); + net::ProxyInfo info; + info.UsePacString("PROXY proxy"); + delegate->OnBeforeSendHeaders(request.get(), info, &headers); + + std::string value; + EXPECT_TRUE(headers.GetHeader("foo", &value)); + EXPECT_EQ(value, "bar"); +} + +TEST_F(NetworkServiceProxyDelegateTest, + DoesNotAddHeadersAfterCacheForProxyNotInConfig) { + auto config = mojom::CustomProxyConfig::New(); + config->rules.ParseFromString("http=proxy"); + config->post_cache_headers.SetHeader("foo", "bar"); + auto delegate = CreateDelegate(std::move(config)); + + net::HttpRequestHeaders headers; + auto request = CreateRequest(GURL(kHttpUrl)); + net::ProxyInfo info; + info.UsePacString("PROXY other"); + delegate->OnBeforeSendHeaders(request.get(), info, &headers); + + EXPECT_TRUE(headers.IsEmpty()); +} + +TEST_F(NetworkServiceProxyDelegateTest, DoesNotAddHeadersAfterCacheForDirect) { + auto config = mojom::CustomProxyConfig::New(); + config->rules.ParseFromString("http=proxy"); + config->post_cache_headers.SetHeader("foo", "bar"); + auto delegate = CreateDelegate(std::move(config)); + + net::HttpRequestHeaders headers; + auto request = CreateRequest(GURL(kHttpUrl)); + net::ProxyInfo info; + info.UseDirect(); + delegate->OnBeforeSendHeaders(request.get(), info, &headers); + + EXPECT_TRUE(headers.IsEmpty()); +} + +TEST_F(NetworkServiceProxyDelegateTest, + RemovesPreCacheHeadersWhenProxyNotInConfig) { + auto config = mojom::CustomProxyConfig::New(); + config->rules.ParseFromString("http=proxy"); + config->pre_cache_headers.SetHeader("foo", "bar"); + auto delegate = CreateDelegate(std::move(config)); + + net::HttpRequestHeaders headers; + headers.SetHeader("foo", "bar"); + auto request = CreateRequest(GURL(kHttpUrl)); + net::ProxyInfo info; + info.UseDirect(); + delegate->OnBeforeSendHeaders(request.get(), info, &headers); + + EXPECT_TRUE(headers.IsEmpty()); +} + +TEST_F(NetworkServiceProxyDelegateTest, + DoesNotRemoveHeaderForHttpsIfAlreadyExists) { + auto config = mojom::CustomProxyConfig::New(); + config->rules.ParseFromString("http=proxy"); + config->pre_cache_headers.SetHeader("foo", "bad"); + auto delegate = CreateDelegate(std::move(config)); + + net::HttpRequestHeaders headers; + headers.SetHeader("foo", "value"); + auto request = CreateRequest(GURL(kHttpsUrl)); + net::ProxyInfo info; + info.UseDirect(); + delegate->OnBeforeSendHeaders(request.get(), info, &headers); + + std::string value; + EXPECT_TRUE(headers.GetHeader("foo", &value)); + EXPECT_EQ(value, "value"); +} + +TEST_F(NetworkServiceProxyDelegateTest, KeepsPreCacheHeadersWhenProxyInConfig) { + auto config = mojom::CustomProxyConfig::New(); + config->rules.ParseFromString("http=proxy"); + config->pre_cache_headers.SetHeader("foo", "bar"); + auto delegate = CreateDelegate(std::move(config)); + + net::HttpRequestHeaders headers; + headers.SetHeader("foo", "bar"); + auto request = CreateRequest(GURL(kHttpUrl)); + net::ProxyInfo info; + info.UsePacString("PROXY proxy"); + delegate->OnBeforeSendHeaders(request.get(), info, &headers); + + std::string value; + EXPECT_TRUE(headers.GetHeader("foo", &value)); + EXPECT_EQ(value, "bar"); +} + +TEST_F(NetworkServiceProxyDelegateTest, OnResolveProxySuccessHttpProxy) { + auto config = mojom::CustomProxyConfig::New(); + config->rules.ParseFromString("http=foo"); + auto delegate = CreateDelegate(std::move(config)); + + net::ProxyInfo result; + result.UseDirect(); + delegate->OnResolveProxy(GURL(kHttpUrl), "GET", net::ProxyRetryInfoMap(), + &result); + + net::ProxyList expected_proxy_list; + expected_proxy_list.AddProxyServer( + net::ProxyServer::FromPacString("PROXY foo")); + EXPECT_TRUE(result.proxy_list().Equals(expected_proxy_list)); + // HTTP proxies are nto used as alternative QUIC proxies. + EXPECT_FALSE(result.alternative_proxy().is_valid()); +} + +TEST_F(NetworkServiceProxyDelegateTest, OnResolveProxySuccessHttpsProxy) { + auto config = mojom::CustomProxyConfig::New(); + config->rules.ParseFromString("http=https://foo"); + auto delegate = CreateDelegate(std::move(config)); + + net::ProxyInfo result; + result.UseDirect(); + delegate->OnResolveProxy(GURL(kHttpUrl), "GET", net::ProxyRetryInfoMap(), + &result); + + net::ProxyList expected_proxy_list; + expected_proxy_list.AddProxyServer( + net::ProxyServer::FromPacString("HTTPS foo")); + EXPECT_TRUE(result.proxy_list().Equals(expected_proxy_list)); + EXPECT_EQ(result.alternative_proxy(), + net::ProxyServer::FromPacString("QUIC foo")); +} + +TEST_F(NetworkServiceProxyDelegateTest, OnResolveProxyLocalhost) { + auto config = mojom::CustomProxyConfig::New(); + config->rules.ParseFromString("http=foo"); + auto delegate = CreateDelegate(std::move(config)); + + net::ProxyInfo result; + result.UseDirect(); + delegate->OnResolveProxy(GURL(kLocalhost), "GET", net::ProxyRetryInfoMap(), + &result); + + EXPECT_TRUE(result.is_direct()); + EXPECT_FALSE(result.alternative_proxy().is_valid()); +} + +TEST_F(NetworkServiceProxyDelegateTest, OnResolveProxyEmptyConfig) { + auto delegate = CreateDelegate(mojom::CustomProxyConfig::New()); + + net::ProxyInfo result; + result.UseDirect(); + delegate->OnResolveProxy(GURL(kHttpUrl), "GET", net::ProxyRetryInfoMap(), + &result); + + EXPECT_TRUE(result.is_direct()); + EXPECT_FALSE(result.alternative_proxy().is_valid()); +} + +TEST_F(NetworkServiceProxyDelegateTest, OnResolveProxyNonIdempotentMethod) { + auto config = mojom::CustomProxyConfig::New(); + config->rules.ParseFromString("http=foo"); + auto delegate = CreateDelegate(std::move(config)); + + net::ProxyInfo result; + result.UseDirect(); + delegate->OnResolveProxy(GURL(kHttpUrl), "POST", net::ProxyRetryInfoMap(), + &result); + + EXPECT_TRUE(result.is_direct()); + EXPECT_FALSE(result.alternative_proxy().is_valid()); +} + +TEST_F(NetworkServiceProxyDelegateTest, OnResolveProxyWebsocketScheme) { + auto config = mojom::CustomProxyConfig::New(); + config->rules.ParseFromString("http=foo"); + auto delegate = CreateDelegate(std::move(config)); + + net::ProxyInfo result; + result.UseDirect(); + delegate->OnResolveProxy(GURL(kWebsocketUrl), "GET", net::ProxyRetryInfoMap(), + &result); + + EXPECT_TRUE(result.is_direct()); + EXPECT_FALSE(result.alternative_proxy().is_valid()); +} + +TEST_F(NetworkServiceProxyDelegateTest, OnResolveProxyDoesNotOverrideExisting) { + auto config = mojom::CustomProxyConfig::New(); + config->rules.ParseFromString("http=foo"); + auto delegate = CreateDelegate(std::move(config)); + + net::ProxyInfo result; + result.UsePacString("PROXY bar"); + delegate->OnResolveProxy(GURL(kHttpUrl), "GET", net::ProxyRetryInfoMap(), + &result); + + net::ProxyList expected_proxy_list; + expected_proxy_list.AddProxyServer( + net::ProxyServer::FromPacString("PROXY bar")); + EXPECT_TRUE(result.proxy_list().Equals(expected_proxy_list)); + EXPECT_FALSE(result.alternative_proxy().is_valid()); +} + +TEST_F(NetworkServiceProxyDelegateTest, OnResolveProxyDeprioritizesBadProxies) { + auto config = mojom::CustomProxyConfig::New(); + config->rules.ParseFromString("http=foo,bar"); + auto delegate = CreateDelegate(std::move(config)); + + net::ProxyInfo result; + result.UseDirect(); + net::ProxyRetryInfoMap retry_map; + net::ProxyRetryInfo& info = retry_map["foo:80"]; + info.try_while_bad = false; + info.bad_until = base::TimeTicks::Now() + base::TimeDelta::FromDays(2); + delegate->OnResolveProxy(GURL(kHttpUrl), "GET", retry_map, &result); + + net::ProxyList expected_proxy_list; + expected_proxy_list.AddProxyServer( + net::ProxyServer::FromPacString("PROXY bar")); + EXPECT_TRUE(result.proxy_list().Equals(expected_proxy_list)); +} + +} // namespace network diff --git a/chromium/services/network/network_service_unittest.cc b/chromium/services/network/network_service_unittest.cc index 525fe283b6e..7e6c4861c20 100644 --- a/chromium/services/network/network_service_unittest.cc +++ b/chromium/services/network/network_service_unittest.cc @@ -36,6 +36,7 @@ #include "services/network/network_context.h" #include "services/network/network_service.h" #include "services/network/public/cpp/features.h" +#include "services/network/public/mojom/net_log.mojom.h" #include "services/network/public/mojom/network_change_manager.mojom.h" #include "services/network/public/mojom/network_service.mojom.h" #include "services/network/test/test_network_service_client.h" @@ -763,7 +764,9 @@ TEST_F(NetworkServiceTestWithService, StartsNetLog) { base::File log_file(log_path, base::File::FLAG_CREATE_ALWAYS | base::File::FLAG_WRITE); - network_service_->StartNetLog(std::move(log_file), std::move(dict)); + network_service_->StartNetLog(std::move(log_file), + network::mojom::NetLogCaptureMode::DEFAULT, + std::move(dict)); CreateNetworkContext(); LoadURL(test_server()->GetURL("/echo")); EXPECT_EQ(net::OK, client()->completion_status().error_code); @@ -1111,106 +1114,12 @@ TEST_F(NetworkServiceTestWithService, CRLSetDoesNotDowngrade) { // The SpawnedTestServer does not work on iOS. #if !defined(OS_IOS) -class AllowBadCertsNetworkServiceClient : public mojom::NetworkServiceClient { - public: - explicit AllowBadCertsNetworkServiceClient( - mojom::NetworkServiceClientRequest network_service_client_request) - : binding_(this, std::move(network_service_client_request)) {} - ~AllowBadCertsNetworkServiceClient() override {} - - // mojom::NetworkServiceClient implementation: - void OnAuthRequired( - uint32_t process_id, - uint32_t routing_id, - uint32_t request_id, - const GURL& url, - const GURL& site_for_cookies, - bool first_auth_attempt, - const scoped_refptr<net::AuthChallengeInfo>& auth_info, - int32_t resource_type, - const base::Optional<ResourceResponseHead>& head, - mojom::AuthChallengeResponderPtr auth_challenge_responder) override { - NOTREACHED(); - } - - void OnCertificateRequested( - uint32_t process_id, - uint32_t routing_id, - uint32_t request_id, - const scoped_refptr<net::SSLCertRequestInfo>& cert_info, - mojom::NetworkServiceClient::OnCertificateRequestedCallback callback) - override { - NOTREACHED(); - } - - void OnSSLCertificateError(uint32_t process_id, - uint32_t routing_id, - uint32_t request_id, - int32_t resource_type, - const GURL& url, - const net::SSLInfo& ssl_info, - bool fatal, - OnSSLCertificateErrorCallback response) override { - std::move(response).Run(net::OK); - } - - void OnFileUploadRequested(uint32_t process_id, - bool async, - const std::vector<base::FilePath>& file_paths, - OnFileUploadRequestedCallback callback) override { - NOTREACHED(); - } - - void OnCookiesRead(int process_id, - int routing_id, - const GURL& url, - const GURL& first_party_url, - const net::CookieList& cookie_list, - bool blocked_by_policy) override { - NOTREACHED(); - } - - void OnCookieChange(int process_id, - int routing_id, - const GURL& url, - const GURL& first_party_url, - const net::CanonicalCookie& cookie, - bool blocked_by_policy) override { - NOTREACHED(); - } - - void OnLoadingStateUpdate(std::vector<mojom::LoadInfoPtr> infos, - OnLoadingStateUpdateCallback callback) override { - NOTREACHED(); - } - - void OnClearSiteData(int process_id, - int routing_id, - const GURL& url, - const std::string& header_value, - int load_flags, - OnClearSiteDataCallback callback) override { - NOTREACHED(); - } - - private: - mojo::Binding<mojom::NetworkServiceClient> binding_; - - DISALLOW_COPY_AND_ASSIGN(AllowBadCertsNetworkServiceClient); -}; - // Test |primary_network_context|, which is required by AIA fetching, among // other things. TEST_F(NetworkServiceTestWithService, AIAFetching) { mojom::NetworkContextParamsPtr context_params = CreateContextParams(); - mojom::NetworkServiceClientPtr network_service_client; context_params->primary_network_context = true; - // Have to allow bad certs when using - // SpawnedTestServer::SSLOptions::CERT_AUTO_AIA_INTERMEDIATE. - AllowBadCertsNetworkServiceClient allow_bad_certs_client( - mojo::MakeRequest(&network_service_client)); - network_service_->CreateNetworkContext(mojo::MakeRequest(&network_context_), std::move(context_params)); @@ -1405,15 +1314,15 @@ TEST_F(NetworkServiceNetworkChangeTest, MAYBE_NetworkChangeManagerRequest) { manager_client.WaitForNotification(mojom::ConnectionType::CONNECTION_3G); } -class NetworkServiceClearSiteDataTest : public NetworkServiceTest { +class NetworkServiceNetworkDelegateTest : public NetworkServiceTest { public: - NetworkServiceClearSiteDataTest() { + NetworkServiceNetworkDelegateTest() { // |NetworkServiceNetworkDelegate::HandleClearSiteDataHeader| requires // Network Service. scoped_feature_list_.InitAndEnableFeature( network::features::kNetworkService); } - ~NetworkServiceClearSiteDataTest() override = default; + ~NetworkServiceNetworkDelegateTest() override = default; void CreateNetworkContext() { mojom::NetworkContextParamsPtr context_params = @@ -1459,7 +1368,7 @@ class NetworkServiceClearSiteDataTest : public NetworkServiceTest { net::test_server::EmbeddedTestServer::TYPE_HTTPS)); https_server_->SetSSLConfig(net::EmbeddedTestServer::CERT_OK); https_server_->RegisterRequestHandler(base::BindRepeating( - &NetworkServiceClearSiteDataTest::HandleHTTPSRequest, + &NetworkServiceNetworkDelegateTest::HandleHTTPSRequest, base::Unretained(this))); ASSERT_TRUE(https_server_->Start()); } @@ -1494,7 +1403,7 @@ class NetworkServiceClearSiteDataTest : public NetworkServiceTest { mojom::URLLoaderPtr loader_; base::test::ScopedFeatureList scoped_feature_list_; - DISALLOW_COPY_AND_ASSIGN(NetworkServiceClearSiteDataTest); + DISALLOW_COPY_AND_ASSIGN(NetworkServiceNetworkDelegateTest); }; class ClearSiteDataNetworkServiceClient : public TestNetworkServiceClient { @@ -1547,7 +1456,7 @@ class ClearSiteDataNetworkServiceClient : public TestNetworkServiceClient { // Check that |NetworkServiceNetworkDelegate| handles Clear-Site-Data header // w/ and w/o |NetworkServiceCient|. -TEST_F(NetworkServiceClearSiteDataTest, ClearSiteDataNetworkServiceCient) { +TEST_F(NetworkServiceNetworkDelegateTest, ClearSiteDataNetworkServiceCient) { const char kClearCookiesHeader[] = "Clear-Site-Data: \"cookies\""; CreateNetworkContext(); @@ -1574,7 +1483,7 @@ TEST_F(NetworkServiceClearSiteDataTest, ClearSiteDataNetworkServiceCient) { } // Check that headers are handled and passed to the client correctly. -TEST_F(NetworkServiceClearSiteDataTest, HandleClearSiteDataHeaders) { +TEST_F(NetworkServiceNetworkDelegateTest, HandleClearSiteDataHeaders) { const char kClearCookiesHeaderValue[] = "\"cookies\""; const char kClearCookiesHeader[] = "Clear-Site-Data: \"cookies\""; CreateNetworkContext(); diff --git a/chromium/services/network/p2p/socket.cc b/chromium/services/network/p2p/socket.cc index 8c65a4d1ab6..8595623fb9f 100644 --- a/chromium/services/network/p2p/socket.cc +++ b/chromium/services/network/p2p/socket.cc @@ -67,12 +67,7 @@ P2PSocket::P2PSocket(Delegate* delegate, : delegate_(delegate), client_(std::move(client)), binding_(this, std::move(socket)), - state_(STATE_UNINITIALIZED), protocol_type_(protocol_type), - send_packets_delayed_total_(0), - send_packets_total_(0), - send_bytes_delayed_max_(0), - send_bytes_delayed_cur_(0), weak_ptr_factory_(this) { binding_.set_connection_error_handler( base::BindOnce(&P2PSocket::OnError, base::Unretained(this))); diff --git a/chromium/services/network/p2p/socket.h b/chromium/services/network/p2p/socket.h index f6c426813bd..099d8b6b94b 100644 --- a/chromium/services/network/p2p/socket.h +++ b/chromium/services/network/p2p/socket.h @@ -150,20 +150,19 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) P2PSocket : public mojom::P2PSocket { Delegate* delegate_; mojom::P2PSocketClientPtr client_; mojo::Binding<mojom::P2PSocket> binding_; - State state_; ProtocolType protocol_type_; private: // Track total delayed packets for calculating how many packets are // delayed by system at the end of call. - uint32_t send_packets_delayed_total_; - uint32_t send_packets_total_; + uint32_t send_packets_delayed_total_ = 0; + uint32_t send_packets_total_ = 0; // Track the maximum of consecutive delayed bytes caused by system's // EWOULDBLOCK. - int32_t send_bytes_delayed_max_; - int32_t send_bytes_delayed_cur_; + int32_t send_bytes_delayed_max_ = 0; + int32_t send_bytes_delayed_cur_ = 0; base::WeakPtrFactory<P2PSocket> weak_ptr_factory_; diff --git a/chromium/services/network/p2p/socket_manager.cc b/chromium/services/network/p2p/socket_manager.cc index 9c16446ebf2..27d85b6a4f5 100644 --- a/chromium/services/network/p2p/socket_manager.cc +++ b/chromium/services/network/p2p/socket_manager.cc @@ -89,12 +89,14 @@ class P2PSocketManager::DnsRequest { if (host_name_.back() != '.') host_name_ += '.'; - net::HostResolver::RequestInfo info(net::HostPortPair(host_name_, 0)); - int result = - resolver_->Resolve(info, net::DEFAULT_PRIORITY, &addresses_, - base::BindOnce(&P2PSocketManager::DnsRequest::OnDone, - base::Unretained(this)), - &request_, net::NetLogWithSource()); + net::HostPortPair host(host_name_, 0); + // TODO(crbug.com/879746): Pass in a + // net::HostResolver::ResolveHostParameters with source set to MDNS if we + // have a ".local." TLD (once MDNS is supported). + request_ = + resolver_->CreateRequest(host, net::NetLogWithSource(), base::nullopt); + int result = request_->Start(base::BindOnce( + &P2PSocketManager::DnsRequest::OnDone, base::Unretained(this))); if (result != net::ERR_IO_PENDING) OnDone(result); } @@ -102,26 +104,24 @@ class P2PSocketManager::DnsRequest { private: void OnDone(int result) { net::IPAddressList list; - if (result != net::OK) { + const base::Optional<net::AddressList>& addresses = + request_->GetAddressResults(); + if (result != net::OK || !addresses) { LOG(ERROR) << "Failed to resolve address for " << host_name_ << ", errorcode: " << result; done_callback_.Run(list); return; } - DCHECK(!addresses_.empty()); - for (net::AddressList::iterator iter = addresses_.begin(); - iter != addresses_.end(); ++iter) { - list.push_back(iter->address()); + for (const auto& endpoint : *addresses) { + list.push_back(endpoint.address()); } done_callback_.Run(list); } - net::AddressList addresses_; - std::string host_name_; net::HostResolver* resolver_; - std::unique_ptr<net::HostResolver::Request> request_; + std::unique_ptr<net::HostResolver::ResolveHostRequest> request_; DoneCallback done_callback_; }; diff --git a/chromium/services/network/p2p/socket_tcp.cc b/chromium/services/network/p2p/socket_tcp.cc index 84172180c0c..ba0919b9b37 100644 --- a/chromium/services/network/p2p/socket_tcp.cc +++ b/chromium/services/network/p2p/socket_tcp.cc @@ -53,7 +53,7 @@ P2PSocketTcp::SendBuffer::SendBuffer( buffer(buffer), traffic_annotation(traffic_annotation) {} P2PSocketTcp::SendBuffer::SendBuffer(const SendBuffer& rhs) = default; -P2PSocketTcp::SendBuffer::~SendBuffer() {} +P2PSocketTcp::SendBuffer::~SendBuffer() = default; P2PSocketTcpBase::P2PSocketTcpBase( Delegate* delegate, @@ -62,27 +62,17 @@ P2PSocketTcpBase::P2PSocketTcpBase( P2PSocketType type, ProxyResolvingClientSocketFactory* proxy_resolving_socket_factory) : P2PSocket(delegate, std::move(client), std::move(socket), P2PSocket::TCP), - write_pending_(false), - connected_(false), type_(type), proxy_resolving_socket_factory_(proxy_resolving_socket_factory) {} -P2PSocketTcpBase::~P2PSocketTcpBase() { - if (state_ == STATE_OPEN) { - DCHECK(socket_.get()); - socket_.reset(); - } -} +P2PSocketTcpBase::~P2PSocketTcpBase() = default; void P2PSocketTcpBase::InitAccepted(const net::IPEndPoint& remote_address, std::unique_ptr<net::StreamSocket> socket) { DCHECK(socket); - DCHECK_EQ(state_, STATE_UNINITIALIZED); - remote_address_.ip_address = remote_address; // TODO(ronghuawu): Add FakeSSLServerSocket. socket_ = std::move(socket); - state_ = STATE_OPEN; DoRead(); } @@ -90,10 +80,9 @@ void P2PSocketTcpBase::Init(const net::IPEndPoint& local_address, uint16_t min_port, uint16_t max_port, const P2PHostAndIPEndPoint& remote_address) { - DCHECK_EQ(state_, STATE_UNINITIALIZED); + DCHECK(!socket_); remote_address_ = remote_address; - state_ = STATE_CONNECTING; net::HostPortPair dest_host_port_pair; // If there is a domain name, let's try it first, it's required by some proxy @@ -129,7 +118,6 @@ void P2PSocketTcpBase::Init(const net::IPEndPoint& local_address, } void P2PSocketTcpBase::OnConnected(int result) { - DCHECK_EQ(state_, STATE_CONNECTING); DCHECK_NE(result, net::ERR_IO_PENDING); if (result != net::OK) { @@ -142,7 +130,6 @@ void P2PSocketTcpBase::OnConnected(int result) { } void P2PSocketTcpBase::OnOpen() { - state_ = STATE_OPEN; // Setting socket send and receive buffer size. if (net::OK != socket_->SetReceiveBufferSize(kTcpRecvSocketBufferSize)) { LOG(WARNING) << "Failed to set socket receive buffer size to " @@ -157,7 +144,6 @@ void P2PSocketTcpBase::OnOpen() { if (!DoSendSocketCreateMsg()) return; - DCHECK_EQ(state_, STATE_OPEN); DoRead(); } @@ -208,7 +194,7 @@ bool P2PSocketTcpBase::DoSendSocketCreateMsg() { void P2PSocketTcpBase::DoRead() { while (true) { if (!read_buffer_.get()) { - read_buffer_ = new net::GrowableIOBuffer(); + read_buffer_ = base::MakeRefCounted<net::GrowableIOBuffer>(); read_buffer_->SetCapacity(kTcpReadBufferSize); } else if (read_buffer_->RemainingCapacity() < kTcpReadBufferSize) { // Make sure that we always have at least kTcpReadBufferSize of @@ -326,8 +312,6 @@ bool P2PSocketTcpBase::HandleWriteResult(int result) { } bool P2PSocketTcpBase::HandleReadResult(int result) { - DCHECK_EQ(state_, STATE_OPEN); - if (result < 0) { LOG(ERROR) << "Error when reading from TCP socket: " << result; OnError(); diff --git a/chromium/services/network/p2p/socket_tcp.h b/chromium/services/network/p2p/socket_tcp.h index bac29f8afb5..168326433df 100644 --- a/chromium/services/network/p2p/socket_tcp.h +++ b/chromium/services/network/p2p/socket_tcp.h @@ -110,10 +110,10 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) P2PSocketTcpBase : public P2PSocket { base::queue<SendBuffer> write_queue_; SendBuffer write_buffer_; - bool write_pending_; + bool write_pending_ = false; - bool connected_; - P2PSocketType type_; + bool connected_ = false; + const P2PSocketType type_; ProxyResolvingClientSocketFactory* proxy_resolving_socket_factory_; DISALLOW_COPY_AND_ASSIGN(P2PSocketTcpBase); diff --git a/chromium/services/network/p2p/socket_tcp_server.cc b/chromium/services/network/p2p/socket_tcp_server.cc index 3f9ec1358e2..f1cd3848ef4 100644 --- a/chromium/services/network/p2p/socket_tcp_server.cc +++ b/chromium/services/network/p2p/socket_tcp_server.cc @@ -32,20 +32,13 @@ P2PSocketTcpServer::P2PSocketTcpServer(Delegate* delegate, accept_callback_(base::BindRepeating(&P2PSocketTcpServer::OnAccepted, base::Unretained(this))) {} -P2PSocketTcpServer::~P2PSocketTcpServer() { - if (state_ == STATE_OPEN) { - DCHECK(socket_.get()); - socket_.reset(); - } -} +P2PSocketTcpServer::~P2PSocketTcpServer() = default; // TODO(guidou): Add support for port range. void P2PSocketTcpServer::Init(const net::IPEndPoint& local_address, uint16_t min_port, uint16_t max_port, const P2PHostAndIPEndPoint& remote_address) { - DCHECK_EQ(state_, STATE_UNINITIALIZED); - int result = socket_->Listen(local_address, kListenBacklog); if (result < 0) { LOG(ERROR) << "Listen() failed: " << result; @@ -62,7 +55,6 @@ void P2PSocketTcpServer::Init(const net::IPEndPoint& local_address, } VLOG(1) << "Local address: " << local_address_.ToString(); - state_ = STATE_OPEN; // NOTE: Remote address can be empty as socket is just listening // in this state. client_->SocketCreated(local_address_, remote_address.ip_address); diff --git a/chromium/services/network/p2p/socket_tcp_unittest.cc b/chromium/services/network/p2p/socket_tcp_unittest.cc index 6bb77432a8c..0935e1d279c 100644 --- a/chromium/services/network/p2p/socket_tcp_unittest.cc +++ b/chromium/services/network/p2p/socket_tcp_unittest.cc @@ -64,7 +64,6 @@ class P2PSocketTcpTestBase : public testing::Test { local_address_ = ParseAddress(kTestLocalIpAddress, kTestPort1); socket_impl_->remote_address_ = dest_; - socket_impl_->state_ = P2PSocket::STATE_CONNECTING; socket_impl_->OnConnected(net::OK); base::RunLoop().RunUntilIdle(); } diff --git a/chromium/services/network/p2p/socket_udp.cc b/chromium/services/network/p2p/socket_udp.cc index eb487babb26..e7609dd65d8 100644 --- a/chromium/services/network/p2p/socket_udp.cc +++ b/chromium/services/network/p2p/socket_udp.cc @@ -66,6 +66,17 @@ const char* GetTransientErrorName(int error) { return ""; } +std::unique_ptr<net::DatagramServerSocket> DefaultSocketFactory( + net::NetLog* net_log) { + net::UDPServerSocket* socket = + new net::UDPServerSocket(net_log, net::NetLogSource()); +#if defined(OS_WIN) + socket->UseNonBlockingIO(); +#endif + + return base::WrapUnique(socket); +} + } // namespace namespace network { @@ -77,7 +88,7 @@ P2PSocketUdp::PendingPacket::PendingPacket( uint64_t id, const net::NetworkTrafficAnnotationTag traffic_annotation) : to(to), - data(new net::IOBuffer(content.size())), + data(base::MakeRefCounted<net::IOBuffer>(content.size())), size(content.size()), packet_options(options), id(id), @@ -87,8 +98,7 @@ P2PSocketUdp::PendingPacket::PendingPacket( P2PSocketUdp::PendingPacket::PendingPacket(const PendingPacket& other) = default; - -P2PSocketUdp::PendingPacket::~PendingPacket() {} +P2PSocketUdp::PendingPacket::~PendingPacket() = default; P2PSocketUdp::P2PSocketUdp(Delegate* Delegate, mojom::P2PSocketClientPtr client, @@ -97,9 +107,6 @@ P2PSocketUdp::P2PSocketUdp(Delegate* Delegate, net::NetLog* net_log, const DatagramServerSocketFactory& socket_factory) : P2PSocket(Delegate, std::move(client), std::move(socket), P2PSocket::UDP), - socket_(socket_factory.Run(net_log)), - send_pending_(false), - last_dscp_(net::DSCP_CS0), throttler_(throttler), net_log_(net_log), socket_factory_(socket_factory) {} @@ -114,23 +121,20 @@ P2PSocketUdp::P2PSocketUdp(Delegate* Delegate, std::move(socket), throttler, net_log, - base::Bind(&P2PSocketUdp::DefaultSocketFactory)) {} + base::BindRepeating(&DefaultSocketFactory)) {} -P2PSocketUdp::~P2PSocketUdp() { - if (state_ == STATE_OPEN) { - DCHECK(socket_.get()); - socket_.reset(); - } -} +P2PSocketUdp::~P2PSocketUdp() = default; void P2PSocketUdp::Init(const net::IPEndPoint& local_address, uint16_t min_port, uint16_t max_port, const P2PHostAndIPEndPoint& remote_address) { - DCHECK_EQ(state_, STATE_UNINITIALIZED); + DCHECK(!socket_); DCHECK((min_port == 0 && max_port == 0) || min_port > 0); DCHECK_LE(min_port, max_port); + socket_ = socket_factory_.Run(net_log_); + int result = -1; if (min_port == 0) { result = socket_->Listen(local_address); @@ -177,12 +181,10 @@ void P2PSocketUdp::Init(const net::IPEndPoint& local_address, } VLOG(1) << "Local address: " << address.ToString(); - state_ = STATE_OPEN; - // NOTE: Remote address will be same as what renderer provided. client_->SocketCreated(address, remote_address.ip_address); - recv_buffer_ = new net::IOBuffer(kUdpReadBufferSize); + recv_buffer_ = base::MakeRefCounted<net::IOBuffer>(kUdpReadBufferSize); DoRead(); } @@ -202,8 +204,6 @@ void P2PSocketUdp::OnRecv(int result) { } bool P2PSocketUdp::HandleReadResult(int result) { - DCHECK_EQ(STATE_OPEN, state_); - if (result > 0) { std::vector<int8_t> data(recv_buffer_->data(), recv_buffer_->data() + result); @@ -342,7 +342,7 @@ void P2PSocketUdp::OnSend(uint64_t packet_id, } // Send next packets if we have them waiting in the buffer. - while (state_ == STATE_OPEN && !send_queue_.empty() && !send_pending_) { + while (!send_queue_.empty() && !send_pending_) { PendingPacket packet = send_queue_.front(); send_queue_.pop_front(); if (!DoSend(packet)) @@ -437,16 +437,4 @@ int P2PSocketUdp::SetSocketDiffServCodePointInternal( #endif } -// static -std::unique_ptr<net::DatagramServerSocket> P2PSocketUdp::DefaultSocketFactory( - net::NetLog* net_log) { - net::UDPServerSocket* socket = - new net::UDPServerSocket(net_log, net::NetLogSource()); -#if defined(OS_WIN) - socket->UseNonBlockingIO(); -#endif - - return base::WrapUnique(socket); -} - } // namespace network diff --git a/chromium/services/network/p2p/socket_udp.h b/chromium/services/network/p2p/socket_udp.h index 5b65792e257..39c3fb0ebde 100644 --- a/chromium/services/network/p2p/socket_udp.h +++ b/chromium/services/network/p2p/socket_udp.h @@ -104,16 +104,14 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) P2PSocketUdp : public P2PSocket { int result); int SetSocketDiffServCodePointInternal(net::DiffServCodePoint dscp); - static std::unique_ptr<net::DatagramServerSocket> DefaultSocketFactory( - net::NetLog* net_log); std::unique_ptr<net::DatagramServerSocket> socket_; scoped_refptr<net::IOBuffer> recv_buffer_; net::IPEndPoint recv_address_; base::circular_deque<PendingPacket> send_queue_; - bool send_pending_; - net::DiffServCodePoint last_dscp_; + bool send_pending_ = false; + net::DiffServCodePoint last_dscp_ = net::DSCP_CS0; // Set of peer for which we have received STUN binding request or // response or relay allocation request or response. diff --git a/chromium/services/network/proxy_lookup_request.cc b/chromium/services/network/proxy_lookup_request.cc index fd0dde81c79..61301c97cc4 100644 --- a/chromium/services/network/proxy_lookup_request.cc +++ b/chromium/services/network/proxy_lookup_request.cc @@ -37,11 +37,6 @@ ProxyLookupRequest::~ProxyLookupRequest() { void ProxyLookupRequest::Start(const GURL& url) { proxy_lookup_client_.set_connection_error_handler( base::BindOnce(&ProxyLookupRequest::DestroySelf, base::Unretained(this))); - net::ProxyDelegate* proxy_delegate = network_context_->url_request_context() - ->http_transaction_factory() - ->GetSession() - ->context() - .proxy_delegate; // TODO(mmenke): The NetLogWithSource() means nothing is logged. Fix that. int result = network_context_->url_request_context() @@ -49,7 +44,7 @@ void ProxyLookupRequest::Start(const GURL& url) { ->ResolveProxy(url, std::string(), &proxy_info_, base::BindOnce(&ProxyLookupRequest::OnResolveComplete, base::Unretained(this)), - &request_, proxy_delegate, net::NetLogWithSource()); + &request_, net::NetLogWithSource()); if (result != net::ERR_IO_PENDING) OnResolveComplete(result); } diff --git a/chromium/services/network/proxy_resolving_client_socket.cc b/chromium/services/network/proxy_resolving_client_socket.cc index 58c3f9a3d0f..509ad1456f3 100644 --- a/chromium/services/network/proxy_resolving_client_socket.cc +++ b/chromium/services/network/proxy_resolving_client_socket.cc @@ -248,16 +248,13 @@ int ProxyResolvingClientSocket::DoLoop(int result) { int ProxyResolvingClientSocket::DoProxyResolve() { next_state_ = STATE_PROXY_RESOLVE_COMPLETE; - // TODO(xunjieli): Having a null ProxyDelegate is bad. Figure out how to - // interact with the new interface for proxy delegate. - // https://crbug.com/793071. // base::Unretained(this) is safe because resolution request is canceled when // |proxy_resolve_request_| is destroyed. return network_session_->proxy_resolution_service()->ResolveProxy( url_, "POST", &proxy_info_, base::BindRepeating(&ProxyResolvingClientSocket::OnIOComplete, base::Unretained(this)), - &proxy_resolve_request_, nullptr /*proxy_delegate*/, net_log_); + &proxy_resolve_request_, net_log_); } int ProxyResolvingClientSocket::DoProxyResolveComplete(int result) { @@ -334,8 +331,7 @@ int ProxyResolvingClientSocket::DoInitConnectionComplete(int result) { return ReconsiderProxyAfterError(result); } - network_session_->proxy_resolution_service()->ReportSuccess(proxy_info_, - nullptr); + network_session_->proxy_resolution_service()->ReportSuccess(proxy_info_); return net::OK; } diff --git a/chromium/services/network/proxy_resolving_client_socket_factory.h b/chromium/services/network/proxy_resolving_client_socket_factory.h index 21bf179e6bd..653ecf9b8c3 100644 --- a/chromium/services/network/proxy_resolving_client_socket_factory.h +++ b/chromium/services/network/proxy_resolving_client_socket_factory.h @@ -27,7 +27,8 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) ProxyResolvingClientSocketFactory { // Constructs a ProxyResolvingClientSocketFactory. This factory shares // network session params with |request_context|, but keeps separate socket // pools by instantiating and owning a separate |network_session_|. - ProxyResolvingClientSocketFactory(net::URLRequestContext* request_context); + explicit ProxyResolvingClientSocketFactory( + net::URLRequestContext* request_context); ~ProxyResolvingClientSocketFactory(); // Creates a socket. |url|'s host and port specify where a connection will be @@ -41,6 +42,10 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) ProxyResolvingClientSocketFactory { std::unique_ptr<ProxyResolvingClientSocket> CreateSocket(const GURL& url, bool use_tls); + const net::HttpNetworkSession* network_session() const { + return network_session_.get(); + } + private: std::unique_ptr<net::HttpNetworkSession> network_session_; net::URLRequestContext* request_context_; diff --git a/chromium/services/network/proxy_resolving_client_socket_unittest.cc b/chromium/services/network/proxy_resolving_client_socket_unittest.cc index 9ca22048902..5f3be868965 100644 --- a/chromium/services/network/proxy_resolving_client_socket_unittest.cc +++ b/chromium/services/network/proxy_resolving_client_socket_unittest.cc @@ -326,9 +326,9 @@ TEST_P(ProxyResolvingClientSocketTest, ReadWriteErrors) { net::TestCompletionCallback read_write_callback; int read_write_result; std::string test_data_string("test data"); - scoped_refptr<net::IOBuffer> read_buffer(new net::IOBufferWithSize(10)); - scoped_refptr<net::IOBuffer> write_buffer( - new net::StringIOBuffer(test_data_string)); + auto read_buffer = base::MakeRefCounted<net::IOBufferWithSize>(10); + auto write_buffer = + base::MakeRefCounted<net::StringIOBuffer>(test_data_string); if (test.is_read_error) { read_write_result = socket->Read(read_buffer.get(), 10, read_write_callback.callback()); diff --git a/chromium/services/network/proxy_resolving_socket_factory_mojo.cc b/chromium/services/network/proxy_resolving_socket_factory_mojo.cc index 114c8c4c46e..53dd76394a5 100644 --- a/chromium/services/network/proxy_resolving_socket_factory_mojo.cc +++ b/chromium/services/network/proxy_resolving_socket_factory_mojo.cc @@ -16,8 +16,9 @@ namespace network { ProxyResolvingSocketFactoryMojo::ProxyResolvingSocketFactoryMojo( net::URLRequestContext* request_context) - : factory_impl_(std::make_unique<ProxyResolvingClientSocketFactory>( - request_context)) {} + : factory_impl_(request_context), + tls_socket_factory_(request_context, + &factory_impl_.network_session()->context()) {} ProxyResolvingSocketFactoryMojo::~ProxyResolvingSocketFactoryMojo() {} @@ -26,10 +27,12 @@ void ProxyResolvingSocketFactoryMojo::CreateProxyResolvingSocket( bool use_tls, const net::MutableNetworkTrafficAnnotationTag& traffic_annotation, mojom::ProxyResolvingSocketRequest request, + mojom::SocketObserverPtr observer, CreateProxyResolvingSocketCallback callback) { auto socket = std::make_unique<ProxyResolvingSocketMojo>( - factory_impl_->CreateSocket(url, use_tls), - static_cast<net::NetworkTrafficAnnotationTag>(traffic_annotation)); + factory_impl_.CreateSocket(url, use_tls), + static_cast<net::NetworkTrafficAnnotationTag>(traffic_annotation), + std::move(observer), &tls_socket_factory_); ProxyResolvingSocketMojo* socket_raw = socket.get(); proxy_resolving_socket_bindings_.AddBinding(std::move(socket), std::move(request)); diff --git a/chromium/services/network/proxy_resolving_socket_factory_mojo.h b/chromium/services/network/proxy_resolving_socket_factory_mojo.h index fdd1f83ecc3..59d33bdb0e4 100644 --- a/chromium/services/network/proxy_resolving_socket_factory_mojo.h +++ b/chromium/services/network/proxy_resolving_socket_factory_mojo.h @@ -12,7 +12,9 @@ #include "base/memory/ref_counted.h" #include "mojo/public/cpp/bindings/strong_binding_set.h" #include "net/traffic_annotation/network_traffic_annotation.h" +#include "services/network/proxy_resolving_client_socket_factory.h" #include "services/network/public/mojom/proxy_resolving_socket.mojom.h" +#include "services/network/tls_socket_factory.h" namespace net { class URLRequestContext; @@ -20,8 +22,6 @@ class URLRequestContext; namespace network { -class ProxyResolvingClientSocketFactory; - class COMPONENT_EXPORT(NETWORK_SERVICE) ProxyResolvingSocketFactoryMojo : public mojom::ProxyResolvingSocketFactory { public: @@ -34,10 +34,12 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) ProxyResolvingSocketFactoryMojo bool use_tls, const net::MutableNetworkTrafficAnnotationTag& traffic_annotation, mojom::ProxyResolvingSocketRequest request, + mojom::SocketObserverPtr observer, CreateProxyResolvingSocketCallback callback) override; private: - std::unique_ptr<ProxyResolvingClientSocketFactory> factory_impl_; + ProxyResolvingClientSocketFactory factory_impl_; + TLSSocketFactory tls_socket_factory_; mojo::StrongBindingSet<mojom::ProxyResolvingSocket> proxy_resolving_socket_bindings_; diff --git a/chromium/services/network/proxy_resolving_socket_mojo.cc b/chromium/services/network/proxy_resolving_socket_mojo.cc index 040dc35e77f..86a574991e3 100644 --- a/chromium/services/network/proxy_resolving_socket_mojo.cc +++ b/chromium/services/network/proxy_resolving_socket_mojo.cc @@ -15,8 +15,13 @@ namespace network { ProxyResolvingSocketMojo::ProxyResolvingSocketMojo( std::unique_ptr<ProxyResolvingClientSocket> socket, - const net::NetworkTrafficAnnotationTag& traffic_annotation) - : socket_(std::move(socket)), traffic_annotation_(traffic_annotation) {} + const net::NetworkTrafficAnnotationTag& traffic_annotation, + mojom::SocketObserverPtr observer, + TLSSocketFactory* tls_socket_factory) + : observer_(std::move(observer)), + tls_socket_factory_(tls_socket_factory), + socket_(std::move(socket)), + traffic_annotation_(traffic_annotation) {} ProxyResolvingSocketMojo::~ProxyResolvingSocketMojo() { if (connect_callback_) { @@ -44,6 +49,36 @@ void ProxyResolvingSocketMojo::Connect( OnConnectCompleted(result); } +void ProxyResolvingSocketMojo::UpgradeToTLS( + const net::HostPortPair& host_port_pair, + const net::MutableNetworkTrafficAnnotationTag& traffic_annotation, + mojom::TLSClientSocketRequest request, + mojom::SocketObserverPtr observer, + mojom::ProxyResolvingSocket::UpgradeToTLSCallback callback) { + // Wait for data pipes to be closed by the client before doing the upgrade. + if (socket_data_pump_) { + pending_upgrade_to_tls_callback_ = base::BindOnce( + &ProxyResolvingSocketMojo::UpgradeToTLS, base::Unretained(this), + host_port_pair, traffic_annotation, std::move(request), + std::move(observer), std::move(callback)); + return; + } + tls_socket_factory_->UpgradeToTLS( + this, host_port_pair, nullptr /* sockt_options */, traffic_annotation, + std::move(request), std::move(observer), + base::BindOnce( + [](mojom::ProxyResolvingSocket::UpgradeToTLSCallback callback, + int32_t net_error, + mojo::ScopedDataPipeConsumerHandle receive_stream, + mojo::ScopedDataPipeProducerHandle send_stream, + const base::Optional<net::SSLInfo>& ssl_info) { + DCHECK(!ssl_info); + std::move(callback).Run(net_error, std::move(receive_stream), + std::move(send_stream)); + }, + std::move(callback))); +} + void ProxyResolvingSocketMojo::OnConnectCompleted(int result) { DCHECK(!connect_callback_.is_null()); DCHECK(!socket_data_pump_); @@ -68,8 +103,7 @@ void ProxyResolvingSocketMojo::OnConnectCompleted(int result) { mojo::DataPipe send_pipe; mojo::DataPipe receive_pipe; socket_data_pump_ = std::make_unique<SocketDataPump>( - socket_.get(), nullptr /*delegate*/, - std::move(receive_pipe.producer_handle), + socket_.get(), this /*delegate*/, std::move(receive_pipe.producer_handle), std::move(send_pipe.consumer_handle), traffic_annotation_); std::move(connect_callback_) .Run(net::OK, local_addr, @@ -80,4 +114,28 @@ void ProxyResolvingSocketMojo::OnConnectCompleted(int result) { std::move(send_pipe.producer_handle)); } +void ProxyResolvingSocketMojo::OnNetworkReadError(int net_error) { + if (observer_) + observer_->OnReadError(net_error); +} + +void ProxyResolvingSocketMojo::OnNetworkWriteError(int net_error) { + if (observer_) + observer_->OnWriteError(net_error); +} + +void ProxyResolvingSocketMojo::OnShutdown() { + socket_data_pump_ = nullptr; + if (!pending_upgrade_to_tls_callback_.is_null()) + std::move(pending_upgrade_to_tls_callback_).Run(); +} + +const net::StreamSocket* ProxyResolvingSocketMojo::BorrowSocket() { + return socket_.get(); +} + +std::unique_ptr<net::StreamSocket> ProxyResolvingSocketMojo::TakeSocket() { + return std::move(socket_); +} + } // namespace network diff --git a/chromium/services/network/proxy_resolving_socket_mojo.h b/chromium/services/network/proxy_resolving_socket_mojo.h index d90a56932d2..866c1de82ec 100644 --- a/chromium/services/network/proxy_resolving_socket_mojo.h +++ b/chromium/services/network/proxy_resolving_socket_mojo.h @@ -13,29 +13,55 @@ #include "net/traffic_annotation/network_traffic_annotation.h" #include "services/network/proxy_resolving_client_socket.h" #include "services/network/public/mojom/proxy_resolving_socket.mojom.h" +#include "services/network/socket_data_pump.h" +#include "services/network/tls_socket_factory.h" namespace network { class SocketDataPump; class COMPONENT_EXPORT(NETWORK_SERVICE) ProxyResolvingSocketMojo - : public mojom::ProxyResolvingSocket { + : public mojom::ProxyResolvingSocket, + public SocketDataPump::Delegate, + public TLSSocketFactory::Delegate { public: ProxyResolvingSocketMojo( std::unique_ptr<ProxyResolvingClientSocket> socket, - const net::NetworkTrafficAnnotationTag& traffic_annotation); + const net::NetworkTrafficAnnotationTag& traffic_annotation, + mojom::SocketObserverPtr observer, + TLSSocketFactory* tls_socket_factory); ~ProxyResolvingSocketMojo() override; void Connect( mojom::ProxyResolvingSocketFactory::CreateProxyResolvingSocketCallback callback); + // mojom::ProxyResolvingSocket implementation. + void UpgradeToTLS( + const net::HostPortPair& host_port_pair, + const net::MutableNetworkTrafficAnnotationTag& traffic_annotation, + mojom::TLSClientSocketRequest request, + mojom::SocketObserverPtr observer, + mojom::ProxyResolvingSocket::UpgradeToTLSCallback callback) override; + private: void OnConnectCompleted(int net_result); + // SocketDataPump::Delegate implementation. + void OnNetworkReadError(int net_error) override; + void OnNetworkWriteError(int net_error) override; + void OnShutdown() override; + + // TLSSocketFactory::Delegate implementation. + const net::StreamSocket* BorrowSocket() override; + std::unique_ptr<net::StreamSocket> TakeSocket() override; + + mojom::SocketObserverPtr observer_; + TLSSocketFactory* tls_socket_factory_; std::unique_ptr<ProxyResolvingClientSocket> socket_; const net::NetworkTrafficAnnotationTag traffic_annotation_; mojom::ProxyResolvingSocketFactory::CreateProxyResolvingSocketCallback connect_callback_; + base::OnceClosure pending_upgrade_to_tls_callback_; std::unique_ptr<SocketDataPump> socket_data_pump_; DISALLOW_COPY_AND_ASSIGN(ProxyResolvingSocketMojo); diff --git a/chromium/services/network/proxy_resolving_socket_mojo_unittest.cc b/chromium/services/network/proxy_resolving_socket_mojo_unittest.cc index d9cfd0e57e7..ac180df6210 100644 --- a/chromium/services/network/proxy_resolving_socket_mojo_unittest.cc +++ b/chromium/services/network/proxy_resolving_socket_mojo_unittest.cc @@ -13,6 +13,7 @@ #include "base/test/bind_test_util.h" #include "base/test/scoped_task_environment.h" #include "mojo/public/cpp/bindings/strong_binding.h" +#include "mojo/public/cpp/system/data_pipe_utils.h" #include "net/base/net_errors.h" #include "net/base/test_completion_callback.h" #include "net/dns/mock_host_resolver.h" @@ -20,6 +21,7 @@ #include "net/socket/socket_test_util.h" #include "net/traffic_annotation/network_traffic_annotation_test_helper.h" #include "net/url_request/url_request_test_util.h" +#include "services/network/mojo_socket_test_util.h" #include "services/network/proxy_resolving_socket_factory_mojo.h" #include "services/network/proxy_resolving_socket_mojo.h" #include "services/network/socket_factory.h" @@ -105,6 +107,7 @@ class ProxyResolvingSocketTestBase { int CreateSocketSync( mojom::ProxyResolvingSocketRequest request, + mojom::SocketObserverPtr socket_observer, net::IPEndPoint* peer_addr_out, const GURL& url, mojo::ScopedDataPipeConsumerHandle* receive_pipe_handle_out, @@ -114,7 +117,7 @@ class ProxyResolvingSocketTestBase { factory_ptr_->CreateProxyResolvingSocket( url, use_tls_, net::MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS), - std::move(request), + std::move(request), std::move(socket_observer), base::BindLambdaForTesting( [&](int result, const base::Optional<net::IPEndPoint>& local_addr, const base::Optional<net::IPEndPoint>& peer_addr, @@ -203,10 +206,11 @@ TEST_P(ProxyResolvingSocketTest, ConnectToProxy) { mojo::ScopedDataPipeConsumerHandle client_socket_receive_handle; mojo::ScopedDataPipeProducerHandle client_socket_send_handle; net::IPEndPoint actual_remote_addr; - EXPECT_EQ(net::OK, - CreateSocketSync(mojo::MakeRequest(&socket), &actual_remote_addr, - kDestination, &client_socket_receive_handle, - &client_socket_send_handle)); + EXPECT_EQ(net::OK, CreateSocketSync(mojo::MakeRequest(&socket), + nullptr /* socket_observer*/, + &actual_remote_addr, kDestination, + &client_socket_receive_handle, + &client_socket_send_handle)); // Consume all read data. base::RunLoop().RunUntilIdle(); if (!is_direct) { @@ -248,9 +252,10 @@ TEST_P(ProxyResolvingSocketTest, ConnectError) { mojom::ProxyResolvingSocketPtr socket; mojo::ScopedDataPipeConsumerHandle client_socket_receive_handle; mojo::ScopedDataPipeProducerHandle client_socket_send_handle; - int status = CreateSocketSync(mojo::MakeRequest(&socket), nullptr, - kDestination, &client_socket_receive_handle, - &client_socket_send_handle); + int status = CreateSocketSync( + mojo::MakeRequest(&socket), nullptr /* socket_observer*/, + nullptr /* peer_addr_out */, kDestination, + &client_socket_receive_handle, &client_socket_send_handle); if (test.is_direct) { EXPECT_EQ(net::ERR_FAILED, status); } else { @@ -292,10 +297,11 @@ TEST_P(ProxyResolvingSocketTest, BasicReadWrite) { mojo::ScopedDataPipeConsumerHandle client_socket_receive_handle; mojo::ScopedDataPipeProducerHandle client_socket_send_handle; const GURL kDestination("http://example.com"); - EXPECT_EQ(net::OK, - CreateSocketSync(mojo::MakeRequest(&socket), nullptr, kDestination, - &client_socket_receive_handle, - &client_socket_send_handle)); + EXPECT_EQ(net::OK, CreateSocketSync(mojo::MakeRequest(&socket), + nullptr /* socket_observer */, + nullptr /* peer_addr_out */, kDestination, + &client_socket_receive_handle, + &client_socket_send_handle)); // Loop kNumIterations times to test that writes can follow reads, and reads // can follow writes. for (int j = 0; j < kNumIterations; ++j) { @@ -346,7 +352,7 @@ TEST_F(ProxyResolvingSocketMojoTest, SocketDestroyedBeforeConnectCompletes) { factory()->CreateProxyResolvingSocket( kDestination, false, net::MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS), - mojo::MakeRequest(&socket), + mojo::MakeRequest(&socket), nullptr /* observer */, base::BindLambdaForTesting( [&](int result, const base::Optional<net::IPEndPoint>& local_addr, const base::Optional<net::IPEndPoint>& peer_addr, @@ -359,4 +365,43 @@ TEST_F(ProxyResolvingSocketMojoTest, SocketDestroyedBeforeConnectCompletes) { EXPECT_EQ(net::ERR_ABORTED, net_error); } +TEST_F(ProxyResolvingSocketMojoTest, SocketObserver) { + Init("DIRECT"); + + const char kMsg[] = "message!"; + const char kMsgLen = strlen(kMsg); + + std::vector<net::MockRead> reads = { + net::MockRead(kMsg), + net::MockRead(net::ASYNC, net::ERR_CONNECTION_ABORTED)}; + std::vector<net::MockWrite> writes = { + net::MockWrite(net::ASYNC, net::ERR_TIMED_OUT)}; + + net::StaticSocketDataProvider data_provider(reads, writes); + data_provider.set_connect_data(net::MockConnect(net::ASYNC, net::OK)); + mock_client_socket_factory()->AddSocketDataProvider(&data_provider); + + const GURL kDestination("http://example.com"); + + mojom::ProxyResolvingSocketPtr socket; + mojo::ScopedDataPipeConsumerHandle client_socket_receive_handle; + mojo::ScopedDataPipeProducerHandle client_socket_send_handle; + TestSocketObserver test_observer; + + int status = CreateSocketSync( + mojo::MakeRequest(&socket), test_observer.GetObserverPtr(), + nullptr /* peer_addr_out */, kDestination, &client_socket_receive_handle, + &client_socket_send_handle); + EXPECT_EQ(net::OK, status); + + EXPECT_EQ(kMsg, Read(&client_socket_receive_handle, kMsgLen)); + EXPECT_EQ(net::ERR_CONNECTION_ABORTED, test_observer.WaitForReadError()); + + EXPECT_TRUE(mojo::BlockingCopyFromString(kMsg, client_socket_send_handle)); + EXPECT_EQ(net::ERR_TIMED_OUT, test_observer.WaitForWriteError()); + + EXPECT_TRUE(data_provider.AllReadDataConsumed()); + EXPECT_TRUE(data_provider.AllWriteDataConsumed()); +} + } // namespace network diff --git a/chromium/services/network/public/cpp/BUILD.gn b/chromium/services/network/public/cpp/BUILD.gn index 614bef9a71b..5086f0b5798 100644 --- a/chromium/services/network/public/cpp/BUILD.gn +++ b/chromium/services/network/public/cpp/BUILD.gn @@ -2,9 +2,10 @@ # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. +import("//build/config/jumbo.gni") import("//mojo/public/tools/bindings/mojom.gni") -component("cpp") { +jumbo_component("cpp") { output_name = "network_cpp" sources = [ @@ -58,6 +59,8 @@ component("cpp") { ] } + configs += [ "//build/config/compiler:wexit_time_destructors" ] + public_deps = [ ":cpp_base", "//net", @@ -75,12 +78,14 @@ component("cpp") { defines = [ "IS_NETWORK_CPP_IMPL" ] } -component("cpp_base") { +jumbo_component("cpp_base") { output_name = "network_cpp_base" sources = [ "cors/cors_error_status.cc", "cors/cors_error_status.h", + "cors/preflight_timing_info.cc", + "cors/preflight_timing_info.h", "data_element.cc", "data_element.h", "http_raw_request_response_info.cc", @@ -111,6 +116,17 @@ component("cpp_base") { "url_request_mojom_traits.cc", "url_request_mojom_traits.h", ] + jumbo_excluded_sources = [ + # IPC/Params code generators are based on macros and multiple + # inclusion of headers using those macros. That is not + # compatible with jumbo compiling all source, generators and + # users, together, so exclude those files from jumbo compilation. + "network_ipc_param_traits.cc", + "p2p_param_traits.cc", + ] + + configs += [ "//build/config/compiler:wexit_time_destructors" ] + public_deps = [ "//services/network/public/mojom:data_pipe_interfaces", "//services/network/public/mojom:mutable_network_traffic_annotation_interface", diff --git a/chromium/services/network/public/cpp/cors/cors.cc b/chromium/services/network/public/cpp/cors/cors.cc index 53348a6ebd0..9e99fa3fb4d 100644 --- a/chromium/services/network/public/cpp/cors/cors.cc +++ b/chromium/services/network/public/cpp/cors/cors.cc @@ -9,6 +9,7 @@ #include <set> #include <vector> +#include "base/no_destructor.h" #include "base/strings/string_util.h" #include "net/base/mime_util.h" #include "net/http/http_request_headers.h" @@ -43,18 +44,6 @@ std::string ExtractMIMETypeFromMediaType(const std::string& media_type) { return std::string(); } -// url::Origin::Serialize() serializes all Origins with a 'file' scheme to -// 'file://', but it isn't desirable for CORS check. Returns 'null' instead to -// be aligned with HTTP Origin header calculation in Blink SecurityOrigin. -// |allow_file_origin| is used to realize a behavior change that -// the --allow-file-access-from-files command-line flag needs. -// TODO(mkwst): Generalize and move to url/Origin. -std::string Serialize(const url::Origin& origin, bool allow_file_origin) { - if (!allow_file_origin && origin.scheme() == url::kFileScheme) - return "null"; - return origin.Serialize(); -} - // Returns true only if |header_value| satisfies ABNF: 1*DIGIT [ "." 1*DIGIT ] bool IsSimilarToDoubleABNF(const std::string& header_value) { if (header_value.empty()) @@ -101,10 +90,9 @@ bool IsSimilarToIntABNF(const std::string& header_value) { bool IsCORSSafelistedLowerCaseContentType( const std::string& lower_case_media_type) { DCHECK_EQ(lower_case_media_type, base::ToLowerASCII(lower_case_media_type)); - static const std::set<std::string> safe_types = { - "application/x-www-form-urlencoded", "multipart/form-data", "text/plain"}; std::string mime_type = ExtractMIMETypeFromMediaType(lower_case_media_type); - return safe_types.find(mime_type) != safe_types.end(); + return mime_type == "application/x-www-form-urlencoded" || + mime_type == "multipart/form-data" || mime_type == "text/plain"; } } // namespace @@ -135,8 +123,7 @@ base::Optional<CORSErrorStatus> CheckAccess( const base::Optional<std::string>& allow_origin_header, const base::Optional<std::string>& allow_credentials_header, mojom::FetchCredentialsMode credentials_mode, - const url::Origin& origin, - bool allow_file_origin) { + const url::Origin& origin) { // TODO(toyoshim): This response status code check should not be needed. We // have another status code check after a CheckAccess() call if it is needed. if (!response_status_code) @@ -159,7 +146,7 @@ base::Optional<CORSErrorStatus> CheckAccess( return CORSErrorStatus(mojom::CORSError::kWildcardOriginNotAllowed); } else if (!allow_origin_header) { return CORSErrorStatus(mojom::CORSError::kMissingAllowOriginHeader); - } else if (*allow_origin_header != Serialize(origin, allow_file_origin)) { + } else if (*allow_origin_header != origin.Serialize()) { // We do not use url::Origin::IsSameOriginWith() here for two reasons below. // 1. Allow "null" to match here. The latest spec does not have a clear // information about this (https://fetch.spec.whatwg.org/#cors-check), @@ -217,12 +204,10 @@ base::Optional<CORSErrorStatus> CheckPreflightAccess( const base::Optional<std::string>& allow_origin_header, const base::Optional<std::string>& allow_credentials_header, mojom::FetchCredentialsMode actual_credentials_mode, - const url::Origin& origin, - bool allow_file_origin) { + const url::Origin& origin) { const auto error_status = CheckAccess(response_url, response_status_code, allow_origin_header, - allow_credentials_header, actual_credentials_mode, origin, - allow_file_origin); + allow_credentials_header, actual_credentials_mode, origin); if (!error_status) return base::nullopt; @@ -315,12 +300,38 @@ bool IsCORSEnabledRequestMode(mojom::FetchRequestMode mode) { mode == mojom::FetchRequestMode::kCORSWithForcedPreflight; } +mojom::FetchResponseType CalculateResponseTainting( + const GURL& url, + mojom::FetchRequestMode request_mode, + const base::Optional<url::Origin>& origin, + bool cors_flag) { + if (url.SchemeIs(url::kDataScheme)) + return mojom::FetchResponseType::kBasic; + + if (cors_flag) { + DCHECK(IsCORSEnabledRequestMode(request_mode)); + return mojom::FetchResponseType::kCORS; + } + + if (!origin) { + // This is actually not defined in the fetch spec, but in this case CORS + // is disabled so no one should care this value. + return mojom::FetchResponseType::kBasic; + } + + if (request_mode == mojom::FetchRequestMode::kNoCORS && + !origin->IsSameOriginWith(url::Origin::Create(url))) { + return mojom::FetchResponseType::kOpaque; + } + return mojom::FetchResponseType::kBasic; +} + bool IsCORSSafelistedMethod(const std::string& method) { // https://fetch.spec.whatwg.org/#cors-safelisted-method // "A CORS-safelisted method is a method that is `GET`, `HEAD`, or `POST`." - static const std::set<std::string> safe_methods = { - net::HttpRequestHeaders::kGetMethod, kHeadMethod, kPostMethod}; - return safe_methods.find(base::ToUpperASCII(method)) != safe_methods.end(); + std::string method_upper = base::ToUpperASCII(method); + return method_upper == net::HttpRequestHeaders::kGetMethod || + method_upper == kHeadMethod || method_upper == kPostMethod; } bool IsCORSSafelistedContentType(const std::string& media_type) { @@ -328,6 +339,10 @@ bool IsCORSSafelistedContentType(const std::string& media_type) { } bool IsCORSSafelistedHeader(const std::string& name, const std::string& value) { + // If |value|’s length is greater than 128, then return false. + if (value.size() > 128) + return false; + // https://fetch.spec.whatwg.org/#cors-safelisted-request-header // "A CORS-safelisted header is a header whose name is either one of `Accept`, // `Accept-Language`, and `Content-Language`, or whose name is @@ -343,7 +358,7 @@ bool IsCORSSafelistedHeader(const std::string& name, const std::string& value) { // // Treat 'Intervention' as a CORS-safelisted header, since it is added by // Chrome when an intervention is (or may be) applied. - static const std::set<std::string> safe_names = { + static const char* const safe_names[] = { "accept", "accept-language", "content-language", "intervention", "content-type", "save-data", // The Device Memory header field is a number that indicates the client’s @@ -354,7 +369,8 @@ bool IsCORSSafelistedHeader(const std::string& name, const std::string& value) { // for more details. "device-memory", "dpr", "width", "viewport-width"}; const std::string lower_name = base::ToLowerASCII(name); - if (safe_names.find(lower_name) == safe_names.end()) + if (std::find(std::begin(safe_names), std::end(safe_names), lower_name) == + std::end(safe_names)) return false; // Client hints are device specific, and not origin specific. As such all @@ -369,18 +385,102 @@ bool IsCORSSafelistedHeader(const std::string& name, const std::string& value) { if (lower_name == "save-data") return lower_value == "on"; + if (lower_name == "accept") { + return (value.end() == std::find_if(value.begin(), value.end(), [](char c) { + return (c < 0x20 && c != 0x09) || c == 0x22 || c == 0x28 || + c == 0x29 || c == 0x3a || c == 0x3c || c == 0x3e || + c == 0x3f || c == 0x40 || c == 0x5b || c == 0x5c || + c == 0x5d || c == 0x7b || c == 0x7d || c >= 0x7f; + })); + } + + if (lower_name == "accept-language" || lower_name == "content-language") { + return (value.end() == std::find_if(value.begin(), value.end(), [](char c) { + return !isalnum(c) && c != 0x20 && c != 0x2a && c != 0x2c && + c != 0x2d && c != 0x2e && c != 0x3b && c != 0x3d; + })); + } + if (lower_name == "content-type") return IsCORSSafelistedLowerCaseContentType(lower_value); return true; } +bool IsNoCORSSafelistedHeader(const std::string& name, + const std::string& value) { + const std::string lower_name = base::ToLowerASCII(name); + + if (lower_name != "accept" && lower_name != "accept-language" && + lower_name != "content-language" && lower_name != "content-type") { + return false; + } + + return IsCORSSafelistedHeader(lower_name, value); +} + +std::vector<std::string> CORSUnsafeRequestHeaderNames( + const net::HttpRequestHeaders::HeaderVector& headers) { + std::vector<std::string> potentially_unsafe_names; + std::vector<std::string> header_names; + + constexpr size_t kSafeListValueSizeMax = 1024; + size_t safe_list_value_size = 0; + + for (const auto& header : headers) { + if (!IsCORSSafelistedHeader(header.key, header.value)) { + header_names.push_back(base::ToLowerASCII(header.key)); + } else { + potentially_unsafe_names.push_back(base::ToLowerASCII(header.key)); + safe_list_value_size += header.value.size(); + } + } + if (safe_list_value_size > kSafeListValueSizeMax) { + header_names.insert(header_names.end(), potentially_unsafe_names.begin(), + potentially_unsafe_names.end()); + } + return header_names; +} + +std::vector<std::string> CORSUnsafeNotForbiddenRequestHeaderNames( + const net::HttpRequestHeaders::HeaderVector& headers, + bool is_revalidating) { + std::vector<std::string> header_names; + std::vector<std::string> potentially_unsafe_names; + + constexpr size_t kSafeListValueSizeMax = 1024; + size_t safe_list_value_size = 0; + + for (const auto& header : headers) { + if (IsForbiddenHeader(header.key)) + continue; + + const std::string name = base::ToLowerASCII(header.key); + + if (is_revalidating) { + if (name == "if-modified-since" || name == "if-none-match" || + name == "cache-control") { + continue; + } + } + if (!IsCORSSafelistedHeader(name, header.value)) { + header_names.push_back(name); + } else { + potentially_unsafe_names.push_back(name); + safe_list_value_size += header.value.size(); + } + } + if (safe_list_value_size > kSafeListValueSizeMax) { + header_names.insert(header_names.end(), potentially_unsafe_names.begin(), + potentially_unsafe_names.end()); + } + return header_names; +} + bool IsForbiddenMethod(const std::string& method) { - static const std::vector<std::string> forbidden_methods = {"trace", "track", - "connect"}; const std::string lower_method = base::ToLowerASCII(method); - return std::find(forbidden_methods.begin(), forbidden_methods.end(), - lower_method) != forbidden_methods.end(); + return lower_method == "trace" || lower_method == "track" || + lower_method == "connect"; } bool IsForbiddenHeader(const std::string& name) { @@ -393,40 +493,83 @@ bool IsForbiddenHeader(const std::string& name) { // `User-Agent`, `Via` // or starts with `Proxy-` or `Sec-` (including when it is just `Proxy-` or // `Sec-`)." - static const std::set<std::string> forbidden_names = { - "accept-charset", - "accept-encoding", - "access-control-request-headers", - "access-control-request-method", - "connection", - "content-length", - "cookie", - "cookie2", - "date", - "dnt", - "expect", - "host", - "keep-alive", - "origin", - "referer", - "te", - "trailer", - "transfer-encoding", - "upgrade", - "user-agent", - "via"}; + static const base::NoDestructor<std::set<std::string>> forbidden_names( + std::set<std::string>{"accept-charset", + "accept-encoding", + "access-control-request-headers", + "access-control-request-method", + "connection", + "content-length", + "cookie", + "cookie2", + "date", + "dnt", + "expect", + "host", + "keep-alive", + "origin", + "referer", + "te", + "trailer", + "transfer-encoding", + "upgrade", + "user-agent", + "via"}); const std::string lower_name = base::ToLowerASCII(name); if (StartsWith(lower_name, "proxy-", base::CompareCase::SENSITIVE) || StartsWith(lower_name, "sec-", base::CompareCase::SENSITIVE)) { return true; } - return forbidden_names.find(lower_name) != forbidden_names.end(); + return forbidden_names->find(lower_name) != forbidden_names->end(); } bool IsOkStatus(int status) { return status >= 200 && status < 300; } +bool IsCORSSameOriginResponseType(mojom::FetchResponseType type) { + switch (type) { + case mojom::FetchResponseType::kBasic: + case mojom::FetchResponseType::kCORS: + case mojom::FetchResponseType::kDefault: + return true; + case mojom::FetchResponseType::kError: + case mojom::FetchResponseType::kOpaque: + case mojom::FetchResponseType::kOpaqueRedirect: + return false; + } +} + +bool IsCORSCrossOriginResponseType(mojom::FetchResponseType type) { + switch (type) { + case mojom::FetchResponseType::kBasic: + case mojom::FetchResponseType::kCORS: + case mojom::FetchResponseType::kDefault: + case mojom::FetchResponseType::kError: + return false; + case mojom::FetchResponseType::kOpaque: + case mojom::FetchResponseType::kOpaqueRedirect: + return true; + } +} + +bool CalculateCredentialsFlag(mojom::FetchCredentialsMode credentials_mode, + mojom::FetchResponseType response_tainting) { + // Let |credentials flag| be set if one of + // - |request|’s credentials mode is "include" + // - |request|’s credentials mode is "same-origin" and |request|’s + // response tainting is "basic" + // is true, and unset otherwise. + switch (credentials_mode) { + case network::mojom::FetchCredentialsMode::kOmit: + return false; + case network::mojom::FetchCredentialsMode::kSameOrigin: + return response_tainting == network::mojom::FetchResponseType::kBasic; + case network::mojom::FetchCredentialsMode::kInclude: + return true; + } +} + } // namespace cors } // namespace network diff --git a/chromium/services/network/public/cpp/cors/cors.h b/chromium/services/network/public/cpp/cors/cors.h index 47e5d81e168..c5db9b15141 100644 --- a/chromium/services/network/public/cpp/cors/cors.h +++ b/chromium/services/network/public/cpp/cors/cors.h @@ -6,9 +6,11 @@ #define SERVICES_NETWORK_PUBLIC_CPP_CORS_CORS_H_ #include <string> +#include <vector> #include "base/component_export.h" #include "base/optional.h" +#include "net/http/http_request_headers.h" #include "services/network/public/cpp/cors/cors_error_status.h" #include "services/network/public/mojom/cors.mojom-shared.h" #include "services/network/public/mojom/fetch_api.mojom-shared.h" @@ -54,8 +56,7 @@ base::Optional<CORSErrorStatus> CheckAccess( const base::Optional<std::string>& allow_origin_header, const base::Optional<std::string>& allow_credentials_header, mojom::FetchCredentialsMode credentials_mode, - const url::Origin& origin, - bool allow_file_origin = false); + const url::Origin& origin); // Performs a CORS access check on the CORS-preflight response parameters. // According to the note at https://fetch.spec.whatwg.org/#cors-preflight-fetch @@ -68,8 +69,7 @@ base::Optional<CORSErrorStatus> CheckPreflightAccess( const base::Optional<std::string>& allow_origin_header, const base::Optional<std::string>& allow_credentials_header, mojom::FetchCredentialsMode actual_credentials_mode, - const url::Origin& origin, - bool allow_file_origin = false); + const url::Origin& origin); // Given a redirected-to URL, checks if the location is allowed // according to CORS. That is: @@ -99,6 +99,17 @@ base::Optional<CORSErrorStatus> CheckExternalPreflight( COMPONENT_EXPORT(NETWORK_CPP) bool IsCORSEnabledRequestMode(mojom::FetchRequestMode mode); +// Returns the response tainting value +// (https://fetch.spec.whatwg.org/#concept-request-response-tainting) for a +// request and the CORS flag, as specified in +// https://fetch.spec.whatwg.org/#main-fetch. +COMPONENT_EXPORT(NETWORK_CPP) +mojom::FetchResponseType CalculateResponseTainting( + const GURL& url, + mojom::FetchRequestMode request_mode, + const base::Optional<url::Origin>& origin, + bool cors_flag); + // Checks safelisted request parameters. COMPONENT_EXPORT(NETWORK_CPP) bool IsCORSSafelistedMethod(const std::string& method); @@ -106,6 +117,29 @@ COMPONENT_EXPORT(NETWORK_CPP) bool IsCORSSafelistedContentType(const std::string& name); COMPONENT_EXPORT(NETWORK_CPP) bool IsCORSSafelistedHeader(const std::string& name, const std::string& value); +COMPONENT_EXPORT(NETWORK_CPP) +bool IsNoCORSSafelistedHeader(const std::string& name, + const std::string& value); + +// https://fetch.spec.whatwg.org/#cors-unsafe-request-header-names +// |headers| must not contain multiple headers for the same name. +// The returned list is NOT sorted. +// The returned list consists of lower-cased names. +COMPONENT_EXPORT(NETWORK_CPP) +std::vector<std::string> CORSUnsafeRequestHeaderNames( + const net::HttpRequestHeaders::HeaderVector& headers); + +// https://fetch.spec.whatwg.org/#cors-unsafe-request-header-names +// Returns header names which are not CORS-safelisted AND not forbidden. +// |headers| must not contain multiple headers for the same name. +// When |is_revalidating| is true, "if-modified-since", "if-none-match", and +// "cache-control" are also exempted. +// The returned list is NOT sorted. +// The returned list consists of lower-cased names. +COMPONENT_EXPORT(NETWORK_CPP) +std::vector<std::string> CORSUnsafeNotForbiddenRequestHeaderNames( + const net::HttpRequestHeaders::HeaderVector& headers, + bool is_revalidating); // Checks forbidden method in the fetch spec. // See https://fetch.spec.whatwg.org/#forbidden-method. @@ -122,6 +156,22 @@ COMPONENT_EXPORT(NETWORK_CPP) bool IsForbiddenHeader(const std::string& name); // term in naming the predicate. COMPONENT_EXPORT(NETWORK_CPP) bool IsOkStatus(int status); +// Returns true if |type| is a response type which makes a response +// CORS-same-origin. See https://html.spec.whatwg.org/#cors-same-origin. +COMPONENT_EXPORT(NETWORK_CPP) +bool IsCORSSameOriginResponseType(mojom::FetchResponseType type); + +// Returns true if |type| is a response type which makes a response +// CORS-cross-origin. See https://html.spec.whatwg.org/#cors-cross-origin. +COMPONENT_EXPORT(NETWORK_CPP) +bool IsCORSCrossOriginResponseType(mojom::FetchResponseType type); + +// Returns true if the credentials flag should be set for the given arguments +// as in https://fetch.spec.whatwg.org/#http-network-or-cache-fetch. +COMPONENT_EXPORT(NETWORK_CPP) +bool CalculateCredentialsFlag(mojom::FetchCredentialsMode credentials_mode, + mojom::FetchResponseType response_tainting); + } // namespace cors } // namespace network diff --git a/chromium/services/network/public/cpp/cors/cors_unittest.cc b/chromium/services/network/public/cpp/cors/cors_unittest.cc index ee9a42e3e80..a87ccc1cbb3 100644 --- a/chromium/services/network/public/cpp/cors/cors_unittest.cc +++ b/chromium/services/network/public/cpp/cors/cors_unittest.cc @@ -4,27 +4,29 @@ #include "services/network/public/cpp/cors/cors.h" +#include <limits.h> + #include "testing/gtest/include/gtest/gtest.h" #include "url/gurl.h" #include "url/origin.h" namespace network { - +namespace cors { namespace { using CORSTest = testing::Test; TEST_F(CORSTest, CheckAccessDetectsInvalidResponse) { - base::Optional<CORSErrorStatus> error_status = cors::CheckAccess( - GURL(), 0 /* response_status_code */, - base::nullopt /* allow_origin_header */, - base::nullopt /* allow_credentials_header */, - network::mojom::FetchCredentialsMode::kOmit, url::Origin()); + base::Optional<CORSErrorStatus> error_status = + CheckAccess(GURL(), 0 /* response_status_code */, + base::nullopt /* allow_origin_header */, + base::nullopt /* allow_credentials_header */, + network::mojom::FetchCredentialsMode::kOmit, url::Origin()); ASSERT_TRUE(error_status); EXPECT_EQ(mojom::CORSError::kInvalidResponse, error_status->cors_error); } -// Tests if cors::CheckAccess detects kWildcardOriginNotAllowed error correctly. +// Tests if CheckAccess detects kWildcardOriginNotAllowed error correctly. TEST_F(CORSTest, CheckAccessDetectsWildcardOriginNotAllowed) { const GURL response_url("http://example.com/data"); const url::Origin origin = url::Origin::Create(GURL("http://google.com")); @@ -33,24 +35,24 @@ TEST_F(CORSTest, CheckAccessDetectsWildcardOriginNotAllowed) { // Access-Control-Allow-Origin '*' works. base::Optional<CORSErrorStatus> error1 = - cors::CheckAccess(response_url, response_status_code, - allow_all_header /* allow_origin_header */, - base::nullopt /* allow_credentials_header */, - network::mojom::FetchCredentialsMode::kOmit, origin); + CheckAccess(response_url, response_status_code, + allow_all_header /* allow_origin_header */, + base::nullopt /* allow_credentials_header */, + network::mojom::FetchCredentialsMode::kOmit, origin); EXPECT_FALSE(error1); // Access-Control-Allow-Origin '*' should not be allowed if credentials mode // is kInclude. base::Optional<CORSErrorStatus> error2 = - cors::CheckAccess(response_url, response_status_code, - allow_all_header /* allow_origin_header */, - base::nullopt /* allow_credentials_header */, - network::mojom::FetchCredentialsMode::kInclude, origin); + CheckAccess(response_url, response_status_code, + allow_all_header /* allow_origin_header */, + base::nullopt /* allow_credentials_header */, + network::mojom::FetchCredentialsMode::kInclude, origin); ASSERT_TRUE(error2); EXPECT_EQ(mojom::CORSError::kWildcardOriginNotAllowed, error2->cors_error); } -// Tests if cors::CheckAccess detects kMissingAllowOriginHeader error correctly. +// Tests if CheckAccess detects kMissingAllowOriginHeader error correctly. TEST_F(CORSTest, CheckAccessDetectsMissingAllowOriginHeader) { const GURL response_url("http://example.com/data"); const url::Origin origin = url::Origin::Create(GURL("http://google.com")); @@ -58,15 +60,15 @@ TEST_F(CORSTest, CheckAccessDetectsMissingAllowOriginHeader) { // Access-Control-Allow-Origin is missed. base::Optional<CORSErrorStatus> error = - cors::CheckAccess(response_url, response_status_code, - base::nullopt /* allow_origin_header */, - base::nullopt /* allow_credentials_header */, - network::mojom::FetchCredentialsMode::kOmit, origin); + CheckAccess(response_url, response_status_code, + base::nullopt /* allow_origin_header */, + base::nullopt /* allow_credentials_header */, + network::mojom::FetchCredentialsMode::kOmit, origin); ASSERT_TRUE(error); EXPECT_EQ(mojom::CORSError::kMissingAllowOriginHeader, error->cors_error); } -// Tests if cors::CheckAccess detects kMultipleAllowOriginValues error +// Tests if CheckAccess detects kMultipleAllowOriginValues error // correctly. TEST_F(CORSTest, CheckAccessDetectsMultipleAllowOriginValues) { const GURL response_url("http://example.com/data"); @@ -75,55 +77,55 @@ TEST_F(CORSTest, CheckAccessDetectsMultipleAllowOriginValues) { const std::string space_separated_multiple_origins( "http://example.com http://another.example.com"); - base::Optional<CORSErrorStatus> error1 = cors::CheckAccess( - response_url, response_status_code, - space_separated_multiple_origins /* allow_origin_header */, - base::nullopt /* allow_credentials_header */, - network::mojom::FetchCredentialsMode::kOmit, origin); + base::Optional<CORSErrorStatus> error1 = + CheckAccess(response_url, response_status_code, + space_separated_multiple_origins /* allow_origin_header */, + base::nullopt /* allow_credentials_header */, + network::mojom::FetchCredentialsMode::kOmit, origin); ASSERT_TRUE(error1); EXPECT_EQ(mojom::CORSError::kMultipleAllowOriginValues, error1->cors_error); const std::string comma_separated_multiple_origins( "http://example.com,http://another.example.com"); - base::Optional<CORSErrorStatus> error2 = cors::CheckAccess( - response_url, response_status_code, - comma_separated_multiple_origins /* allow_origin_header */, - base::nullopt /* allow_credentials_header */, - network::mojom::FetchCredentialsMode::kOmit, origin); + base::Optional<CORSErrorStatus> error2 = + CheckAccess(response_url, response_status_code, + comma_separated_multiple_origins /* allow_origin_header */, + base::nullopt /* allow_credentials_header */, + network::mojom::FetchCredentialsMode::kOmit, origin); ASSERT_TRUE(error2); EXPECT_EQ(mojom::CORSError::kMultipleAllowOriginValues, error2->cors_error); } -// Tests if cors::CheckAccess detects kInvalidAllowOriginValue error correctly. +// Tests if CheckAccess detects kInvalidAllowOriginValue error correctly. TEST_F(CORSTest, CheckAccessDetectsInvalidAllowOriginValue) { const GURL response_url("http://example.com/data"); const url::Origin origin = url::Origin::Create(GURL("http://google.com")); const int response_status_code = 200; base::Optional<CORSErrorStatus> error = - cors::CheckAccess(response_url, response_status_code, - std::string("invalid.origin") /* allow_origin_header */, - base::nullopt /* allow_credentials_header */, - network::mojom::FetchCredentialsMode::kOmit, origin); + CheckAccess(response_url, response_status_code, + std::string("invalid.origin") /* allow_origin_header */, + base::nullopt /* allow_credentials_header */, + network::mojom::FetchCredentialsMode::kOmit, origin); ASSERT_TRUE(error); EXPECT_EQ(mojom::CORSError::kInvalidAllowOriginValue, error->cors_error); EXPECT_EQ("invalid.origin", error->failed_parameter); } -// Tests if cors::CheckAccess detects kAllowOriginMismatch error correctly. +// Tests if CheckAccess detects kAllowOriginMismatch error correctly. TEST_F(CORSTest, CheckAccessDetectsAllowOriginMismatch) { const GURL response_url("http://example.com/data"); const url::Origin origin = url::Origin::Create(GURL("http://google.com")); const int response_status_code = 200; base::Optional<CORSErrorStatus> error1 = - cors::CheckAccess(response_url, response_status_code, - origin.Serialize() /* allow_origin_header */, - base::nullopt /* allow_credentials_header */, - network::mojom::FetchCredentialsMode::kOmit, origin); + CheckAccess(response_url, response_status_code, + origin.Serialize() /* allow_origin_header */, + base::nullopt /* allow_credentials_header */, + network::mojom::FetchCredentialsMode::kOmit, origin); ASSERT_FALSE(error1); - base::Optional<CORSErrorStatus> error2 = cors::CheckAccess( + base::Optional<CORSErrorStatus> error2 = CheckAccess( response_url, response_status_code, std::string("http://not.google.com") /* allow_origin_header */, base::nullopt /* allow_credentials_header */, @@ -137,37 +139,37 @@ TEST_F(CORSTest, CheckAccessDetectsAllowOriginMismatch) { const url::Origin null_origin; EXPECT_EQ(null_string, null_origin.Serialize()); - base::Optional<CORSErrorStatus> error3 = cors::CheckAccess( + base::Optional<CORSErrorStatus> error3 = CheckAccess( response_url, response_status_code, null_string /* allow_origin_header */, base::nullopt /* allow_credentials_header */, network::mojom::FetchCredentialsMode::kOmit, null_origin); EXPECT_FALSE(error3); } -// Tests if cors::CheckAccess detects kInvalidAllowCredentials error correctly. +// Tests if CheckAccess detects kInvalidAllowCredentials error correctly. TEST_F(CORSTest, CheckAccessDetectsInvalidAllowCredential) { const GURL response_url("http://example.com/data"); const url::Origin origin = url::Origin::Create(GURL("http://google.com")); const int response_status_code = 200; base::Optional<CORSErrorStatus> error1 = - cors::CheckAccess(response_url, response_status_code, - origin.Serialize() /* allow_origin_header */, - std::string("true") /* allow_credentials_header */, - network::mojom::FetchCredentialsMode::kInclude, origin); + CheckAccess(response_url, response_status_code, + origin.Serialize() /* allow_origin_header */, + std::string("true") /* allow_credentials_header */, + network::mojom::FetchCredentialsMode::kInclude, origin); ASSERT_FALSE(error1); base::Optional<CORSErrorStatus> error2 = - cors::CheckAccess(response_url, response_status_code, - origin.Serialize() /* allow_origin_header */, - std::string("fuga") /* allow_credentials_header */, - network::mojom::FetchCredentialsMode::kInclude, origin); + CheckAccess(response_url, response_status_code, + origin.Serialize() /* allow_origin_header */, + std::string("fuga") /* allow_credentials_header */, + network::mojom::FetchCredentialsMode::kInclude, origin); ASSERT_TRUE(error2); EXPECT_EQ(mojom::CORSError::kInvalidAllowCredentials, error2->cors_error); EXPECT_EQ("fuga", error2->failed_parameter); } -// Tests if cors::CheckRedirectLocation detects kCORSDisabledScheme and +// Tests if CheckRedirectLocation detects kCORSDisabledScheme and // kRedirectContainsCredentials errors correctly. TEST_F(CORSTest, CheckRedirectLocation) { struct TestCase { @@ -277,119 +279,460 @@ TEST_F(CORSTest, CheckRedirectLocation) { << ", tainted: " << test.tainted); EXPECT_EQ(test.expectation, - cors::CheckRedirectLocation(test.url, test.request_mode, origin, - test.cors_flag, test.tainted)); + CheckRedirectLocation(test.url, test.request_mode, origin, + test.cors_flag, test.tainted)); } } TEST_F(CORSTest, CheckPreflightDetectsErrors) { - EXPECT_FALSE(cors::CheckPreflight(200)); - EXPECT_FALSE(cors::CheckPreflight(299)); + EXPECT_FALSE(CheckPreflight(200)); + EXPECT_FALSE(CheckPreflight(299)); - base::Optional<mojom::CORSError> error1 = cors::CheckPreflight(300); + base::Optional<mojom::CORSError> error1 = CheckPreflight(300); ASSERT_TRUE(error1); EXPECT_EQ(mojom::CORSError::kPreflightInvalidStatus, *error1); - EXPECT_FALSE(cors::CheckExternalPreflight(std::string("true"))); + EXPECT_FALSE(CheckExternalPreflight(std::string("true"))); base::Optional<CORSErrorStatus> error2 = - cors::CheckExternalPreflight(base::nullopt); + CheckExternalPreflight(base::nullopt); ASSERT_TRUE(error2); EXPECT_EQ(mojom::CORSError::kPreflightMissingAllowExternal, error2->cors_error); EXPECT_EQ("", error2->failed_parameter); base::Optional<CORSErrorStatus> error3 = - cors::CheckExternalPreflight(std::string("TRUE")); + CheckExternalPreflight(std::string("TRUE")); ASSERT_TRUE(error3); EXPECT_EQ(mojom::CORSError::kPreflightInvalidAllowExternal, error3->cors_error); EXPECT_EQ("TRUE", error3->failed_parameter); } -TEST_F(CORSTest, CheckCORSSafelist) { +TEST_F(CORSTest, CalculateResponseTainting) { + using mojom::FetchResponseType; + using mojom::FetchRequestMode; + + const GURL same_origin_url("https://example.com/"); + const GURL cross_origin_url("https://example2.com/"); + const url::Origin origin = url::Origin::Create(GURL("https://example.com")); + const base::Optional<url::Origin> no_origin; + + // CORS flag is false, same-origin request + EXPECT_EQ(FetchResponseType::kBasic, + CalculateResponseTainting( + same_origin_url, FetchRequestMode::kSameOrigin, origin, false)); + EXPECT_EQ(FetchResponseType::kBasic, + CalculateResponseTainting( + same_origin_url, FetchRequestMode::kNoCORS, origin, false)); + EXPECT_EQ(FetchResponseType::kBasic, + CalculateResponseTainting(same_origin_url, FetchRequestMode::kCORS, + origin, false)); + EXPECT_EQ(FetchResponseType::kBasic, + CalculateResponseTainting( + same_origin_url, FetchRequestMode::kCORSWithForcedPreflight, + origin, false)); + EXPECT_EQ(FetchResponseType::kBasic, + CalculateResponseTainting( + same_origin_url, FetchRequestMode::kNavigate, origin, false)); + + // CORS flag is false, cross-origin request + EXPECT_EQ(FetchResponseType::kOpaque, + CalculateResponseTainting( + cross_origin_url, FetchRequestMode::kNoCORS, origin, false)); + EXPECT_EQ(FetchResponseType::kBasic, + CalculateResponseTainting( + cross_origin_url, FetchRequestMode::kNavigate, origin, false)); + + // CORS flag is true, same-origin request + EXPECT_EQ(FetchResponseType::kCORS, + CalculateResponseTainting(same_origin_url, FetchRequestMode::kCORS, + origin, true)); + EXPECT_EQ(FetchResponseType::kCORS, + CalculateResponseTainting( + same_origin_url, FetchRequestMode::kCORSWithForcedPreflight, + origin, true)); + + // CORS flag is true, cross-origin request + EXPECT_EQ(FetchResponseType::kCORS, + CalculateResponseTainting(cross_origin_url, FetchRequestMode::kCORS, + origin, true)); + EXPECT_EQ(FetchResponseType::kCORS, + CalculateResponseTainting( + cross_origin_url, FetchRequestMode::kCORSWithForcedPreflight, + origin, true)); + + // Origin is not provided. + EXPECT_EQ(FetchResponseType::kBasic, + CalculateResponseTainting( + same_origin_url, FetchRequestMode::kNoCORS, no_origin, false)); + EXPECT_EQ( + FetchResponseType::kBasic, + CalculateResponseTainting(same_origin_url, FetchRequestMode::kNavigate, + no_origin, false)); + EXPECT_EQ(FetchResponseType::kBasic, + CalculateResponseTainting( + cross_origin_url, FetchRequestMode::kNoCORS, no_origin, false)); + EXPECT_EQ( + FetchResponseType::kBasic, + CalculateResponseTainting(cross_origin_url, FetchRequestMode::kNavigate, + no_origin, false)); +} + +TEST_F(CORSTest, SafelistedMethod) { // Method check should be case-insensitive. - EXPECT_TRUE(cors::IsCORSSafelistedMethod("get")); - EXPECT_TRUE(cors::IsCORSSafelistedMethod("Get")); - EXPECT_TRUE(cors::IsCORSSafelistedMethod("GET")); - EXPECT_TRUE(cors::IsCORSSafelistedMethod("HEAD")); - EXPECT_TRUE(cors::IsCORSSafelistedMethod("POST")); - EXPECT_FALSE(cors::IsCORSSafelistedMethod("OPTIONS")); - - // Content-Type check should be case-insensitive, and should ignore spaces and - // parameters such as charset after a semicolon. + EXPECT_TRUE(IsCORSSafelistedMethod("get")); + EXPECT_TRUE(IsCORSSafelistedMethod("Get")); + EXPECT_TRUE(IsCORSSafelistedMethod("GET")); + EXPECT_TRUE(IsCORSSafelistedMethod("HEAD")); + EXPECT_TRUE(IsCORSSafelistedMethod("POST")); + EXPECT_FALSE(IsCORSSafelistedMethod("OPTIONS")); +} + +TEST_F(CORSTest, SafelistedHeader) { + // See SafelistedAccept/AcceptLanguage/ContentLanguage/ContentType also. + + EXPECT_TRUE(IsCORSSafelistedHeader("accept", "foo")); + EXPECT_FALSE(IsCORSSafelistedHeader("foo", "bar")); + EXPECT_FALSE(IsCORSSafelistedHeader("user-agent", "foo")); +} + +TEST_F(CORSTest, SafelistedAccept) { + EXPECT_TRUE(IsCORSSafelistedHeader("accept", "text/html")); + EXPECT_TRUE(IsCORSSafelistedHeader("AccepT", "text/html")); + + constexpr char kAllowed[] = + "\t !#$%&'*+,-./0123456789;=" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ^_`abcdefghijklmnopqrstuvwxyz|~"; + for (int i = CHAR_MIN; i <= CHAR_MAX; ++i) { + SCOPED_TRACE(testing::Message() << "c = static_cast<char>(" << i << ")"); + char c = static_cast<char>(i); + // 1 for the trailing null character. + auto* end = kAllowed + base::size(kAllowed) - 1; + EXPECT_EQ(std::find(kAllowed, end, c) != end, + IsCORSSafelistedHeader("accept", std::string(1, c))); + EXPECT_EQ(std::find(kAllowed, end, c) != end, + IsCORSSafelistedHeader("AccepT", std::string(1, c))); + } + + EXPECT_TRUE(IsCORSSafelistedHeader("accept", std::string(128, 'a'))); + EXPECT_FALSE(IsCORSSafelistedHeader("accept", std::string(129, 'a'))); + EXPECT_TRUE(IsCORSSafelistedHeader("AccepT", std::string(128, 'a'))); + EXPECT_FALSE(IsCORSSafelistedHeader("AccepT", std::string(129, 'a'))); +} + +TEST_F(CORSTest, SafelistedAcceptLanguage) { + EXPECT_TRUE(IsCORSSafelistedHeader("accept-language", "en,ja")); + EXPECT_TRUE(IsCORSSafelistedHeader("aCcEPT-lAngUAge", "en,ja")); + + constexpr char kAllowed[] = + "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz *,-.;="; + for (int i = CHAR_MIN; i <= CHAR_MAX; ++i) { + SCOPED_TRACE(testing::Message() << "c = static_cast<char>(" << i << ")"); + char c = static_cast<char>(i); + // 1 for the trailing null character. + auto* end = kAllowed + base::size(kAllowed) - 1; + EXPECT_EQ(std::find(kAllowed, end, c) != end, + IsCORSSafelistedHeader("aCcEPT-lAngUAge", std::string(1, c))); + } + EXPECT_TRUE(IsCORSSafelistedHeader("accept-language", std::string(128, 'a'))); + EXPECT_FALSE( + IsCORSSafelistedHeader("accept-language", std::string(129, 'a'))); + EXPECT_TRUE(IsCORSSafelistedHeader("aCcEPT-lAngUAge", std::string(128, 'a'))); + EXPECT_FALSE( + IsCORSSafelistedHeader("aCcEPT-lAngUAge", std::string(129, 'a'))); +} + +TEST_F(CORSTest, SafelistedContentLanguage) { + EXPECT_TRUE(IsCORSSafelistedHeader("content-language", "en,ja")); + EXPECT_TRUE(IsCORSSafelistedHeader("cONTent-LANguaGe", "en,ja")); + + constexpr char kAllowed[] = + "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz *,-.;="; + for (int i = CHAR_MIN; i <= CHAR_MAX; ++i) { + SCOPED_TRACE(testing::Message() << "c = static_cast<char>(" << i << ")"); + char c = static_cast<char>(i); + // 1 for the trailing null character. + auto* end = kAllowed + base::size(kAllowed) - 1; + EXPECT_EQ(std::find(kAllowed, end, c) != end, + IsCORSSafelistedHeader("content-language", std::string(1, c))); + EXPECT_EQ(std::find(kAllowed, end, c) != end, + IsCORSSafelistedHeader("cONTent-LANguaGe", std::string(1, c))); + } + EXPECT_TRUE( + IsCORSSafelistedHeader("content-language", std::string(128, 'a'))); + EXPECT_FALSE( + IsCORSSafelistedHeader("content-language", std::string(129, 'a'))); EXPECT_TRUE( - cors::IsCORSSafelistedContentType("application/x-www-form-urlencoded")); - EXPECT_TRUE(cors::IsCORSSafelistedContentType("multipart/form-data")); - EXPECT_TRUE(cors::IsCORSSafelistedContentType("text/plain")); - EXPECT_TRUE(cors::IsCORSSafelistedContentType("TEXT/PLAIN")); - EXPECT_TRUE(cors::IsCORSSafelistedContentType("text/plain;charset=utf-8")); - EXPECT_TRUE(cors::IsCORSSafelistedContentType(" text/plain ;charset=utf-8")); - EXPECT_FALSE(cors::IsCORSSafelistedContentType("text/html")); - - // Header check should be case-insensitive. Value must be considered only for - // Content-Type. - EXPECT_TRUE(cors::IsCORSSafelistedHeader("accept", "text/html")); - EXPECT_TRUE(cors::IsCORSSafelistedHeader("Accept-Language", "en")); - EXPECT_TRUE(cors::IsCORSSafelistedHeader("Content-Language", "ja")); - EXPECT_TRUE(cors::IsCORSSafelistedHeader("SAVE-DATA", "on")); - EXPECT_TRUE(cors::IsCORSSafelistedHeader("Intervention", "")); - EXPECT_FALSE(cors::IsCORSSafelistedHeader("Cache-Control", "")); - EXPECT_TRUE(cors::IsCORSSafelistedHeader("Content-Type", "text/plain")); - EXPECT_FALSE(cors::IsCORSSafelistedHeader("Content-Type", "image/png")); + IsCORSSafelistedHeader("cONTent-LANguaGe", std::string(128, 'a'))); + EXPECT_FALSE( + IsCORSSafelistedHeader("cONTent-LANguaGe", std::string(129, 'a'))); +} + +TEST_F(CORSTest, SafelistedContentType) { + EXPECT_TRUE(IsCORSSafelistedHeader("content-type", "text/plain")); + EXPECT_TRUE(IsCORSSafelistedHeader("CoNtEnt-TyPE", "text/plain")); + EXPECT_TRUE( + IsCORSSafelistedHeader("content-type", "text/plain; charset=utf-8")); + EXPECT_TRUE( + IsCORSSafelistedHeader("content-type", " text/plain ; charset=UTF-8")); + EXPECT_TRUE( + IsCORSSafelistedHeader("content-type", "text/plain; param=BOGUS")); + EXPECT_TRUE(IsCORSSafelistedHeader("content-type", + "application/x-www-form-urlencoded")); + EXPECT_TRUE(IsCORSSafelistedHeader("content-type", "multipart/form-data")); + + EXPECT_TRUE(IsCORSSafelistedHeader("content-type", "Text/plain")); + EXPECT_TRUE(IsCORSSafelistedHeader("content-type", "tEXT/PLAIN")); + EXPECT_FALSE(IsCORSSafelistedHeader("content-type", "text/html")); + EXPECT_FALSE(IsCORSSafelistedHeader("CoNtEnt-TyPE", "text/html")); + + EXPECT_FALSE(IsCORSSafelistedHeader("content-type", "image/png")); + EXPECT_FALSE(IsCORSSafelistedHeader("CoNtEnt-TyPE", "image/png")); + EXPECT_TRUE(IsCORSSafelistedHeader( + "content-type", "text/plain; charset=" + std::string(108, 'a'))); + EXPECT_TRUE(IsCORSSafelistedHeader( + "cONTent-tYPE", "text/plain; charset=" + std::string(108, 'a'))); + EXPECT_FALSE(IsCORSSafelistedHeader( + "content-type", "text/plain; charset=" + std::string(109, 'a'))); + EXPECT_FALSE(IsCORSSafelistedHeader( + "cONTent-tYPE", "text/plain; charset=" + std::string(109, 'a'))); } TEST_F(CORSTest, CheckCORSClientHintsSafelist) { - EXPECT_FALSE(cors::IsCORSSafelistedHeader("device-memory", "")); - EXPECT_FALSE(cors::IsCORSSafelistedHeader("device-memory", "abc")); - EXPECT_TRUE(cors::IsCORSSafelistedHeader("device-memory", "1.25")); - EXPECT_TRUE(cors::IsCORSSafelistedHeader("DEVICE-memory", "1.25")); - EXPECT_FALSE(cors::IsCORSSafelistedHeader("device-memory", "1.25-2.5")); - EXPECT_FALSE(cors::IsCORSSafelistedHeader("device-memory", "-1.25")); - EXPECT_FALSE(cors::IsCORSSafelistedHeader("device-memory", "1e2")); - EXPECT_FALSE(cors::IsCORSSafelistedHeader("device-memory", "inf")); - EXPECT_FALSE(cors::IsCORSSafelistedHeader("device-memory", "-2.3")); - EXPECT_FALSE(cors::IsCORSSafelistedHeader("device-memory", "NaN")); - EXPECT_FALSE(cors::IsCORSSafelistedHeader("DEVICE-memory", "1.25.3")); - EXPECT_FALSE(cors::IsCORSSafelistedHeader("DEVICE-memory", "1.")); - EXPECT_FALSE(cors::IsCORSSafelistedHeader("DEVICE-memory", ".1")); - EXPECT_FALSE(cors::IsCORSSafelistedHeader("DEVICE-memory", ".")); - EXPECT_TRUE(cors::IsCORSSafelistedHeader("DEVICE-memory", "1")); - - EXPECT_FALSE(cors::IsCORSSafelistedHeader("dpr", "")); - EXPECT_FALSE(cors::IsCORSSafelistedHeader("dpr", "abc")); - EXPECT_TRUE(cors::IsCORSSafelistedHeader("dpr", "1.25")); - EXPECT_TRUE(cors::IsCORSSafelistedHeader("Dpr", "1.25")); - EXPECT_FALSE(cors::IsCORSSafelistedHeader("dpr", "1.25-2.5")); - EXPECT_FALSE(cors::IsCORSSafelistedHeader("dpr", "-1.25")); - EXPECT_FALSE(cors::IsCORSSafelistedHeader("dpr", "1e2")); - EXPECT_FALSE(cors::IsCORSSafelistedHeader("dpr", "inf")); - EXPECT_FALSE(cors::IsCORSSafelistedHeader("dpr", "-2.3")); - EXPECT_FALSE(cors::IsCORSSafelistedHeader("dpr", "NaN")); - EXPECT_FALSE(cors::IsCORSSafelistedHeader("dpr", "1.25.3")); - EXPECT_FALSE(cors::IsCORSSafelistedHeader("dpr", "1.")); - EXPECT_FALSE(cors::IsCORSSafelistedHeader("dpr", ".1")); - EXPECT_FALSE(cors::IsCORSSafelistedHeader("dpr", ".")); - EXPECT_TRUE(cors::IsCORSSafelistedHeader("dpr", "1")); - - EXPECT_FALSE(cors::IsCORSSafelistedHeader("width", "")); - EXPECT_FALSE(cors::IsCORSSafelistedHeader("width", "abc")); - EXPECT_TRUE(cors::IsCORSSafelistedHeader("width", "125")); - EXPECT_TRUE(cors::IsCORSSafelistedHeader("width", "1")); - EXPECT_TRUE(cors::IsCORSSafelistedHeader("WIDTH", "125")); - EXPECT_FALSE(cors::IsCORSSafelistedHeader("width", "125.2")); - EXPECT_FALSE(cors::IsCORSSafelistedHeader("width", "-125")); - EXPECT_TRUE(cors::IsCORSSafelistedHeader("width", "2147483648")); - - EXPECT_FALSE(cors::IsCORSSafelistedHeader("viewport-width", "")); - EXPECT_FALSE(cors::IsCORSSafelistedHeader("viewport-width", "abc")); - EXPECT_TRUE(cors::IsCORSSafelistedHeader("viewport-width", "125")); - EXPECT_TRUE(cors::IsCORSSafelistedHeader("viewport-width", "1")); - EXPECT_TRUE(cors::IsCORSSafelistedHeader("viewport-Width", "125")); - EXPECT_FALSE(cors::IsCORSSafelistedHeader("viewport-width", "125.2")); - EXPECT_TRUE(cors::IsCORSSafelistedHeader("viewport-width", "2147483648")); + EXPECT_FALSE(IsCORSSafelistedHeader("device-memory", "")); + EXPECT_FALSE(IsCORSSafelistedHeader("device-memory", "abc")); + EXPECT_TRUE(IsCORSSafelistedHeader("device-memory", "1.25")); + EXPECT_TRUE(IsCORSSafelistedHeader("DEVICE-memory", "1.25")); + EXPECT_FALSE(IsCORSSafelistedHeader("device-memory", "1.25-2.5")); + EXPECT_FALSE(IsCORSSafelistedHeader("device-memory", "-1.25")); + EXPECT_FALSE(IsCORSSafelistedHeader("device-memory", "1e2")); + EXPECT_FALSE(IsCORSSafelistedHeader("device-memory", "inf")); + EXPECT_FALSE(IsCORSSafelistedHeader("device-memory", "-2.3")); + EXPECT_FALSE(IsCORSSafelistedHeader("device-memory", "NaN")); + EXPECT_FALSE(IsCORSSafelistedHeader("DEVICE-memory", "1.25.3")); + EXPECT_FALSE(IsCORSSafelistedHeader("DEVICE-memory", "1.")); + EXPECT_FALSE(IsCORSSafelistedHeader("DEVICE-memory", ".1")); + EXPECT_FALSE(IsCORSSafelistedHeader("DEVICE-memory", ".")); + EXPECT_TRUE(IsCORSSafelistedHeader("DEVICE-memory", "1")); + + EXPECT_FALSE(IsCORSSafelistedHeader("dpr", "")); + EXPECT_FALSE(IsCORSSafelistedHeader("dpr", "abc")); + EXPECT_TRUE(IsCORSSafelistedHeader("dpr", "1.25")); + EXPECT_TRUE(IsCORSSafelistedHeader("Dpr", "1.25")); + EXPECT_FALSE(IsCORSSafelistedHeader("dpr", "1.25-2.5")); + EXPECT_FALSE(IsCORSSafelistedHeader("dpr", "-1.25")); + EXPECT_FALSE(IsCORSSafelistedHeader("dpr", "1e2")); + EXPECT_FALSE(IsCORSSafelistedHeader("dpr", "inf")); + EXPECT_FALSE(IsCORSSafelistedHeader("dpr", "-2.3")); + EXPECT_FALSE(IsCORSSafelistedHeader("dpr", "NaN")); + EXPECT_FALSE(IsCORSSafelistedHeader("dpr", "1.25.3")); + EXPECT_FALSE(IsCORSSafelistedHeader("dpr", "1.")); + EXPECT_FALSE(IsCORSSafelistedHeader("dpr", ".1")); + EXPECT_FALSE(IsCORSSafelistedHeader("dpr", ".")); + EXPECT_TRUE(IsCORSSafelistedHeader("dpr", "1")); + + EXPECT_FALSE(IsCORSSafelistedHeader("width", "")); + EXPECT_FALSE(IsCORSSafelistedHeader("width", "abc")); + EXPECT_TRUE(IsCORSSafelistedHeader("width", "125")); + EXPECT_TRUE(IsCORSSafelistedHeader("width", "1")); + EXPECT_TRUE(IsCORSSafelistedHeader("WIDTH", "125")); + EXPECT_FALSE(IsCORSSafelistedHeader("width", "125.2")); + EXPECT_FALSE(IsCORSSafelistedHeader("width", "-125")); + EXPECT_TRUE(IsCORSSafelistedHeader("width", "2147483648")); + + EXPECT_FALSE(IsCORSSafelistedHeader("viewport-width", "")); + EXPECT_FALSE(IsCORSSafelistedHeader("viewport-width", "abc")); + EXPECT_TRUE(IsCORSSafelistedHeader("viewport-width", "125")); + EXPECT_TRUE(IsCORSSafelistedHeader("viewport-width", "1")); + EXPECT_TRUE(IsCORSSafelistedHeader("viewport-Width", "125")); + EXPECT_FALSE(IsCORSSafelistedHeader("viewport-width", "125.2")); + EXPECT_TRUE(IsCORSSafelistedHeader("viewport-width", "2147483648")); } -} // namespace +TEST_F(CORSTest, CORSUnsafeRequestHeaderNames) { + // Needed because initializer list is not allowed for a macro argument. + using List = std::vector<std::string>; + + // Empty => Empty + EXPECT_EQ(CORSUnsafeRequestHeaderNames({}), List({})); + + // Some headers are safelisted. + EXPECT_EQ(CORSUnsafeRequestHeaderNames({{"content-type", "text/plain"}, + {"dpr", "12345"}, + {"aCCept", "en,ja"}, + {"accept-charset", "utf-8"}, + {"uSer-Agent", "foo"}, + {"hogE", "fuga"}}), + List({"accept-charset", "user-agent", "hoge"})); + + // All headers are not safelisted. + EXPECT_EQ( + CORSUnsafeRequestHeaderNames({{"content-type", "text/html"}, + {"dpr", "123-45"}, + {"aCCept", "en,ja"}, + {"accept-charset", "utf-8"}, + {"uSer-Agent", "foo"}, + {"hogE", "fuga"}}), + List({"content-type", "dpr", "accept-charset", "user-agent", "hoge"})); + + // |safelistValueSize| is 1024. + EXPECT_EQ( + CORSUnsafeRequestHeaderNames( + {{"content-type", "text/plain; charset=" + std::string(108, '1')}, + {"accept", std::string(128, '1')}, + {"accept-language", std::string(128, '1')}, + {"content-language", std::string(128, '1')}, + {"dpr", std::string(128, '1')}, + {"device-memory", std::string(128, '1')}, + {"save-data", "on"}, + {"viewport-width", std::string(128, '1')}, + {"width", std::string(126, '1')}, + {"hogE", "fuga"}}), + List({"hoge"})); + + // |safelistValueSize| is 1025. + EXPECT_EQ( + CORSUnsafeRequestHeaderNames( + {{"content-type", "text/plain; charset=" + std::string(108, '1')}, + {"accept", std::string(128, '1')}, + {"accept-language", std::string(128, '1')}, + {"content-language", std::string(128, '1')}, + {"dpr", std::string(128, '1')}, + {"device-memory", std::string(128, '1')}, + {"save-data", "on"}, + {"viewport-width", std::string(128, '1')}, + {"width", std::string(127, '1')}, + {"hogE", "fuga"}}), + List({"hoge", "content-type", "accept", "accept-language", + "content-language", "dpr", "device-memory", "save-data", + "viewport-width", "width"})); + + // |safelistValueSize| is 897 because "content-type" is not safelisted. + EXPECT_EQ( + CORSUnsafeRequestHeaderNames( + {{"content-type", "text/plain; charset=" + std::string(128, '1')}, + {"accept", std::string(128, '1')}, + {"accept-language", std::string(128, '1')}, + {"content-language", std::string(128, '1')}, + {"dpr", std::string(128, '1')}, + {"device-memory", std::string(128, '1')}, + {"save-data", "on"}, + {"viewport-width", std::string(128, '1')}, + {"width", std::string(127, '1')}, + {"hogE", "fuga"}}), + List({"content-type", "hoge"})); +} +TEST_F(CORSTest, CORSUnsafeNotForbiddenRequestHeaderNames) { + // Needed because initializer list is not allowed for a macro argument. + using List = std::vector<std::string>; + + // Empty => Empty + EXPECT_EQ( + CORSUnsafeNotForbiddenRequestHeaderNames({}, false /* is_revalidating */), + List({})); + + // "user-agent" is NOT forbidden per spec, but forbidden in Chromium. + EXPECT_EQ( + CORSUnsafeNotForbiddenRequestHeaderNames({{"content-type", "text/plain"}, + {"dpr", "12345"}, + {"aCCept", "en,ja"}, + {"accept-charset", "utf-8"}, + {"uSer-Agent", "foo"}, + {"hogE", "fuga"}}, + false /* is_revalidating */), + List({"hoge"})); + + EXPECT_EQ( + CORSUnsafeNotForbiddenRequestHeaderNames({{"content-type", "text/html"}, + {"dpr", "123-45"}, + {"aCCept", "en,ja"}, + {"accept-charset", "utf-8"}, + {"hogE", "fuga"}}, + false /* is_revalidating */), + List({"content-type", "dpr", "hoge"})); + + // |safelistValueSize| is 1024. + EXPECT_EQ( + CORSUnsafeNotForbiddenRequestHeaderNames( + {{"content-type", "text/plain; charset=" + std::string(108, '1')}, + {"accept", std::string(128, '1')}, + {"accept-language", std::string(128, '1')}, + {"content-language", std::string(128, '1')}, + {"dpr", std::string(128, '1')}, + {"device-memory", std::string(128, '1')}, + {"save-data", "on"}, + {"viewport-width", std::string(128, '1')}, + {"width", std::string(126, '1')}, + {"accept-charset", "utf-8"}, + {"hogE", "fuga"}}, + false /* is_revalidating */), + List({"hoge"})); + + // |safelistValueSize| is 1025. + EXPECT_EQ( + CORSUnsafeNotForbiddenRequestHeaderNames( + {{"content-type", "text/plain; charset=" + std::string(108, '1')}, + {"accept", std::string(128, '1')}, + {"accept-language", std::string(128, '1')}, + {"content-language", std::string(128, '1')}, + {"dpr", std::string(128, '1')}, + {"device-memory", std::string(128, '1')}, + {"save-data", "on"}, + {"viewport-width", std::string(128, '1')}, + {"width", std::string(127, '1')}, + {"accept-charset", "utf-8"}, + {"hogE", "fuga"}}, + false /* is_revalidating */), + List({"hoge", "content-type", "accept", "accept-language", + "content-language", "dpr", "device-memory", "save-data", + "viewport-width", "width"})); + + // |safelistValueSize| is 897 because "content-type" is not safelisted. + EXPECT_EQ( + CORSUnsafeNotForbiddenRequestHeaderNames( + {{"content-type", "text/plain; charset=" + std::string(128, '1')}, + {"accept", std::string(128, '1')}, + {"accept-language", std::string(128, '1')}, + {"content-language", std::string(128, '1')}, + {"dpr", std::string(128, '1')}, + {"device-memory", std::string(128, '1')}, + {"save-data", "on"}, + {"viewport-width", std::string(128, '1')}, + {"width", std::string(127, '1')}, + {"accept-charset", "utf-8"}, + {"hogE", "fuga"}}, + false /* is_revalidating */), + List({"content-type", "hoge"})); +} + +TEST_F(CORSTest, CORSUnsafeNotForbiddenRequestHeaderNamesWithRevalidating) { + // Needed because initializer list is not allowed for a macro argument. + using List = std::vector<std::string>; + + // Empty => Empty + EXPECT_EQ( + CORSUnsafeNotForbiddenRequestHeaderNames({}, true /* is_revalidating */), + List({})); + + // These three headers will be ignored. + EXPECT_EQ( + CORSUnsafeNotForbiddenRequestHeaderNames({{"If-MODifIED-since", "x"}, + {"iF-nONE-MATCh", "y"}, + {"CACHE-ContrOl", "z"}}, + true /* is_revalidating */), + List({})); + + // Without is_revalidating set, these three headers will not be safelisted. + EXPECT_EQ( + CORSUnsafeNotForbiddenRequestHeaderNames({{"If-MODifIED-since", "x"}, + {"iF-nONE-MATCh", "y"}, + {"CACHE-ContrOl", "z"}}, + false /* is_revalidating */), + List({"if-modified-since", "if-none-match", "cache-control"})); +} + +} // namespace +} // namespace cors } // namespace network diff --git a/chromium/services/network/public/cpp/cors/origin_access_entry.cc b/chromium/services/network/public/cpp/cors/origin_access_entry.cc index 6b7400d2cb6..26361e43d12 100644 --- a/chromium/services/network/public/cpp/cors/origin_access_entry.cc +++ b/chromium/services/network/public/cpp/cors/origin_access_entry.cc @@ -30,12 +30,15 @@ bool IsSubdomainOfHost(const std::string& subdomain, const std::string& host) { } // namespace -OriginAccessEntry::OriginAccessEntry(const std::string& protocol, - const std::string& host, - MatchMode match_mode) +OriginAccessEntry::OriginAccessEntry( + const std::string& protocol, + const std::string& host, + MatchMode match_mode, + const network::mojom::CORSOriginAccessMatchPriority priority) : protocol_(protocol), host_(host), match_mode_(match_mode), + priority_(priority), host_is_ip_address_(url::HostIsIPAddress(host)), host_is_public_suffix_(false) { if (host_is_ip_address_) diff --git a/chromium/services/network/public/cpp/cors/origin_access_entry.h b/chromium/services/network/public/cpp/cors/origin_access_entry.h index 46a35c272fd..5052f1aec4e 100644 --- a/chromium/services/network/public/cpp/cors/origin_access_entry.h +++ b/chromium/services/network/public/cpp/cors/origin_access_entry.h @@ -9,6 +9,7 @@ #include "base/component_export.h" #include "base/macros.h" +#include "services/network/public/mojom/cors.mojom-shared.h" namespace url { class Origin; @@ -46,9 +47,14 @@ class COMPONENT_EXPORT(NETWORK_CPP) OriginAccessEntry final { // will match all domains in the specified protocol. // IPv6 addresses must include brackets (e.g. // '[2001:db8:85a3::8a2e:370:7334]', not '2001:db8:85a3::8a2e:370:7334'). - OriginAccessEntry(const std::string& protocol, - const std::string& host, - MatchMode match_mode); + // The priority argument is used to break ties when multiple entries + // match. + OriginAccessEntry( + const std::string& protocol, + const std::string& host, + MatchMode match_mode, + const network::mojom::CORSOriginAccessMatchPriority priority = + network::mojom::CORSOriginAccessMatchPriority::kDefaultPriority); OriginAccessEntry(OriginAccessEntry&& from); // 'matchesOrigin' requires a protocol match (e.g. 'http' != 'https'). @@ -57,6 +63,9 @@ class COMPONENT_EXPORT(NETWORK_CPP) OriginAccessEntry final { MatchResult MatchesDomain(const url::Origin& domain) const; bool host_is_ip_address() const { return host_is_ip_address_; } + network::mojom::CORSOriginAccessMatchPriority priority() const { + return priority_; + } const std::string& registerable_domain() const { return registerable_domain_; } @@ -65,6 +74,7 @@ class COMPONENT_EXPORT(NETWORK_CPP) OriginAccessEntry final { const std::string protocol_; const std::string host_; const MatchMode match_mode_; + network::mojom::CORSOriginAccessMatchPriority priority_; const bool host_is_ip_address_; std::string registerable_domain_; diff --git a/chromium/services/network/public/cpp/cors/origin_access_entry_unittest.cc b/chromium/services/network/public/cpp/cors/origin_access_entry_unittest.cc index d814028c340..a68e978baf5 100644 --- a/chromium/services/network/public/cpp/cors/origin_access_entry_unittest.cc +++ b/chromium/services/network/public/cpp/cors/origin_access_entry_unittest.cc @@ -3,6 +3,7 @@ // found in the LICENSE file. #include "services/network/public/cpp/cors/origin_access_entry.h" +#include "services/network/public/mojom/cors.mojom.h" #include "testing/gtest/include/gtest/gtest.h" #include "url/gurl.h" @@ -16,11 +17,15 @@ namespace { TEST(OriginAccessEntryTest, PublicSuffixListTest) { url::Origin origin = url::Origin::Create(GURL("http://www.google.com")); - OriginAccessEntry entry1("http", "google.com", - OriginAccessEntry::kAllowSubdomains); - OriginAccessEntry entry2("http", "hamster.com", - OriginAccessEntry::kAllowSubdomains); - OriginAccessEntry entry3("http", "com", OriginAccessEntry::kAllowSubdomains); + OriginAccessEntry entry1( + "http", "google.com", OriginAccessEntry::kAllowSubdomains, + network::mojom::CORSOriginAccessMatchPriority::kDefaultPriority); + OriginAccessEntry entry2( + "http", "hamster.com", OriginAccessEntry::kAllowSubdomains, + network::mojom::CORSOriginAccessMatchPriority::kDefaultPriority); + OriginAccessEntry entry3( + "http", "com", OriginAccessEntry::kAllowSubdomains, + network::mojom::CORSOriginAccessMatchPriority::kDefaultPriority); EXPECT_EQ(OriginAccessEntry::kMatchesOrigin, entry1.MatchesOrigin(origin)); EXPECT_EQ(OriginAccessEntry::kDoesNotMatchOrigin, entry2.MatchesOrigin(origin)); @@ -86,8 +91,9 @@ TEST(OriginAccessEntryTest, AllowSubdomainsTest) { SCOPED_TRACE(testing::Message() << "Host: " << test.host << ", Origin: " << test.origin); url::Origin origin_to_test = url::Origin::Create(GURL(test.origin)); - OriginAccessEntry entry1(test.protocol, test.host, - OriginAccessEntry::kAllowSubdomains); + OriginAccessEntry entry1( + test.protocol, test.host, OriginAccessEntry::kAllowSubdomains, + network::mojom::CORSOriginAccessMatchPriority::kDefaultPriority); EXPECT_EQ(test.expected_origin, entry1.MatchesOrigin(origin_to_test)); EXPECT_EQ(test.expected_domain, entry1.MatchesDomain(origin_to_test)); } @@ -134,8 +140,9 @@ TEST(OriginAccessEntryTest, AllowRegisterableDomainsTest) { for (const auto& test : inputs) { url::Origin origin_to_test = url::Origin::Create(GURL(test.origin)); - OriginAccessEntry entry1(test.protocol, test.host, - OriginAccessEntry::kAllowRegisterableDomains); + OriginAccessEntry entry1( + test.protocol, test.host, OriginAccessEntry::kAllowRegisterableDomains, + network::mojom::CORSOriginAccessMatchPriority::kDefaultPriority); SCOPED_TRACE(testing::Message() << "Host: " << test.host << ", Origin: " << test.origin @@ -186,8 +193,9 @@ TEST(OriginAccessEntryTest, AllowRegisterableDomainsTestWithDottedSuffix) { for (const auto& test : inputs) { url::Origin origin_to_test = url::Origin::Create(GURL(test.origin)); - OriginAccessEntry entry1(test.protocol, test.host, - OriginAccessEntry::kAllowRegisterableDomains); + OriginAccessEntry entry1( + test.protocol, test.host, OriginAccessEntry::kAllowRegisterableDomains, + network::mojom::CORSOriginAccessMatchPriority::kDefaultPriority); SCOPED_TRACE(testing::Message() << "Host: " << test.host << ", Origin: " << test.origin @@ -235,8 +243,9 @@ TEST(OriginAccessEntryTest, DisallowSubdomainsTest) { SCOPED_TRACE(testing::Message() << "Host: " << test.host << ", Origin: " << test.origin); url::Origin origin_to_test = url::Origin::Create(GURL(test.origin)); - OriginAccessEntry entry1(test.protocol, test.host, - OriginAccessEntry::kDisallowSubdomains); + OriginAccessEntry entry1( + test.protocol, test.host, OriginAccessEntry::kDisallowSubdomains, + network::mojom::CORSOriginAccessMatchPriority::kDefaultPriority); EXPECT_EQ(test.expected, entry1.MatchesOrigin(origin_to_test)); } } @@ -260,8 +269,9 @@ TEST(OriginAccessEntryTest, IPAddressTest) { for (const auto& test : inputs) { SCOPED_TRACE(testing::Message() << "Host: " << test.host); - OriginAccessEntry entry(test.protocol, test.host, - OriginAccessEntry::kDisallowSubdomains); + OriginAccessEntry entry( + test.protocol, test.host, OriginAccessEntry::kDisallowSubdomains, + network::mojom::CORSOriginAccessMatchPriority::kDefaultPriority); EXPECT_EQ(test.is_ip_address, entry.host_is_ip_address()) << test.host; } } @@ -287,12 +297,14 @@ TEST(OriginAccessEntryTest, IPAddressMatchingTest) { SCOPED_TRACE(testing::Message() << "Host: " << test.host << ", Origin: " << test.origin); url::Origin origin_to_test = url::Origin::Create(GURL(test.origin)); - OriginAccessEntry entry1(test.protocol, test.host, - OriginAccessEntry::kAllowSubdomains); + OriginAccessEntry entry1( + test.protocol, test.host, OriginAccessEntry::kAllowSubdomains, + network::mojom::CORSOriginAccessMatchPriority::kDefaultPriority); EXPECT_EQ(test.expected, entry1.MatchesOrigin(origin_to_test)); - OriginAccessEntry entry2(test.protocol, test.host, - OriginAccessEntry::kDisallowSubdomains); + OriginAccessEntry entry2( + test.protocol, test.host, OriginAccessEntry::kDisallowSubdomains, + network::mojom::CORSOriginAccessMatchPriority::kDefaultPriority); EXPECT_EQ(test.expected, entry2.MatchesOrigin(origin_to_test)); } } diff --git a/chromium/services/network/public/cpp/cors/origin_access_list.cc b/chromium/services/network/public/cpp/cors/origin_access_list.cc index 96420a396bd..150d9a0aafb 100644 --- a/chromium/services/network/public/cpp/cors/origin_access_list.cc +++ b/chromium/services/network/public/cpp/cors/origin_access_list.cc @@ -14,15 +14,18 @@ OriginAccessList::~OriginAccessList() = default; void OriginAccessList::SetAllowListForOrigin( const url::Origin& source_origin, const std::vector<mojom::CorsOriginPatternPtr>& patterns) { - SetForOrigin(source_origin, patterns, &allow_list_); + SetForOrigin(source_origin, patterns, &allow_list_, + network::mojom::CORSOriginAccessMatchPriority::kDefaultPriority); } void OriginAccessList::AddAllowListEntryForOrigin( const url::Origin& source_origin, const std::string& protocol, const std::string& domain, - bool allow_subdomains) { - AddForOrigin(source_origin, protocol, domain, allow_subdomains, &allow_list_); + bool allow_subdomains, + const network::mojom::CORSOriginAccessMatchPriority priority) { + AddForOrigin(source_origin, protocol, domain, allow_subdomains, &allow_list_, + priority); } void OriginAccessList::ClearAllowList() { @@ -32,15 +35,18 @@ void OriginAccessList::ClearAllowList() { void OriginAccessList::SetBlockListForOrigin( const url::Origin& source_origin, const std::vector<mojom::CorsOriginPatternPtr>& patterns) { - SetForOrigin(source_origin, patterns, &block_list_); + SetForOrigin(source_origin, patterns, &block_list_, + network::mojom::CORSOriginAccessMatchPriority::kDefaultPriority); } void OriginAccessList::AddBlockListEntryForOrigin( const url::Origin& source_origin, const std::string& protocol, const std::string& domain, - bool allow_subdomains) { - AddForOrigin(source_origin, protocol, domain, allow_subdomains, &block_list_); + bool allow_subdomains, + const network::mojom::CORSOriginAccessMatchPriority priority) { + AddForOrigin(source_origin, protocol, domain, allow_subdomains, &block_list_, + priority); } void OriginAccessList::ClearBlockList() { @@ -49,21 +55,33 @@ void OriginAccessList::ClearBlockList() { bool OriginAccessList::IsAllowed(const url::Origin& source_origin, const GURL& destination) const { - if (source_origin.unique()) + if (source_origin.opaque()) return false; std::string source = source_origin.Serialize(); url::Origin destination_origin = url::Origin::Create(destination); - return IsInMapForOrigin(source, destination_origin, allow_list_) && - !IsInMapForOrigin(source, destination_origin, block_list_); + network::mojom::CORSOriginAccessMatchPriority allow_list_priority = + GetHighestPriorityOfRuleForOrigin(source, destination_origin, + allow_list_); + if (allow_list_priority == + network::mojom::CORSOriginAccessMatchPriority::kNoMatchingOrigin) + return false; + network::mojom::CORSOriginAccessMatchPriority block_list_priority = + GetHighestPriorityOfRuleForOrigin(source, destination_origin, + block_list_); + if (block_list_priority == + network::mojom::CORSOriginAccessMatchPriority::kNoMatchingOrigin) + return true; + return allow_list_priority > block_list_priority; } // static void OriginAccessList::SetForOrigin( const url::Origin& source_origin, const std::vector<mojom::CorsOriginPatternPtr>& patterns, - PatternMap* map) { + PatternMap* map, + const network::mojom::CORSOriginAccessMatchPriority priority) { DCHECK(map); - DCHECK(!source_origin.unique()); + DCHECK(!source_origin.opaque()); std::string source = source_origin.Serialize(); map->erase(source); @@ -75,40 +93,50 @@ void OriginAccessList::SetForOrigin( native_patterns.push_back(OriginAccessEntry( pattern->protocol, pattern->domain, pattern->allow_subdomains ? OriginAccessEntry::kAllowSubdomains - : OriginAccessEntry::kDisallowSubdomains)); + : OriginAccessEntry::kDisallowSubdomains, + priority)); } } // static -void OriginAccessList::AddForOrigin(const url::Origin& source_origin, - const std::string& protocol, - const std::string& domain, - bool allow_subdomains, - PatternMap* map) { +void OriginAccessList::AddForOrigin( + const url::Origin& source_origin, + const std::string& protocol, + const std::string& domain, + bool allow_subdomains, + PatternMap* map, + const network::mojom::CORSOriginAccessMatchPriority priority) { DCHECK(map); - DCHECK(!source_origin.unique()); + DCHECK(!source_origin.opaque()); std::string source = source_origin.Serialize(); (*map)[source].push_back(OriginAccessEntry( protocol, domain, allow_subdomains ? OriginAccessEntry::kAllowSubdomains - : OriginAccessEntry::kDisallowSubdomains)); + : OriginAccessEntry::kDisallowSubdomains, + priority)); } // static -bool OriginAccessList::IsInMapForOrigin(const std::string& source, - const url::Origin& destination_origin, - const PatternMap& map) { +// TODO(nrpeter): Sort OriginAccessEntry entries on edit then we can return the +// first match which will be the top priority. +network::mojom::CORSOriginAccessMatchPriority +OriginAccessList::GetHighestPriorityOfRuleForOrigin( + const std::string& source, + const url::Origin& destination_origin, + const PatternMap& map) { + network::mojom::CORSOriginAccessMatchPriority highest_priority = + network::mojom::CORSOriginAccessMatchPriority::kNoMatchingOrigin; auto patterns_for_origin_it = map.find(source); if (patterns_for_origin_it == map.end()) - return false; + return highest_priority; for (const auto& entry : patterns_for_origin_it->second) { if (entry.MatchesOrigin(destination_origin) != OriginAccessEntry::kDoesNotMatchOrigin) { - return true; + highest_priority = std::max(highest_priority, entry.priority()); } } - return false; + return highest_priority; } } // namespace cors diff --git a/chromium/services/network/public/cpp/cors/origin_access_list.h b/chromium/services/network/public/cpp/cors/origin_access_list.h index 518f5ef07ab..0a8cffdc97d 100644 --- a/chromium/services/network/public/cpp/cors/origin_access_list.h +++ b/chromium/services/network/public/cpp/cors/origin_access_list.h @@ -19,9 +19,8 @@ namespace network { namespace cors { -// A class to manage origin access whitelisting. It manages two lists for -// whitelisting and blacklisting. If these lists conflict, blacklisting will be -// respected. These lists are managed per source-origin basis. +// A class to manage origin access allow / block lists. If these lists conflict, +// blacklisting is respected. These lists are managed per source-origin basis. class COMPONENT_EXPORT(NETWORK_CPP) OriginAccessList { public: OriginAccessList(); @@ -34,11 +33,14 @@ class COMPONENT_EXPORT(NETWORK_CPP) OriginAccessList { const std::vector<mojom::CorsOriginPatternPtr>& patterns); // Adds a matching pattern for |protocol|, |domain|, and |allow_subdomains| - // to the allow list. - void AddAllowListEntryForOrigin(const url::Origin& source_origin, - const std::string& protocol, - const std::string& domain, - bool allow_subdomains); + // to the allow list. When two or more entries in a list match the entry + // with the higher |priority| takes precedence. + void AddAllowListEntryForOrigin( + const url::Origin& source_origin, + const std::string& protocol, + const std::string& domain, + bool allow_subdomains, + const network::mojom::CORSOriginAccessMatchPriority priority); // Clears the old allow list. void ClearAllowList(); @@ -50,11 +52,14 @@ class COMPONENT_EXPORT(NETWORK_CPP) OriginAccessList { const std::vector<mojom::CorsOriginPatternPtr>& patterns); // Adds a matching pattern for |protocol|, |domain|, and |allow_subdomains| - // to the block list. - void AddBlockListEntryForOrigin(const url::Origin& source_origin, - const std::string& protocol, - const std::string& domain, - bool allow_subdomains); + // to the block list. When two or more entries in a list match the entry + // with the higher |priority| takes precedence. + void AddBlockListEntryForOrigin( + const url::Origin& source_origin, + const std::string& protocol, + const std::string& domain, + bool allow_subdomains, + const network::mojom::CORSOriginAccessMatchPriority priority); // Clears the old block list. void ClearBlockList(); @@ -71,15 +76,19 @@ class COMPONENT_EXPORT(NETWORK_CPP) OriginAccessList { static void SetForOrigin( const url::Origin& source_origin, const std::vector<mojom::CorsOriginPatternPtr>& patterns, - PatternMap* map); - static void AddForOrigin(const url::Origin& source_origin, - const std::string& protocol, - const std::string& domain, - bool allow_subdomains, - PatternMap* map); - static bool IsInMapForOrigin(const std::string& source, - const url::Origin& destination_origin, - const PatternMap& map); + PatternMap* map, + const network::mojom::CORSOriginAccessMatchPriority priority); + static void AddForOrigin( + const url::Origin& source_origin, + const std::string& protocol, + const std::string& domain, + bool allow_subdomains, + PatternMap* map, + const network::mojom::CORSOriginAccessMatchPriority priority); + static network::mojom::CORSOriginAccessMatchPriority + GetHighestPriorityOfRuleForOrigin(const std::string& source, + const url::Origin& destination_origin, + const PatternMap& map); PatternMap allow_list_; PatternMap block_list_; diff --git a/chromium/services/network/public/cpp/cors/origin_access_list_unittest.cc b/chromium/services/network/public/cpp/cors/origin_access_list_unittest.cc index 8152944e37b..7bab76846e9 100644 --- a/chromium/services/network/public/cpp/cors/origin_access_list_unittest.cc +++ b/chromium/services/network/public/cpp/cors/origin_access_list_unittest.cc @@ -3,6 +3,7 @@ // found in the LICENSE file. #include "services/network/public/cpp/cors/origin_access_list.h" +#include "services/network/public/mojom/cors.mojom.h" #include <memory> @@ -50,29 +51,35 @@ class OriginAccessListTest : public testing::Test { const std::string& host, bool allow_subdomains) { std::vector<mojom::CorsOriginPatternPtr> patterns; - patterns.push_back( - mojom::CorsOriginPattern::New(protocol, host, allow_subdomains)); + patterns.push_back(mojom::CorsOriginPattern::New( + protocol, host, allow_subdomains, + network::mojom::CORSOriginAccessMatchPriority::kDefaultPriority)); list_.SetAllowListForOrigin(source_origin_, patterns); } - void AddAllowListEntry(const std::string& protocol, - const std::string& host, - bool allow_subdomains) { + void AddAllowListEntry( + const std::string& protocol, + const std::string& host, + bool allow_subdomains, + const network::mojom::CORSOriginAccessMatchPriority priority) { list_.AddAllowListEntryForOrigin(source_origin_, protocol, host, - allow_subdomains); + allow_subdomains, priority); } void SetBlockListEntry(const std::string& protocol, const std::string& host, bool allow_subdomains) { std::vector<mojom::CorsOriginPatternPtr> patterns; - patterns.push_back( - mojom::CorsOriginPattern::New(protocol, host, allow_subdomains)); + patterns.push_back(mojom::CorsOriginPattern::New( + protocol, host, allow_subdomains, + network::mojom::CORSOriginAccessMatchPriority::kDefaultPriority)); list_.SetBlockListForOrigin(source_origin_, patterns); } - void AddBlockListEntry(const std::string& protocol, - const std::string& host, - bool allow_subdomains) { + void AddBlockListEntry( + const std::string& protocol, + const std::string& host, + bool allow_subdomains, + const network::mojom::CORSOriginAccessMatchPriority priority) { list_.AddBlockListEntryForOrigin(source_origin_, protocol, host, - allow_subdomains); + allow_subdomains, priority); } void ResetLists() { std::vector<mojom::CorsOriginPatternPtr> patterns; @@ -114,7 +121,9 @@ TEST_F(OriginAccessListTest, IsAccessAllowed) { // Adding an entry that matches subdomains should grant access to any // subdomains. - AddAllowListEntry("https", "example.com", true); + AddAllowListEntry( + "https", "example.com", true, + network::mojom::CORSOriginAccessMatchPriority::kDefaultPriority); EXPECT_TRUE(IsAllowed(https_example_origin())); EXPECT_TRUE(IsAllowed(https_sub_example_origin())); EXPECT_FALSE(IsAllowed(http_example_origin())); @@ -139,12 +148,54 @@ TEST_F(OriginAccessListTest, IsAccessAllowedWithBlockListEntry) { TEST_F(OriginAccessListTest, IsAccessAllowedWildcardWithBlockListEntry) { SetAllowListEntry("https", "", true); - AddBlockListEntry("https", "google.com", false); + AddBlockListEntry( + "https", "google.com", false, + network::mojom::CORSOriginAccessMatchPriority::kDefaultPriority); EXPECT_TRUE(IsAllowed(https_example_origin())); EXPECT_FALSE(IsAllowed(https_google_origin())); } +TEST_F(OriginAccessListTest, IsPriorityRespected) { + SetAllowListEntry("https", "example.com", true); + EXPECT_TRUE(IsAllowed(https_example_origin())); + EXPECT_TRUE(IsAllowed(https_sub_example_origin())); + + // Higher priority blocklist overrides lower priority allowlist. + AddBlockListEntry( + "https", "example.com", true, + network::mojom::CORSOriginAccessMatchPriority::kLowPriority); + EXPECT_FALSE(IsAllowed(https_example_origin())); + EXPECT_FALSE(IsAllowed(https_sub_example_origin())); + + // Higher priority allowlist overrides lower priority blocklist. + AddAllowListEntry( + "https", "example.com", false, + network::mojom::CORSOriginAccessMatchPriority::kMediumPriority); + EXPECT_TRUE(IsAllowed(https_example_origin())); + EXPECT_FALSE(IsAllowed(https_sub_example_origin())); +} + +TEST_F(OriginAccessListTest, IsPriorityRespectedReverse) { + AddAllowListEntry( + "https", "example.com", false, + network::mojom::CORSOriginAccessMatchPriority::kMediumPriority); + EXPECT_TRUE(IsAllowed(https_example_origin())); + EXPECT_FALSE(IsAllowed(https_sub_example_origin())); + + AddBlockListEntry( + "https", "example.com", true, + network::mojom::CORSOriginAccessMatchPriority::kLowPriority); + EXPECT_TRUE(IsAllowed(https_example_origin())); + EXPECT_FALSE(IsAllowed(https_sub_example_origin())); + + AddAllowListEntry( + "https", "example.com", true, + network::mojom::CORSOriginAccessMatchPriority::kDefaultPriority); + EXPECT_TRUE(IsAllowed(https_example_origin())); + EXPECT_FALSE(IsAllowed(https_sub_example_origin())); +} + } // namespace } // namespace cors diff --git a/chromium/services/network/public/cpp/cors/preflight_cache.cc b/chromium/services/network/public/cpp/cors/preflight_cache.cc index 7fe92ba5934..bf30102093f 100644 --- a/chromium/services/network/public/cpp/cors/preflight_cache.cc +++ b/chromium/services/network/public/cpp/cors/preflight_cache.cc @@ -26,7 +26,8 @@ bool PreflightCache::CheckIfRequestCanSkipPreflight( const GURL& url, mojom::FetchCredentialsMode credentials_mode, const std::string& method, - const net::HttpRequestHeaders& request_headers) { + const net::HttpRequestHeaders& request_headers, + bool is_revalidating) { // Either |origin| or |url| are not in cache. auto cache_per_origin = cache_.find(origin); if (cache_per_origin == cache_.end()) @@ -38,8 +39,8 @@ bool PreflightCache::CheckIfRequestCanSkipPreflight( // Both |origin| and |url| are in cache. Check if the entry is still valid and // sufficient to skip CORS-preflight. - if (cache_entry->second->EnsureAllowedRequest(credentials_mode, method, - request_headers)) { + if (cache_entry->second->EnsureAllowedRequest( + credentials_mode, method, request_headers, is_revalidating)) { return true; } diff --git a/chromium/services/network/public/cpp/cors/preflight_cache.h b/chromium/services/network/public/cpp/cors/preflight_cache.h index 0a1d6a78795..58337b9b1df 100644 --- a/chromium/services/network/public/cpp/cors/preflight_cache.h +++ b/chromium/services/network/public/cpp/cors/preflight_cache.h @@ -44,7 +44,8 @@ class COMPONENT_EXPORT(NETWORK_CPP) PreflightCache final { const GURL& url, mojom::FetchCredentialsMode credentials_mode, const std::string& method, - const net::HttpRequestHeaders& headers); + const net::HttpRequestHeaders& headers, + bool is_revalidating); // Counts cached origins for testing. size_t CountOriginsForTesting() const; diff --git a/chromium/services/network/public/cpp/cors/preflight_cache_unittest.cc b/chromium/services/network/public/cpp/cors/preflight_cache_unittest.cc index 2ad111bbeaa..e1a3f335729 100644 --- a/chromium/services/network/public/cpp/cors/preflight_cache_unittest.cc +++ b/chromium/services/network/public/cpp/cors/preflight_cache_unittest.cc @@ -35,7 +35,7 @@ class PreflightCacheTest : public testing::Test { bool CheckEntryAndRefreshCache(const std::string& origin, const GURL& url) { return cache_.CheckIfRequestCanSkipPreflight( origin, url, network::mojom::FetchCredentialsMode::kInclude, "POST", - net::HttpRequestHeaders()); + net::HttpRequestHeaders(), false); } void Advance(int seconds) { diff --git a/chromium/services/network/public/cpp/cors/preflight_result.cc b/chromium/services/network/public/cpp/cors/preflight_result.cc index 3cf0f254142..a26a8b7e884 100644 --- a/chromium/services/network/public/cpp/cors/preflight_result.cc +++ b/chromium/services/network/public/cpp/cors/preflight_result.cc @@ -132,24 +132,22 @@ base::Optional<CORSErrorStatus> PreflightResult::EnsureAllowedCrossOriginMethod( base::Optional<CORSErrorStatus> PreflightResult::EnsureAllowedCrossOriginHeaders( - const net::HttpRequestHeaders& headers) const { + const net::HttpRequestHeaders& headers, + bool is_revalidating) const { if (!credentials_ && headers_.find("*") != headers_.end()) return base::nullopt; - for (const auto& header : headers.GetHeaderVector()) { + // Forbidden headers are forbidden to be used by JavaScript, and checked + // beforehand. But user-agents may add these headers internally, and it's + // fine. + for (const auto& name : CORSUnsafeNotForbiddenRequestHeaderNames( + headers.GetHeaderVector(), is_revalidating)) { // Header list check is performed in case-insensitive way. Here, we have a // parsed header list set in lower case, and search each header in lower // case. - const std::string key = base::ToLowerASCII(header.key); - if (headers_.find(key) == headers_.end() && - !IsCORSSafelistedHeader(key, header.value)) { - // Forbidden headers are forbidden to be used by JavaScript, and checked - // beforehand. But user-agents may add these headers internally, and it's - // fine. - if (IsForbiddenHeader(key)) - continue; + if (headers_.find(name) == headers_.end()) { return CORSErrorStatus( - mojom::CORSError::kHeaderDisallowedByPreflightResponse, header.key); + mojom::CORSError::kHeaderDisallowedByPreflightResponse, name); } } return base::nullopt; @@ -158,7 +156,8 @@ PreflightResult::EnsureAllowedCrossOriginHeaders( bool PreflightResult::EnsureAllowedRequest( mojom::FetchCredentialsMode credentials_mode, const std::string& method, - const net::HttpRequestHeaders& headers) const { + const net::HttpRequestHeaders& headers, + bool is_revalidating) const { if (absolute_expiry_time_ <= Now()) return false; @@ -170,7 +169,7 @@ bool PreflightResult::EnsureAllowedRequest( if (EnsureAllowedCrossOriginMethod(method)) return false; - if (EnsureAllowedCrossOriginHeaders(headers)) + if (EnsureAllowedCrossOriginHeaders(headers, is_revalidating)) return false; return true; diff --git a/chromium/services/network/public/cpp/cors/preflight_result.h b/chromium/services/network/public/cpp/cors/preflight_result.h index bd49e120142..0df888a86b7 100644 --- a/chromium/services/network/public/cpp/cors/preflight_result.h +++ b/chromium/services/network/public/cpp/cors/preflight_result.h @@ -56,7 +56,8 @@ class COMPONENT_EXPORT(NETWORK_CPP) PreflightResult final { // added by the user agent. They must be checked separately and rejected for // JavaScript-initiated requests. base::Optional<CORSErrorStatus> EnsureAllowedCrossOriginHeaders( - const net::HttpRequestHeaders& headers) const; + const net::HttpRequestHeaders& headers, + bool is_revalidating) const; // Checks if the given combination of |credentials_mode|, |method|, and // |headers| is allowed by the CORS-preflight response. @@ -64,7 +65,8 @@ class COMPONENT_EXPORT(NETWORK_CPP) PreflightResult final { // EnsureAllowCrossOriginHeaders does not. bool EnsureAllowedRequest(mojom::FetchCredentialsMode credentials_mode, const std::string& method, - const net::HttpRequestHeaders& headers) const; + const net::HttpRequestHeaders& headers, + bool is_revalidating) const; // Refers the cache expiry time. base::TimeTicks absolute_expiry_time() const { return absolute_expiry_time_; } diff --git a/chromium/services/network/public/cpp/cors/preflight_result_unittest.cc b/chromium/services/network/public/cpp/cors/preflight_result_unittest.cc index 2a83e34afec..d1a709cb2d7 100644 --- a/chromium/services/network/public/cpp/cors/preflight_result_unittest.cc +++ b/chromium/services/network/public/cpp/cors/preflight_result_unittest.cc @@ -135,15 +135,15 @@ const TestCase header_cases[] = { {"GET", "", mojom::FetchCredentialsMode::kOmit, "GET", "X-MY-HEADER:t", mojom::FetchCredentialsMode::kOmit, CORSErrorStatus(mojom::CORSError::kHeaderDisallowedByPreflightResponse, - "X-MY-HEADER")}, + "x-my-header")}, {"GET", "X-SOME-OTHER-HEADER", mojom::FetchCredentialsMode::kOmit, "GET", "X-MY-HEADER:t", mojom::FetchCredentialsMode::kOmit, CORSErrorStatus(mojom::CORSError::kHeaderDisallowedByPreflightResponse, - "X-MY-HEADER")}, + "x-my-header")}, {"GET", "X-MY-HEADER", mojom::FetchCredentialsMode::kOmit, "GET", "X-MY-HEADER:t\r\nY-MY-HEADER:t", mojom::FetchCredentialsMode::kOmit, CORSErrorStatus(mojom::CORSError::kHeaderDisallowedByPreflightResponse, - "Y-MY-HEADER")}, + "y-my-header")}, }; TEST_F(PreflightResultTest, MaxAge) { @@ -188,7 +188,7 @@ TEST_F(PreflightResultTest, EnsureHeaders) { net::HttpRequestHeaders headers; headers.AddHeadersFromString(test.request_headers); EXPECT_EQ(test.expected_result, - result->EnsureAllowedCrossOriginHeaders(headers)); + result->EnsureAllowedCrossOriginHeaders(headers, false)); } } @@ -201,9 +201,10 @@ TEST_F(PreflightResultTest, EnsureRequest) { net::HttpRequestHeaders headers; if (!test.request_headers.empty()) headers.AddHeadersFromString(test.request_headers); - EXPECT_EQ(test.expected_result == base::nullopt, - result->EnsureAllowedRequest(test.request_credentials_mode, - test.request_method, headers)); + EXPECT_EQ( + test.expected_result == base::nullopt, + result->EnsureAllowedRequest(test.request_credentials_mode, + test.request_method, headers, false)); } for (const auto& test : header_cases) { @@ -214,9 +215,10 @@ TEST_F(PreflightResultTest, EnsureRequest) { net::HttpRequestHeaders headers; if (!test.request_headers.empty()) headers.AddHeadersFromString(test.request_headers); - EXPECT_EQ(test.expected_result == base::nullopt, - result->EnsureAllowedRequest(test.request_credentials_mode, - test.request_method, headers)); + EXPECT_EQ( + test.expected_result == base::nullopt, + result->EnsureAllowedRequest(test.request_credentials_mode, + test.request_method, headers, false)); } struct { @@ -245,7 +247,7 @@ TEST_F(PreflightResultTest, EnsureRequest) { net::HttpRequestHeaders headers; EXPECT_EQ(test.expected_result, result->EnsureAllowedRequest(test.request_credentials_mode, "GET", - headers)); + headers, false)); } } diff --git a/chromium/services/network/public/cpp/cors/preflight_timing_info.cc b/chromium/services/network/public/cpp/cors/preflight_timing_info.cc new file mode 100644 index 00000000000..2691e8c9ab4 --- /dev/null +++ b/chromium/services/network/public/cpp/cors/preflight_timing_info.cc @@ -0,0 +1,26 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "services/network/public/cpp/cors/preflight_timing_info.h" + +namespace network { + +namespace cors { + +PreflightTimingInfo::PreflightTimingInfo() = default; +PreflightTimingInfo::PreflightTimingInfo(const PreflightTimingInfo& info) = + default; +PreflightTimingInfo::~PreflightTimingInfo() = default; + +bool PreflightTimingInfo::operator==(const PreflightTimingInfo& rhs) const { + return start_time == rhs.start_time && finish_time == rhs.finish_time && + alpn_negotiated_protocol == rhs.alpn_negotiated_protocol && + connection_info == rhs.connection_info && + timing_allow_origin == rhs.timing_allow_origin && + transfer_size == rhs.transfer_size; +} + +} // namespace cors + +} // namespace network diff --git a/chromium/services/network/public/cpp/cors/preflight_timing_info.h b/chromium/services/network/public/cpp/cors/preflight_timing_info.h new file mode 100644 index 00000000000..a59c56f9d18 --- /dev/null +++ b/chromium/services/network/public/cpp/cors/preflight_timing_info.h @@ -0,0 +1,43 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef SERVICES_NETWORK_PUBLIC_CPP_CORS_PREFLIGHT_TIMING_INFO_H_ +#define SERVICES_NETWORK_PUBLIC_CPP_CORS_PREFLIGHT_TIMING_INFO_H_ + +#include <string> + +#include "base/component_export.h" +#include "base/memory/scoped_refptr.h" +#include "net/http/http_response_headers.h" +#include "net/http/http_response_info.h" +#include "services/network/public/mojom/cors.mojom-shared.h" + +namespace network { + +namespace cors { + +// Stores performance monitoring information for CORS preflight requests that +// are made in the NetworkService. Will be used to carry information from the +// NetworkService to call sites via URLLoaderCompletionStatus. +struct COMPONENT_EXPORT(NETWORK_CPP_BASE) PreflightTimingInfo { + PreflightTimingInfo(); + PreflightTimingInfo(const PreflightTimingInfo& info); + ~PreflightTimingInfo(); + + base::TimeTicks start_time; + base::TimeTicks finish_time; + std::string alpn_negotiated_protocol; + net::HttpResponseInfo::ConnectionInfo connection_info = + net::HttpResponseInfo::CONNECTION_INFO_UNKNOWN; + std::string timing_allow_origin; + uint64_t transfer_size = 0; + + bool operator==(const PreflightTimingInfo& rhs) const; +}; + +} // namespace cors + +} // namespace network + +#endif // SERVICES_NETWORK_PUBLIC_CPP_CORS_PREFLIGHT_TIMING_INFO_H_ diff --git a/chromium/services/network/public/cpp/features.cc b/chromium/services/network/public/cpp/features.cc index 0f22e31b4af..bd224a18f6d 100644 --- a/chromium/services/network/public/cpp/features.cc +++ b/chromium/services/network/public/cpp/features.cc @@ -13,7 +13,7 @@ const base::Feature kExpectCTReporting{"ExpectCTReporting", base::FEATURE_ENABLED_BY_DEFAULT}; const base::Feature kNetworkErrorLogging{"NetworkErrorLogging", - base::FEATURE_DISABLED_BY_DEFAULT}; + base::FEATURE_ENABLED_BY_DEFAULT}; // Enables the network service. const base::Feature kNetworkService{"NetworkService", base::FEATURE_DISABLED_BY_DEFAULT}; @@ -22,7 +22,7 @@ const base::Feature kNetworkService{"NetworkService", const base::Feature kOutOfBlinkCORS{"OutOfBlinkCORS", base::FEATURE_DISABLED_BY_DEFAULT}; -const base::Feature kReporting{"Reporting", base::FEATURE_DISABLED_BY_DEFAULT}; +const base::Feature kReporting{"Reporting", base::FEATURE_ENABLED_BY_DEFAULT}; // Based on the field trial parameters, this feature will override the value of // the maximum number of delayable requests allowed in flight. The number of @@ -40,7 +40,7 @@ const base::Feature kThrottleDelayable{"ThrottleDelayable", // ResourceScheduler just as HTTP/1.1 resources are. However, requests from such // servers are not subject to kMaxNumDelayableRequestsPerHostPerClient limit. const base::Feature kDelayRequestsOnMultiplexedConnections{ - "DelayRequestsOnMultiplexedConnections", base::FEATURE_DISABLED_BY_DEFAULT}; + "DelayRequestsOnMultiplexedConnections", base::FEATURE_ENABLED_BY_DEFAULT}; } // namespace features } // namespace network diff --git a/chromium/services/network/public/cpp/host_resolver.typemap b/chromium/services/network/public/cpp/host_resolver.typemap index ea65cc5ffc1..b71768ade43 100644 --- a/chromium/services/network/public/cpp/host_resolver.typemap +++ b/chromium/services/network/public/cpp/host_resolver.typemap @@ -12,4 +12,7 @@ sources = [ public_deps = [ "//net", ] -type_mappings = [ "network.mojom.ResolveHostParameters.DnsQueryType=net::HostResolver::DnsQueryType" ] +type_mappings = [ + "network.mojom.ResolveHostParameters.DnsQueryType=net::HostResolver::DnsQueryType", + "network.mojom.ResolveHostParameters.Source=net::HostResolverSource", +] diff --git a/chromium/services/network/public/cpp/host_resolver_mojom_traits.cc b/chromium/services/network/public/cpp/host_resolver_mojom_traits.cc index d8085bc7ce8..0364c3e0d83 100644 --- a/chromium/services/network/public/cpp/host_resolver_mojom_traits.cc +++ b/chromium/services/network/public/cpp/host_resolver_mojom_traits.cc @@ -41,4 +41,40 @@ bool EnumTraits<ResolveHostParameters::DnsQueryType, } } +// static +ResolveHostParameters::Source +EnumTraits<ResolveHostParameters::Source, net::HostResolverSource>::ToMojom( + net::HostResolverSource input) { + switch (input) { + case net::HostResolverSource::ANY: + return ResolveHostParameters::Source::ANY; + case net::HostResolverSource::SYSTEM: + return ResolveHostParameters::Source::SYSTEM; + case net::HostResolverSource::DNS: + return ResolveHostParameters::Source::DNS; + case net::HostResolverSource::MULTICAST_DNS: + return ResolveHostParameters::Source::MULTICAST_DNS; + } +} + +// static +bool EnumTraits<ResolveHostParameters::Source, net::HostResolverSource>:: + FromMojom(ResolveHostParameters::Source input, + net::HostResolverSource* output) { + switch (input) { + case ResolveHostParameters::Source::ANY: + *output = net::HostResolverSource::ANY; + return true; + case ResolveHostParameters::Source::SYSTEM: + *output = net::HostResolverSource::SYSTEM; + return true; + case ResolveHostParameters::Source::DNS: + *output = net::HostResolverSource::DNS; + return true; + case ResolveHostParameters::Source::MULTICAST_DNS: + *output = net::HostResolverSource::MULTICAST_DNS; + return true; + } +} + } // namespace mojo diff --git a/chromium/services/network/public/cpp/host_resolver_mojom_traits.h b/chromium/services/network/public/cpp/host_resolver_mojom_traits.h index c0daabda476..fab0ed1eae0 100644 --- a/chromium/services/network/public/cpp/host_resolver_mojom_traits.h +++ b/chromium/services/network/public/cpp/host_resolver_mojom_traits.h @@ -21,6 +21,15 @@ struct EnumTraits<network::mojom::ResolveHostParameters::DnsQueryType, net::HostResolver::DnsQueryType* output); }; +template <> +struct EnumTraits<network::mojom::ResolveHostParameters::Source, + net::HostResolverSource> { + static network::mojom::ResolveHostParameters::Source ToMojom( + net::HostResolverSource input); + static bool FromMojom(network::mojom::ResolveHostParameters::Source input, + net::HostResolverSource* output); +}; + } // namespace mojo #endif // SERVICES_NETWORK_PUBLIC_CPP_HOST_RESOLVER_MOJOM_TRAITS_H_ diff --git a/chromium/services/network/public/cpp/net_ipc_param_traits.cc b/chromium/services/network/public/cpp/net_ipc_param_traits.cc index 447febe3bf7..445809f000a 100644 --- a/chromium/services/network/public/cpp/net_ipc_param_traits.cc +++ b/chromium/services/network/public/cpp/net_ipc_param_traits.cc @@ -222,6 +222,47 @@ void ParamTraits<scoped_refptr<net::HttpResponseHeaders>>::Log( l->append("<HttpResponseHeaders>"); } +void ParamTraits<net::ProxyServer>::Write(base::Pickle* m, + const param_type& p) { + net::ProxyServer::Scheme scheme = p.scheme(); + WriteParam(m, scheme); + // When scheme is either 'direct' or 'invalid' |host_port_pair| + // should not be called, as per the method implementation body. + if (scheme != net::ProxyServer::SCHEME_DIRECT && + scheme != net::ProxyServer::SCHEME_INVALID) { + WriteParam(m, p.host_port_pair()); + } + WriteParam(m, p.is_trusted_proxy()); +} + +bool ParamTraits<net::ProxyServer>::Read(const base::Pickle* m, + base::PickleIterator* iter, + param_type* r) { + net::ProxyServer::Scheme scheme; + bool is_trusted_proxy = false; + if (!ReadParam(m, iter, &scheme)) + return false; + + // When scheme is either 'direct' or 'invalid' |host_port_pair| + // should not be called, as per the method implementation body. + net::HostPortPair host_port_pair; + if (scheme != net::ProxyServer::SCHEME_DIRECT && + scheme != net::ProxyServer::SCHEME_INVALID && + !ReadParam(m, iter, &host_port_pair)) { + return false; + } + + if (!ReadParam(m, iter, &is_trusted_proxy)) + return false; + + *r = net::ProxyServer(scheme, host_port_pair, is_trusted_proxy); + return true; +} + +void ParamTraits<net::ProxyServer>::Log(const param_type& p, std::string* l) { + l->append("<ProxyServer>"); +} + void ParamTraits<net::OCSPVerifyResult>::Write(base::Pickle* m, const param_type& p) { WriteParam(m, p.response_status); @@ -290,8 +331,6 @@ void ParamTraits<net::SSLInfo>::Write(base::Pickle* m, const param_type& p) { WriteParam(m, p.pkp_bypassed); WriteParam(m, p.client_cert_sent); WriteParam(m, p.channel_id_sent); - WriteParam(m, p.token_binding_negotiated); - WriteParam(m, p.token_binding_key_param); WriteParam(m, p.handshake_type); WriteParam(m, p.public_key_hashes); WriteParam(m, p.pinning_failure_log); @@ -319,8 +358,6 @@ bool ParamTraits<net::SSLInfo>::Read(const base::Pickle* m, ReadParam(m, iter, &r->pkp_bypassed) && ReadParam(m, iter, &r->client_cert_sent) && ReadParam(m, iter, &r->channel_id_sent) && - ReadParam(m, iter, &r->token_binding_negotiated) && - ReadParam(m, iter, &r->token_binding_key_param) && ReadParam(m, iter, &r->handshake_type) && ReadParam(m, iter, &r->public_key_hashes) && ReadParam(m, iter, &r->pinning_failure_log) && @@ -485,35 +522,34 @@ void ParamTraits<net::LoadTimingInfo>::Log(const param_type& p, } void ParamTraits<url::Origin>::Write(base::Pickle* m, const url::Origin& p) { - WriteParam(m, p.unique()); - WriteParam(m, p.scheme()); - WriteParam(m, p.host()); - WriteParam(m, p.port()); + WriteParam(m, p.GetTupleOrPrecursorTupleIfOpaque().scheme()); + WriteParam(m, p.GetTupleOrPrecursorTupleIfOpaque().host()); + WriteParam(m, p.GetTupleOrPrecursorTupleIfOpaque().port()); + WriteParam(m, p.GetNonceForSerialization()); } bool ParamTraits<url::Origin>::Read(const base::Pickle* m, base::PickleIterator* iter, url::Origin* p) { - bool unique; std::string scheme; std::string host; uint16_t port; - if (!ReadParam(m, iter, &unique) || !ReadParam(m, iter, &scheme) || - !ReadParam(m, iter, &host) || !ReadParam(m, iter, &port)) { - *p = url::Origin(); + base::Optional<base::UnguessableToken> nonce_if_opaque; + if (!ReadParam(m, iter, &scheme) || !ReadParam(m, iter, &host) || + !ReadParam(m, iter, &port) || !ReadParam(m, iter, &nonce_if_opaque)) { return false; } - *p = unique ? url::Origin() - : url::Origin::UnsafelyCreateOriginWithoutNormalization( - scheme, host, port); - - // If a unique origin was created, but the unique flag wasn't set, then - // the values provided to 'UnsafelyCreateOriginWithoutNormalization' were - // invalid; kill the renderer. - if (!unique && p->unique()) + base::Optional<url::Origin> creation_result = + nonce_if_opaque + ? url::Origin::UnsafelyCreateOpaqueOriginWithoutNormalization( + scheme, host, port, url::Origin::Nonce(*nonce_if_opaque)) + : url::Origin::UnsafelyCreateTupleOriginWithoutNormalization( + scheme, host, port); + if (!creation_result) return false; + *p = std::move(creation_result.value()); return true; } diff --git a/chromium/services/network/public/cpp/net_ipc_param_traits.h b/chromium/services/network/public/cpp/net_ipc_param_traits.h index 3beeb69077d..9ae1082fd76 100644 --- a/chromium/services/network/public/cpp/net_ipc_param_traits.h +++ b/chromium/services/network/public/cpp/net_ipc_param_traits.h @@ -13,6 +13,7 @@ #include "ipc/param_traits_macros.h" #include "net/base/auth.h" #include "net/base/host_port_pair.h" +#include "net/base/proxy_server.h" #include "net/base/request_priority.h" #include "net/cert/cert_verify_result.h" #include "net/cert/ct_policy_status.h" @@ -113,6 +114,16 @@ struct COMPONENT_EXPORT(NETWORK_CPP_BASE) ParamTraits<net::HttpRequestHeaders> { }; template <> +struct COMPONENT_EXPORT(NETWORK_CPP_BASE) ParamTraits<net::ProxyServer> { + typedef net::ProxyServer param_type; + static void Write(base::Pickle* m, const param_type& p); + static bool Read(const base::Pickle* m, + base::PickleIterator* iter, + param_type* r); + static void Log(const param_type& p, std::string* l); +}; + +template <> struct COMPONENT_EXPORT(NETWORK_CPP_BASE) ParamTraits<net::OCSPVerifyResult> { typedef net::OCSPVerifyResult param_type; static void Write(base::Pickle* m, const param_type& p); @@ -203,6 +214,9 @@ struct COMPONENT_EXPORT(NETWORK_CPP_BASE) ParamTraits<url::Origin> { IPC_ENUM_TRAITS_MAX_VALUE( net::ct::CTPolicyCompliance, net::ct::CTPolicyCompliance::CT_POLICY_COMPLIANCE_DETAILS_NOT_AVAILABLE) + +IPC_ENUM_TRAITS(net::ProxyServer::Scheme); // BitMask. + IPC_ENUM_TRAITS_MAX_VALUE(net::OCSPVerifyResult::ResponseStatus, net::OCSPVerifyResult::PARSE_RESPONSE_DATA_ERROR) IPC_ENUM_TRAITS_MAX_VALUE(net::OCSPRevocationStatus, @@ -216,7 +230,6 @@ IPC_ENUM_TRAITS_MAX_VALUE(net::SSLClientCertType, IPC_ENUM_TRAITS_MAX_VALUE(net::SSLInfo::HandshakeType, net::SSLInfo::HANDSHAKE_FULL) -IPC_ENUM_TRAITS_MAX_VALUE(net::TokenBindingParam, net::TB_PARAM_ECDSAP256) IPC_ENUM_TRAITS_MAX_VALUE(net::URLRequest::ReferrerPolicy, net::URLRequest::MAX_REFERRER_POLICY - 1) @@ -243,7 +256,6 @@ IPC_STRUCT_TRAITS_BEGIN(net::RedirectInfo) IPC_STRUCT_TRAITS_MEMBER(new_referrer) IPC_STRUCT_TRAITS_MEMBER(insecure_scheme_was_upgraded) IPC_STRUCT_TRAITS_MEMBER(new_referrer_policy) - IPC_STRUCT_TRAITS_MEMBER(referred_token_binding_host) IPC_STRUCT_TRAITS_END() IPC_ENUM_TRAITS_MAX_VALUE(net::HttpResponseInfo::ConnectionInfo, diff --git a/chromium/services/network/public/cpp/network_connection_tracker.cc b/chromium/services/network/public/cpp/network_connection_tracker.cc index b80b7682739..8a3fce43fcc 100644 --- a/chromium/services/network/public/cpp/network_connection_tracker.cc +++ b/chromium/services/network/public/cpp/network_connection_tracker.cc @@ -84,6 +84,16 @@ bool NetworkConnectionTracker::GetConnectionType( return false; } +bool NetworkConnectionTracker::IsOffline() { + base::subtle::Atomic32 type_value = + base::subtle::NoBarrier_Load(&connection_type_); + if (type_value != kConnectionTypeInvalid) { + auto type = static_cast<network::mojom::ConnectionType>(type_value); + return type == network::mojom::ConnectionType::CONNECTION_NONE; + } + return true; +} + // static bool NetworkConnectionTracker::IsConnectionCellular( network::mojom::ConnectionType type) { diff --git a/chromium/services/network/public/cpp/network_connection_tracker.h b/chromium/services/network/public/cpp/network_connection_tracker.h index cfd66c0b59d..a370627fc63 100644 --- a/chromium/services/network/public/cpp/network_connection_tracker.h +++ b/chromium/services/network/public/cpp/network_connection_tracker.h @@ -21,6 +21,12 @@ namespace network { +// Defines the type of a callback that will return a NetworkConnectionTracker +// instance. +class NetworkConnectionTracker; +using NetworkConnectionTrackerGetter = + base::RepeatingCallback<NetworkConnectionTracker*()>; + // This class subscribes to network change events from // network::mojom::NetworkChangeManager and propogates these notifications to // its NetworkConnectionObservers registered through @@ -60,6 +66,9 @@ class COMPONENT_EXPORT(NETWORK_CPP) NetworkConnectionTracker virtual bool GetConnectionType(network::mojom::ConnectionType* type, ConnectionTypeCallback callback); + // Returns true if the network is currently in an offline or unknown state. + bool IsOffline(); + // Returns true if |type| is a cellular connection. // Returns false if |type| is CONNECTION_UNKNOWN, and thus, depending on the // implementation of GetConnectionType(), it is possible that diff --git a/chromium/services/network/public/cpp/network_ipc_param_traits.h b/chromium/services/network/public/cpp/network_ipc_param_traits.h index 044480644e9..b1f0d480e6c 100644 --- a/chromium/services/network/public/cpp/network_ipc_param_traits.h +++ b/chromium/services/network/public/cpp/network_ipc_param_traits.h @@ -13,6 +13,7 @@ #include "ipc/param_traits_macros.h" #include "net/base/auth.h" #include "net/base/host_port_pair.h" +#include "net/base/proxy_server.h" #include "net/base/request_priority.h" #include "net/cert/cert_verify_result.h" #include "net/cert/ct_policy_status.h" @@ -111,11 +112,21 @@ IPC_STRUCT_TRAITS_BEGIN(network::CORSErrorStatus) IPC_STRUCT_TRAITS_MEMBER(failed_parameter) IPC_STRUCT_TRAITS_END() +IPC_STRUCT_TRAITS_BEGIN(network::cors::PreflightTimingInfo) + IPC_STRUCT_TRAITS_MEMBER(start_time) + IPC_STRUCT_TRAITS_MEMBER(finish_time) + IPC_STRUCT_TRAITS_MEMBER(alpn_negotiated_protocol) + IPC_STRUCT_TRAITS_MEMBER(connection_info) + IPC_STRUCT_TRAITS_MEMBER(timing_allow_origin) + IPC_STRUCT_TRAITS_MEMBER(transfer_size) +IPC_STRUCT_TRAITS_END() + IPC_STRUCT_TRAITS_BEGIN(network::URLLoaderCompletionStatus) IPC_STRUCT_TRAITS_MEMBER(error_code) IPC_STRUCT_TRAITS_MEMBER(extended_error_code) IPC_STRUCT_TRAITS_MEMBER(exists_in_cache) IPC_STRUCT_TRAITS_MEMBER(completion_time) + IPC_STRUCT_TRAITS_MEMBER(cors_preflight_timing_info) IPC_STRUCT_TRAITS_MEMBER(encoded_data_length) IPC_STRUCT_TRAITS_MEMBER(encoded_body_length) IPC_STRUCT_TRAITS_MEMBER(decoded_body_length) @@ -135,6 +146,7 @@ IPC_STRUCT_TRAITS_BEGIN(network::ResourceRequest) IPC_STRUCT_TRAITS_MEMBER(referrer_policy) IPC_STRUCT_TRAITS_MEMBER(is_prerendering) IPC_STRUCT_TRAITS_MEMBER(headers) + IPC_STRUCT_TRAITS_MEMBER(requested_with) IPC_STRUCT_TRAITS_MEMBER(load_flags) IPC_STRUCT_TRAITS_MEMBER(allow_credentials) IPC_STRUCT_TRAITS_MEMBER(plugin_child_id) @@ -166,7 +178,11 @@ IPC_STRUCT_TRAITS_BEGIN(network::ResourceRequest) IPC_STRUCT_TRAITS_MEMBER(previews_state) IPC_STRUCT_TRAITS_MEMBER(initiated_in_secure_context) IPC_STRUCT_TRAITS_MEMBER(upgrade_if_insecure) + IPC_STRUCT_TRAITS_MEMBER(is_revalidating) IPC_STRUCT_TRAITS_MEMBER(throttling_profile_id) + IPC_STRUCT_TRAITS_MEMBER(custom_proxy_pre_cache_headers) + IPC_STRUCT_TRAITS_MEMBER(custom_proxy_post_cache_headers) + IPC_STRUCT_TRAITS_MEMBER(fetch_window_id) IPC_STRUCT_TRAITS_END() IPC_STRUCT_TRAITS_BEGIN(network::ResourceResponseInfo) @@ -192,6 +208,7 @@ IPC_STRUCT_TRAITS_BEGIN(network::ResourceResponseInfo) IPC_STRUCT_TRAITS_MEMBER(alpn_negotiated_protocol) IPC_STRUCT_TRAITS_MEMBER(socket_address) IPC_STRUCT_TRAITS_MEMBER(was_fetched_via_cache) + IPC_STRUCT_TRAITS_MEMBER(proxy_server) IPC_STRUCT_TRAITS_MEMBER(was_fetched_via_service_worker) IPC_STRUCT_TRAITS_MEMBER(was_fallback_required_by_service_worker) IPC_STRUCT_TRAITS_MEMBER(url_list_via_service_worker) @@ -207,6 +224,7 @@ IPC_STRUCT_TRAITS_BEGIN(network::ResourceResponseInfo) IPC_STRUCT_TRAITS_MEMBER(cors_exposed_header_names) IPC_STRUCT_TRAITS_MEMBER(async_revalidation_requested) IPC_STRUCT_TRAITS_MEMBER(did_mime_sniff) + IPC_STRUCT_TRAITS_MEMBER(is_signed_exchange_inner_response) IPC_STRUCT_TRAITS_END() IPC_ENUM_TRAITS_MAX_VALUE(network::mojom::FetchResponseType, diff --git a/chromium/services/network/public/cpp/network_param.typemap b/chromium/services/network/public/cpp/network_param.typemap index 5a4f5dbb473..e8a71b2f126 100644 --- a/chromium/services/network/public/cpp/network_param.typemap +++ b/chromium/services/network/public/cpp/network_param.typemap @@ -26,6 +26,7 @@ sources = [ deps = [ "//ipc", + "//services/network/public/cpp:cpp_base", ] public_deps = [ diff --git a/chromium/services/network/public/cpp/network_param_android.typemap b/chromium/services/network/public/cpp/network_param_android.typemap new file mode 100644 index 00000000000..6ccef5ba713 --- /dev/null +++ b/chromium/services/network/public/cpp/network_param_android.typemap @@ -0,0 +1,15 @@ +# Copyright 2018 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +mojom = "//services/network/public/mojom/network_param.mojom" +os_whitelist = [ "android" ] +public_headers = [ "//base/android/application_status_listener.h" ] +traits_headers = + [ "//services/network/public/cpp/network_param_mojom_traits.h" ] + +deps = [ + "//base", +] +type_mappings = + [ "network.mojom.ApplicationState=base::android::ApplicationState" ] diff --git a/chromium/services/network/public/cpp/network_param_mojom_traits.cc b/chromium/services/network/public/cpp/network_param_mojom_traits.cc index 3ac8d893279..7b51f550e09 100644 --- a/chromium/services/network/public/cpp/network_param_mojom_traits.cc +++ b/chromium/services/network/public/cpp/network_param_mojom_traits.cc @@ -13,4 +13,53 @@ bool StructTraits<network::mojom::HttpVersionDataView, net::HttpVersion>::Read( return true; } +#if defined(OS_ANDROID) +network::mojom::ApplicationState +EnumTraits<network::mojom::ApplicationState, base::android::ApplicationState>:: + ToMojom(base::android::ApplicationState input) { + switch (input) { + case base::android::APPLICATION_STATE_UNKNOWN: + return network::mojom::ApplicationState::UNKNOWN; + case base::android::APPLICATION_STATE_HAS_RUNNING_ACTIVITIES: + return network::mojom::ApplicationState::HAS_RUNNING_ACTIVITIES; + case base::android::APPLICATION_STATE_HAS_PAUSED_ACTIVITIES: + return network::mojom::ApplicationState::HAS_PAUSED_ACTIVITIES; + case base::android::APPLICATION_STATE_HAS_STOPPED_ACTIVITIES: + return network::mojom::ApplicationState::HAS_STOPPED_ACTIVITIES; + case base::android::APPLICATION_STATE_HAS_DESTROYED_ACTIVITIES: + return network::mojom::ApplicationState::HAS_DESTROYED_ACTIVITIES; + } + NOTREACHED(); + return static_cast<network::mojom::ApplicationState>(input); +} + +bool EnumTraits<network::mojom::ApplicationState, + base::android::ApplicationState>:: + FromMojom(network::mojom::ApplicationState input, + base::android::ApplicationState* output) { + switch (input) { + case network::mojom::ApplicationState::UNKNOWN: + *output = base::android::ApplicationState::APPLICATION_STATE_UNKNOWN; + return true; + case network::mojom::ApplicationState::HAS_RUNNING_ACTIVITIES: + *output = base::android::ApplicationState:: + APPLICATION_STATE_HAS_RUNNING_ACTIVITIES; + return true; + case network::mojom::ApplicationState::HAS_PAUSED_ACTIVITIES: + *output = base::android::ApplicationState:: + APPLICATION_STATE_HAS_PAUSED_ACTIVITIES; + return true; + case network::mojom::ApplicationState::HAS_STOPPED_ACTIVITIES: + *output = base::android::ApplicationState:: + APPLICATION_STATE_HAS_STOPPED_ACTIVITIES; + return true; + case network::mojom::ApplicationState::HAS_DESTROYED_ACTIVITIES: + *output = base::android::ApplicationState:: + APPLICATION_STATE_HAS_DESTROYED_ACTIVITIES; + return true; + } + return false; +} +#endif + } // namespace mojo diff --git a/chromium/services/network/public/cpp/network_param_mojom_traits.h b/chromium/services/network/public/cpp/network_param_mojom_traits.h index 81e3d8223a5..2840fe9743a 100644 --- a/chromium/services/network/public/cpp/network_param_mojom_traits.h +++ b/chromium/services/network/public/cpp/network_param_mojom_traits.h @@ -5,10 +5,15 @@ #ifndef SERVICES_NETWORK_PUBLIC_CPP_NETWORK_PARAM_MOJOM_TRAITS_H_ #define SERVICES_NETWORK_PUBLIC_CPP_NETWORK_PARAM_MOJOM_TRAITS_H_ +#include "build/build_config.h" #include "mojo/public/cpp/bindings/struct_traits.h" #include "net/http/http_version.h" #include "services/network/public/mojom/network_param.mojom.h" +#if defined(OS_ANDROID) +#include "base/android/application_status_listener.h" +#endif + namespace mojo { template <> @@ -25,6 +30,17 @@ class StructTraits<network::mojom::HttpVersionDataView, net::HttpVersion> { net::HttpVersion* out); }; +#if defined(OS_ANDROID) +template <> +struct EnumTraits<network::mojom::ApplicationState, + base::android::ApplicationState> { + static network::mojom::ApplicationState ToMojom( + base::android::ApplicationState input); + static bool FromMojom(network::mojom::ApplicationState input, + base::android::ApplicationState* output); +}; +#endif + } // namespace mojo #endif // SERVICES_NETWORK_PUBLIC_CPP_NETWORK_PARAM_MOJOM_TRAITS_H_ diff --git a/chromium/services/network/public/cpp/network_quality_tracker.cc b/chromium/services/network/public/cpp/network_quality_tracker.cc index ef6d27f479d..77a08c9a75a 100644 --- a/chromium/services/network/public/cpp/network_quality_tracker.cc +++ b/chromium/services/network/public/cpp/network_quality_tracker.cc @@ -16,6 +16,7 @@ NetworkQualityTracker::NetworkQualityTracker( : get_network_service_callback_(callback), effective_connection_type_(net::EFFECTIVE_CONNECTION_TYPE_UNKNOWN), downlink_bandwidth_kbps_(std::numeric_limits<int32_t>::max()), + network_quality_overridden_for_testing_(false), binding_(this) { InitializeMojoChannel(); DCHECK(binding_.is_bound()); @@ -76,9 +77,12 @@ void NetworkQualityTracker::ReportEffectiveConnectionTypeForTesting( net::EffectiveConnectionType effective_connection_type) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + network_quality_overridden_for_testing_ = true; + effective_connection_type_ = effective_connection_type; - for (auto& observer : effective_connection_type_observer_list_) + for (auto& observer : effective_connection_type_observer_list_) { observer.OnEffectiveConnectionTypeChanged(effective_connection_type); + } } void NetworkQualityTracker::ReportRTTsAndThroughputForTesting( @@ -86,6 +90,8 @@ void NetworkQualityTracker::ReportRTTsAndThroughputForTesting( int32_t downstream_throughput_kbps) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + network_quality_overridden_for_testing_ = true; + http_rtt_ = http_rtt; downlink_bandwidth_kbps_ = downstream_throughput_kbps; @@ -108,6 +114,9 @@ void NetworkQualityTracker::OnNetworkQualityChanged( int32_t bandwidth_kbps) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + if (network_quality_overridden_for_testing_) + return; + // If the RTT values are unavailable, set them to value 0. if (http_rtt < base::TimeDelta()) http_rtt = base::TimeDelta(); diff --git a/chromium/services/network/public/cpp/network_quality_tracker.h b/chromium/services/network/public/cpp/network_quality_tracker.h index 9102487f4d3..092064c7ec7 100644 --- a/chromium/services/network/public/cpp/network_quality_tracker.h +++ b/chromium/services/network/public/cpp/network_quality_tracker.h @@ -128,12 +128,14 @@ class COMPONENT_EXPORT(NETWORK_CPP) NetworkQualityTracker // Changes effective connection type estimate to the provided value, and // reports |effective_connection_type| to all - // EffectiveConnectionTypeObservers. + // EffectiveConnectionTypeObservers. Calling this also disables all organic + // notifications sent to observers. void ReportEffectiveConnectionTypeForTesting( net::EffectiveConnectionType effective_connection_type); // Changes RTT and throughput estimate to the provided estimates, and - // reports it to all RTTAndThroughputEstimatesObservers. + // reports it to all RTTAndThroughputEstimatesObservers. Calling this also + // disables all organic notifications sent to observers. void ReportRTTsAndThroughputForTesting(base::TimeDelta http_rtt, int32_t downstream_throughput_kbps); @@ -167,6 +169,10 @@ class COMPONENT_EXPORT(NETWORK_CPP) NetworkQualityTracker base::TimeDelta transport_rtt_; int32_t downlink_bandwidth_kbps_; + // True if network quality has been overridden by tests. If set to true, it + // disables all organic notifications sent to observers. + bool network_quality_overridden_for_testing_; + base::ObserverList<EffectiveConnectionTypeObserver>::Unchecked effective_connection_type_observer_list_; diff --git a/chromium/services/network/public/cpp/network_switches.cc b/chromium/services/network/public/cpp/network_switches.cc index 7790a6fa6e7..a06bd76f3c3 100644 --- a/chromium/services/network/public/cpp/network_switches.cc +++ b/chromium/services/network/public/cpp/network_switches.cc @@ -34,6 +34,12 @@ const char kIgnoreCertificateErrorsSPKIList[] = // user data directory. const char kLogNetLog[] = "log-net-log"; +// Causes SSL key material to be logged to the specified file for debugging +// purposes. See +// https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format +// for the format. +const char kSSLKeyLogFile[] = "ssl-key-log-file"; + // Don't send HTTP-Referer headers. const char kNoReferrers[] = "no-referrers"; diff --git a/chromium/services/network/public/cpp/network_switches.h b/chromium/services/network/public/cpp/network_switches.h index 3829a2d4360..ec23a353fa2 100644 --- a/chromium/services/network/public/cpp/network_switches.h +++ b/chromium/services/network/public/cpp/network_switches.h @@ -17,6 +17,7 @@ extern const char kHostResolverRules[]; COMPONENT_EXPORT(NETWORK_CPP) extern const char kIgnoreCertificateErrorsSPKIList[]; COMPONENT_EXPORT(NETWORK_CPP) extern const char kLogNetLog[]; +COMPONENT_EXPORT(NETWORK_CPP) extern const char kSSLKeyLogFile[]; COMPONENT_EXPORT(NETWORK_CPP) extern const char kNoReferrers[]; } // namespace switches diff --git a/chromium/services/network/public/cpp/proxy_config.typemap b/chromium/services/network/public/cpp/proxy_config.typemap index 8f3f5043b5a..4dce13dd903 100644 --- a/chromium/services/network/public/cpp/proxy_config.typemap +++ b/chromium/services/network/public/cpp/proxy_config.typemap @@ -17,7 +17,7 @@ type_mappings = [ "network.mojom.ProxyBypassRules=net::ProxyBypassRules", "network.mojom.ProxyList=net::ProxyList", "network.mojom.ProxyRulesType=net::ProxyConfig::ProxyRules::Type", - "network.mojom.ProxyRule=net::ProxyConfig::ProxyRule", + "network.mojom.ProxyRules=net::ProxyConfig::ProxyRules", "network.mojom.ProxyConfig=net::ProxyConfig", "network.mojom.ProxyConfigWithAnnotation=net::ProxyConfigWithAnnotation", ] diff --git a/chromium/services/network/public/cpp/resource_request.h b/chromium/services/network/public/cpp/resource_request.h index 506d4b42bdb..8481b5d6faa 100644 --- a/chromium/services/network/public/cpp/resource_request.h +++ b/chromium/services/network/public/cpp/resource_request.h @@ -70,6 +70,13 @@ struct COMPONENT_EXPORT(NETWORK_CPP_BASE) ResourceRequest { // Additional HTTP request headers. net::HttpRequestHeaders headers; + // 'X-Requested-With' header value. Some consumers want to set this header, + // but such internal headers must be ignored by CORS checks (which run inside + // Network Service), so the value is stored here (rather than in |headers|) + // and later populated in the headers after CORS check. + // TODO(toyoshim): Remove it once PPAPI is deprecated. + std::string requested_with; + // net::URLRequest load flags (0 by default). int load_flags = 0; @@ -217,8 +224,38 @@ struct COMPONENT_EXPORT(NETWORK_CPP_BASE) ResourceRequest { // HTTPS due to an Upgrade-Insecure-Requests requirement. bool upgrade_if_insecure = false; + // True when the request is revalidating. + // Some users, notably blink, has its own cache. This flag is set to exempt + // some CORS logic for a revalidating request. + bool is_revalidating = false; + // The profile ID of network conditions to throttle the network request. base::Optional<base::UnguessableToken> throttling_profile_id; + + // Headers that will be added pre and post cache if the network context uses + // the custom proxy for this request. The custom proxy is used for requests + // that match the custom proxy config, and would otherwise be made direct. + net::HttpRequestHeaders custom_proxy_pre_cache_headers; + net::HttpRequestHeaders custom_proxy_post_cache_headers; + + // See https://fetch.spec.whatwg.org/#concept-request-window + // + // This is an opaque id of the original requestor of the resource, which might + // be different to the current requestor which is |render_frame_id|. For + // example, if a navigation for window "abc" is intercepted by a service + // worker, which re-issues the request via fetch, the re-issued request has + // |render_frame_id| of MSG_ROUTING_NONE (the service worker) and |window_id| + // of "abc". This is used for, e.g., client certificate selection. It's + // important that this id be unguessable so renderers cannot impersonate + // other renderers. + // + // This may be empty when the original requestor is the current requestor or + // is not a window. When it's empty, use |render_frame_id| instead. In + // practical terms, it's empty for requests that didn't go through a service + // worker, or if the original requestor is not a window. When the request + // goes through a service worker, the id is + // ServiceWorkerProviderHost::fetch_request_window_id. + base::Optional<base::UnguessableToken> fetch_window_id; }; } // namespace network diff --git a/chromium/services/network/public/cpp/resource_response.cc b/chromium/services/network/public/cpp/resource_response.cc index d00c077a134..d3c8d8ba7fa 100644 --- a/chromium/services/network/public/cpp/resource_response.cc +++ b/chromium/services/network/public/cpp/resource_response.cc @@ -61,6 +61,9 @@ scoped_refptr<ResourceResponse> ResourceResponse::DeepCopy() const { new_response->head.async_revalidation_requested = head.async_revalidation_requested; new_response->head.did_mime_sniff = head.did_mime_sniff; + new_response->head.is_signed_exchange_inner_response = + head.is_signed_exchange_inner_response; + new_response->head.intercepted_by_plugin = head.intercepted_by_plugin; return new_response; } diff --git a/chromium/services/network/public/cpp/resource_response_info.h b/chromium/services/network/public/cpp/resource_response_info.h index 5b12bf3ddcd..5a47c03bac7 100644 --- a/chromium/services/network/public/cpp/resource_response_info.h +++ b/chromium/services/network/public/cpp/resource_response_info.h @@ -15,6 +15,7 @@ #include "base/time/time.h" #include "net/base/host_port_pair.h" #include "net/base/load_timing_info.h" +#include "net/base/proxy_server.h" #include "net/cert/ct_policy_status.h" #include "net/cert/signed_certificate_timestamp_and_status.h" #include "net/http/http_response_headers.h" @@ -114,6 +115,9 @@ struct COMPONENT_EXPORT(NETWORK_CPP_BASE) ResourceResponseInfo { // True if the response was delivered through a proxy. bool was_fetched_via_proxy; + // The proxy server used for this request, if any. + net::ProxyServer proxy_server; + // True if the response was fetched by a ServiceWorker. bool was_fetched_via_service_worker; @@ -185,6 +189,12 @@ struct COMPONENT_EXPORT(NETWORK_CPP_BASE) ResourceResponseInfo { // mime sniffing anymore. bool did_mime_sniff; + // True if the response is an inner response of a signed exchange. + bool is_signed_exchange_inner_response = false; + + // True if the response was intercepted by a plugin. + bool intercepted_by_plugin = false; + // NOTE: When adding or changing fields here, also update // ResourceResponse::DeepCopy in resource_response.cc. }; diff --git a/chromium/services/network/public/cpp/server/http_server_unittest.cc b/chromium/services/network/public/cpp/server/http_server_unittest.cc index 273aa921eee..23c06a31d59 100644 --- a/chromium/services/network/public/cpp/server/http_server_unittest.cc +++ b/chromium/services/network/public/cpp/server/http_server_unittest.cc @@ -49,9 +49,10 @@ class TestHttpClient { base::RunLoop run_loop; int net_error = net::ERR_FAILED; factory_.CreateTCPConnectedSocket( - base::nullopt, /* local address */ - addresses, TRAFFIC_ANNOTATION_FOR_TESTS, mojo::MakeRequest(&socket_), - nullptr, /* observer */ + base::nullopt /* local address */, addresses, + nullptr /* tcp_connected_socket_options */, + TRAFFIC_ANNOTATION_FOR_TESTS, mojo::MakeRequest(&socket_), + nullptr /* observer */, base::BindOnce( [](base::RunLoop* run_loop, int* result_out, mojo::ScopedDataPipeConsumerHandle* receive_pipe_handle_out, diff --git a/chromium/services/network/public/cpp/typemaps.gni b/chromium/services/network/public/cpp/typemaps.gni index be61d9491b7..3bdbeed5221 100644 --- a/chromium/services/network/public/cpp/typemaps.gni +++ b/chromium/services/network/public/cpp/typemaps.gni @@ -11,6 +11,7 @@ typemaps = [ "//services/network/public/cpp/mutable_network_traffic_annotation_tag.typemap", "//services/network/public/cpp/mutable_partial_network_traffic_annotation_tag.typemap", "//services/network/public/cpp/network_param.typemap", + "//services/network/public/cpp/network_param_android.typemap", "//services/network/public/cpp/network_types.typemap", "//services/network/public/cpp/p2p.typemap", "//services/network/public/cpp/proxy_config.typemap", diff --git a/chromium/services/network/public/cpp/url_loader_completion_status.cc b/chromium/services/network/public/cpp/url_loader_completion_status.cc index 32004ff62ec..4a6d60dbe62 100644 --- a/chromium/services/network/public/cpp/url_loader_completion_status.cc +++ b/chromium/services/network/public/cpp/url_loader_completion_status.cc @@ -29,6 +29,7 @@ bool URLLoaderCompletionStatus::operator==( extended_error_code == rhs.extended_error_code && exists_in_cache == rhs.exists_in_cache && completion_time == rhs.completion_time && + cors_preflight_timing_info == rhs.cors_preflight_timing_info && encoded_data_length == rhs.encoded_data_length && encoded_body_length == rhs.encoded_body_length && decoded_body_length == rhs.decoded_body_length && diff --git a/chromium/services/network/public/cpp/url_loader_completion_status.h b/chromium/services/network/public/cpp/url_loader_completion_status.h index 3897f291bb3..6e243cb7f09 100644 --- a/chromium/services/network/public/cpp/url_loader_completion_status.h +++ b/chromium/services/network/public/cpp/url_loader_completion_status.h @@ -13,6 +13,7 @@ #include "base/time/time.h" #include "net/ssl/ssl_info.h" #include "services/network/public/cpp/cors/cors_error_status.h" +#include "services/network/public/cpp/cors/preflight_timing_info.h" #include "services/network/public/mojom/cors.mojom-shared.h" namespace network { @@ -48,6 +49,9 @@ struct COMPONENT_EXPORT(NETWORK_CPP_BASE) URLLoaderCompletionStatus { // Time the request completed. base::TimeTicks completion_time; + // Timing info if CORS preflights were made. + std::vector<cors::PreflightTimingInfo> cors_preflight_timing_info; + // Total amount of data received from the network. int64_t encoded_data_length = 0; diff --git a/chromium/services/network/public/mojom/BUILD.gn b/chromium/services/network/public/mojom/BUILD.gn index c6334f5217b..268ccc892eb 100644 --- a/chromium/services/network/public/mojom/BUILD.gn +++ b/chromium/services/network/public/mojom/BUILD.gn @@ -82,6 +82,7 @@ mojom("mojom") { "fetch_api.mojom", "host_resolver.mojom", "http_request_headers.mojom", + "net_log.mojom", "network_change_manager.mojom", "network_context.mojom", "network_quality_estimator_manager.mojom", diff --git a/chromium/services/network/public/mojom/cookie_manager.mojom b/chromium/services/network/public/mojom/cookie_manager.mojom index 95238f047af..d991ce80c6f 100644 --- a/chromium/services/network/public/mojom/cookie_manager.mojom +++ b/chromium/services/network/public/mojom/cookie_manager.mojom @@ -15,6 +15,15 @@ struct CookieManagerParams { // Content settings for cookies. array<content_settings.mojom.ContentSettingPatternSource> settings; + + // Schemes that unconditionally allow cookies from secure origins. + array<string> secure_origin_cookies_allowed_schemes; + + // Schemes that unconditionally allow cookies from the same scheme. + array<string> matching_scheme_cookies_allowed_schemes; + + // Schemes that unconditionally allow third party cookies. + array<string> third_party_cookies_allowed_schemes; }; enum CookiePriority { diff --git a/chromium/services/network/public/mojom/cors.mojom b/chromium/services/network/public/mojom/cors.mojom index 0e436e27826..5199351f84e 100644 --- a/chromium/services/network/public/mojom/cors.mojom +++ b/chromium/services/network/public/mojom/cors.mojom @@ -95,3 +95,13 @@ enum CORSError { // Cross origin redirect location contains credentials such as 'user:pass'. kRedirectContainsCredentials, }; + +// Determine which Cors exception takes precedence when multiple matches occur. +enum CORSOriginAccessMatchPriority { + kNoMatchingOrigin, + kDefaultPriority, + kLowPriority, + kMediumPriority, + kHighPriority, + kMaxPriority +}; diff --git a/chromium/services/network/public/mojom/cors_origin_pattern.mojom b/chromium/services/network/public/mojom/cors_origin_pattern.mojom index 182abc8d137..a00e97cffc8 100644 --- a/chromium/services/network/public/mojom/cors_origin_pattern.mojom +++ b/chromium/services/network/public/mojom/cors_origin_pattern.mojom @@ -4,6 +4,8 @@ module network.mojom; +import "services/network/public/mojom/cors.mojom"; + // Parameters for representing a access origin whitelist or blacklist for CORS. struct CorsOriginPattern { // The protocol part of the destination URL. @@ -14,4 +16,10 @@ struct CorsOriginPattern { // Whether subdomains match this protocol and host pattern. bool allow_subdomains; + + // Order of preference in which the pattern is applied. Higher priority + // patterns take precedence over lower ones. In the case were both a + // allow list and block list rule of the same priority match a request, + // the block list rule takes priority. + CORSOriginAccessMatchPriority priority; }; diff --git a/chromium/services/network/public/mojom/host_resolver.mojom b/chromium/services/network/public/mojom/host_resolver.mojom index 1bbd82f46a4..28dda6efe8b 100644 --- a/chromium/services/network/public/mojom/host_resolver.mojom +++ b/chromium/services/network/public/mojom/host_resolver.mojom @@ -33,7 +33,7 @@ interface ResolveHostClient { }; // Parameter-grouping struct for additional optional parameters for -// HostResolver::CreateRequest() calls. All fields are optional and have a +// HostResolver::ResolveHost() calls. All fields are optional and have a // reasonable default. struct ResolveHostParameters { // DNS query type for a ResolveHostRequest. @@ -52,6 +52,33 @@ struct ResolveHostParameters { // The initial net priority for the host resolution request. RequestPriority initial_priority = RequestPriority.kLowest; + // Enumeration to specify the allowed results source for requests. + enum Source { + // Resolver will pick an appropriate source. Results could come from DNS, + // MulticastDNS, HOSTS file, etc. + ANY, + + // Results will only be retrieved from the system or OS, e.g. via the + // getaddrinfo() system call. + SYSTEM, + + // Results will only come from DNS queries. + DNS, + + // Results will only come from Multicast DNS queries. + MULTICAST_DNS, + }; + + // The source to use for resolved addresses. Default allows the resolver to + // pick an appropriate source. Only affects use of big external sources (eg + // calling the system for resolution or using DNS). Even if a source is + // specified, results can still come from cache, resolving "localhost" or IP + // literals, etc. + Source source = Source.ANY; + + // If |false|, results will not come from the host cache. + bool allow_cached_response = true; + // If set, the outstanding request can be controlled, eg cancelled, via the // handle. ResolveHostHandle&? control_handle; diff --git a/chromium/services/network/public/mojom/net_log.mojom b/chromium/services/network/public/mojom/net_log.mojom new file mode 100644 index 00000000000..8ccafc6dbfd --- /dev/null +++ b/chromium/services/network/public/mojom/net_log.mojom @@ -0,0 +1,53 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +module network.mojom; + +import "mojo/public/mojom/base/file.mojom"; +import "mojo/public/mojom/base/values.mojom"; + +enum NetLogCaptureMode { + // Log all events, but do not include the actual transferred bytes, and + // remove cookies and HTTP credentials and HTTP/2 GOAWAY frame debug data. + DEFAULT, + + // Log all events, but do not include the actual transferred bytes as + // parameters for bytes sent/received events. + INCLUDE_COOKIES_AND_CREDENTIALS, + + // Log everything possible, even if it is slow and memory expensive. + // Includes logging of transferred bytes. + INCLUDE_SOCKET_BYTES +}; + +// Manages export of ongoing NetLog events to a file. +// Both Start and Stop must succeed, in that order, for the export to +// be complete and have a well-formed file. You may call Start again after +// Stop's callback has been invoked, but doing things like calling Start twice +// without intervening successful stops will result in an error. +interface NetLogExporter { + const uint64 kUnlimitedFileSize = 0xFFFFFFFFFFFFFFFF; + + // Starts logging to |destination|, including definitions of |extra_constants| + // in the log in addition to the standard constants required by the log. + // Contents in |destination| might not be complete until Stop() is called + // successfully. + // + // If |max_file_size| is kUnlimitedFileSize log size will not be limited. + // + // Returns network error code. + Start( + mojo_base.mojom.File destination, + mojo_base.mojom.DictionaryValue extra_constants, + NetLogCaptureMode capture_mode, + uint64 max_file_size) => (int32 net_error); + + // Finalizes the log file. If |polled_values| is provided, it will be + // included alongside net configuration info inside the 'polledData' field + // of the log object. + // + // Returns network error code; if successful this will occur only after + // the file has been fully written. + Stop(mojo_base.mojom.DictionaryValue polled_values) => (int32 net_error); +}; diff --git a/chromium/services/network/public/mojom/network_context.mojom b/chromium/services/network/public/mojom/network_context.mojom index d5707f963c0..f5e6b59d9a3 100644 --- a/chromium/services/network/public/mojom/network_context.mojom +++ b/chromium/services/network/public/mojom/network_context.mojom @@ -4,7 +4,6 @@ module network.mojom; -import "mojo/public/mojom/base/file.mojom"; import "mojo/public/mojom/base/file_path.mojom"; import "mojo/public/mojom/base/time.mojom"; import "mojo/public/mojom/base/unguessable_token.mojom"; @@ -12,9 +11,12 @@ import "mojo/public/mojom/base/values.mojom"; import "net/interfaces/address_list.mojom"; import "net/interfaces/ip_endpoint.mojom"; import "services/network/public/mojom/cookie_manager.mojom"; +import "services/network/public/mojom/cors_origin_pattern.mojom"; import "services/network/public/mojom/ct_log_info.mojom"; import "services/network/public/mojom/host_resolver.mojom"; +import "services/network/public/mojom/http_request_headers.mojom"; import "services/network/public/mojom/mutable_network_traffic_annotation_tag.mojom"; +import "services/network/public/mojom/net_log.mojom"; import "services/network/public/mojom/network_param.mojom"; import "services/network/public/mojom/p2p.mojom"; import "services/network/public/mojom/p2p_trusted.mojom"; @@ -33,6 +35,40 @@ import "services/proxy_resolver/public/mojom/proxy_resolver.mojom"; import "url/mojom/origin.mojom"; import "url/mojom/url.mojom"; +// Config for setting a custom proxy config that will be used if a request +// matches the proxy rules and would otherwise be direct. This config allows +// headers to be set on requests to the proxies from the config before and/or +// after the caching layer. Currently only supports proxying http requests. +struct CustomProxyConfig { + // The custom proxy rules to use. Right now this is limited to proxies for + // http requests. + ProxyRules rules; + + // The custom proxy can set these headers in this config which will be added + // to all requests using the proxy. This allows setting headers that may be + // privacy/security sensitive which we don't want to send to the renderer. + // Headers that require per-request logic can be added through the + // |custom_proxy_pre_cache_headers| and |custom_proxy_post_cache_headers| + // fields in ResourceRequest. + // + // Headers that will be set before the cache for http requests. If the request + // does not use a custom proxy, these headers will be removed before sending + // to the network. If a request already has one of these headers set, it may + // be overwritten if a custom proxy is used, or removed if a custom proxy is + // not used. + HttpRequestHeaders pre_cache_headers; + + // Headers that will be set after the cache for requests that are issued + // through a custom proxy. Headers here will overwrite matching headers on the + // request if a custom proxy is used. + HttpRequestHeaders post_cache_headers; +}; + +// Client to update the custom proxy config. +interface CustomProxyConfigClient { + OnCustomProxyConfigUpdated(CustomProxyConfig proxy_config); +}; + // Parameters for constructing a network context. struct NetworkContextParams { // Name used by memory tools to identify the context. @@ -140,6 +176,11 @@ struct NetworkContextParams { ProxyConfigWithAnnotation? initial_proxy_config; ProxyConfigClient&? proxy_config_client_request; + // If |custom_proxy_config_client_request| is set, this context will listen + // for updates to the custom proxy config, and use it if applicable for + // requests which would otherwise be made direct. + CustomProxyConfigClient&? custom_proxy_config_client_request; + // If |proxy_config_client_request| is non-null, this is called during // periods of network activity, and can be used as a signal for polling-based // logic to determine the proxy config. @@ -149,6 +190,10 @@ struct NetworkContextParams { // ProxyConfigServices be modified not to need this notification? ProxyConfigPollerClient? proxy_config_poller_client; + // Optional client that will be notified of errors related to the proxy + // settings. + ProxyErrorClient? proxy_error_client; + // When PAC quick checking is enabled, DNS lookups for PAC script's host are // timed out aggressively. This prevents hanging all network request on DNS // lookups that are slow or are blockholed, at the cost of making it more @@ -203,10 +248,12 @@ struct NetworkConditions { // response received. mojo_base.mojom.TimeDelta latency; - // Maximal aggregated download throughput (bytes/sec). 0 disables download throttling. + // Maximal aggregated download throughput (bytes/sec). 0 disables download + // throttling. double download_throughput; - // Maximal aggregated upload throughput (bytes/sec). 0 disables upload throttling. + // Maximal aggregated upload throughput (bytes/sec). 0 disables upload + // throttling. double upload_throughput; }; @@ -245,52 +292,6 @@ struct NetworkUsage { int64 total_bytes_sent; }; -// Manages export of ongoing NetLog events to a file. -// Both Start and Stop must succeed, in that order, for the export to -// be complete and have a well-formed file. You may call Start again after -// Stop's callback has been invoked, but doing things like calling Start twice -// without intervening successful stops will result in an error. -interface NetLogExporter { - - enum CaptureMode { - // Log all events, but do not include the actual transferred bytes, and - // remove cookies and HTTP credentials and HTTP/2 GOAWAY frame debug data. - DEFAULT, - - // Log all events, but do not include the actual transferred bytes as - // parameters for bytes sent/received events. - INCLUDE_COOKIES_AND_CREDENTIALS, - - // Log everything possible, even if it is slow and memory expensive. - // Includes logging of transferred bytes. - INCLUDE_SOCKET_BYTES - }; - - const uint64 kUnlimitedFileSize = 0xFFFFFFFFFFFFFFFF; - - // Starts logging to |destination|, including definitions of |extra_constants| - // in the log in addition to the standard constants required by the log. - // Contents in |destination| might not be complete until Stop() is called - // successfully. - // - // If |max_file_size| is kUnlimitedFileSize log size will not be limited. - // - // Returns network error code. - Start( - mojo_base.mojom.File destination, - mojo_base.mojom.DictionaryValue extra_constants, - CaptureMode capture_mode, - uint64 max_file_size) => (int32 net_error); - - // Finalizes the log file. If |polled_values| is provided, it will be - // included alongside net configuration info inside the 'polledData' field - // of the log object. - // - // Returns network error code; if successful this will occur only after - // the file has been fully written. - Stop(mojo_base.mojom.DictionaryValue polled_values) => (int32 net_error); -}; - const uint32 kBrowserProcessId = 0; const uint32 kInvalidProcessId = 0xffffffff; @@ -440,6 +441,10 @@ interface NetworkContext { // Closes all open connections within this context. CloseAllConnections() => (); + // Close all idle connections for the HTTP network session used by + // this context. + CloseIdleConnections() => (); + // Configures network conditions for the specified throttling profile. // The throttling will be applied only to requests that have matching // throttling_profile_id. @@ -462,6 +467,18 @@ interface NetworkContext { array<string> excluded_spkis, array<string> excluded_legacy_spkis); + // Adds explicitly-specified data as if it was processed from an Expect-CT + // header. + AddExpectCT(string host, mojo_base.mojom.Time expiry, + bool enforce, url.mojom.Url report_uri) => (bool success); + + // Send a test CT report with dummy data for test purposes. + SetExpectCTTestReport(url.mojom.Url report_uri) => (bool success); + + // Retrieves the expect CT state from the associated network context + // transport security state. + GetExpectCTState(string domain) => (mojo_base.mojom.DictionaryValue state); + // Creates a UDP socket. Caller can supply a |receiver| interface pointer // to listen for incoming datagrams. A null |receiver| is acceptable if caller // is not interested in incoming data. @@ -505,14 +522,32 @@ interface NetworkContext { CreateTCPConnectedSocket( net.interfaces.IPEndPoint? local_addr, net.interfaces.AddressList remote_addr_list, + TCPConnectedSocketOptions? tcp_connected_socket_options, MutableNetworkTrafficAnnotationTag traffic_annotation, TCPConnectedSocket& socket, SocketObserver? observer) - => (int32 result, - net.interfaces.IPEndPoint? local_addr, - net.interfaces.IPEndPoint? peer_addr, - handle<data_pipe_consumer>? receive_stream, - handle<data_pipe_producer>? send_stream); + => (int32 result, + net.interfaces.IPEndPoint? local_addr, + net.interfaces.IPEndPoint? peer_addr, + handle<data_pipe_consumer>? receive_stream, + handle<data_pipe_producer>? send_stream); + + // Creates a TCPSocket bound to |local_addr|. The socket created can only be + // used for the purpose specified in |traffic_annotation|, and cannot be + // re-used for other purposes. |local_addr| is treated the same as in + // CreateTCPServerSocket(). + // + // On success, the resulting local address will be written to |local_addr_out| + // and |result| is net::OK. On failure, |result| is a network error code. + // + // It's recommended consumers use CreateTCPServerSocket() or + // CreateTCPConnectedSocket(). This method is just provided so legacy + // consumers can mimic Berkeley sockets semantics. + CreateTCPBoundSocket(net.interfaces.IPEndPoint local_addr, + MutableNetworkTrafficAnnotationTag traffic_annotation, + TCPBoundSocket& socket) + => (int32 result, + net.interfaces.IPEndPoint? local_addr); // Creates a ProxyResolvingSocketFactory that shares some configuration params // with this NetworkContext, but uses separate socket pools. @@ -532,6 +567,12 @@ interface NetworkContext { LookUpProxyForURL(url.mojom.Url url, ProxyLookupClient proxy_lookup_client); + // Forces refetching the proxy configuration, and applying it. + ForceReloadProxyConfig() => (); + + // Clears the list of bad proxy servers that has been cached. + ClearBadProxiesCache() => (); + // Create a NetLogExporter, which helps export NetLog to an existing file. // Note that the log is generally global, including all NetworkContexts // managed by the same NetworkService. The particular NetworkContext this is @@ -606,12 +647,29 @@ interface NetworkContext { // accessed via HTTPS. IsHSTSActiveForHost(string host) => (bool result); - [Sync] + // Sets allowed and blocked origins respectively for the URLLoaderFactory + // consumers to access beyond the same-origin policy. The list is managed per + // each |source_origin|, and each call will flash old set lists for the + // |source_origin|. The passed |patterns| will be set instead. If an empty + // array is given for |allow_patterns| and/or |block_patterns|, the + // |source_origin|'s origin list for each is set to empty respectively. + SetCorsOriginAccessListsForOrigin( + url.mojom.Origin source_origin, array<CorsOriginPattern> allow_patterns, + array<CorsOriginPattern> block_patterns) => (); + // Adds explicitly-specified data as if it was processed from an - // HSTS header. - AddHSTSForTesting(string host, - mojo_base.mojom.Time expiry, - bool include_subdomains) => (); + // HSTS header. Used by tests and implementation of chrome://net-internals. + AddHSTS(string host, mojo_base.mojom.Time expiry, + bool include_subdomains) => (); + + // Retrieve values from the HSTS state from the associated contexts + // transport security state. + GetHSTSState(string domain) => (mojo_base.mojom.DictionaryValue state); + + // Deletes any dynamic data stored for |host| from the transport + // security state. Returns true iff an entry was deleted. + // See net::TransportSecurityState::DeleteDynamicDataForHost for more detail. + DeleteDynamicDataForHost(string host) => (bool result); [Sync] // Will force the transaction to fail with the given error code. diff --git a/chromium/services/network/public/mojom/network_param.mojom b/chromium/services/network/public/mojom/network_param.mojom index 41f31038dfe..e8b54d75eea 100644 --- a/chromium/services/network/public/mojom/network_param.mojom +++ b/chromium/services/network/public/mojom/network_param.mojom @@ -4,6 +4,16 @@ module network.mojom; +// Mirror of base::android::ApplicationState. +[EnableIf=is_android] +enum ApplicationState { + UNKNOWN, + HAS_RUNNING_ACTIVITIES, + HAS_PAUSED_ACTIVITIES, + HAS_STOPPED_ACTIVITIES, + HAS_DESTROYED_ACTIVITIES, +}; + [Native] struct AuthChallengeInfo; diff --git a/chromium/services/network/public/mojom/network_service.mojom b/chromium/services/network/public/mojom/network_service.mojom index 2dda78fe274..48be6a73dd4 100644 --- a/chromium/services/network/public/mojom/network_service.mojom +++ b/chromium/services/network/public/mojom/network_service.mojom @@ -8,8 +8,10 @@ import "mojo/public/mojom/base/file.mojom"; import "mojo/public/mojom/base/file_path.mojom"; import "mojo/public/mojom/base/read_only_buffer.mojom"; import "mojo/public/mojom/base/string16.mojom"; +import "mojo/public/mojom/base/unguessable_token.mojom"; import "mojo/public/mojom/base/values.mojom"; import "services/network/public/mojom/cookie_manager.mojom"; +import "services/network/public/mojom/net_log.mojom"; import "services/network/public/mojom/network_change_manager.mojom"; import "services/network/public/mojom/network_context.mojom"; import "services/network/public/mojom/network_param.mojom"; @@ -80,7 +82,12 @@ interface NetworkServiceClient { // called. // 2. The request is aborted, net::URLRequest::CancelWithError() needs to be // called. - OnCertificateRequested(uint32 process_id, + // + // |window_id| or else |process_id| and |routing_id| indicates + // the frame making the request, see + // network::ResourceRequest::fetch_window_id. + OnCertificateRequested(mojo_base.mojom.UnguessableToken? window_id, + uint32 process_id, uint32 routing_id, uint32 request_id, network.mojom.SSLCertRequestInfo cert_info) => ( @@ -222,8 +229,16 @@ interface NetworkService { // for more details). Most clients will just be adding a dictionary under // the key "clientInfo". StartNetLog(mojo_base.mojom.File file, + NetLogCaptureMode capture_mode, mojo_base.mojom.DictionaryValue constants); + // Starts logging SSL key material to the |file|. This must be called before + // any SSL connections are made. (See |SSLClientSocket::SetSSLKeyLogger()| + // for more details). + // TODO(crbug.com/841001) This should pass a File which has already been + // opened to be sandbox friendly. + SetSSLKeyLogFile(mojo_base.mojom.FilePath file); + // Creates a new network context with the given parameters. CreateNetworkContext(NetworkContext& context, NetworkContextParams params); @@ -316,4 +331,8 @@ interface NetworkService { // Reverts AddCorbExceptionForPlugin. RemoveCorbExceptionForPlugin(uint32 process_id); + + // Called on state changes of the Android application. + [EnableIf=is_android] + OnApplicationStateChange(ApplicationState state); }; diff --git a/chromium/services/network/public/mojom/proxy_config_with_annotation.mojom b/chromium/services/network/public/mojom/proxy_config_with_annotation.mojom index f88039a44e5..9528833b8da 100644 --- a/chromium/services/network/public/mojom/proxy_config_with_annotation.mojom +++ b/chromium/services/network/public/mojom/proxy_config_with_annotation.mojom @@ -22,4 +22,27 @@ interface ProxyConfigClient { // it might be a good time to double-check the proxy configuration. interface ProxyConfigPollerClient { OnLazyProxyConfigPoll(); -};
\ No newline at end of file +}; + +// Called to notify error related to the configured proxy settings. +interface ProxyErrorClient { + // Called when the PAC script being used by the NetworkContext throws a + // JavaScript error or fails to execute. This error is not necessarily + // fatal for URL loading, since by default errors in a PAC script + // result in a fallback to DIRECT connections. + OnPACScriptError(int32 line_number, string details); + + // This is a best effort notification that a URL request failed due to + // a problem with the proxy settings. |net_error| is the error code that the + // request failed with. + // + // This only surfaces failures for an entire URL load, and not from + // individual proxy servers. For instance if a PAC script returned 4 proxy + // servers, and sending the request through the first three failed before + // successfulyl sending through the fourth, this method is NOT called. + // + // There is some ambiguity with how errors are classified as being a + // "proxy error". The current implementation includes a mix of + // connection and protocol errors. + OnRequestMaybeFailedDueToProxySettings(int32 net_error); +}; diff --git a/chromium/services/network/public/mojom/proxy_resolving_socket.mojom b/chromium/services/network/public/mojom/proxy_resolving_socket.mojom index 59cd312bea8..11c93585c6b 100644 --- a/chromium/services/network/public/mojom/proxy_resolving_socket.mojom +++ b/chromium/services/network/public/mojom/proxy_resolving_socket.mojom @@ -6,15 +6,34 @@ module network.mojom; import "net/interfaces/ip_endpoint.mojom"; import "services/network/public/mojom/mutable_network_traffic_annotation_tag.mojom"; +import "services/network/public/mojom/network_param.mojom"; +import "services/network/public/mojom/ssl_config.mojom"; +import "services/network/public/mojom/tcp_socket.mojom"; +import "services/network/public/mojom/tls_socket.mojom"; import "url/mojom/url.mojom"; // Represents a connected socket that respects system's proxy settings. Writes // and Reads are through the data pipes supplied upon construction. Consumer // can close the socket by destroying the interface pointer. -interface ProxyResolvingSocket{ +interface ProxyResolvingSocket { // TODO(xunjieli): Add methods to configure the socket connection and allow // consumers to specify whether they want to disconnect or return the socket // to socket pools. + + // Upgrades a proxy socket to a TLS client socket. + // IMPORTANT: Caller needs close the previous send and receive pipes before + // this method can complete asynchronously. + // + // On success, |net_error| is net::OK. Caller is to use |send_stream| to send + // data and |receive_stream| to receive data over the connection. On failure, + // |result| is a network error code. + UpgradeToTLS(HostPortPair host_port_pair, + MutableNetworkTrafficAnnotationTag traffic_annotation, + TLSClientSocket& request, + SocketObserver? observer) + => (int32 net_error, + handle<data_pipe_consumer>? receive_stream, + handle<data_pipe_producer>? send_stream); }; // Factory interface for creating ProxyResolvingSocket. Each factory instance @@ -37,7 +56,8 @@ interface ProxyResolvingSocketFactory { // when the implementation of this factory goes away. CreateProxyResolvingSocket(url.mojom.Url url, bool use_tls, MutableNetworkTrafficAnnotationTag traffic_annotation, - ProxyResolvingSocket& socket) + ProxyResolvingSocket& socket, + SocketObserver? observer) => (int32 result, net.interfaces.IPEndPoint? local_addr, net.interfaces.IPEndPoint? peer_addr, diff --git a/chromium/services/network/public/mojom/ssl_config.mojom b/chromium/services/network/public/mojom/ssl_config.mojom index 05e98db3257..a8a4122fa13 100644 --- a/chromium/services/network/public/mojom/ssl_config.mojom +++ b/chromium/services/network/public/mojom/ssl_config.mojom @@ -14,7 +14,6 @@ enum SSLVersion { // Versions of TLS 1.3 that are supported. enum TLS13Variant { kDraft23, - kDraft28, kFinal, }; @@ -27,11 +26,12 @@ struct SSLConfig { bool sha1_local_anchors_enabled = false; bool symantec_enforcement_disabled = false; - // SSL 2.0 and 3.0 are not supported. + // SSL 2.0 and 3.0 are not supported. Note these lines must be kept in sync + // with net/ssl/ssl_config.cc. SSLVersion version_min = kTLS1; SSLVersion version_max = kTLS12; - TLS13Variant tls13_variant = kDraft23; + TLS13Variant tls13_variant = kFinal; // Though cipher suites are sent in TLS as "uint8_t CipherSuite[2]", in // big-endian form, they should be declared in host byte order, with the diff --git a/chromium/services/network/public/mojom/tcp_socket.mojom b/chromium/services/network/public/mojom/tcp_socket.mojom index 92131f04fec..cb9d63b1a07 100644 --- a/chromium/services/network/public/mojom/tcp_socket.mojom +++ b/chromium/services/network/public/mojom/tcp_socket.mojom @@ -4,12 +4,60 @@ module network.mojom; +import "net/interfaces/address_list.mojom"; import "net/interfaces/ip_endpoint.mojom"; import "services/network/public/mojom/ssl_config.mojom"; import "services/network/public/mojom/tls_socket.mojom"; import "services/network/public/mojom/network_param.mojom"; import "services/network/public/mojom/mutable_network_traffic_annotation_tag.mojom"; +struct TCPConnectedSocketOptions { + // Sets the OS send buffer size (in bytes) for the socket. This is the + // SO_SNDBUF socket option. If 0, the default size is used. The value will + // be clamped to a reasonable range. + int32 send_buffer_size = 0; + + // Sets the OS receive buffer size (in bytes) for the socket. This is the + // SO_RCVBUF socket option. If 0, the default size is used. The value will + // be clamped to a reasonable range. + int32 receive_buffer_size = 0; + + // This function enables/disables buffering in the kernel. By default, on + // Linux, TCP sockets will wait up to 200ms for more data to complete a packet + // before transmitting. The network service, however, overrides the default + // setting all socket, so the kernel will not wait unless this is set to + // false. See TCP_NODELAY in `man 7 tcp`. On Windows, the Nagle implementation + // is governed by RFC 896. + bool no_delay = true; +}; + +// Represents a bound TCP socket. Once a call succeeds, cannot be reused. +interface TCPBoundSocket { + // Starts listening on the socket. |net_error| is set to net::OK on success, + // or a network error code on failure. Works just like + // NetworkContext::CreateServerSocket, except it operates on an already bound + // socket. The TCPBoundSocket will be destroyed on completion, whether the + // call succeeds or not. + Listen(uint32 backlog, TCPServerSocket& socket) + => (int32 net_error); + + // Attempts to connect the socket to |remote_addr_list|. |net_error| is set to + // net::OK on success, or a network error code on failure. Works just like + // NetworkContext::CreateTCPConnectedSocket(), except it operates on an + // already bound socket. The TCPBoundSocket will be destroyed on completion, + // whether the call succeeds or not. + Connect( + net.interfaces.AddressList remote_addr_list, + TCPConnectedSocketOptions? tcp_connected_socket_options, + TCPConnectedSocket& socket, + SocketObserver? observer) + => (int32 net_error, + net.interfaces.IPEndPoint? local_addr, + net.interfaces.IPEndPoint? peer_addr, + handle<data_pipe_consumer>? receive_stream, + handle<data_pipe_producer>? send_stream); +}; + // Represents a connected TCP socket. Writes and Reads are through the data // pipes supplied upon construction. Consumer should use // SocketObserver interface to get notified about any error occurred @@ -24,7 +72,8 @@ interface TCPConnectedSocket { // On success, |net_error| is net::OK. Caller is to use |send_stream| to send // data and |receive_stream| to receive data over the connection. On failure, // |result| is a network error code. - // |ssl_info| is only returned if |options::skip_cert_verification| is true. + // |ssl_info| is only returned if |options::unsafely_skip_cert_verification| + // is true. UpgradeToTLS(HostPortPair host_port_pair, TLSClientSocketOptions? options, MutableNetworkTrafficAnnotationTag traffic_annotation, @@ -40,13 +89,21 @@ interface TCPConnectedSocket { // the platform. Consumers do not need to set these themselves unless they // want to change the default settings. - // This function enables/disables buffering in the kernel. By default, on - // Linux, TCP sockets will wait up to 200ms for more data to complete a packet - // before transmitting. After calling this function, the kernel will not wait. - // See TCP_NODELAY in `man 7 tcp`. On Windows, the Nagle implementation is - // governed by RFC 896. SetTCPNoDelay() sets the TCP_NODELAY option. Use - // |no_delay| to enable or disable it. - // Returns whether the operation is successful. + // These set the send / receive buffer sizes on the connected socket. See the + // corresponding values in TCPConnectedSocketOptions for descriptions, + // though note that passing in "0" here will set the size to the minimum + // value, instead of restoring the default. Consumers should prefer setting + // these values on creation, as some platforms may not respect changes to + // these values on a connected socket, even if the method succeeds. These are + // present mostly for legacy consumers that expose the behavior to + // non-Chrome code. + // A network error code is returned on completion. + SetSendBufferSize(int32 send_buffer_size) => (int32 net_error); + SetReceiveBufferSize(int32 receive_buffer_size) => (int32 net_error); + + // Enables / disables TCP_NODELAY on the connected socket. See + // TCPConnectedSocketOptions::no_delay for more details. + // Returns whether the operation was successful. SetNoDelay(bool no_delay) => (bool success); // Enables or disables TCP Keep-Alive. This sets SO_KEEPALIVE on the socket. @@ -60,12 +117,12 @@ interface TCPConnectedSocket { // error, if a network error happens during a read or a write, consumer can find // out about it by implementing this interface. interface SocketObserver { - // Called when an error occurs during reading from the network. The producer - // side of |receive_stream| will be closed. + // Called when a network read fails. Called with net::OK if the socket was + // closed gracefully. The producer side of |receive_stream| will be closed. OnReadError(int32 net_error); - // Called when an error occurs during sending to the network. The consumer - // side of |send_stream| will be closed. + // Called when a network write fails. The consumer side of |send_stream| will + // be closed. OnWriteError(int32 net_error); }; @@ -82,6 +139,9 @@ interface TCPServerSocket { // |backlog| is a number that is specified when requesting TCPServerSocket. If // more than |backlog| number of Accept()s are outstanding, // net::ERR_INSUFFICIENT_RESOUCES will be returned. + // + // Accepted sockets may not be upgraded to TLS by invoking UpgradeToTLS, as + // UpgradeToTLS only supports the client part of the TLS handshake. Accept(SocketObserver? observer) => (int32 net_error, net.interfaces.IPEndPoint? remote_addr, diff --git a/chromium/services/network/public/mojom/tls_socket.mojom b/chromium/services/network/public/mojom/tls_socket.mojom index 92d30197086..35e96296a07 100644 --- a/chromium/services/network/public/mojom/tls_socket.mojom +++ b/chromium/services/network/public/mojom/tls_socket.mojom @@ -20,6 +20,11 @@ interface TLSClientSocket { struct TLSClientSocketOptions { SSLVersion version_min = kTLS1; SSLVersion version_max = kTLS12; - // If |true|, the SSLInfo will be returned in the UpgradeToTLS callback. - bool skip_cert_verification = false; + + // If true, the SSLInfo will be returned in the UpgradeToTLS callback on + // success. + bool send_ssl_info = false; + + // If true, the SSLInfo will also be returned in the UpgradeToTLS callback. + bool unsafely_skip_cert_verification = false; }; diff --git a/chromium/services/network/resource_scheduler.cc b/chromium/services/network/resource_scheduler.cc index 65fd1535f61..bf92db77881 100644 --- a/chromium/services/network/resource_scheduler.cc +++ b/chromium/services/network/resource_scheduler.cc @@ -13,6 +13,7 @@ #include "base/macros.h" #include "base/metrics/field_trial.h" #include "base/metrics/field_trial_params.h" +#include "base/metrics/histogram_functions.h" #include "base/metrics/histogram_macros.h" #include "base/optional.h" #include "base/sequenced_task_runner.h" @@ -718,6 +719,15 @@ class ResourceScheduler::Client { "ResourceScheduler.NumDelayableRequestsInFlightAtStart.NonDelayable", in_flight_delayable_count_); } + + DCHECK(!request->url_request()->creation_time().is_null()); + base::TimeDelta queuing_duration = + base::TimeTicks::Now() - request->url_request()->creation_time(); + base::UmaHistogramMediumTimes( + "ResourceScheduler.RequestQueuingDuration.Priority" + + base::IntToString(request->get_request_priority_params().priority), + queuing_duration); + InsertInFlightRequest(request); request->Start(start_mode); } diff --git a/chromium/services/network/resource_scheduler_params_manager.cc b/chromium/services/network/resource_scheduler_params_manager.cc index 644d785b9f9..f363f2ab638 100644 --- a/chromium/services/network/resource_scheduler_params_manager.cc +++ b/chromium/services/network/resource_scheduler_params_manager.cc @@ -87,8 +87,10 @@ ResourceSchedulerParamsManager:: features::kDelayRequestsOnMultiplexedConnections, "MaxEffectiveConnectionType")); - if (!max_effective_connection_type) - return result; + if (!max_effective_connection_type) { + // Use a default value if one is not set using field trial params. + max_effective_connection_type = net::EFFECTIVE_CONNECTION_TYPE_3G; + } for (int ect = net::EFFECTIVE_CONNECTION_TYPE_SLOW_2G; ect <= max_effective_connection_type.value(); ++ect) { diff --git a/chromium/services/network/resource_scheduler_params_manager_unittest.cc b/chromium/services/network/resource_scheduler_params_manager_unittest.cc index 60f71defc99..742767bde60 100644 --- a/chromium/services/network/resource_scheduler_params_manager_unittest.cc +++ b/chromium/services/network/resource_scheduler_params_manager_unittest.cc @@ -87,7 +87,6 @@ class ResourceSchedulerParamsManagerTest : public testing::Test { switch (effective_connection_type) { case net::EFFECTIVE_CONNECTION_TYPE_UNKNOWN: case net::EFFECTIVE_CONNECTION_TYPE_OFFLINE: - case net::EFFECTIVE_CONNECTION_TYPE_3G: case net::EFFECTIVE_CONNECTION_TYPE_4G: EXPECT_EQ(10u, resource_scheduler_params_manager .GetParamsForEffectiveConnectionType( @@ -102,7 +101,20 @@ class ResourceSchedulerParamsManagerTest : public testing::Test { .GetParamsForEffectiveConnectionType(effective_connection_type) .delay_requests_on_multiplexed_connections); return; - + case net::EFFECTIVE_CONNECTION_TYPE_3G: + EXPECT_EQ(10u, resource_scheduler_params_manager + .GetParamsForEffectiveConnectionType( + effective_connection_type) + .max_delayable_requests); + EXPECT_EQ(0.0, resource_scheduler_params_manager + .GetParamsForEffectiveConnectionType( + effective_connection_type) + .non_delayable_weight); + EXPECT_TRUE( + resource_scheduler_params_manager + .GetParamsForEffectiveConnectionType(effective_connection_type) + .delay_requests_on_multiplexed_connections); + return; case net::EFFECTIVE_CONNECTION_TYPE_SLOW_2G: case net::EFFECTIVE_CONNECTION_TYPE_2G: EXPECT_EQ(8u, resource_scheduler_params_manager @@ -113,7 +125,7 @@ class ResourceSchedulerParamsManagerTest : public testing::Test { .GetParamsForEffectiveConnectionType( effective_connection_type) .non_delayable_weight); - EXPECT_FALSE( + EXPECT_TRUE( resource_scheduler_params_manager .GetParamsForEffectiveConnectionType(effective_connection_type) .delay_requests_on_multiplexed_connections); @@ -184,6 +196,17 @@ TEST_F(ResourceSchedulerParamsManagerTest, .GetParamsForEffectiveConnectionType(ect) .delay_requests_on_multiplexed_connections); + } else if (effective_connection_type == net::EFFECTIVE_CONNECTION_TYPE_3G) { + EXPECT_EQ(10u, resource_scheduler_params_manager + .GetParamsForEffectiveConnectionType(ect) + .max_delayable_requests); + EXPECT_EQ(0.0, resource_scheduler_params_manager + .GetParamsForEffectiveConnectionType(ect) + .non_delayable_weight); + EXPECT_FALSE(resource_scheduler_params_manager + .GetParamsForEffectiveConnectionType(ect) + .delay_requests_on_multiplexed_connections); + } else { VerifyDefaultParams( resource_scheduler_params_manager, diff --git a/chromium/services/network/resource_scheduler_unittest.cc b/chromium/services/network/resource_scheduler_unittest.cc index bf1d9263595..15ea7c162d2 100644 --- a/chromium/services/network/resource_scheduler_unittest.cc +++ b/chromium/services/network/resource_scheduler_unittest.cc @@ -447,6 +447,8 @@ TEST_F(ResourceSchedulerTest, OneIsolatedLowRequest) { } TEST_F(ResourceSchedulerTest, OneLowLoadsUntilCriticalComplete) { + base::HistogramTester histogram_tester; + SetMaxDelayableRequests(1); std::unique_ptr<TestRequest> high( NewRequest("http://host/high", net::HIGHEST)); @@ -463,6 +465,15 @@ TEST_F(ResourceSchedulerTest, OneLowLoadsUntilCriticalComplete) { high.reset(); base::RunLoop().RunUntilIdle(); EXPECT_TRUE(low2->started()); + + histogram_tester.ExpectTotalCount( + "ResourceScheduler.RequestQueuingDuration.Priority" + + base::IntToString(net::HIGHEST), + 1); + histogram_tester.ExpectTotalCount( + "ResourceScheduler.RequestQueuingDuration.Priority" + + base::IntToString(net::LOWEST), + 2); } TEST_F(ResourceSchedulerTest, SchedulerYieldsOnSpdy) { diff --git a/chromium/services/network/socket_data_pump.cc b/chromium/services/network/socket_data_pump.cc index 75f18f5e9bf..50808083ca4 100644 --- a/chromium/services/network/socket_data_pump.cc +++ b/chromium/services/network/socket_data_pump.cc @@ -89,8 +89,13 @@ void SocketDataPump::ReceiveMore() { receive_stream_close_watcher_.ArmOrNotify(); return; } - // Handle EOF. + // Handle EOF. Has to be done here rather than in + // OnNetworkReadIfReadyCompleted because net::OK in the sync completion case + // means EOF, but in the async case just means the socket is ready to be read + // from again. if (read_result == net::OK) { + if (delegate_) + delegate_->OnNetworkReadError(read_result); ShutdownReceive(); return; } @@ -113,16 +118,20 @@ void SocketDataPump::OnReceiveStreamWritable(MojoResult result) { } void SocketDataPump::OnNetworkReadIfReadyCompleted(int result) { + // This method is called either on ReadIfReady sync completion, except in the + // EOF case, or on async completion. In the sync case, result is < 0 on error, + // or > 0 on success. In the async case, result is < 0 on error, or 0 if we + // should try to read from the socket again (And possibly get any of more + // data, an EOF, or an error). DCHECK(receive_stream_.is_valid()); if (read_if_ready_pending_) { DCHECK_GE(net::OK, result); read_if_ready_pending_ = false; } - if (result < 0 && delegate_) - delegate_->OnNetworkReadError(result); - if (result < 0) { + if (delegate_) + delegate_->OnNetworkReadError(result); ShutdownReceive(); return; } @@ -188,15 +197,14 @@ void SocketDataPump::OnNetworkWriteCompleted(int result) { DCHECK(pending_send_buffer_); DCHECK(!send_stream_.is_valid()); - if (result < 0 && delegate_) - delegate_->OnNetworkWriteError(result); - // Partial write is possible. pending_send_buffer_->CompleteRead(result >= 0 ? result : 0); send_stream_ = pending_send_buffer_->ReleaseHandle(); pending_send_buffer_ = nullptr; if (result <= 0) { + if (delegate_) + delegate_->OnNetworkWriteError(result); ShutdownSend(); return; } diff --git a/chromium/services/network/socket_data_pump_unittest.cc b/chromium/services/network/socket_data_pump_unittest.cc index 208449e7af9..8ae4b23598b 100644 --- a/chromium/services/network/socket_data_pump_unittest.cc +++ b/chromium/services/network/socket_data_pump_unittest.cc @@ -234,6 +234,27 @@ TEST_P(SocketDataPumpTest, PartialStreamSocketWrite) { EXPECT_TRUE(data_provider.AllWriteDataConsumed()); } +TEST_P(SocketDataPumpTest, ReadEof) { + net::IoMode mode = GetParam(); + net::MockRead reads[] = {net::MockRead(mode, net::OK)}; + const char kTestMsg[] = "hello!"; + net::MockWrite writes[] = { + net::MockWrite(mode, kTestMsg, strlen(kTestMsg), 0)}; + net::StaticSocketDataProvider data_provider(reads, writes); + Init(&data_provider); + EXPECT_EQ("", Read(&receive_handle_, 1)); + EXPECT_EQ(net::OK, delegate()->WaitForReadError()); + // Writes can proceed even though there is a read error. + uint32_t num_bytes = strlen(kTestMsg); + EXPECT_EQ(MOJO_RESULT_OK, send_handle_->WriteData(&kTestMsg, &num_bytes, + MOJO_WRITE_DATA_FLAG_NONE)); + EXPECT_EQ(strlen(kTestMsg), num_bytes); + + base::RunLoop().RunUntilIdle(); + EXPECT_TRUE(data_provider.AllReadDataConsumed()); + EXPECT_TRUE(data_provider.AllWriteDataConsumed()); +} + TEST_P(SocketDataPumpTest, ReadError) { net::IoMode mode = GetParam(); net::MockRead reads[] = {net::MockRead(mode, net::ERR_FAILED)}; @@ -255,6 +276,27 @@ TEST_P(SocketDataPumpTest, ReadError) { EXPECT_TRUE(data_provider.AllWriteDataConsumed()); } +TEST_P(SocketDataPumpTest, WriteEof) { + net::IoMode mode = GetParam(); + const char kTestMsg[] = "hello!"; + net::MockRead reads[] = {net::MockRead(mode, kTestMsg, strlen(kTestMsg), 0), + net::MockRead(mode, net::OK)}; + net::MockWrite writes[] = {net::MockWrite(mode, net::OK)}; + net::StaticSocketDataProvider data_provider(reads, writes); + Init(&data_provider); + uint32_t num_bytes = strlen(kTestMsg); + EXPECT_EQ(MOJO_RESULT_OK, send_handle_->WriteData(&kTestMsg, &num_bytes, + MOJO_WRITE_DATA_FLAG_NONE)); + EXPECT_EQ(strlen(kTestMsg), num_bytes); + EXPECT_EQ(net::OK, delegate()->WaitForWriteError()); + // Reads can proceed even though there is a read error. + EXPECT_EQ(kTestMsg, Read(&receive_handle_, strlen(kTestMsg))); + + base::RunLoop().RunUntilIdle(); + EXPECT_TRUE(data_provider.AllReadDataConsumed()); + EXPECT_TRUE(data_provider.AllWriteDataConsumed()); +} + TEST_P(SocketDataPumpTest, WriteError) { net::IoMode mode = GetParam(); const char kTestMsg[] = "hello!"; diff --git a/chromium/services/network/socket_factory.cc b/chromium/services/network/socket_factory.cc index 43fbb75a503..e6c3bf90572 100644 --- a/chromium/services/network/socket_factory.cc +++ b/chromium/services/network/socket_factory.cc @@ -22,43 +22,16 @@ #include "net/ssl/ssl_config_service.h" #include "net/url_request/url_request_context.h" #include "services/network/ssl_config_type_converter.h" -#include "services/network/tcp_connected_socket.h" #include "services/network/tls_client_socket.h" #include "services/network/udp_socket.h" namespace network { -namespace { -// Cert verifier which blindly accepts all certificates, regardless of validity. -class FakeCertVerifier : public net::CertVerifier { - public: - FakeCertVerifier() {} - ~FakeCertVerifier() override {} - - int Verify(const RequestParams& params, - net::CertVerifyResult* verify_result, - net::CompletionOnceCallback, - std::unique_ptr<Request>*, - const net::NetLogWithSource&) override { - verify_result->Reset(); - verify_result->verified_cert = params.certificate(); - return net::OK; - } - void SetConfig(const Config& config) override {} -}; -} // namespace SocketFactory::SocketFactory(net::NetLog* net_log, net::URLRequestContext* url_request_context) : net_log_(net_log), - ssl_client_socket_context_( - url_request_context->cert_verifier(), - nullptr, /* TODO(rkn): ChannelIDService is not thread safe. */ - url_request_context->transport_security_state(), - url_request_context->cert_transparency_verifier(), - url_request_context->ct_policy_enforcer(), - std::string() /* TODO(rsleevi): Ensure a proper unique shard. */), client_socket_factory_(nullptr), - ssl_config_service_(url_request_context->ssl_config_service()) { + tls_socket_factory_(url_request_context, nullptr /*http_context*/) { if (url_request_context->GetNetworkSessionContext()) { client_socket_factory_ = url_request_context->GetNetworkSessionContext()->client_socket_factory; @@ -97,70 +70,60 @@ void SocketFactory::CreateTCPServerSocket( void SocketFactory::CreateTCPConnectedSocket( const base::Optional<net::IPEndPoint>& local_addr, const net::AddressList& remote_addr_list, + mojom::TCPConnectedSocketOptionsPtr tcp_connected_socket_options, const net::NetworkTrafficAnnotationTag& traffic_annotation, mojom::TCPConnectedSocketRequest request, mojom::SocketObserverPtr observer, mojom::NetworkContext::CreateTCPConnectedSocketCallback callback) { auto socket = std::make_unique<TCPConnectedSocket>( - std::move(observer), net_log_, this, client_socket_factory_, - traffic_annotation); + std::move(observer), net_log_, &tls_socket_factory_, + client_socket_factory_, traffic_annotation); TCPConnectedSocket* socket_raw = socket.get(); tcp_connected_socket_bindings_.AddBinding(std::move(socket), std::move(request)); - socket_raw->Connect(local_addr, remote_addr_list, std::move(callback)); + socket_raw->Connect(local_addr, remote_addr_list, + std::move(tcp_connected_socket_options), + std::move(callback)); } -void SocketFactory::CreateTLSClientSocket( - const net::HostPortPair& host_port_pair, - mojom::TLSClientSocketOptionsPtr socket_options, - mojom::TLSClientSocketRequest request, - std::unique_ptr<net::ClientSocketHandle> tcp_socket, - mojom::SocketObserverPtr observer, +void SocketFactory::CreateTCPBoundSocket( + const net::IPEndPoint& local_addr, const net::NetworkTrafficAnnotationTag& traffic_annotation, - mojom::TCPConnectedSocket::UpgradeToTLSCallback callback) { - auto socket = std::make_unique<TLSClientSocket>( - std::move(request), std::move(observer), - static_cast<net::NetworkTrafficAnnotationTag>(traffic_annotation)); - TLSClientSocket* socket_raw = socket.get(); - tls_socket_bindings_.AddBinding(std::move(socket), std::move(request)); + mojom::TCPBoundSocketRequest request, + mojom::NetworkContext::CreateTCPBoundSocketCallback callback) { + auto socket = + std::make_unique<TCPBoundSocket>(this, net_log_, traffic_annotation); + net::IPEndPoint local_addr_out; + int result = socket->Bind(local_addr, &local_addr_out); + if (result != net::OK) { + std::move(callback).Run(result, base::nullopt); + return; + } + socket->set_id(tcp_bound_socket_bindings_.AddBinding(std::move(socket), + std::move(request))); + std::move(callback).Run(result, local_addr_out); +} - net::SSLConfig ssl_config; - ssl_config_service_->GetSSLConfig(&ssl_config); - net::SSLClientSocketContext& ssl_client_socket_context = - ssl_client_socket_context_; +void SocketFactory::DestroyBoundSocket(mojo::BindingId bound_socket_id) { + tcp_bound_socket_bindings_.RemoveBinding(bound_socket_id); +} - bool send_ssl_info = false; - if (socket_options) { - ssl_config.version_min = - mojo::MojoSSLVersionToNetSSLVersion(socket_options->version_min); - ssl_config.version_max = - mojo::MojoSSLVersionToNetSSLVersion(socket_options->version_max); +void SocketFactory::OnBoundSocketListening( + mojo::BindingId bound_socket_id, + std::unique_ptr<TCPServerSocket> server_socket, + mojom::TCPServerSocketRequest server_socket_request) { + tcp_server_socket_bindings_.AddBinding(std::move(server_socket), + std::move(server_socket_request)); + tcp_bound_socket_bindings_.RemoveBinding(bound_socket_id); +} - if (socket_options->skip_cert_verification) { - if (!no_verification_cert_verifier_) { - no_verification_cert_verifier_ = base::WrapUnique(new FakeCertVerifier); - no_verification_transport_security_state_.reset( - new net::TransportSecurityState); - no_verification_cert_transparency_verifier_.reset( - new net::MultiLogCTVerifier()); - no_verification_ct_policy_enforcer_.reset( - new net::DefaultCTPolicyEnforcer()); - no_verification_ssl_client_socket_context_.cert_verifier = - no_verification_cert_verifier_.get(); - no_verification_ssl_client_socket_context_.transport_security_state = - no_verification_transport_security_state_.get(); - no_verification_ssl_client_socket_context_.cert_transparency_verifier = - no_verification_cert_transparency_verifier_.get(); - no_verification_ssl_client_socket_context_.ct_policy_enforcer = - no_verification_ct_policy_enforcer_.get(); - } - ssl_client_socket_context = no_verification_ssl_client_socket_context_; - send_ssl_info = true; - } - } - socket_raw->Connect(host_port_pair, ssl_config, std::move(tcp_socket), - ssl_client_socket_context, client_socket_factory_, - std::move(callback), send_ssl_info); +void SocketFactory::OnBoundSocketConnected( + mojo::BindingId bound_socket_id, + std::unique_ptr<TCPConnectedSocket> connected_socket, + mojom::TCPConnectedSocketRequest connected_socket_request) { + tcp_connected_socket_bindings_.AddBinding( + std::move(connected_socket), std::move(connected_socket_request)); + tcp_bound_socket_bindings_.RemoveBinding(bound_socket_id); } void SocketFactory::OnAccept(std::unique_ptr<TCPConnectedSocket> socket, diff --git a/chromium/services/network/socket_factory.h b/chromium/services/network/socket_factory.h index 31297409785..9bf1188606c 100644 --- a/chromium/services/network/socket_factory.h +++ b/chromium/services/network/socket_factory.h @@ -16,16 +16,15 @@ #include "net/traffic_annotation/network_traffic_annotation.h" #include "services/network/public/mojom/network_service.mojom.h" #include "services/network/public/mojom/tcp_socket.mojom.h" -#include "services/network/public/mojom/tls_socket.mojom.h" #include "services/network/public/mojom/udp_socket.mojom.h" +#include "services/network/tcp_bound_socket.h" #include "services/network/tcp_connected_socket.h" #include "services/network/tcp_server_socket.h" +#include "services/network/tls_socket_factory.h" namespace net { -class ClientSocketHandle; class ClientSocketFactory; class NetLog; -class SSLConfigService; } // namespace net namespace network { @@ -33,8 +32,7 @@ namespace network { // Helper class that handles socket requests. It takes care of destroying // socket implementation instances when mojo pipes are broken. class COMPONENT_EXPORT(NETWORK_SERVICE) SocketFactory - : public TCPServerSocket::Delegate, - public TCPConnectedSocket::Delegate { + : public TCPServerSocket::Delegate { public: // Constructs a SocketFactory. If |net_log| is non-null, it is used to // log NetLog events when logging is enabled. |net_log| used to must outlive @@ -43,6 +41,7 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) SocketFactory net::URLRequestContext* url_request_context); virtual ~SocketFactory(); + // These all correspond to the NetworkContext methods of the same name. void CreateUDPSocket(mojom::UDPSocketRequest request, mojom::UDPSocketReceiverPtr receiver); void CreateTCPServerSocket( @@ -54,45 +53,52 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) SocketFactory void CreateTCPConnectedSocket( const base::Optional<net::IPEndPoint>& local_addr, const net::AddressList& remote_addr_list, + mojom::TCPConnectedSocketOptionsPtr tcp_connected_socket_options, const net::NetworkTrafficAnnotationTag& traffic_annotation, mojom::TCPConnectedSocketRequest request, mojom::SocketObserverPtr observer, mojom::NetworkContext::CreateTCPConnectedSocketCallback callback); + void CreateTCPBoundSocket( + const net::IPEndPoint& local_addr, + const net::NetworkTrafficAnnotationTag& traffic_annotation, + mojom::TCPBoundSocketRequest request, + mojom::NetworkContext::CreateTCPBoundSocketCallback callback); + + // Destroys the specified BoundSocket object. + void DestroyBoundSocket(mojo::BindingId bound_socket_id); + + // Invoked when a BoundSocket successfully starts listening. Destroys the + // BoundSocket object, adding a binding for the provided TCPServerSocket in + // its place. + void OnBoundSocketListening( + mojo::BindingId bound_socket_id, + std::unique_ptr<TCPServerSocket> server_socket, + mojom::TCPServerSocketRequest server_socket_request); + + // Invoked when a BoundSocket successfully establishes a connection. Destroys + // the BoundSocket object, adding a binding for the provided + // TCPConnectedSocket in its place. + void OnBoundSocketConnected( + mojo::BindingId bound_socket_id, + std::unique_ptr<TCPConnectedSocket> connected_socket, + mojom::TCPConnectedSocketRequest connected_socket_request); + + TLSSocketFactory* tls_socket_factory() { return &tls_socket_factory_; } private: // TCPServerSocket::Delegate implementation: void OnAccept(std::unique_ptr<TCPConnectedSocket> socket, mojom::TCPConnectedSocketRequest request) override; - // TCPConnectedSocket::Delegate implementation: - void CreateTLSClientSocket( - const net::HostPortPair& host_port_pair, - mojom::TLSClientSocketOptionsPtr socket_options, - mojom::TLSClientSocketRequest request, - std::unique_ptr<net::ClientSocketHandle> tcp_socket, - mojom::SocketObserverPtr observer, - const net::NetworkTrafficAnnotationTag& traffic_annotation, - mojom::TCPConnectedSocket::UpgradeToTLSCallback callback) override; - net::NetLog* const net_log_; - // The following are used when |skip_cert_verification| is specified in - // upgrade options. - net::SSLClientSocketContext no_verification_ssl_client_socket_context_; - std::unique_ptr<net::CertVerifier> no_verification_cert_verifier_; - std::unique_ptr<net::TransportSecurityState> - no_verification_transport_security_state_; - std::unique_ptr<net::CTVerifier> no_verification_cert_transparency_verifier_; - std::unique_ptr<net::CTPolicyEnforcer> no_verification_ct_policy_enforcer_; - - net::SSLClientSocketContext ssl_client_socket_context_; net::ClientSocketFactory* client_socket_factory_; - net::SSLConfigService* const ssl_config_service_; + TLSSocketFactory tls_socket_factory_; mojo::StrongBindingSet<mojom::UDPSocket> udp_socket_bindings_; mojo::StrongBindingSet<mojom::TCPServerSocket> tcp_server_socket_bindings_; mojo::StrongBindingSet<mojom::TCPConnectedSocket> tcp_connected_socket_bindings_; - mojo::StrongBindingSet<mojom::TLSClientSocket> tls_socket_bindings_; + mojo::StrongBindingSet<mojom::TCPBoundSocket> tcp_bound_socket_bindings_; DISALLOW_COPY_AND_ASSIGN(SocketFactory); }; diff --git a/chromium/services/network/ssl_config_type_converter.cc b/chromium/services/network/ssl_config_type_converter.cc index fa3460656fc..41f718d1329 100644 --- a/chromium/services/network/ssl_config_type_converter.cc +++ b/chromium/services/network/ssl_config_type_converter.cc @@ -11,8 +11,6 @@ net::TLS13Variant MojoTLS13VariantToNetTLS13Variant( switch (tls13_variant) { case network::mojom::TLS13Variant::kDraft23: return net::kTLS13VariantDraft23; - case network::mojom::TLS13Variant::kDraft28: - return net::kTLS13VariantDraft28; case network::mojom::TLS13Variant::kFinal: return net::kTLS13VariantFinal; } diff --git a/chromium/services/network/tcp_bound_socket.cc b/chromium/services/network/tcp_bound_socket.cc new file mode 100644 index 00000000000..37411575bad --- /dev/null +++ b/chromium/services/network/tcp_bound_socket.cc @@ -0,0 +1,159 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "services/network/tcp_bound_socket.h" + +#include <utility> + +#include "base/bind.h" +#include "base/logging.h" +#include "base/numerics/safe_conversions.h" +#include "base/optional.h" +#include "net/base/ip_endpoint.h" +#include "net/base/net_errors.h" +#include "net/log/net_log.h" +#include "net/socket/tcp_client_socket.h" +#include "net/socket/tcp_server_socket.h" +#include "net/socket/tcp_socket.h" +#include "services/network/socket_factory.h" +#include "services/network/tcp_connected_socket.h" + +namespace network { + +TCPBoundSocket::TCPBoundSocket( + SocketFactory* socket_factory, + net::NetLog* net_log, + const net::NetworkTrafficAnnotationTag& traffic_annotation) + : socket_factory_(socket_factory), + socket_(std::make_unique<net::TCPSocket>( + nullptr /*socket_performance_watcher*/, + net_log, + net::NetLogSource())), + traffic_annotation_(traffic_annotation), + weak_factory_(this) {} + +TCPBoundSocket::~TCPBoundSocket() = default; + +int TCPBoundSocket::Bind(const net::IPEndPoint& local_addr, + net::IPEndPoint* local_addr_out) { + bind_address_ = local_addr; + + int result = socket_->Open(local_addr.GetFamily()); + if (result != net::OK) + return result; + + // This is primarily intended for use with server sockets. + result = socket_->SetDefaultOptionsForServer(); + if (result != net::OK) + return result; + + result = socket_->Bind(local_addr); + if (result != net::OK) + return result; + + return socket_->GetLocalAddress(local_addr_out); +} + +void TCPBoundSocket::Listen(uint32_t backlog, + mojom::TCPServerSocketRequest request, + ListenCallback callback) { + DCHECK(socket_->IsValid()); + + if (!socket_) { + // Drop unexpected calls on the floor. Could destroy |this|, but as this is + // currently only reachable from more trusted processes, doesn't seem too + // useful. + NOTREACHED(); + return; + } + + int result = ListenInternal(backlog); + + // Succeed or fail, pass the result back to the caller. + std::move(callback).Run(result); + + // Tear down everything on error. + if (result != net::OK) { + socket_factory_->DestroyBoundSocket(binding_id_); + return; + } + + // On success, also set up the TCPServerSocket. + std::unique_ptr<TCPServerSocket> server_socket = + std::make_unique<TCPServerSocket>( + std::make_unique<net::TCPServerSocket>(std::move(socket_)), backlog, + socket_factory_, traffic_annotation_); + socket_factory_->OnBoundSocketListening(binding_id_, std::move(server_socket), + std::move(request)); + // The above call will have destroyed |this|. +} + +void TCPBoundSocket::Connect( + const net::AddressList& remote_addr_list, + mojom::TCPConnectedSocketOptionsPtr tcp_connected_socket_options, + mojom::TCPConnectedSocketRequest request, + mojom::SocketObserverPtr observer, + ConnectCallback callback) { + DCHECK(socket_->IsValid()); + + if (!socket_) { + // Drop unexpected calls on the floor. Could destroy |this|, but as this is + // currently only reachable from more trusted processes, doesn't seem too + // useful. + NOTREACHED(); + return; + } + + DCHECK(!connect_callback_); + DCHECK(!connected_socket_request_); + DCHECK(!connecting_socket_); + + connected_socket_request_ = std::move(request); + connect_callback_ = std::move(callback); + + // Create a TCPConnectedSocket and have it do the work of connecting and + // configuring the socket. This saves a bit of code, and reduces the number of + // tests this class needs, since it can rely on TCPConnectedSocket's tests for + // a lot of cases. + connecting_socket_ = std::make_unique<TCPConnectedSocket>( + std::move(observer), socket_->net_log().net_log(), + socket_factory_->tls_socket_factory(), + nullptr /* client_socket_factory */, traffic_annotation_); + connecting_socket_->ConnectWithSocket( + net::TCPClientSocket::CreateFromBoundSocket( + std::move(socket_), remote_addr_list, bind_address_), + std::move(tcp_connected_socket_options), + base::BindOnce(&TCPBoundSocket::OnConnectComplete, + base::Unretained(this))); +} + +void TCPBoundSocket::OnConnectComplete( + int result, + const base::Optional<net::IPEndPoint>& local_addr, + const base::Optional<net::IPEndPoint>& peer_addr, + mojo::ScopedDataPipeConsumerHandle receive_stream, + mojo::ScopedDataPipeProducerHandle send_stream) { + DCHECK(connecting_socket_); + DCHECK(connect_callback_); + + std::move(connect_callback_) + .Run(result, local_addr, peer_addr, std::move(receive_stream), + std::move(send_stream)); + if (result != net::OK) { + socket_factory_->DestroyBoundSocket(binding_id_); + // The above call will have destroyed |this|. + return; + } + + socket_factory_->OnBoundSocketConnected(binding_id_, + std::move(connecting_socket_), + std::move(connected_socket_request_)); + // The above call will have destroyed |this|. +} + +int TCPBoundSocket::ListenInternal(int backlog) { + return socket_->Listen(backlog); +} + +} // namespace network diff --git a/chromium/services/network/tcp_bound_socket.h b/chromium/services/network/tcp_bound_socket.h new file mode 100644 index 00000000000..eb2f7ec8fb4 --- /dev/null +++ b/chromium/services/network/tcp_bound_socket.h @@ -0,0 +1,91 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef SERVICES_NETWORK_TCP_BOUND_SOCKET_H_ +#define SERVICES_NETWORK_TCP_BOUND_SOCKET_H_ + +#include <memory> + +#include "base/component_export.h" +#include "base/macros.h" +#include "base/memory/ref_counted.h" +#include "base/memory/weak_ptr.h" +#include "mojo/public/cpp/bindings/binding_set.h" +#include "mojo/public/cpp/bindings/interface_request.h" +#include "net/base/ip_endpoint.h" +#include "net/socket/tcp_socket.h" +#include "net/traffic_annotation/network_traffic_annotation.h" +#include "services/network/public/mojom/tcp_socket.mojom.h" +#include "services/network/tcp_server_socket.h" + +namespace net { +class IPEndPoint; +class NetLog; +} // namespace net + +namespace network { +class SocketFactory; + +// A socket bound to an address. Can be converted into either a TCPServerSocket +// or a TCPConnectedSocket. +class COMPONENT_EXPORT(NETWORK_SERVICE) TCPBoundSocket + : public mojom::TCPBoundSocket { + public: + // Constructs a TCPBoundSocket. |socket_factory| must outlive |this|. When a + // new connection is accepted, |socket_factory| will be notified to take + // ownership of the connection. + TCPBoundSocket(SocketFactory* socket_factory, + net::NetLog* net_log, + const net::NetworkTrafficAnnotationTag& traffic_annotation); + ~TCPBoundSocket() override; + + // Attempts to bind a socket to the specified address. Returns net::OK on + // success, setting |local_addr_out| to the bound address. Returns a network + // error code on failure. Must be called before Listen() or Connect(). + int Bind(const net::IPEndPoint& local_addr, net::IPEndPoint* local_addr_out); + + // Sets the id used to remove the socket from the owning BindingSet. Must be + // called before Listen() or Connect(). + void set_id(mojo::BindingId binding_id) { binding_id_ = binding_id; } + + // mojom::TCPBoundSocket implementation. + void Listen(uint32_t backlog, + mojom::TCPServerSocketRequest request, + ListenCallback callback) override; + void Connect(const net::AddressList& remote_addr, + mojom::TCPConnectedSocketOptionsPtr tcp_connected_socket_options, + mojom::TCPConnectedSocketRequest request, + mojom::SocketObserverPtr observer, + ConnectCallback callback) override; + + private: + void OnConnectComplete(int result, + const base::Optional<net::IPEndPoint>& local_addr, + const base::Optional<net::IPEndPoint>& peer_addr, + mojo::ScopedDataPipeConsumerHandle receive_stream, + mojo::ScopedDataPipeProducerHandle send_stream); + + virtual int ListenInternal(int backlog); + + net::IPEndPoint bind_address_; + + mojo::BindingId binding_id_ = -1; + SocketFactory* const socket_factory_; + std::unique_ptr<net::TCPSocket> socket_; + const net::NetworkTrafficAnnotationTag traffic_annotation_; + + mojom::TCPConnectedSocketRequest connected_socket_request_; + ConnectCallback connect_callback_; + + // Takes ownership of |socket_| if Connect() is called. + std::unique_ptr<TCPConnectedSocket> connecting_socket_; + + base::WeakPtrFactory<TCPBoundSocket> weak_factory_; + + DISALLOW_COPY_AND_ASSIGN(TCPBoundSocket); +}; + +} // namespace network + +#endif // SERVICES_NETWORK_TCP_BOUND_SOCKET_H_ diff --git a/chromium/services/network/tcp_bound_socket_unittest.cc b/chromium/services/network/tcp_bound_socket_unittest.cc new file mode 100644 index 00000000000..143de24bbc7 --- /dev/null +++ b/chromium/services/network/tcp_bound_socket_unittest.cc @@ -0,0 +1,499 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <stdint.h> + +#include <string> +#include <utility> +#include <vector> + +#include "base/bind.h" +#include "base/run_loop.h" +#include "base/strings/stringprintf.h" +#include "base/test/bind_test_util.h" +#include "base/test/scoped_task_environment.h" +#include "build/build_config.h" +#include "mojo/public/cpp/bindings/interface_request.h" +#include "mojo/public/cpp/system/data_pipe_utils.h" +#include "mojo/public/cpp/system/simple_watcher.h" +#include "mojo/public/cpp/system/wait.h" +#include "net/base/address_list.h" +#include "net/base/ip_address.h" +#include "net/base/ip_endpoint.h" +#include "net/base/net_errors.h" +#include "net/test/embedded_test_server/embedded_test_server.h" +#include "net/test/embedded_test_server/http_request.h" +#include "net/test/embedded_test_server/http_response.h" +#include "net/traffic_annotation/network_traffic_annotation_test_helper.h" +#include "net/url_request/url_request_test_util.h" +#include "services/network/mojo_socket_test_util.h" +#include "services/network/public/mojom/tcp_socket.mojom.h" +#include "services/network/socket_factory.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace network { +namespace { + +class TCPBoundSocketTest : public testing::Test { + public: + TCPBoundSocketTest() + : scoped_task_environment_( + base::test::ScopedTaskEnvironment::MainThreadType::IO), + factory_(nullptr /* net_log */, &url_request_context_) {} + ~TCPBoundSocketTest() override {} + + SocketFactory* factory() { return &factory_; } + + int BindSocket(const net::IPEndPoint& ip_endpoint_in, + mojom::TCPBoundSocketPtr* bound_socket, + net::IPEndPoint* ip_endpoint_out) { + base::RunLoop run_loop; + int bind_result = net::ERR_IO_PENDING; + factory()->CreateTCPBoundSocket( + ip_endpoint_in, TRAFFIC_ANNOTATION_FOR_TESTS, + mojo::MakeRequest(bound_socket), + base::BindLambdaForTesting( + [&](int net_error, + const base::Optional<net::IPEndPoint>& local_addr) { + bind_result = net_error; + if (net_error == net::OK) { + *ip_endpoint_out = *local_addr; + } else { + EXPECT_FALSE(local_addr); + } + run_loop.Quit(); + })); + run_loop.Run(); + + // On error, |bound_socket| should be closed. + if (bind_result != net::OK && !bound_socket->encountered_error()) { + base::RunLoop close_pipe_run_loop; + bound_socket->set_connection_error_handler( + close_pipe_run_loop.QuitClosure()); + close_pipe_run_loop.Run(); + } + return bind_result; + } + + int Listen(mojom::TCPBoundSocketPtr bound_socket, + mojom::TCPServerSocketPtr* server_socket) { + base::RunLoop bound_socket_destroyed_run_loop; + bound_socket.set_connection_error_handler( + bound_socket_destroyed_run_loop.QuitClosure()); + + base::RunLoop run_loop; + int listen_result = net::ERR_IO_PENDING; + bound_socket->Listen(1 /* backlog */, mojo::MakeRequest(server_socket), + base::BindLambdaForTesting([&](int net_error) { + listen_result = net_error; + run_loop.Quit(); + })); + run_loop.Run(); + + // Whether Bind() fails or succeeds, |bound_socket| is destroyed. + bound_socket_destroyed_run_loop.Run(); + + // On error, |server_socket| should be closed. + if (listen_result != net::OK && !server_socket->encountered_error()) { + base::RunLoop close_pipe_run_loop; + server_socket->set_connection_error_handler( + close_pipe_run_loop.QuitClosure()); + close_pipe_run_loop.Run(); + } + + return listen_result; + } + + int Connect(mojom::TCPBoundSocketPtr bound_socket, + const net::IPEndPoint& expected_local_addr, + const net::IPEndPoint& connect_to_addr, + mojom::TCPConnectedSocketOptionsPtr tcp_connected_socket_options, + mojom::TCPConnectedSocketPtr* connected_socket, + mojom::SocketObserverPtr socket_observer, + mojo::ScopedDataPipeConsumerHandle* client_socket_receive_handle, + mojo::ScopedDataPipeProducerHandle* client_socket_send_handle) { + base::RunLoop bound_socket_destroyed_run_loop; + bound_socket.set_connection_error_handler( + bound_socket_destroyed_run_loop.QuitClosure()); + + int connect_result = net::ERR_IO_PENDING; + base::RunLoop run_loop; + bound_socket->Connect( + net::AddressList(connect_to_addr), + std::move(tcp_connected_socket_options), + mojo::MakeRequest(connected_socket), std::move(socket_observer), + base::BindLambdaForTesting( + [&](int net_error, + const base::Optional<net::IPEndPoint>& local_addr, + const base::Optional<net::IPEndPoint>& remote_addr, + mojo::ScopedDataPipeConsumerHandle receive_stream, + mojo::ScopedDataPipeProducerHandle send_stream) { + connect_result = net_error; + if (net_error == net::OK) { + EXPECT_EQ(expected_local_addr, *local_addr); + EXPECT_EQ(connect_to_addr, *remote_addr); + *client_socket_receive_handle = std::move(receive_stream); + *client_socket_send_handle = std::move(send_stream); + } else { + EXPECT_FALSE(local_addr); + EXPECT_FALSE(remote_addr); + EXPECT_FALSE(receive_stream.is_valid()); + EXPECT_FALSE(send_stream.is_valid()); + } + run_loop.Quit(); + })); + run_loop.Run(); + + // Whether Bind() fails or succeeds, |bound_socket| is destroyed. + bound_socket_destroyed_run_loop.Run(); + + // On error, |connected_socket| should be closed. + if (connect_result != net::OK && !connected_socket->encountered_error()) { + base::RunLoop close_pipe_run_loop; + connected_socket->set_connection_error_handler( + close_pipe_run_loop.QuitClosure()); + close_pipe_run_loop.Run(); + } + + return connect_result; + } + + // Attempts to read exactly |expected_bytes| from |receive_handle|, or reads + // until the pipe is closed if |expected_bytes| is 0. + std::string ReadData(mojo::DataPipeConsumerHandle receive_handle, + uint32_t expected_bytes = 0) { + std::string read_data; + while (expected_bytes == 0 || read_data.size() < expected_bytes) { + const void* buffer; + uint32_t num_bytes = expected_bytes - read_data.size(); + MojoResult result = receive_handle.BeginReadData( + &buffer, &num_bytes, MOJO_READ_DATA_FLAG_NONE); + if (result == MOJO_RESULT_SHOULD_WAIT) { + scoped_task_environment_.RunUntilIdle(); + continue; + } + if (result != MOJO_RESULT_OK) { + if (expected_bytes != 0) + ADD_FAILURE() << "Read failed"; + return read_data; + } + read_data.append(static_cast<const char*>(buffer), num_bytes); + receive_handle.EndReadData(num_bytes); + } + + return read_data; + } + + static net::IPEndPoint LocalHostWithAnyPort() { + return net::IPEndPoint(net::IPAddress::IPv4Localhost(), 0 /* port */); + } + + base::test::ScopedTaskEnvironment* scoped_task_environment() { + return &scoped_task_environment_; + } + + private: + base::test::ScopedTaskEnvironment scoped_task_environment_; + net::TestURLRequestContext url_request_context_; + SocketFactory factory_; + + DISALLOW_COPY_AND_ASSIGN(TCPBoundSocketTest); +}; + +// Try to bind a socket to an address already being listened on, which should +// fail. +TEST_F(TCPBoundSocketTest, BindError) { + // Set up a listening socket. + mojom::TCPBoundSocketPtr bound_socket1; + net::IPEndPoint bound_address1; + ASSERT_EQ(net::OK, BindSocket(LocalHostWithAnyPort(), &bound_socket1, + &bound_address1)); + mojom::TCPServerSocketPtr server_socket; + ASSERT_EQ(net::OK, Listen(std::move(bound_socket1), &server_socket)); + + // Try to bind another socket to the listening socket's address. + mojom::TCPBoundSocketPtr bound_socket2; + net::IPEndPoint bound_address2; + int result = BindSocket(bound_address1, &bound_socket2, &bound_address2); + // Depending on platform, can get different errors. Some platforms can return + // either error. + EXPECT_TRUE(result == net::ERR_ADDRESS_IN_USE || + result == net::ERR_INVALID_ARGUMENT); +} + +// Test the case of a connect error. To cause a connect error, bind a socket, +// but don't listen on it, and then try connecting to it using another bound +// socket. +// +// Don't run on Apple platforms because this pattern ends in a connect timeout +// on OSX (after 25+ seconds) instead of connection refused. +#if !defined(OS_MACOSX) && !defined(OS_IOS) +TEST_F(TCPBoundSocketTest, ConnectError) { + mojom::TCPBoundSocketPtr bound_socket1; + net::IPEndPoint bound_address1; + ASSERT_EQ(net::OK, BindSocket(LocalHostWithAnyPort(), &bound_socket1, + &bound_address1)); + + // Trying to bind to an address currently being used for listening should + // fail. + mojom::TCPBoundSocketPtr bound_socket2; + net::IPEndPoint bound_address2; + ASSERT_EQ(net::OK, BindSocket(LocalHostWithAnyPort(), &bound_socket2, + &bound_address2)); + mojom::TCPConnectedSocketPtr connected_socket; + mojo::ScopedDataPipeConsumerHandle client_socket_receive_handle; + mojo::ScopedDataPipeProducerHandle client_socket_send_handle; + EXPECT_EQ(net::ERR_CONNECTION_REFUSED, + Connect(std::move(bound_socket2), bound_address2, bound_address1, + nullptr /* tcp_connected_socket_options */, + &connected_socket, mojom::SocketObserverPtr(), + &client_socket_receive_handle, &client_socket_send_handle)); +} +#endif // !defined(OS_MACOSX) && !defined(OS_IOS) + +// Test listen failure. + +// All platforms except Windows use SO_REUSEADDR on server sockets by default, +// which allows binding multiple sockets to the same port at once, as long as +// nothing is listening on it yet. +// +// Apple platforms don't allow binding multiple TCP sockets to the same port +// even with SO_REUSEADDR enabled. +#if !defined(OS_WIN) && !defined(OS_MACOSX) && !defined(OS_IOS) +TEST_F(TCPBoundSocketTest, ListenError) { + // Bind a socket. + mojom::TCPBoundSocketPtr bound_socket1; + net::IPEndPoint bound_address1; + ASSERT_EQ(net::OK, BindSocket(LocalHostWithAnyPort(), &bound_socket1, + &bound_address1)); + + // Bind another socket to the same address, which should succeed, due to + // SO_REUSEADDR. + mojom::TCPBoundSocketPtr bound_socket2; + net::IPEndPoint bound_address2; + ASSERT_EQ(net::OK, + BindSocket(bound_address1, &bound_socket2, &bound_address2)); + + // Listen on the first socket, which should also succeed. + mojom::TCPServerSocketPtr server_socket1; + ASSERT_EQ(net::OK, Listen(std::move(bound_socket1), &server_socket1)); + + // Listen on the second socket should fail. + mojom::TCPServerSocketPtr server_socket2; + int result = Listen(std::move(bound_socket2), &server_socket2); + // Depending on platform, can get different errors. Some platforms can return + // either error. + EXPECT_TRUE(result == net::ERR_ADDRESS_IN_USE || + result == net::ERR_INVALID_ARGUMENT); +} +#endif // !defined(OS_WIN) && !defined(OS_MACOSX) && !defined(OS_IOS) + +// Test the case bind succeeds, and transfer some data. +TEST_F(TCPBoundSocketTest, ReadWrite) { + // Set up a listening socket. + mojom::TCPBoundSocketPtr bound_socket1; + net::IPEndPoint server_address; + ASSERT_EQ(net::OK, BindSocket(LocalHostWithAnyPort(), &bound_socket1, + &server_address)); + mojom::TCPServerSocketPtr server_socket; + ASSERT_EQ(net::OK, Listen(std::move(bound_socket1), &server_socket)); + + // Connect to the socket with another socket. + mojom::TCPBoundSocketPtr bound_socket2; + net::IPEndPoint client_address; + ASSERT_EQ(net::OK, BindSocket(LocalHostWithAnyPort(), &bound_socket2, + &client_address)); + mojom::TCPConnectedSocketPtr client_socket; + TestSocketObserver socket_observer; + mojo::ScopedDataPipeConsumerHandle client_socket_receive_handle; + mojo::ScopedDataPipeProducerHandle client_socket_send_handle; + EXPECT_EQ(net::OK, + Connect(std::move(bound_socket2), client_address, server_address, + nullptr /* tcp_connected_socket_options */, &client_socket, + socket_observer.GetObserverPtr(), + &client_socket_receive_handle, &client_socket_send_handle)); + + base::RunLoop run_loop; + mojom::TCPConnectedSocketPtr accept_socket; + mojo::ScopedDataPipeConsumerHandle accept_socket_receive_handle; + mojo::ScopedDataPipeProducerHandle accept_socket_send_handle; + server_socket->Accept( + nullptr /* ovserver */, + base::BindLambdaForTesting( + [&](int net_error, const base::Optional<net::IPEndPoint>& remote_addr, + mojom::TCPConnectedSocketPtr connected_socket, + mojo::ScopedDataPipeConsumerHandle receive_stream, + mojo::ScopedDataPipeProducerHandle send_stream) { + EXPECT_EQ(net_error, net::OK); + EXPECT_EQ(*remote_addr, client_address); + accept_socket = std::move(connected_socket); + accept_socket_receive_handle = std::move(receive_stream); + accept_socket_send_handle = std::move(send_stream); + run_loop.Quit(); + })); + run_loop.Run(); + + const std::string kData = "Jumbo Shrimp"; + ASSERT_TRUE(mojo::BlockingCopyFromString(kData, client_socket_send_handle)); + EXPECT_EQ(kData, ReadData(accept_socket_receive_handle.get(), kData.size())); + + ASSERT_TRUE(mojo::BlockingCopyFromString(kData, accept_socket_send_handle)); + EXPECT_EQ(kData, ReadData(client_socket_receive_handle.get(), kData.size())); + + // Close the accept socket. + accept_socket.reset(); + + // Wait for read error on the client socket. + EXPECT_EQ(net::OK, socket_observer.WaitForReadError()); + + // Write data to the client socket until there's an error. + while (true) { + void* buffer = nullptr; + uint32_t buffer_num_bytes = 0; + MojoResult result = client_socket_send_handle->BeginWriteData( + &buffer, &buffer_num_bytes, MOJO_WRITE_DATA_FLAG_NONE); + if (result == MOJO_RESULT_SHOULD_WAIT) { + scoped_task_environment()->RunUntilIdle(); + continue; + } + if (result != MOJO_RESULT_OK) + break; + memset(buffer, 0, buffer_num_bytes); + client_socket_send_handle->EndWriteData(buffer_num_bytes); + } + // Wait for write error on the client socket. Don't check exact error, out of + // paranoia. + EXPECT_LT(socket_observer.WaitForWriteError(), 0); +} + +// Establish a connection while passing in some options. This test doesn't check +// that the options are actually set, since there's no API for that. +TEST_F(TCPBoundSocketTest, ConnectWithOptions) { + // Set up a listening socket. + network::mojom::TCPBoundSocketPtr bound_socket1; + net::IPEndPoint server_address; + ASSERT_EQ(net::OK, BindSocket(LocalHostWithAnyPort(), &bound_socket1, + &server_address)); + network::mojom::TCPServerSocketPtr server_socket; + ASSERT_EQ(net::OK, Listen(std::move(bound_socket1), &server_socket)); + + // Connect to the socket with another socket. + network::mojom::TCPBoundSocketPtr bound_socket2; + net::IPEndPoint client_address; + ASSERT_EQ(net::OK, BindSocket(LocalHostWithAnyPort(), &bound_socket2, + &client_address)); + network::mojom::TCPConnectedSocketPtr client_socket; + TestSocketObserver socket_observer; + mojo::ScopedDataPipeConsumerHandle client_socket_receive_handle; + mojo::ScopedDataPipeProducerHandle client_socket_send_handle; + mojom::TCPConnectedSocketOptionsPtr tcp_connected_socket_options = + mojom::TCPConnectedSocketOptions::New(); + tcp_connected_socket_options->send_buffer_size = 32 * 1024; + tcp_connected_socket_options->receive_buffer_size = 64 * 1024; + tcp_connected_socket_options->no_delay = false; + + EXPECT_EQ(net::OK, + Connect(std::move(bound_socket2), client_address, server_address, + std::move(tcp_connected_socket_options), &client_socket, + socket_observer.GetObserverPtr(), + &client_socket_receive_handle, &client_socket_send_handle)); + + base::RunLoop run_loop; + network::mojom::TCPConnectedSocketPtr accept_socket; + mojo::ScopedDataPipeConsumerHandle accept_socket_receive_handle; + mojo::ScopedDataPipeProducerHandle accept_socket_send_handle; + server_socket->Accept( + nullptr /* ovserver */, + base::BindLambdaForTesting( + [&](int net_error, const base::Optional<net::IPEndPoint>& remote_addr, + network::mojom::TCPConnectedSocketPtr connected_socket, + mojo::ScopedDataPipeConsumerHandle receive_stream, + mojo::ScopedDataPipeProducerHandle send_stream) { + EXPECT_EQ(net_error, net::OK); + EXPECT_EQ(*remote_addr, client_address); + accept_socket = std::move(connected_socket); + accept_socket_receive_handle = std::move(receive_stream); + accept_socket_send_handle = std::move(send_stream); + run_loop.Quit(); + })); + run_loop.Run(); + + const std::string kData = "Jumbo Shrimp"; + ASSERT_TRUE(mojo::BlockingCopyFromString(kData, client_socket_send_handle)); + EXPECT_EQ(kData, ReadData(accept_socket_receive_handle.get(), kData.size())); + + ASSERT_TRUE(mojo::BlockingCopyFromString(kData, accept_socket_send_handle)); + EXPECT_EQ(kData, ReadData(client_socket_receive_handle.get(), kData.size())); +} + +// Test that a TCPBoundSocket can be upgraded to TLS once connected. +TEST_F(TCPBoundSocketTest, UpgradeToTLS) { + // Simplest way to set up an TLS server is to use the embedded test server. + net::test_server::EmbeddedTestServer test_server( + net::test_server::EmbeddedTestServer::TYPE_HTTPS); + test_server.RegisterRequestHandler(base::BindRepeating( + [](const net::test_server::HttpRequest& request) + -> std::unique_ptr<net::test_server::HttpResponse> { + std::unique_ptr<net::test_server::BasicHttpResponse> basic_response = + std::make_unique<net::test_server::BasicHttpResponse>(); + basic_response->set_content(request.relative_url); + return basic_response; + })); + ASSERT_TRUE(test_server.Start()); + + network::mojom::TCPBoundSocketPtr bound_socket; + net::IPEndPoint client_address; + ASSERT_EQ(net::OK, + BindSocket(LocalHostWithAnyPort(), &bound_socket, &client_address)); + network::mojom::TCPConnectedSocketPtr client_socket; + TestSocketObserver socket_observer; + mojo::ScopedDataPipeConsumerHandle client_socket_receive_handle; + mojo::ScopedDataPipeProducerHandle client_socket_send_handle; + + EXPECT_EQ(net::OK, + Connect(std::move(bound_socket), client_address, + net::IPEndPoint(net::IPAddress::IPv4Localhost(), + test_server.host_port_pair().port()), + nullptr /* tcp_connected_socket_options */, &client_socket, + socket_observer.GetObserverPtr(), + &client_socket_receive_handle, &client_socket_send_handle)); + + // Need to closed these pipes for UpgradeToTLS to complete. + client_socket_receive_handle.reset(); + client_socket_send_handle.reset(); + + base::RunLoop run_loop; + mojom::TLSClientSocketPtr tls_client_socket; + client_socket->UpgradeToTLS( + test_server.host_port_pair(), nullptr /* options */, + net::MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS), + mojo::MakeRequest(&tls_client_socket), nullptr /* observer */, + base::BindLambdaForTesting( + [&](int net_error, + mojo::ScopedDataPipeConsumerHandle receive_pipe_handle, + mojo::ScopedDataPipeProducerHandle send_pipe_handle, + const base::Optional<net::SSLInfo>& ssl_info) { + EXPECT_EQ(net::OK, net_error); + client_socket_receive_handle = std::move(receive_pipe_handle); + client_socket_send_handle = std::move(send_pipe_handle); + run_loop.Quit(); + })); + run_loop.Run(); + + const char kPath[] = "/foo"; + + // Send an HTTP request. + std::string request = base::StringPrintf("GET %s HTTP/1.0\r\n\r\n", kPath); + EXPECT_TRUE(mojo::BlockingCopyFromString(request, client_socket_send_handle)); + + // Read the response, and make sure it looks reasonable. + std::string response = ReadData(client_socket_receive_handle.get()); + EXPECT_EQ("HTTP/", response.substr(0, 5)); + // The response body should be the path, so make sure the response ends with + // the path. + EXPECT_EQ(kPath, response.substr(response.length() - strlen(kPath))); +} + +} // namespace +} // namespace network diff --git a/chromium/services/network/tcp_connected_socket.cc b/chromium/services/network/tcp_connected_socket.cc index e6bbbbf7fb0..80b18c939f4 100644 --- a/chromium/services/network/tcp_connected_socket.cc +++ b/chromium/services/network/tcp_connected_socket.cc @@ -7,6 +7,7 @@ #include <utility> #include "base/logging.h" +#include "base/numerics/ranges.h" #include "base/numerics/safe_conversions.h" #include "base/optional.h" #include "net/base/net_errors.h" @@ -17,16 +18,62 @@ namespace network { +namespace { + +int ClampTCPBufferSize(int requested_buffer_size) { + return base::ClampToRange(requested_buffer_size, 0, + TCPConnectedSocket::kMaxBufferSize); +} + +// Sets the initial options on a fresh socket. Assumes |socket| is currently +// configured using the default client socket options +// (TCPSocket::SetDefaultOptionsForClient()). +int ConfigureSocket( + net::TransportClientSocket* socket, + const mojom::TCPConnectedSocketOptions& tcp_connected_socket_options) { + int send_buffer_size = + ClampTCPBufferSize(tcp_connected_socket_options.send_buffer_size); + if (send_buffer_size > 0) { + int result = socket->SetSendBufferSize(send_buffer_size); + DCHECK_NE(net::ERR_IO_PENDING, result); + if (result != net::OK) + return result; + } + + int receive_buffer_size = + ClampTCPBufferSize(tcp_connected_socket_options.receive_buffer_size); + if (receive_buffer_size > 0) { + int result = socket->SetReceiveBufferSize(receive_buffer_size); + DCHECK_NE(net::ERR_IO_PENDING, result); + if (result != net::OK) + return result; + } + + // No delay is set by default, so only update the setting if it's false. + if (!tcp_connected_socket_options.no_delay) { + // Unlike the above calls, TcpSocket::SetNoDelay() returns a bool rather + // than a network error code. + if (!socket->SetNoDelay(false)) + return net::ERR_FAILED; + } + + return net::OK; +} + +} // namespace + +const int TCPConnectedSocket::kMaxBufferSize = 128 * 1024; + TCPConnectedSocket::TCPConnectedSocket( mojom::SocketObserverPtr observer, net::NetLog* net_log, - Delegate* delegate, + TLSSocketFactory* tls_socket_factory, net::ClientSocketFactory* client_socket_factory, const net::NetworkTrafficAnnotationTag& traffic_annotation) : observer_(std::move(observer)), net_log_(net_log), - delegate_(delegate), client_socket_factory_(client_socket_factory), + tls_socket_factory_(tls_socket_factory), traffic_annotation_(traffic_annotation) {} TCPConnectedSocket::TCPConnectedSocket( @@ -37,8 +84,8 @@ TCPConnectedSocket::TCPConnectedSocket( const net::NetworkTrafficAnnotationTag& traffic_annotation) : observer_(std::move(observer)), net_log_(nullptr), - delegate_(nullptr), client_socket_factory_(nullptr), + tls_socket_factory_(nullptr), socket_(std::move(socket)), traffic_annotation_(traffic_annotation) { socket_data_pump_ = std::make_unique<SocketDataPump>( @@ -60,24 +107,46 @@ TCPConnectedSocket::~TCPConnectedSocket() { void TCPConnectedSocket::Connect( const base::Optional<net::IPEndPoint>& local_addr, const net::AddressList& remote_addr_list, + mojom::TCPConnectedSocketOptionsPtr tcp_connected_socket_options, mojom::NetworkContext::CreateTCPConnectedSocketCallback callback) { DCHECK(!socket_); DCHECK(callback); - auto socket = client_socket_factory_->CreateTransportClientSocket( - remote_addr_list, nullptr /*socket_performance_watcher*/, net_log_, - net::NetLogSource()); - connect_callback_ = std::move(callback); - int result = net::OK; - if (local_addr) - result = socket->Bind(local_addr.value()); - if (result == net::OK) { - result = socket->Connect(base::BindRepeating( - &TCPConnectedSocket::OnConnectCompleted, base::Unretained(this))); + std::unique_ptr<net::TransportClientSocket> socket = + client_socket_factory_->CreateTransportClientSocket( + remote_addr_list, nullptr /*socket_performance_watcher*/, net_log_, + net::NetLogSource()); + + if (local_addr) { + int result = socket->Bind(local_addr.value()); + if (result != net::OK) { + OnConnectCompleted(result); + return; + } } + + return ConnectWithSocket(std::move(socket), + std::move(tcp_connected_socket_options), + std::move(callback)); +} + +void TCPConnectedSocket::ConnectWithSocket( + std::unique_ptr<net::TransportClientSocket> socket, + mojom::TCPConnectedSocketOptionsPtr tcp_connected_socket_options, + mojom::NetworkContext::CreateTCPConnectedSocketCallback callback) { socket_ = std::move(socket); + connect_callback_ = std::move(callback); + + if (tcp_connected_socket_options) { + socket_->SetBeforeConnectCallback(base::BindRepeating( + &ConfigureSocket, socket_.get(), *tcp_connected_socket_options)); + } + int result = socket_->Connect(base::BindRepeating( + &TCPConnectedSocket::OnConnectCompleted, base::Unretained(this))); + if (result == net::ERR_IO_PENDING) return; + OnConnectCompleted(result); } @@ -87,7 +156,13 @@ void TCPConnectedSocket::UpgradeToTLS( const net::MutableNetworkTrafficAnnotationTag& traffic_annotation, mojom::TLSClientSocketRequest request, mojom::SocketObserverPtr observer, - UpgradeToTLSCallback callback) { + mojom::TCPConnectedSocket::UpgradeToTLSCallback callback) { + if (!tls_socket_factory_) { + std::move(callback).Run( + net::ERR_NOT_IMPLEMENTED, mojo::ScopedDataPipeConsumerHandle(), + mojo::ScopedDataPipeProducerHandle(), base::nullopt /* ssl_info*/); + return; + } // Wait for data pipes to be closed by the client before doing the upgrade. if (socket_data_pump_) { pending_upgrade_to_tls_callback_ = base::BindOnce( @@ -96,19 +171,33 @@ void TCPConnectedSocket::UpgradeToTLS( std::move(request), std::move(observer), std::move(callback)); return; } - if (!socket_ || !socket_->IsConnected()) { - std::move(callback).Run( - net::ERR_SOCKET_NOT_CONNECTED, mojo::ScopedDataPipeConsumerHandle(), - mojo::ScopedDataPipeProducerHandle(), base::nullopt); + tls_socket_factory_->UpgradeToTLS( + this, host_port_pair, std::move(socket_options), traffic_annotation, + std::move(request), std::move(observer), std::move(callback)); +} + +void TCPConnectedSocket::SetSendBufferSize(int send_buffer_size, + SetSendBufferSizeCallback callback) { + if (!socket_) { + // Fail is this method was called after upgrading to TLS. + std::move(callback).Run(net::ERR_UNEXPECTED); + return; + } + int result = socket_->SetSendBufferSize(ClampTCPBufferSize(send_buffer_size)); + std::move(callback).Run(result); +} + +void TCPConnectedSocket::SetReceiveBufferSize( + int send_buffer_size, + SetSendBufferSizeCallback callback) { + if (!socket_) { + // Fail is this method was called after upgrading to TLS. + std::move(callback).Run(net::ERR_UNEXPECTED); return; } - auto socket_handle = std::make_unique<net::ClientSocketHandle>(); - socket_handle->SetSocket(std::move(socket_)); - delegate_->CreateTLSClientSocket( - host_port_pair, std::move(socket_options), std::move(request), - std::move(socket_handle), std::move(observer), - static_cast<net::NetworkTrafficAnnotationTag>(traffic_annotation), - std::move(callback)); + int result = + socket_->SetReceiveBufferSize(ClampTCPBufferSize(send_buffer_size)); + std::move(callback).Run(result); } void TCPConnectedSocket::SetNoDelay(bool no_delay, @@ -176,4 +265,12 @@ void TCPConnectedSocket::OnShutdown() { std::move(pending_upgrade_to_tls_callback_).Run(); } +const net::StreamSocket* TCPConnectedSocket::BorrowSocket() { + return socket_.get(); +} + +std::unique_ptr<net::StreamSocket> TCPConnectedSocket::TakeSocket() { + return std::move(socket_); +} + } // namespace network diff --git a/chromium/services/network/tcp_connected_socket.h b/chromium/services/network/tcp_connected_socket.h index 3f37d797e25..ee89dac1a6f 100644 --- a/chromium/services/network/tcp_connected_socket.h +++ b/chromium/services/network/tcp_connected_socket.h @@ -21,11 +21,11 @@ #include "services/network/public/mojom/network_context.mojom.h" #include "services/network/public/mojom/tcp_socket.mojom.h" #include "services/network/socket_data_pump.h" +#include "services/network/tls_socket_factory.h" namespace net { class NetLog; class ClientSocketFactory; -class ClientSocketHandle; class TransportClientSocket; } // namespace net @@ -33,25 +33,19 @@ namespace network { class COMPONENT_EXPORT(NETWORK_SERVICE) TCPConnectedSocket : public mojom::TCPConnectedSocket, - public SocketDataPump::Delegate { + public SocketDataPump::Delegate, + public TLSSocketFactory::Delegate { public: - // Interface to handle a mojom::TLSClientSocketRequest. - class Delegate { - public: - // Handles a mojom::TLSClientSocketRequest. - virtual void CreateTLSClientSocket( - const net::HostPortPair& host_port_pair, - mojom::TLSClientSocketOptionsPtr socket_options, - mojom::TLSClientSocketRequest request, - std::unique_ptr<net::ClientSocketHandle> tcp_socket, - mojom::SocketObserverPtr observer, - const net::NetworkTrafficAnnotationTag& traffic_annotation, - mojom::TCPConnectedSocket::UpgradeToTLSCallback callback) = 0; - }; + // Max send/receive buffer size the consumer is allowed to set. Exposed for + // testing. + static const int kMaxBufferSize; + + // If |client_socket_factory| is nullptr, consumers must use + // ConnectWithSocket() instead of Connect(). TCPConnectedSocket( mojom::SocketObserverPtr observer, net::NetLog* net_log, - Delegate* delegate, + TLSSocketFactory* tls_socket_factory, net::ClientSocketFactory* client_socket_factory, const net::NetworkTrafficAnnotationTag& traffic_annotation); TCPConnectedSocket( @@ -61,9 +55,19 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) TCPConnectedSocket mojo::ScopedDataPipeConsumerHandle send_pipe_handle, const net::NetworkTrafficAnnotationTag& traffic_annotation); ~TCPConnectedSocket() override; + void Connect( const base::Optional<net::IPEndPoint>& local_addr, const net::AddressList& remote_addr_list, + mojom::TCPConnectedSocketOptionsPtr tcp_connected_socket_options, + mojom::NetworkContext::CreateTCPConnectedSocketCallback callback); + + // Tries to connects using the provided TCPClientSocket. |socket| owns the + // list of addresses to try to connect to, so this method doesn't need any + // addresses as input. + void ConnectWithSocket( + std::unique_ptr<net::TransportClientSocket> socket, + mojom::TCPConnectedSocketOptionsPtr tcp_connected_socket_options, mojom::NetworkContext::CreateTCPConnectedSocketCallback callback); // mojom::TCPConnectedSocket implementation. @@ -73,7 +77,11 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) TCPConnectedSocket const net::MutableNetworkTrafficAnnotationTag& traffic_annotation, mojom::TLSClientSocketRequest request, mojom::SocketObserverPtr observer, - UpgradeToTLSCallback callback) override; + mojom::TCPConnectedSocket::UpgradeToTLSCallback callback) override; + void SetSendBufferSize(int send_buffer_size, + SetSendBufferSizeCallback callback) override; + void SetReceiveBufferSize(int send_buffer_size, + SetSendBufferSizeCallback callback) override; void SetNoDelay(bool no_delay, SetNoDelayCallback callback) override; void SetKeepAlive(bool enable, int32_t delay_secs, @@ -88,11 +96,15 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) TCPConnectedSocket void OnNetworkWriteError(int net_error) override; void OnShutdown() override; + // TLSSocketFactory::Delegate implementation. + const net::StreamSocket* BorrowSocket() override; + std::unique_ptr<net::StreamSocket> TakeSocket() override; + const mojom::SocketObserverPtr observer_; net::NetLog* const net_log_; - Delegate* const delegate_; net::ClientSocketFactory* const client_socket_factory_; + TLSSocketFactory* tls_socket_factory_; std::unique_ptr<net::TransportClientSocket> socket_; diff --git a/chromium/services/network/tcp_server_socket.cc b/chromium/services/network/tcp_server_socket.cc index 58ebfe815c2..a84778ff060 100644 --- a/chromium/services/network/tcp_server_socket.cc +++ b/chromium/services/network/tcp_server_socket.cc @@ -22,10 +22,20 @@ TCPServerSocket::TCPServerSocket( Delegate* delegate, net::NetLog* net_log, const net::NetworkTrafficAnnotationTag& traffic_annotation) + : TCPServerSocket( + std::make_unique<net::TCPServerSocket>(net_log, net::NetLogSource()), + 0 /*backlog*/, + delegate, + traffic_annotation) {} + +TCPServerSocket::TCPServerSocket( + std::unique_ptr<net::ServerSocket> server_socket, + int backlog, + Delegate* delegate, + const net::NetworkTrafficAnnotationTag& traffic_annotation) : delegate_(delegate), - socket_( - std::make_unique<net::TCPServerSocket>(net_log, net::NetLogSource())), - backlog_(0), + socket_(std::move(server_socket)), + backlog_(backlog), traffic_annotation_(traffic_annotation), weak_factory_(this) {} diff --git a/chromium/services/network/tcp_server_socket.h b/chromium/services/network/tcp_server_socket.h index 6fde43c99b1..42e5f044d18 100644 --- a/chromium/services/network/tcp_server_socket.h +++ b/chromium/services/network/tcp_server_socket.h @@ -47,6 +47,13 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) TCPServerSocket TCPServerSocket(Delegate* delegate, net::NetLog* net_log, const net::NetworkTrafficAnnotationTag& traffic_annotation); + + // As above, but takes an already listening socket. + TCPServerSocket(std::unique_ptr<net::ServerSocket> server_socket, + int backlog, + Delegate* delegate, + const net::NetworkTrafficAnnotationTag& traffic_annotation); + ~TCPServerSocket() override; int Listen(const net::IPEndPoint& local_addr, diff --git a/chromium/services/network/tcp_socket_unittest.cc b/chromium/services/network/tcp_socket_unittest.cc index 290ac45e9f9..d3a91f2c583 100644 --- a/chromium/services/network/tcp_socket_unittest.cc +++ b/chromium/services/network/tcp_socket_unittest.cc @@ -192,6 +192,10 @@ class TestServer { const net::IPEndPoint& server_addr() { return server_addr_; } + mojom::TCPConnectedSocket* most_recent_connected_socket() { + return connected_sockets_.back().get(); + } + private: void OnAccept(net::CompletionOnceCallback callback, int result, @@ -306,13 +310,15 @@ class TCPSocketTest : public testing::Test { const base::Optional<net::IPEndPoint>& local_addr, const net::IPEndPoint& remote_addr, mojo::ScopedDataPipeConsumerHandle* receive_pipe_handle_out, - mojo::ScopedDataPipeProducerHandle* send_pipe_handle_out) { + mojo::ScopedDataPipeProducerHandle* send_pipe_handle_out, + mojom::TCPConnectedSocketOptionsPtr tcp_connected_socket_options = + nullptr) { net::AddressList remote_addr_list(remote_addr); base::RunLoop run_loop; int net_error = net::ERR_FAILED; factory_->CreateTCPConnectedSocket( - local_addr, remote_addr_list, TRAFFIC_ANNOTATION_FOR_TESTS, - std::move(request), std::move(observer), + local_addr, remote_addr_list, std::move(tcp_connected_socket_options), + TRAFFIC_ANNOTATION_FOR_TESTS, std::move(request), std::move(observer), base::BindLambdaForTesting( [&](int result, const base::Optional<net::IPEndPoint>& actual_local_addr, @@ -460,6 +466,48 @@ TEST_F(TCPSocketTest, ServerReceivesMultipleAccept) { } } +// Check that accepted sockets can't be upgraded to TLS, since UpgradeToTLS only +// supports the client side of a TLS handshake. +TEST_F(TCPSocketTest, AcceptedSocketCantUpgradeToTLS) { + TestServer server; + server.Start(1 /* backlog */); + + net::TestCompletionCallback callback; + server.AcceptOneConnection(callback.callback()); + + mojom::TCPConnectedSocketPtr client_socket; + mojo::ScopedDataPipeConsumerHandle client_socket_receive_handle; + mojo::ScopedDataPipeProducerHandle client_socket_send_handle; + EXPECT_EQ(net::OK, + CreateTCPConnectedSocketSync( + mojo::MakeRequest(&client_socket), nullptr /*observer*/, + base::nullopt /*local_addr*/, server.server_addr(), + &client_socket_receive_handle, &client_socket_send_handle)); + + EXPECT_EQ(net::OK, callback.WaitForResult()); + + // Consumers generally close these before attempting to upgrade the socket, + // since TCPConnectedSocket waits for the pipes to close before upgrading the + // connection. + client_socket_receive_handle.reset(); + client_socket_send_handle.reset(); + + base::RunLoop run_loop; + mojom::TLSClientSocketPtr tls_client_socket; + server.most_recent_connected_socket()->UpgradeToTLS( + net::HostPortPair("foopy", 443), nullptr /* options */, + net::MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS), + mojo::MakeRequest(&tls_client_socket), nullptr /* observer */, + base::BindLambdaForTesting( + [&](int net_error, + mojo::ScopedDataPipeConsumerHandle receive_pipe_handle, + mojo::ScopedDataPipeProducerHandle send_pipe_handle, + const base::Optional<net::SSLInfo>& ssl_info) { + EXPECT_EQ(net::ERR_NOT_IMPLEMENTED, net_error); + run_loop.Quit(); + })); +} + // Tests that if a socket is closed, the other side can observe that the pipes // are broken. TEST_F(TCPSocketTest, SocketClosed) { @@ -973,18 +1021,195 @@ TEST_P(TCPSocketWithMockSocketTest, WriteError) { EXPECT_TRUE(data_provider.AllWriteDataConsumed()); } -TEST_F(TCPSocketWithMockSocketTest, SetNoDelayAndKeepAlive) { - // Populate with some mock reads, so UpgradeToTLS() won't error out because of - // a closed receive pipe. - const net::MockRead kReads[] = { - net::MockRead(net::ASYNC, "hello", 5 /* length */), - net::MockRead(net::ASYNC, net::OK)}; - net::StaticSocketDataProvider data_provider(kReads, - base::span<net::MockWrite>()); - net::SSLSocketDataProvider ssl_socket(net::ASYNC, net::ERR_FAILED); +TEST_P(TCPSocketWithMockSocketTest, InitialTCPConnectedSocketOptions) { + net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234); + mojo::ScopedDataPipeConsumerHandle client_socket_receive_handle; + mojo::ScopedDataPipeProducerHandle client_socket_send_handle; + for (int receive_buffer_size : + {-1, 0, 1024, TCPConnectedSocket::kMaxBufferSize, + TCPConnectedSocket::kMaxBufferSize + 1}) { + for (int send_buffer_size : + {-1, 0, 2048, TCPConnectedSocket::kMaxBufferSize, + TCPConnectedSocket::kMaxBufferSize + 1}) { + for (int no_delay : {false, true}) { + mojom::TCPConnectedSocketPtr client_socket; + net::StaticSocketDataProvider data_provider; + data_provider.set_connect_data( + net::MockConnect(GetParam(), net::OK, server_addr)); + mock_client_socket_factory_.AddSocketDataProvider(&data_provider); + + mojom::TCPConnectedSocketOptionsPtr tcp_connected_socket_options = + mojom::TCPConnectedSocketOptions::New(); + tcp_connected_socket_options->receive_buffer_size = receive_buffer_size; + tcp_connected_socket_options->send_buffer_size = send_buffer_size; + tcp_connected_socket_options->no_delay = no_delay; + EXPECT_EQ(net::OK, + CreateTCPConnectedSocketSync( + mojo::MakeRequest(&client_socket), nullptr /*observer*/, + base::nullopt /*local_addr*/, server_addr, + &client_socket_receive_handle, &client_socket_send_handle, + std::move(tcp_connected_socket_options))); + + if (receive_buffer_size <= 0) { + EXPECT_EQ(-1, data_provider.receive_buffer_size()); + } else if (receive_buffer_size <= TCPConnectedSocket::kMaxBufferSize) { + EXPECT_EQ(receive_buffer_size, data_provider.receive_buffer_size()); + } else { + EXPECT_EQ(TCPConnectedSocket::kMaxBufferSize, + data_provider.receive_buffer_size()); + } + + if (send_buffer_size <= 0) { + EXPECT_EQ(-1, data_provider.send_buffer_size()); + } else if (send_buffer_size <= TCPConnectedSocket::kMaxBufferSize) { + EXPECT_EQ(send_buffer_size, data_provider.send_buffer_size()); + } else { + EXPECT_EQ(TCPConnectedSocket::kMaxBufferSize, + data_provider.send_buffer_size()); + } + + EXPECT_EQ(no_delay, data_provider.no_delay()); + } + } + } +} + +TEST_P(TCPSocketWithMockSocketTest, InitialTCPConnectedSocketOptionsFails) { + net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234); + mojo::ScopedDataPipeConsumerHandle client_socket_receive_handle; + mojo::ScopedDataPipeProducerHandle client_socket_send_handle; + + enum class FailedCall { + SET_RECEIVE_BUFFER_SIZE, + SET_SEND_BUFFER_SIZE, + SET_NO_DELAY, + }; + for (const auto& failed_call : + {FailedCall::SET_RECEIVE_BUFFER_SIZE, FailedCall::SET_SEND_BUFFER_SIZE, + FailedCall::SET_NO_DELAY}) { + mojom::TCPConnectedSocketPtr client_socket; + net::StaticSocketDataProvider data_provider; + data_provider.set_connect_data( + net::MockConnect(GetParam(), net::OK, server_addr)); + switch (failed_call) { + case FailedCall::SET_RECEIVE_BUFFER_SIZE: + data_provider.set_set_receive_buffer_size_result(net::ERR_FAILED); + break; + case FailedCall::SET_SEND_BUFFER_SIZE: + data_provider.set_set_send_buffer_size_result(net::ERR_FAILED); + break; + case FailedCall::SET_NO_DELAY: + data_provider.set_set_no_delay_result(false); + break; + } + mock_client_socket_factory_.AddSocketDataProvider(&data_provider); + + mojom::TCPConnectedSocketOptionsPtr tcp_connected_socket_options = + mojom::TCPConnectedSocketOptions::New(); + tcp_connected_socket_options->receive_buffer_size = 1; + tcp_connected_socket_options->send_buffer_size = 2; + tcp_connected_socket_options->no_delay = false; + EXPECT_EQ(net::ERR_FAILED, + CreateTCPConnectedSocketSync( + mojo::MakeRequest(&client_socket), nullptr /*observer*/, + base::nullopt /*local_addr*/, server_addr, + &client_socket_receive_handle, &client_socket_send_handle, + std::move(tcp_connected_socket_options))); + } +} + +TEST_P(TCPSocketWithMockSocketTest, SetBufferSizes) { + mojom::TCPConnectedSocketPtr client_socket; + net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234); + mojo::ScopedDataPipeConsumerHandle client_socket_receive_handle; + mojo::ScopedDataPipeProducerHandle client_socket_send_handle; + + net::StaticSocketDataProvider data_provider; + data_provider.set_connect_data( + net::MockConnect(GetParam(), net::OK, server_addr)); + mock_client_socket_factory_.AddSocketDataProvider(&data_provider); + + EXPECT_EQ(net::OK, + CreateTCPConnectedSocketSync( + mojo::MakeRequest(&client_socket), nullptr /*observer*/, + base::nullopt /*local_addr*/, server_addr, + &client_socket_receive_handle, &client_socket_send_handle)); + + EXPECT_EQ(-1, data_provider.receive_buffer_size()); + + net::TestCompletionCallback callback; + // Setting a buffer size < 0 is replaced by setting a buffer size of 0. + client_socket->SetReceiveBufferSize(-1, callback.callback()); + EXPECT_EQ(net::OK, callback.WaitForResult()); + EXPECT_EQ(0, data_provider.receive_buffer_size()); + + client_socket->SetReceiveBufferSize(1024, callback.callback()); + EXPECT_EQ(net::OK, callback.WaitForResult()); + EXPECT_EQ(1024, data_provider.receive_buffer_size()); + + client_socket->SetReceiveBufferSize(TCPConnectedSocket::kMaxBufferSize + 1, + callback.callback()); + EXPECT_EQ(net::OK, callback.WaitForResult()); + EXPECT_EQ(TCPConnectedSocket::kMaxBufferSize, + data_provider.receive_buffer_size()); + + client_socket->SetReceiveBufferSize(0, callback.callback()); + EXPECT_EQ(net::OK, callback.WaitForResult()); + EXPECT_EQ(0, data_provider.receive_buffer_size()); + + EXPECT_EQ(-1, data_provider.send_buffer_size()); + + // Setting a buffer size < 0 is replaced by setting a buffer size of 0. + client_socket->SetSendBufferSize(-1, callback.callback()); + EXPECT_EQ(net::OK, callback.WaitForResult()); + EXPECT_EQ(0, data_provider.send_buffer_size()); + + client_socket->SetSendBufferSize(1024, callback.callback()); + EXPECT_EQ(net::OK, callback.WaitForResult()); + EXPECT_EQ(1024, data_provider.send_buffer_size()); + + client_socket->SetSendBufferSize(TCPConnectedSocket::kMaxBufferSize + 1, + callback.callback()); + EXPECT_EQ(net::OK, callback.WaitForResult()); + EXPECT_EQ(TCPConnectedSocket::kMaxBufferSize, + data_provider.send_buffer_size()); + + client_socket->SetSendBufferSize(0, callback.callback()); + EXPECT_EQ(net::OK, callback.WaitForResult()); + EXPECT_EQ(0, data_provider.send_buffer_size()); +} + +TEST_P(TCPSocketWithMockSocketTest, SetBufferSizesFails) { + mojom::TCPConnectedSocketPtr client_socket; + net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234); + mojo::ScopedDataPipeConsumerHandle client_socket_receive_handle; + mojo::ScopedDataPipeProducerHandle client_socket_send_handle; + + net::StaticSocketDataProvider data_provider; + data_provider.set_connect_data( + net::MockConnect(GetParam(), net::OK, server_addr)); + data_provider.set_set_receive_buffer_size_result(net::ERR_FAILED); + data_provider.set_set_send_buffer_size_result(net::ERR_UNEXPECTED); + mock_client_socket_factory_.AddSocketDataProvider(&data_provider); + + EXPECT_EQ(net::OK, + CreateTCPConnectedSocketSync( + mojo::MakeRequest(&client_socket), nullptr /*observer*/, + base::nullopt /*local_addr*/, server_addr, + &client_socket_receive_handle, &client_socket_send_handle)); + + net::TestCompletionCallback callback; + client_socket->SetReceiveBufferSize(1024, callback.callback()); + EXPECT_EQ(net::ERR_FAILED, callback.WaitForResult()); + + client_socket->SetSendBufferSize(1024, callback.callback()); + EXPECT_EQ(net::ERR_UNEXPECTED, callback.WaitForResult()); +} + +TEST_F(TCPSocketWithMockSocketTest, SetNoDelayAndKeepAlive) { + net::StaticSocketDataProvider data_provider; mock_client_socket_factory_.AddSocketDataProvider(&data_provider); - mock_client_socket_factory_.AddSSLSocketDataProvider(&ssl_socket); mojo::ScopedDataPipeConsumerHandle client_socket_receive_handle; mojo::ScopedDataPipeProducerHandle client_socket_send_handle; @@ -996,6 +1221,18 @@ TEST_F(TCPSocketWithMockSocketTest, SetNoDelayAndKeepAlive) { mojo::MakeRequest(&client_socket), nullptr /*observer*/, base::nullopt, server_addr, &client_socket_receive_handle, &client_socket_send_handle)); + + EXPECT_TRUE(data_provider.no_delay()); + { + base::RunLoop run_loop; + client_socket->SetNoDelay(false /* no_delay */, + base::BindLambdaForTesting([&](bool success) { + EXPECT_TRUE(success); + run_loop.Quit(); + })); + run_loop.Run(); + EXPECT_FALSE(data_provider.no_delay()); + } { base::RunLoop run_loop; client_socket->SetNoDelay(true /* no_delay */, @@ -1004,16 +1241,108 @@ TEST_F(TCPSocketWithMockSocketTest, SetNoDelayAndKeepAlive) { run_loop.Quit(); })); run_loop.Run(); + EXPECT_TRUE(data_provider.no_delay()); } + { + const int kKeepAliveDelay = 123; base::RunLoop run_loop; - client_socket->SetKeepAlive(true /* enable */, 123 /* delay */, + client_socket->SetKeepAlive(true /* enable */, kKeepAliveDelay, + base::BindLambdaForTesting([&](bool success) { + EXPECT_TRUE(success); + run_loop.Quit(); + })); + run_loop.Run(); + EXPECT_TRUE(data_provider.keep_alive_enabled()); + EXPECT_EQ(kKeepAliveDelay, data_provider.keep_alive_delay()); + } + + { + base::RunLoop run_loop; + client_socket->SetKeepAlive(false /* enable */, 0 /* delay */, + base::BindLambdaForTesting([&](bool success) { + EXPECT_TRUE(success); + run_loop.Quit(); + })); + run_loop.Run(); + EXPECT_FALSE(data_provider.keep_alive_enabled()); + } + + { + const int kKeepAliveDelay = 1234; + base::RunLoop run_loop; + client_socket->SetKeepAlive(true /* enable */, kKeepAliveDelay, base::BindLambdaForTesting([&](bool success) { EXPECT_TRUE(success); run_loop.Quit(); })); run_loop.Run(); + EXPECT_TRUE(data_provider.keep_alive_enabled()); + EXPECT_EQ(kKeepAliveDelay, data_provider.keep_alive_delay()); } +} + +TEST_F(TCPSocketWithMockSocketTest, SetNoDelayFails) { + net::StaticSocketDataProvider data_provider; + data_provider.set_set_no_delay_result(false); + data_provider.set_set_keep_alive_result(false); + mock_client_socket_factory_.AddSocketDataProvider(&data_provider); + + mojo::ScopedDataPipeConsumerHandle client_socket_receive_handle; + mojo::ScopedDataPipeProducerHandle client_socket_send_handle; + + mojom::TCPConnectedSocketPtr client_socket; + net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234); + EXPECT_EQ(net::OK, + CreateTCPConnectedSocketSync( + mojo::MakeRequest(&client_socket), nullptr /*observer*/, + base::nullopt, server_addr, &client_socket_receive_handle, + &client_socket_send_handle)); + + { + base::RunLoop run_loop; + client_socket->SetNoDelay(false /* no_delay */, + base::BindLambdaForTesting([&](bool success) { + EXPECT_FALSE(success); + run_loop.Quit(); + })); + run_loop.Run(); + } + + { + base::RunLoop run_loop; + client_socket->SetKeepAlive(true /* enable */, 123 /* delay */, + base::BindLambdaForTesting([&](bool success) { + EXPECT_FALSE(success); + run_loop.Quit(); + })); + run_loop.Run(); + } +} + +TEST_F(TCPSocketWithMockSocketTest, SetOptionsAfterTLSUpgrade) { + // Populate with some mock reads, so UpgradeToTLS() won't error out because of + // a closed receive pipe. + const net::MockRead kReads[] = { + net::MockRead(net::ASYNC, "hello", 5 /* length */), + net::MockRead(net::ASYNC, net::OK)}; + net::StaticSocketDataProvider data_provider(kReads, + base::span<net::MockWrite>()); + net::SSLSocketDataProvider ssl_socket(net::ASYNC, net::ERR_FAILED); + + mock_client_socket_factory_.AddSocketDataProvider(&data_provider); + mock_client_socket_factory_.AddSSLSocketDataProvider(&ssl_socket); + + mojo::ScopedDataPipeConsumerHandle client_socket_receive_handle; + mojo::ScopedDataPipeProducerHandle client_socket_send_handle; + + mojom::TCPConnectedSocketPtr client_socket; + net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234); + EXPECT_EQ(net::OK, + CreateTCPConnectedSocketSync( + mojo::MakeRequest(&client_socket), nullptr /*observer*/, + base::nullopt, server_addr, &client_socket_receive_handle, + &client_socket_send_handle)); // UpgradeToTLS will destroy network::TCPConnectedSocket::|socket_|. Calling // SetNoDelay and SetKeepAlive should error out. @@ -1037,14 +1366,25 @@ TEST_F(TCPSocketWithMockSocketTest, SetNoDelayAndKeepAlive) { })); run_loop.Run(); } + + net::TestCompletionCallback callback; + client_socket->SetReceiveBufferSize(1024, callback.callback()); + EXPECT_EQ(net::ERR_UNEXPECTED, callback.WaitForResult()); + EXPECT_EQ(-1, data_provider.receive_buffer_size()); + + client_socket->SetSendBufferSize(1024, callback.callback()); + EXPECT_EQ(net::ERR_UNEXPECTED, callback.WaitForResult()); + EXPECT_EQ(-1, data_provider.send_buffer_size()); + { base::RunLoop run_loop; - client_socket->SetNoDelay(true /* no_delay */, + client_socket->SetNoDelay(false /* no_delay */, base::BindLambdaForTesting([&](bool success) { EXPECT_FALSE(success); run_loop.Quit(); })); run_loop.Run(); + EXPECT_TRUE(data_provider.no_delay()); } { base::RunLoop run_loop; @@ -1071,7 +1411,8 @@ TEST_F(TCPSocketWithMockSocketTest, SocketDestroyedBeforeConnectCompletes) { int net_error = net::OK; base::RunLoop run_loop; factory()->CreateTCPConnectedSocket( - base::nullopt, remote_addr_list, TRAFFIC_ANNOTATION_FOR_TESTS, + base::nullopt, remote_addr_list, + nullptr /* tcp_connected_socket_options */, TRAFFIC_ANNOTATION_FOR_TESTS, mojo::MakeRequest(&client_socket), nullptr, base::BindLambdaForTesting( [&](int result, diff --git a/chromium/services/network/throttling/throttling_controller_unittest.cc b/chromium/services/network/throttling/throttling_controller_unittest.cc index c23ad02f5a4..1f7b1e546ec 100644 --- a/chromium/services/network/throttling/throttling_controller_unittest.cc +++ b/chromium/services/network/throttling/throttling_controller_unittest.cc @@ -62,7 +62,7 @@ class ThrottlingControllerTestHelper { completion_callback_(base::BindRepeating(&TestCallback::Run, base::Unretained(&callback_))), mock_transaction_(kSimpleGET_Transaction), - buffer_(new net::IOBuffer(64)), + buffer_(base::MakeRefCounted<net::IOBuffer>(64)), net_log_(std::make_unique<net::NetLog>()), net_log_with_source_( net::NetLogWithSource::Make(net_log_.get(), diff --git a/chromium/services/network/tls_client_socket_unittest.cc b/chromium/services/network/tls_client_socket_unittest.cc index 8a27192c824..7d869b7448e 100644 --- a/chromium/services/network/tls_client_socket_unittest.cc +++ b/chromium/services/network/tls_client_socket_unittest.cc @@ -25,6 +25,7 @@ #include "net/traffic_annotation/network_traffic_annotation_test_helper.h" #include "net/url_request/url_request_test_util.h" #include "services/network/mojo_socket_test_util.h" +#include "services/network/proxy_resolving_socket_factory_mojo.h" #include "services/network/public/mojom/network_service.mojom.h" #include "services/network/socket_factory.h" #include "testing/gtest/include/gtest/gtest.h" @@ -43,24 +44,49 @@ const size_t kSecretMsgSize = strlen(kSecretMsg); class TLSClientSocketTestBase { public: - TLSClientSocketTestBase() - : scoped_task_environment_( + enum Mode { kDirect, kProxyResolving }; + + explicit TLSClientSocketTestBase(Mode mode) + : mode_(mode), + scoped_task_environment_( base::test::ScopedTaskEnvironment::MainThreadType::IO), url_request_context_(true) {} - ~TLSClientSocketTestBase() {} + virtual ~TLSClientSocketTestBase() {} + + Mode mode() { return mode_; } protected: + // One of the two fields will be set, depending on the mode. + struct SocketHandle { + mojom::TCPConnectedSocketPtr tcp_socket; + mojom::ProxyResolvingSocketPtr proxy_socket; + }; + + struct SocketRequest { + mojom::TCPConnectedSocketRequest tcp_socket_request; + mojom::ProxyResolvingSocketRequest proxy_socket_request; + }; + // Initializes the test fixture. If |use_mock_sockets|, mock client socket // factory will be used. - void Init(bool use_mock_sockets) { + void Init(bool use_mock_sockets, bool configure_proxy) { if (use_mock_sockets) { mock_client_socket_factory_.set_enable_read_if_ready(true); url_request_context_.set_client_socket_factory( &mock_client_socket_factory_); } + if (configure_proxy) { + proxy_resolution_service_ = net::ProxyResolutionService::CreateFixed( + "http://proxy:8080", TRAFFIC_ANNOTATION_FOR_TESTS); + url_request_context_.set_proxy_resolution_service( + proxy_resolution_service_.get()); + } url_request_context_.Init(); factory_ = std::make_unique<SocketFactory>(nullptr /*net_log*/, &url_request_context_); + proxy_resolving_factory_ = + std::make_unique<ProxyResolvingSocketFactoryMojo>( + &url_request_context_); } // Reads |num_bytes| from |handle| or reads until an error occurs. Returns the @@ -83,6 +109,33 @@ class TLSClientSocketTestBase { return received_contents; } + SocketRequest MakeRequest(SocketHandle* handle) { + SocketRequest result; + if (mode_ == kDirect) + result.tcp_socket_request = mojo::MakeRequest(&handle->tcp_socket); + else + result.proxy_socket_request = mojo::MakeRequest(&handle->proxy_socket); + return result; + } + + void ResetSocket(SocketHandle* handle) { + if (mode_ == kDirect) + handle->tcp_socket.reset(); + else + handle->proxy_socket.reset(); + } + + int CreateSocketSync(SocketRequest request, + const net::IPEndPoint& remote_addr) { + if (mode_ == kDirect) { + return CreateTCPConnectedSocketSync(std::move(request.tcp_socket_request), + remote_addr); + } else { + return CreateProxyResolvingSocketSync( + std::move(request.proxy_socket_request), remote_addr); + } + } + int CreateTCPConnectedSocketSync(mojom::TCPConnectedSocketRequest request, const net::IPEndPoint& remote_addr) { net::AddressList remote_addr_list(remote_addr); @@ -90,6 +143,7 @@ class TLSClientSocketTestBase { int net_error = net::ERR_FAILED; factory_->CreateTCPConnectedSocket( base::nullopt /* local_addr */, remote_addr_list, + nullptr /* tcp_connected_socket_options */, TRAFFIC_ANNOTATION_FOR_TESTS, std::move(request), pre_tls_observer()->GetObserverPtr(), base::BindLambdaForTesting( @@ -107,21 +161,86 @@ class TLSClientSocketTestBase { return net_error; } - void UpgradeToTLS(mojom::TCPConnectedSocket* client_socket, + int CreateProxyResolvingSocketSync(mojom::ProxyResolvingSocketRequest request, + const net::IPEndPoint& remote_addr) { + GURL url("https://" + remote_addr.ToString()); + base::RunLoop run_loop; + int net_error = net::ERR_FAILED; + proxy_resolving_factory_->CreateProxyResolvingSocket( + url, false /* use_tls */, + net::MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS), + std::move(request), nullptr /* observer */, + base::BindLambdaForTesting( + [&](int result, + const base::Optional<net::IPEndPoint>& actual_local_addr, + const base::Optional<net::IPEndPoint>& peer_addr, + mojo::ScopedDataPipeConsumerHandle receive_pipe_handle, + mojo::ScopedDataPipeProducerHandle send_pipe_handle) { + net_error = result; + pre_tls_recv_handle_ = std::move(receive_pipe_handle); + pre_tls_send_handle_ = std::move(send_pipe_handle); + run_loop.Quit(); + })); + run_loop.Run(); + return net_error; + } + + void UpgradeToTLS(SocketHandle* handle, const net::HostPortPair& host_port_pair, mojom::TLSClientSocketRequest request, net::CompletionOnceCallback callback) { + if (mode_ == kDirect) { + UpgradeTCPConnectedSocketToTLS(handle->tcp_socket.get(), host_port_pair, + nullptr /* options */, std::move(request), + std::move(callback)); + } else { + UpgradeProxyResolvingSocketToTLS(handle->proxy_socket.get(), + host_port_pair, std::move(request), + std::move(callback)); + } + } + + void UpgradeTCPConnectedSocketToTLS(mojom::TCPConnectedSocket* client_socket, + const net::HostPortPair& host_port_pair, + mojom::TLSClientSocketOptionsPtr options, + mojom::TLSClientSocketRequest request, + net::CompletionOnceCallback callback) { client_socket->UpgradeToTLS( - host_port_pair, nullptr /* ssl_config_ptr */, + host_port_pair, std::move(options), net::MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS), std::move(request), post_tls_observer()->GetObserverPtr(), base::BindOnce( [](net::CompletionOnceCallback cb, - mojo::ScopedDataPipeConsumerHandle* consumer_handle, - mojo::ScopedDataPipeProducerHandle* producer_handle, int result, + mojo::ScopedDataPipeConsumerHandle* consumer_handle_out, + mojo::ScopedDataPipeProducerHandle* producer_handle_out, + base::Optional<net::SSLInfo>* ssl_info_out, int result, mojo::ScopedDataPipeConsumerHandle receive_pipe_handle, mojo::ScopedDataPipeProducerHandle send_pipe_handle, const base::Optional<net::SSLInfo>& ssl_info) { + *consumer_handle_out = std::move(receive_pipe_handle); + *producer_handle_out = std::move(send_pipe_handle); + *ssl_info_out = ssl_info; + std::move(cb).Run(result); + }, + std::move(callback), &post_tls_recv_handle_, &post_tls_send_handle_, + &ssl_info_)); + } + + void UpgradeProxyResolvingSocketToTLS( + mojom::ProxyResolvingSocket* client_socket, + const net::HostPortPair& host_port_pair, + mojom::TLSClientSocketRequest request, + net::CompletionOnceCallback callback) { + client_socket->UpgradeToTLS( + host_port_pair, + net::MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS), + std::move(request), post_tls_observer()->GetObserverPtr(), + base::BindOnce( + [](net::CompletionOnceCallback cb, + mojo::ScopedDataPipeConsumerHandle* consumer_handle, + mojo::ScopedDataPipeProducerHandle* producer_handle, int result, + mojo::ScopedDataPipeConsumerHandle receive_pipe_handle, + mojo::ScopedDataPipeProducerHandle send_pipe_handle) { *consumer_handle = std::move(receive_pipe_handle); *producer_handle = std::move(send_pipe_handle); std::move(cb).Run(result); @@ -149,11 +268,16 @@ class TLSClientSocketTestBase { return &post_tls_send_handle_; } + const base::Optional<net::SSLInfo>& ssl_info() { return ssl_info_; } + net::MockClientSocketFactory* mock_client_socket_factory() { return &mock_client_socket_factory_; } + Mode mode() const { return mode_; } + private: + Mode mode_; base::test::ScopedTaskEnvironment scoped_task_environment_; // Mojo data handles obtained from CreateTCPConnectedSocket. @@ -164,9 +288,14 @@ class TLSClientSocketTestBase { mojo::ScopedDataPipeConsumerHandle post_tls_recv_handle_; mojo::ScopedDataPipeProducerHandle post_tls_send_handle_; + // SSLInfo obtained from UpgradeToTLS. + base::Optional<net::SSLInfo> ssl_info_; + + std::unique_ptr<net::ProxyResolutionService> proxy_resolution_service_; net::TestURLRequestContext url_request_context_; net::MockClientSocketFactory mock_client_socket_factory_; std::unique_ptr<SocketFactory> factory_; + std::unique_ptr<ProxyResolvingSocketFactoryMojo> proxy_resolving_factory_; TestSocketObserver pre_tls_observer_; TestSocketObserver post_tls_observer_; mojo::StrongBindingSet<mojom::TCPServerSocket> tcp_server_socket_bindings_; @@ -176,12 +305,12 @@ class TLSClientSocketTestBase { DISALLOW_COPY_AND_ASSIGN(TLSClientSocketTestBase); }; -} // namespace -class TLSClientSocketTest : public TLSClientSocketTestBase, - public testing::Test { +class TLSClientSocketTest + : public ::testing::TestWithParam<TLSClientSocketTestBase::Mode>, + public TLSClientSocketTestBase { public: - TLSClientSocketTest() : TLSClientSocketTestBase() { - Init(true /* use_mock_sockets */); + TLSClientSocketTest() : TLSClientSocketTestBase(GetParam()) { + Init(true /* use_mock_sockets */, false /* configure_proxy */); } ~TLSClientSocketTest() override {} @@ -192,7 +321,7 @@ class TLSClientSocketTest : public TLSClientSocketTestBase, // Basic test to call UpgradeToTLS, and then read/write after UpgradeToTLS is // successful. -TEST_F(TLSClientSocketTest, UpgradeToTLS) { +TEST_P(TLSClientSocketTest, UpgradeToTLS) { const net::MockRead kReads[] = {net::MockRead(net::ASYNC, kMsg, kMsgSize, 1), net::MockRead(net::SYNCHRONOUS, net::OK, 2)}; const net::MockWrite kWrites[] = { @@ -203,20 +332,20 @@ TEST_F(TLSClientSocketTest, UpgradeToTLS) { net::SSLSocketDataProvider ssl_socket(net::ASYNC, net::OK); mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket); - mojom::TCPConnectedSocketPtr client_socket; + SocketHandle client_socket; net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234); - EXPECT_EQ(net::OK, CreateTCPConnectedSocketSync( - mojo::MakeRequest(&client_socket), server_addr)); + EXPECT_EQ(net::OK, + CreateSocketSync(MakeRequest(&client_socket), server_addr)); net::HostPortPair host_port_pair("example.org", 443); pre_tls_recv_handle()->reset(); pre_tls_send_handle()->reset(); net::TestCompletionCallback callback; mojom::TLSClientSocketPtr tls_socket; - UpgradeToTLS(client_socket.get(), host_port_pair, - mojo::MakeRequest(&tls_socket), callback.callback()); + UpgradeToTLS(&client_socket, host_port_pair, mojo::MakeRequest(&tls_socket), + callback.callback()); ASSERT_EQ(net::OK, callback.WaitForResult()); - client_socket.reset(); + ResetSocket(&client_socket); uint32_t num_bytes = strlen(kMsg); EXPECT_EQ(MOJO_RESULT_OK, post_tls_send_handle()->get().WriteData( @@ -230,7 +359,7 @@ TEST_F(TLSClientSocketTest, UpgradeToTLS) { // Same as the UpgradeToTLS test above, except this test calls // base::RunLoop().RunUntilIdle() after destroying the pre-tls data pipes. -TEST_F(TLSClientSocketTest, ClosePipesRunUntilIdleAndUpgradeToTLS) { +TEST_P(TLSClientSocketTest, ClosePipesRunUntilIdleAndUpgradeToTLS) { const net::MockRead kReads[] = {net::MockRead(net::ASYNC, kMsg, kMsgSize, 1), net::MockRead(net::SYNCHRONOUS, net::OK, 2)}; const net::MockWrite kWrites[] = { @@ -241,10 +370,10 @@ TEST_F(TLSClientSocketTest, ClosePipesRunUntilIdleAndUpgradeToTLS) { net::SSLSocketDataProvider ssl_socket(net::ASYNC, net::OK); mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket); - mojom::TCPConnectedSocketPtr client_socket; + SocketHandle client_socket; net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234); - EXPECT_EQ(net::OK, CreateTCPConnectedSocketSync( - mojo::MakeRequest(&client_socket), server_addr)); + EXPECT_EQ(net::OK, + CreateSocketSync(MakeRequest(&client_socket), server_addr)); net::HostPortPair host_port_pair("example.org", 443); @@ -256,10 +385,10 @@ TEST_F(TLSClientSocketTest, ClosePipesRunUntilIdleAndUpgradeToTLS) { net::TestCompletionCallback callback; mojom::TLSClientSocketPtr tls_socket; - UpgradeToTLS(client_socket.get(), host_port_pair, - mojo::MakeRequest(&tls_socket), callback.callback()); + UpgradeToTLS(&client_socket, host_port_pair, mojo::MakeRequest(&tls_socket), + callback.callback()); ASSERT_EQ(net::OK, callback.WaitForResult()); - client_socket.reset(); + ResetSocket(&client_socket); uint32_t num_bytes = strlen(kMsg); EXPECT_EQ(MOJO_RESULT_OK, post_tls_send_handle()->get().WriteData( @@ -273,7 +402,7 @@ TEST_F(TLSClientSocketTest, ClosePipesRunUntilIdleAndUpgradeToTLS) { // Calling UpgradeToTLS on the same TCPConnectedSocketPtr is illegal and should // receive an error. -TEST_F(TLSClientSocketTest, UpgradeToTLSTwice) { +TEST_P(TLSClientSocketTest, UpgradeToTLSTwice) { const net::MockRead kReads[] = {net::MockRead(net::ASYNC, net::OK, 0)}; net::SequencedSocketData data_provider(kReads, base::span<net::MockWrite>()); data_provider.set_connect_data(net::MockConnect(net::SYNCHRONOUS, net::OK)); @@ -281,10 +410,10 @@ TEST_F(TLSClientSocketTest, UpgradeToTLSTwice) { net::SSLSocketDataProvider ssl_socket(net::ASYNC, net::OK); mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket); - mojom::TCPConnectedSocketPtr client_socket; + SocketHandle client_socket; net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234); - EXPECT_EQ(net::OK, CreateTCPConnectedSocketSync( - mojo::MakeRequest(&client_socket), server_addr)); + EXPECT_EQ(net::OK, + CreateSocketSync(MakeRequest(&client_socket), server_addr)); net::HostPortPair host_port_pair("example.org", 443); pre_tls_recv_handle()->reset(); @@ -293,26 +422,40 @@ TEST_F(TLSClientSocketTest, UpgradeToTLSTwice) { // First UpgradeToTLS should complete successfully. net::TestCompletionCallback callback; mojom::TLSClientSocketPtr tls_socket; - UpgradeToTLS(client_socket.get(), host_port_pair, - mojo::MakeRequest(&tls_socket), callback.callback()); + UpgradeToTLS(&client_socket, host_port_pair, mojo::MakeRequest(&tls_socket), + callback.callback()); ASSERT_EQ(net::OK, callback.WaitForResult()); // Second time UpgradeToTLS is called, it should fail. mojom::TLSClientSocketPtr tls_socket2; base::RunLoop run_loop; int net_error = net::ERR_FAILED; - client_socket->UpgradeToTLS( - host_port_pair, nullptr /* ssl_config_ptr */, - net::MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS), - mojo::MakeRequest(&tls_socket2), nullptr /*observer */, - base::BindLambdaForTesting( - [&](int result, - mojo::ScopedDataPipeConsumerHandle receive_pipe_handle, - mojo::ScopedDataPipeProducerHandle send_pipe_handle, - const base::Optional<net::SSLInfo>& ssl_info) { - net_error = result; - run_loop.Quit(); - })); + if (mode() == kDirect) { + auto upgrade2_callback = base::BindLambdaForTesting( + [&](int result, mojo::ScopedDataPipeConsumerHandle receive_pipe_handle, + mojo::ScopedDataPipeProducerHandle send_pipe_handle, + const base::Optional<net::SSLInfo>& ssl_info) { + net_error = result; + run_loop.Quit(); + }); + client_socket.tcp_socket->UpgradeToTLS( + host_port_pair, nullptr /* ssl_config_ptr */, + net::MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS), + mojo::MakeRequest(&tls_socket2), nullptr /*observer */, + std::move(upgrade2_callback)); + } else { + auto upgrade2_callback = base::BindLambdaForTesting( + [&](int result, mojo::ScopedDataPipeConsumerHandle receive_pipe_handle, + mojo::ScopedDataPipeProducerHandle send_pipe_handle) { + net_error = result; + run_loop.Quit(); + }); + client_socket.proxy_socket->UpgradeToTLS( + host_port_pair, + net::MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS), + mojo::MakeRequest(&tls_socket2), nullptr /*observer */, + std::move(upgrade2_callback)); + } run_loop.Run(); ASSERT_EQ(net::ERR_SOCKET_NOT_CONNECTED, net_error); @@ -322,7 +465,10 @@ TEST_F(TLSClientSocketTest, UpgradeToTLSTwice) { EXPECT_TRUE(data_provider.AllWriteDataConsumed()); } -TEST_F(TLSClientSocketTest, UpgradeToTLSWithCustomSSLConfig) { +TEST_P(TLSClientSocketTest, UpgradeToTLSWithCustomSSLConfig) { + // No custom options in the proxy-resolving case. + if (mode() != kDirect) + return; const net::MockRead kReads[] = {net::MockRead(net::ASYNC, net::OK, 0)}; net::SequencedSocketData data_provider(kReads, base::span<net::MockWrite>()); data_provider.set_connect_data(net::MockConnect(net::SYNCHRONOUS, net::OK)); @@ -332,10 +478,10 @@ TEST_F(TLSClientSocketTest, UpgradeToTLSWithCustomSSLConfig) { ssl_socket.expected_ssl_version_max = net::SSL_PROTOCOL_VERSION_TLS1_2; mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket); - mojom::TCPConnectedSocketPtr client_socket; + SocketHandle client_socket; net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234); - EXPECT_EQ(net::OK, CreateTCPConnectedSocketSync( - mojo::MakeRequest(&client_socket), server_addr)); + EXPECT_EQ(net::OK, + CreateSocketSync(MakeRequest(&client_socket), server_addr)); net::HostPortPair host_port_pair("example.org", 443); pre_tls_recv_handle()->reset(); @@ -348,18 +494,18 @@ TEST_F(TLSClientSocketTest, UpgradeToTLSWithCustomSSLConfig) { options->version_min = mojom::SSLVersion::kTLS11; options->version_max = mojom::SSLVersion::kTLS12; int net_error = net::ERR_FAILED; - client_socket->UpgradeToTLS( + auto upgrade_callback = base::BindLambdaForTesting( + [&](int result, mojo::ScopedDataPipeConsumerHandle receive_pipe_handle, + mojo::ScopedDataPipeProducerHandle send_pipe_handle, + const base::Optional<net::SSLInfo>& ssl_info) { + net_error = result; + run_loop.Quit(); + }); + client_socket.tcp_socket->UpgradeToTLS( host_port_pair, std::move(options), net::MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS), mojo::MakeRequest(&tls_socket), nullptr /*observer */, - base::BindLambdaForTesting( - [&](int result, - mojo::ScopedDataPipeConsumerHandle receive_pipe_handle, - mojo::ScopedDataPipeProducerHandle send_pipe_handle, - const base::Optional<net::SSLInfo>& ssl_info) { - net_error = result; - run_loop.Quit(); - })); + std::move(upgrade_callback)); run_loop.Run(); ASSERT_EQ(net::OK, net_error); @@ -371,7 +517,7 @@ TEST_F(TLSClientSocketTest, UpgradeToTLSWithCustomSSLConfig) { // Same as the UpgradeToTLS test, except this also reads and writes to the tcp // connection before UpgradeToTLS is called. -TEST_F(TLSClientSocketTest, ReadWriteBeforeUpgradeToTLS) { +TEST_P(TLSClientSocketTest, ReadWriteBeforeUpgradeToTLS) { const net::MockRead kReads[] = { net::MockRead(net::SYNCHRONOUS, kMsg, kMsgSize, 0), net::MockRead(net::ASYNC, kSecretMsg, kSecretMsgSize, 3), @@ -386,10 +532,10 @@ TEST_F(TLSClientSocketTest, ReadWriteBeforeUpgradeToTLS) { net::SSLSocketDataProvider ssl_socket(net::ASYNC, net::OK); mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket); - mojom::TCPConnectedSocketPtr client_socket; + SocketHandle client_socket; net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234); - EXPECT_EQ(net::OK, CreateTCPConnectedSocketSync( - mojo::MakeRequest(&client_socket), server_addr)); + EXPECT_EQ(net::OK, + CreateSocketSync(MakeRequest(&client_socket), server_addr)); EXPECT_EQ(kMsg, Read(pre_tls_recv_handle(), kMsgSize)); @@ -402,10 +548,10 @@ TEST_F(TLSClientSocketTest, ReadWriteBeforeUpgradeToTLS) { pre_tls_send_handle()->reset(); net::TestCompletionCallback callback; mojom::TLSClientSocketPtr tls_socket; - UpgradeToTLS(client_socket.get(), host_port_pair, - mojo::MakeRequest(&tls_socket), callback.callback()); + UpgradeToTLS(&client_socket, host_port_pair, mojo::MakeRequest(&tls_socket), + callback.callback()); ASSERT_EQ(net::OK, callback.WaitForResult()); - client_socket.reset(); + ResetSocket(&client_socket); num_bytes = strlen(kSecretMsg); EXPECT_EQ(MOJO_RESULT_OK, @@ -420,7 +566,7 @@ TEST_F(TLSClientSocketTest, ReadWriteBeforeUpgradeToTLS) { // Tests that a read error is encountered after UpgradeToTLS completes // successfully. -TEST_F(TLSClientSocketTest, ReadErrorAfterUpgradeToTLS) { +TEST_P(TLSClientSocketTest, ReadErrorAfterUpgradeToTLS) { const net::MockRead kReads[] = { net::MockRead(net::ASYNC, kSecretMsg, kSecretMsgSize, 1), net::MockRead(net::SYNCHRONOUS, net::ERR_CONNECTION_CLOSED, 2)}; @@ -432,20 +578,20 @@ TEST_F(TLSClientSocketTest, ReadErrorAfterUpgradeToTLS) { net::SSLSocketDataProvider ssl_socket(net::ASYNC, net::OK); mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket); - mojom::TCPConnectedSocketPtr client_socket; + SocketHandle client_socket; net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234); - EXPECT_EQ(net::OK, CreateTCPConnectedSocketSync( - mojo::MakeRequest(&client_socket), server_addr)); + EXPECT_EQ(net::OK, + CreateSocketSync(MakeRequest(&client_socket), server_addr)); net::HostPortPair host_port_pair("example.org", 443); pre_tls_recv_handle()->reset(); pre_tls_send_handle()->reset(); net::TestCompletionCallback callback; mojom::TLSClientSocketPtr tls_socket; - UpgradeToTLS(client_socket.get(), host_port_pair, - mojo::MakeRequest(&tls_socket), callback.callback()); + UpgradeToTLS(&client_socket, host_port_pair, mojo::MakeRequest(&tls_socket), + callback.callback()); ASSERT_EQ(net::OK, callback.WaitForResult()); - client_socket.reset(); + ResetSocket(&client_socket); uint32_t num_bytes = strlen(kSecretMsg); EXPECT_EQ(MOJO_RESULT_OK, @@ -463,7 +609,7 @@ TEST_F(TLSClientSocketTest, ReadErrorAfterUpgradeToTLS) { // Tests that a read error is encountered after UpgradeToTLS completes // successfully. -TEST_F(TLSClientSocketTest, WriteErrorAfterUpgradeToTLS) { +TEST_P(TLSClientSocketTest, WriteErrorAfterUpgradeToTLS) { const net::MockRead kReads[] = {net::MockRead(net::ASYNC, net::OK, 0)}; const net::MockWrite kWrites[] = { net::MockWrite(net::SYNCHRONOUS, net::ERR_CONNECTION_CLOSED, 1)}; @@ -473,20 +619,20 @@ TEST_F(TLSClientSocketTest, WriteErrorAfterUpgradeToTLS) { net::SSLSocketDataProvider ssl_socket(net::ASYNC, net::OK); mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket); - mojom::TCPConnectedSocketPtr client_socket; + SocketHandle client_socket; net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234); - EXPECT_EQ(net::OK, CreateTCPConnectedSocketSync( - mojo::MakeRequest(&client_socket), server_addr)); + EXPECT_EQ(net::OK, + CreateSocketSync(MakeRequest(&client_socket), server_addr)); net::HostPortPair host_port_pair("example.org", 443); pre_tls_recv_handle()->reset(); pre_tls_send_handle()->reset(); net::TestCompletionCallback callback; mojom::TLSClientSocketPtr tls_socket; - UpgradeToTLS(client_socket.get(), host_port_pair, - mojo::MakeRequest(&tls_socket), callback.callback()); + UpgradeToTLS(&client_socket, host_port_pair, mojo::MakeRequest(&tls_socket), + callback.callback()); ASSERT_EQ(net::OK, callback.WaitForResult()); - client_socket.reset(); + ResetSocket(&client_socket); uint32_t num_bytes = strlen(kSecretMsg); EXPECT_EQ(MOJO_RESULT_OK, @@ -503,7 +649,7 @@ TEST_F(TLSClientSocketTest, WriteErrorAfterUpgradeToTLS) { // Tests that reading from the pre-tls data pipe is okay even after UpgradeToTLS // is called. -TEST_F(TLSClientSocketTest, ReadFromPreTlsDataPipeAfterUpgradeToTLS) { +TEST_P(TLSClientSocketTest, ReadFromPreTlsDataPipeAfterUpgradeToTLS) { const net::MockRead kReads[] = { net::MockRead(net::ASYNC, kMsg, kMsgSize, 0), net::MockRead(net::ASYNC, kSecretMsg, kSecretMsgSize, 2), @@ -516,17 +662,17 @@ TEST_F(TLSClientSocketTest, ReadFromPreTlsDataPipeAfterUpgradeToTLS) { net::SSLSocketDataProvider ssl_socket(net::ASYNC, net::OK); mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket); - mojom::TCPConnectedSocketPtr client_socket; + SocketHandle client_socket; net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234); - EXPECT_EQ(net::OK, CreateTCPConnectedSocketSync( - mojo::MakeRequest(&client_socket), server_addr)); + EXPECT_EQ(net::OK, + CreateSocketSync(MakeRequest(&client_socket), server_addr)); net::HostPortPair host_port_pair("example.org", 443); pre_tls_send_handle()->reset(); net::TestCompletionCallback callback; mojom::TLSClientSocketPtr tls_socket; - UpgradeToTLS(client_socket.get(), host_port_pair, - mojo::MakeRequest(&tls_socket), callback.callback()); + UpgradeToTLS(&client_socket, host_port_pair, mojo::MakeRequest(&tls_socket), + callback.callback()); base::RunLoop().RunUntilIdle(); EXPECT_EQ(kMsg, Read(pre_tls_recv_handle(), kMsgSize)); @@ -534,7 +680,7 @@ TEST_F(TLSClientSocketTest, ReadFromPreTlsDataPipeAfterUpgradeToTLS) { // Reset pre-tls receive pipe now and UpgradeToTLS should complete. pre_tls_recv_handle()->reset(); ASSERT_EQ(net::OK, callback.WaitForResult()); - client_socket.reset(); + ResetSocket(&client_socket); uint32_t num_bytes = strlen(kSecretMsg); EXPECT_EQ(MOJO_RESULT_OK, @@ -549,7 +695,7 @@ TEST_F(TLSClientSocketTest, ReadFromPreTlsDataPipeAfterUpgradeToTLS) { // Tests that writing to the pre-tls data pipe is okay even after UpgradeToTLS // is called. -TEST_F(TLSClientSocketTest, WriteToPreTlsDataPipeAfterUpgradeToTLS) { +TEST_P(TLSClientSocketTest, WriteToPreTlsDataPipeAfterUpgradeToTLS) { const net::MockRead kReads[] = { net::MockRead(net::ASYNC, kSecretMsg, kSecretMsgSize, 2), net::MockRead(net::SYNCHRONOUS, net::OK, 3)}; @@ -562,17 +708,17 @@ TEST_F(TLSClientSocketTest, WriteToPreTlsDataPipeAfterUpgradeToTLS) { net::SSLSocketDataProvider ssl_socket(net::ASYNC, net::OK); mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket); - mojom::TCPConnectedSocketPtr client_socket; + SocketHandle client_socket; net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234); - EXPECT_EQ(net::OK, CreateTCPConnectedSocketSync( - mojo::MakeRequest(&client_socket), server_addr)); + EXPECT_EQ(net::OK, + CreateSocketSync(MakeRequest(&client_socket), server_addr)); net::HostPortPair host_port_pair("example.org", 443); pre_tls_recv_handle()->reset(); net::TestCompletionCallback callback; mojom::TLSClientSocketPtr tls_socket; - UpgradeToTLS(client_socket.get(), host_port_pair, - mojo::MakeRequest(&tls_socket), callback.callback()); + UpgradeToTLS(&client_socket, host_port_pair, mojo::MakeRequest(&tls_socket), + callback.callback()); base::RunLoop().RunUntilIdle(); uint32_t num_bytes = strlen(kMsg); @@ -582,7 +728,7 @@ TEST_F(TLSClientSocketTest, WriteToPreTlsDataPipeAfterUpgradeToTLS) { // Reset pre-tls send pipe now and UpgradeToTLS should complete. pre_tls_send_handle()->reset(); ASSERT_EQ(net::OK, callback.WaitForResult()); - client_socket.reset(); + ResetSocket(&client_socket); num_bytes = strlen(kSecretMsg); EXPECT_EQ(MOJO_RESULT_OK, @@ -597,7 +743,7 @@ TEST_F(TLSClientSocketTest, WriteToPreTlsDataPipeAfterUpgradeToTLS) { // Tests that reading from and writing to pre-tls data pipe is okay even after // UpgradeToTLS is called. -TEST_F(TLSClientSocketTest, ReadAndWritePreTlsDataPipeAfterUpgradeToTLS) { +TEST_P(TLSClientSocketTest, ReadAndWritePreTlsDataPipeAfterUpgradeToTLS) { const net::MockRead kReads[] = { net::MockRead(net::ASYNC, kMsg, kMsgSize, 0), net::MockRead(net::ASYNC, kSecretMsg, kSecretMsgSize, 3), @@ -611,17 +757,17 @@ TEST_F(TLSClientSocketTest, ReadAndWritePreTlsDataPipeAfterUpgradeToTLS) { net::SSLSocketDataProvider ssl_socket(net::ASYNC, net::OK); mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket); - mojom::TCPConnectedSocketPtr client_socket; + SocketHandle client_socket; net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234); - EXPECT_EQ(net::OK, CreateTCPConnectedSocketSync( - mojo::MakeRequest(&client_socket), server_addr)); + EXPECT_EQ(net::OK, + CreateSocketSync(MakeRequest(&client_socket), server_addr)); net::HostPortPair host_port_pair("example.org", 443); base::RunLoop run_loop; net::TestCompletionCallback callback; mojom::TLSClientSocketPtr tls_socket; - UpgradeToTLS(client_socket.get(), host_port_pair, - mojo::MakeRequest(&tls_socket), callback.callback()); + UpgradeToTLS(&client_socket, host_port_pair, mojo::MakeRequest(&tls_socket), + callback.callback()); EXPECT_EQ(kMsg, Read(pre_tls_recv_handle(), kMsgSize)); uint32_t num_bytes = strlen(kMsg); EXPECT_EQ(MOJO_RESULT_OK, pre_tls_send_handle()->get().WriteData( @@ -631,7 +777,7 @@ TEST_F(TLSClientSocketTest, ReadAndWritePreTlsDataPipeAfterUpgradeToTLS) { pre_tls_recv_handle()->reset(); pre_tls_send_handle()->reset(); ASSERT_EQ(net::OK, callback.WaitForResult()); - client_socket.reset(); + ResetSocket(&client_socket); num_bytes = strlen(kSecretMsg); EXPECT_EQ(MOJO_RESULT_OK, @@ -645,7 +791,11 @@ TEST_F(TLSClientSocketTest, ReadAndWritePreTlsDataPipeAfterUpgradeToTLS) { } // Tests that a read error is encountered before UpgradeToTLS completes. -TEST_F(TLSClientSocketTest, ReadErrorBeforeUpgradeToTLS) { +TEST_P(TLSClientSocketTest, ReadErrorBeforeUpgradeToTLS) { + // This requires pre_tls_observer(), which is not provided by proxy resolving + // sockets. + if (mode() != kDirect) + return; const net::MockRead kReads[] = { net::MockRead(net::ASYNC, kMsg, kMsgSize, 0), net::MockRead(net::SYNCHRONOUS, net::ERR_CONNECTION_CLOSED, 1)}; @@ -653,17 +803,17 @@ TEST_F(TLSClientSocketTest, ReadErrorBeforeUpgradeToTLS) { data_provider.set_connect_data(net::MockConnect(net::SYNCHRONOUS, net::OK)); mock_client_socket_factory()->AddSocketDataProvider(&data_provider); - mojom::TCPConnectedSocketPtr client_socket; + SocketHandle client_socket; net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234); - EXPECT_EQ(net::OK, CreateTCPConnectedSocketSync( - mojo::MakeRequest(&client_socket), server_addr)); + EXPECT_EQ(net::OK, + CreateSocketSync(MakeRequest(&client_socket), server_addr)); net::HostPortPair host_port_pair("example.org", 443); pre_tls_send_handle()->reset(); net::TestCompletionCallback callback; mojom::TLSClientSocketPtr tls_socket; - UpgradeToTLS(client_socket.get(), host_port_pair, - mojo::MakeRequest(&tls_socket), callback.callback()); + UpgradeToTLS(&client_socket, host_port_pair, mojo::MakeRequest(&tls_socket), + callback.callback()); EXPECT_EQ(kMsg, Read(pre_tls_recv_handle(), kMsgSize)); EXPECT_EQ(net::ERR_CONNECTION_CLOSED, pre_tls_observer()->WaitForReadError()); @@ -671,7 +821,7 @@ TEST_F(TLSClientSocketTest, ReadErrorBeforeUpgradeToTLS) { // Reset pre-tls receive pipe now and UpgradeToTLS should complete. pre_tls_recv_handle()->reset(); ASSERT_EQ(net::ERR_SOCKET_NOT_CONNECTED, callback.WaitForResult()); - client_socket.reset(); + ResetSocket(&client_socket); base::RunLoop().RunUntilIdle(); EXPECT_TRUE(data_provider.AllReadDataConsumed()); @@ -679,7 +829,12 @@ TEST_F(TLSClientSocketTest, ReadErrorBeforeUpgradeToTLS) { } // Tests that a write error is encountered before UpgradeToTLS completes. -TEST_F(TLSClientSocketTest, WriteErrorBeforeUpgradeToTLS) { +TEST_P(TLSClientSocketTest, WriteErrorBeforeUpgradeToTLS) { + // This requires pre_tls_observer(), which is not provided by proxy resolving + // sockets. + if (mode() != kDirect) + return; + const net::MockRead kReads[] = {net::MockRead(net::ASYNC, net::OK, 1)}; const net::MockWrite kWrites[] = { net::MockWrite(net::SYNCHRONOUS, net::ERR_CONNECTION_CLOSED, 0)}; @@ -687,17 +842,17 @@ TEST_F(TLSClientSocketTest, WriteErrorBeforeUpgradeToTLS) { data_provider.set_connect_data(net::MockConnect(net::SYNCHRONOUS, net::OK)); mock_client_socket_factory()->AddSocketDataProvider(&data_provider); - mojom::TCPConnectedSocketPtr client_socket; + SocketHandle client_socket; net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234); - EXPECT_EQ(net::OK, CreateTCPConnectedSocketSync( - mojo::MakeRequest(&client_socket), server_addr)); + EXPECT_EQ(net::OK, + CreateSocketSync(MakeRequest(&client_socket), server_addr)); net::HostPortPair host_port_pair("example.org", 443); pre_tls_recv_handle()->reset(); net::TestCompletionCallback callback; mojom::TLSClientSocketPtr tls_socket; - UpgradeToTLS(client_socket.get(), host_port_pair, - mojo::MakeRequest(&tls_socket), callback.callback()); + UpgradeToTLS(&client_socket, host_port_pair, mojo::MakeRequest(&tls_socket), + callback.callback()); uint32_t num_bytes = strlen(kMsg); EXPECT_EQ(MOJO_RESULT_OK, pre_tls_send_handle()->get().WriteData( &kMsg, &num_bytes, MOJO_WRITE_DATA_FLAG_NONE)); @@ -707,7 +862,7 @@ TEST_F(TLSClientSocketTest, WriteErrorBeforeUpgradeToTLS) { // Reset pre-tls send pipe now and UpgradeToTLS should complete. pre_tls_send_handle()->reset(); ASSERT_EQ(net::ERR_SOCKET_NOT_CONNECTED, callback.WaitForResult()); - client_socket.reset(); + ResetSocket(&client_socket); base::RunLoop().RunUntilIdle(); // Write failed before the mock read can be consumed. @@ -715,25 +870,91 @@ TEST_F(TLSClientSocketTest, WriteErrorBeforeUpgradeToTLS) { EXPECT_TRUE(data_provider.AllWriteDataConsumed()); } -class TLSClientSocketParameterizedTest - : public TLSClientSocketTestBase, - public testing::TestWithParam<net::IoMode> { +INSTANTIATE_TEST_CASE_P( + /* no prefix */, + TLSClientSocketTest, + ::testing::Values(TLSClientSocketTestBase::kDirect, + TLSClientSocketTestBase::kProxyResolving)); + +// Tests with proxy resolving socket and a proxy actually configured. +class TLSCLientSocketProxyTest : public ::testing::Test, + public TLSClientSocketTestBase { public: - TLSClientSocketParameterizedTest() : TLSClientSocketTestBase() { - Init(true /* use_mock_sockets*/); + TLSCLientSocketProxyTest() + : TLSClientSocketTestBase(TLSClientSocketTestBase::kProxyResolving) { + Init(true /* use_mock_sockets*/, true /* configure_proxy */); } - ~TLSClientSocketParameterizedTest() override {} + ~TLSCLientSocketProxyTest() override {} private: - DISALLOW_COPY_AND_ASSIGN(TLSClientSocketParameterizedTest); + DISALLOW_COPY_AND_ASSIGN(TLSCLientSocketProxyTest); +}; + +TEST_F(TLSCLientSocketProxyTest, UpgradeToTLS) { + const char kConnectRequest[] = + "CONNECT 127.0.0.1:1234 HTTP/1.1\r\n" + "Host: 127.0.0.1:1234\r\n" + "Proxy-Connection: keep-alive\r\n\r\n"; + const char kConnectResponse[] = "HTTP/1.1 200 OK\r\n\r\n"; + + const net::MockRead kReads[] = { + net::MockRead(net::ASYNC, kConnectResponse, strlen(kConnectResponse), 1), + net::MockRead(net::ASYNC, kMsg, kMsgSize, 3), + net::MockRead(net::SYNCHRONOUS, net::OK, 4)}; + const net::MockWrite kWrites[] = { + net::MockWrite(net::ASYNC, kConnectRequest, strlen(kConnectRequest), 0), + net::MockWrite(net::SYNCHRONOUS, kMsg, kMsgSize, 2)}; + net::SequencedSocketData data_provider(kReads, kWrites); + data_provider.set_connect_data(net::MockConnect(net::SYNCHRONOUS, net::OK)); + mock_client_socket_factory()->AddSocketDataProvider(&data_provider); + net::SSLSocketDataProvider ssl_socket(net::ASYNC, net::OK); + mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket); + + SocketHandle client_socket; + net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234); + EXPECT_EQ(net::OK, + CreateSocketSync(MakeRequest(&client_socket), server_addr)); + + net::HostPortPair host_port_pair("example.org", 443); + pre_tls_recv_handle()->reset(); + pre_tls_send_handle()->reset(); + net::TestCompletionCallback callback; + mojom::TLSClientSocketPtr tls_socket; + UpgradeToTLS(&client_socket, host_port_pair, mojo::MakeRequest(&tls_socket), + callback.callback()); + ASSERT_EQ(net::OK, callback.WaitForResult()); + ResetSocket(&client_socket); + + uint32_t num_bytes = strlen(kMsg); + EXPECT_EQ(MOJO_RESULT_OK, post_tls_send_handle()->get().WriteData( + &kMsg, &num_bytes, MOJO_WRITE_DATA_FLAG_NONE)); + EXPECT_EQ(kMsg, Read(post_tls_recv_handle(), kMsgSize)); + base::RunLoop().RunUntilIdle(); + EXPECT_TRUE(ssl_socket.ConnectDataConsumed()); + EXPECT_TRUE(data_provider.AllReadDataConsumed()); + EXPECT_TRUE(data_provider.AllWriteDataConsumed()); +} + +class TLSClientSocketIoModeTest : public TLSClientSocketTestBase, + public testing::TestWithParam<net::IoMode> { + public: + TLSClientSocketIoModeTest() + : TLSClientSocketTestBase(TLSClientSocketTestBase::kDirect) { + Init(true /* use_mock_sockets*/, false /* configure_proxy */); + } + + ~TLSClientSocketIoModeTest() override {} + + private: + DISALLOW_COPY_AND_ASSIGN(TLSClientSocketIoModeTest); }; INSTANTIATE_TEST_CASE_P(/* no prefix */, - TLSClientSocketParameterizedTest, + TLSClientSocketIoModeTest, testing::Values(net::SYNCHRONOUS, net::ASYNC)); -TEST_P(TLSClientSocketParameterizedTest, MultipleWriteToTLSSocket) { +TEST_P(TLSClientSocketIoModeTest, MultipleWriteToTLSSocket) { const int kNumIterations = 3; std::vector<net::MockRead> reads; std::vector<net::MockWrite> writes; @@ -753,7 +974,7 @@ TEST_P(TLSClientSocketParameterizedTest, MultipleWriteToTLSSocket) { } } net::SequencedSocketData data_provider(reads, writes); - data_provider.set_connect_data(net::MockConnect(net::SYNCHRONOUS, net::OK)); + data_provider.set_connect_data(net::MockConnect(GetParam(), net::OK)); mock_client_socket_factory()->AddSocketDataProvider(&data_provider); net::SSLSocketDataProvider ssl_socket(net::ASYNC, net::OK); mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket); @@ -768,10 +989,12 @@ TEST_P(TLSClientSocketParameterizedTest, MultipleWriteToTLSSocket) { pre_tls_send_handle()->reset(); net::TestCompletionCallback callback; mojom::TLSClientSocketPtr tls_socket; - UpgradeToTLS(client_socket.get(), host_port_pair, - mojo::MakeRequest(&tls_socket), callback.callback()); + UpgradeTCPConnectedSocketToTLS( + client_socket.get(), host_port_pair, nullptr /* options */, + mojo::MakeRequest(&tls_socket), callback.callback()); ASSERT_EQ(net::OK, callback.WaitForResult()); client_socket.reset(); + EXPECT_FALSE(ssl_info()); // Loop kNumIterations times to test that writes can follow reads, and reads // can follow writes. @@ -793,75 +1016,249 @@ TEST_P(TLSClientSocketParameterizedTest, MultipleWriteToTLSSocket) { EXPECT_TRUE(data_provider.AllWriteDataConsumed()); } -class TLSClientSocketTestWithEmbeddedTestServer - : public TLSClientSocketTestBase, - public testing::Test { +// Check SSLInfo is provided in both sync and async cases. +TEST_P(TLSClientSocketIoModeTest, SSLInfo) { + // End of file. Reads don't matter, only the handshake does. + std::vector<net::MockRead> reads = {net::MockRead(net::SYNCHRONOUS, net::OK)}; + std::vector<net::MockWrite> writes; + net::SequencedSocketData data_provider(reads, writes); + data_provider.set_connect_data(net::MockConnect(net::SYNCHRONOUS, net::OK)); + mock_client_socket_factory()->AddSocketDataProvider(&data_provider); + net::SSLSocketDataProvider ssl_socket(GetParam(), net::OK); + // Set a value on SSLInfo to make sure it's correctly received. + ssl_socket.ssl_info.is_issued_by_known_root = true; + mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket); + + mojom::TCPConnectedSocketPtr client_socket; + net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), 1234); + EXPECT_EQ(net::OK, CreateTCPConnectedSocketSync( + mojo::MakeRequest(&client_socket), server_addr)); + + net::HostPortPair host_port_pair("example.org", 443); + pre_tls_recv_handle()->reset(); + pre_tls_send_handle()->reset(); + net::TestCompletionCallback callback; + mojom::TLSClientSocketPtr tls_socket; + mojom::TLSClientSocketOptionsPtr options = + mojom::TLSClientSocketOptions::New(); + options->send_ssl_info = true; + UpgradeTCPConnectedSocketToTLS( + client_socket.get(), host_port_pair, std::move(options), + mojo::MakeRequest(&tls_socket), callback.callback()); + ASSERT_EQ(net::OK, callback.WaitForResult()); + ASSERT_TRUE(ssl_info()); + EXPECT_TRUE(ssl_socket.ssl_info.is_issued_by_known_root); + EXPECT_FALSE(ssl_socket.ssl_info.is_fatal_cert_error); +} + +class TLSClientSocketTestWithEmbeddedTestServerBase + : public TLSClientSocketTestBase { public: - TLSClientSocketTestWithEmbeddedTestServer() : TLSClientSocketTestBase() { - Init(false /* use_mock_sockets */); + explicit TLSClientSocketTestWithEmbeddedTestServerBase(Mode mode) + : TLSClientSocketTestBase(mode), + server_(net::EmbeddedTestServer::TYPE_HTTPS) { + Init(false /* use_mock_sockets */, false /* configure_proxy */); + } + + ~TLSClientSocketTestWithEmbeddedTestServerBase() override {} + + // Starts the test server using the specified certificate. + bool StartTestServer(net::EmbeddedTestServer::ServerCertificate certificate) + WARN_UNUSED_RESULT { + server_.RegisterRequestHandler( + base::BindRepeating([](const net::test_server::HttpRequest& request) { + if (base::StartsWith(request.relative_url, "/secret", + base::CompareCase::INSENSITIVE_ASCII)) { + return std::unique_ptr<net::test_server::HttpResponse>( + new net::test_server::RawHttpResponse("HTTP/1.1 200 OK", + "Hello There!")); + } + return std::unique_ptr<net::test_server::HttpResponse>(); + })); + server_.SetSSLConfig(certificate); + return server_.Start(); + } + + // Attempts to eastablish a TLS connection to the test server by first + // establishing a TCP connection, and then upgrading it. Returns the + // resulting network error code. + int CreateTLSSocket() WARN_UNUSED_RESULT { + SocketHandle client_socket; + net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), + server_.port()); + EXPECT_EQ(net::OK, + CreateSocketSync(MakeRequest(&client_socket), server_addr)); + + pre_tls_recv_handle()->reset(); + pre_tls_send_handle()->reset(); + net::TestCompletionCallback callback; + UpgradeToTLS(&client_socket, server_.host_port_pair(), + mojo::MakeRequest(&tls_socket_), callback.callback()); + int result = callback.WaitForResult(); + ResetSocket(&client_socket); + return result; + } + + int CreateTLSSocketWithOptions(mojom::TLSClientSocketOptionsPtr options) + WARN_UNUSED_RESULT { + // Proxy connections don't support TLSClientSocketOptions. + DCHECK_EQ(kDirect, mode()); + + mojom::TCPConnectedSocketPtr tcp_socket; + net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), + server_.port()); + EXPECT_EQ(net::OK, CreateTCPConnectedSocketSync( + mojo::MakeRequest(&tcp_socket), server_addr)); + + pre_tls_recv_handle()->reset(); + pre_tls_send_handle()->reset(); + net::TestCompletionCallback callback; + UpgradeTCPConnectedSocketToTLS( + tcp_socket.get(), server_.host_port_pair(), std::move(options), + mojo::MakeRequest(&tls_socket_), callback.callback()); + int result = callback.WaitForResult(); + tcp_socket.reset(); + return result; + } + + void TestTlsSocket() { + ASSERT_TRUE(tls_socket_.is_bound()); + const char kTestMsg[] = "GET /secret HTTP/1.1\r\n\r\n"; + uint32_t num_bytes = strlen(kTestMsg); + const char kResponse[] = "HTTP/1.1 200 OK\n\n"; + EXPECT_EQ(MOJO_RESULT_OK, + post_tls_send_handle()->get().WriteData( + &kTestMsg, &num_bytes, MOJO_WRITE_DATA_FLAG_NONE)); + EXPECT_EQ(kResponse, Read(post_tls_recv_handle(), strlen(kResponse))); } + net::EmbeddedTestServer* server() { return &server_; } + + private: + net::EmbeddedTestServer server_; + + mojom::TLSClientSocketPtr tls_socket_; + + DISALLOW_COPY_AND_ASSIGN(TLSClientSocketTestWithEmbeddedTestServerBase); +}; + +class TLSClientSocketTestWithEmbeddedTestServer + : public TLSClientSocketTestWithEmbeddedTestServerBase, + public ::testing::TestWithParam<TLSClientSocketTestBase::Mode> { + public: + TLSClientSocketTestWithEmbeddedTestServer() + : TLSClientSocketTestWithEmbeddedTestServerBase(GetParam()) {} ~TLSClientSocketTestWithEmbeddedTestServer() override {} private: DISALLOW_COPY_AND_ASSIGN(TLSClientSocketTestWithEmbeddedTestServer); }; -TEST_F(TLSClientSocketTestWithEmbeddedTestServer, Basic) { - net::EmbeddedTestServer server(net::EmbeddedTestServer::TYPE_HTTPS); - server.RegisterRequestHandler( - base::BindRepeating([](const net::test_server::HttpRequest& request) { - if (base::StartsWith(request.relative_url, "/secret", - base::CompareCase::INSENSITIVE_ASCII)) { - return std::unique_ptr<net::test_server::HttpResponse>( - new net::test_server::RawHttpResponse("HTTP/1.1 200 OK", - "Hello There!")); - } - return std::unique_ptr<net::test_server::HttpResponse>(); - })); - server.SetSSLConfig(net::EmbeddedTestServer::CERT_OK); - ASSERT_TRUE(server.Start()); +TEST_P(TLSClientSocketTestWithEmbeddedTestServer, Basic) { + ASSERT_TRUE(StartTestServer(net::EmbeddedTestServer::CERT_OK)); + ASSERT_EQ(net::OK, CreateTLSSocket()); + // No SSLInfo should be received by default. SSLInfo is only supported in the + // kDirect case, but it doesn't hurt to check it's null it in the + // kProxyResolving case. + EXPECT_FALSE(ssl_info()); + TestTlsSocket(); +} - mojom::TCPConnectedSocketPtr client_socket; - net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), server.port()); - EXPECT_EQ(net::OK, CreateTCPConnectedSocketSync( - mojo::MakeRequest(&client_socket), server_addr)); +TEST_P(TLSClientSocketTestWithEmbeddedTestServer, ServerCertError) { + ASSERT_TRUE(StartTestServer(net::EmbeddedTestServer::CERT_MISMATCHED_NAME)); + ASSERT_EQ(net::ERR_CERT_COMMON_NAME_INVALID, CreateTLSSocket()); + // No SSLInfo should be received by default. SSLInfo is only supported in the + // kDirect case, but it doesn't hurt to check it's null in the kProxyResolving + // case. + EXPECT_FALSE(ssl_info()); + + // Handles should be invalid. + EXPECT_FALSE(post_tls_recv_handle()->is_valid()); + EXPECT_FALSE(post_tls_send_handle()->is_valid()); +} - pre_tls_recv_handle()->reset(); - pre_tls_send_handle()->reset(); - net::TestCompletionCallback callback; - mojom::TLSClientSocketPtr tls_socket; - UpgradeToTLS(client_socket.get(), server.host_port_pair(), - mojo::MakeRequest(&tls_socket), callback.callback()); - ASSERT_EQ(net::OK, callback.WaitForResult()); - client_socket.reset(); +INSTANTIATE_TEST_CASE_P( + /* no prefix */, + TLSClientSocketTestWithEmbeddedTestServer, + ::testing::Values(TLSClientSocketTestBase::kDirect, + TLSClientSocketTestBase::kProxyResolving)); - const char kTestMsg[] = "GET /secret HTTP/1.1\r\n\r\n"; - uint32_t num_bytes = strlen(kTestMsg); - const char kResponse[] = "HTTP/1.1 200 OK\n\n"; - EXPECT_EQ(MOJO_RESULT_OK, - post_tls_send_handle()->get().WriteData(&kTestMsg, &num_bytes, - MOJO_WRITE_DATA_FLAG_NONE)); - EXPECT_EQ(kResponse, Read(post_tls_recv_handle(), strlen(kResponse))); +class TLSClientSocketDirectTestWithEmbeddedTestServer + : public TLSClientSocketTestWithEmbeddedTestServerBase, + public testing::Test { + public: + TLSClientSocketDirectTestWithEmbeddedTestServer() + : TLSClientSocketTestWithEmbeddedTestServerBase(kDirect) {} + ~TLSClientSocketDirectTestWithEmbeddedTestServer() override {} + + private: + DISALLOW_COPY_AND_ASSIGN(TLSClientSocketDirectTestWithEmbeddedTestServer); +}; + +TEST_F(TLSClientSocketDirectTestWithEmbeddedTestServer, SSLInfo) { + ASSERT_TRUE(StartTestServer(net::EmbeddedTestServer::CERT_OK)); + mojom::TLSClientSocketOptionsPtr options = + mojom::TLSClientSocketOptions::New(); + options->send_ssl_info = true; + ASSERT_EQ(net::OK, CreateTLSSocketWithOptions(std::move(options))); + + ASSERT_TRUE(ssl_info()); + EXPECT_TRUE(ssl_info()->is_valid()); + EXPECT_FALSE(ssl_info()->is_fatal_cert_error); + + TestTlsSocket(); } -TEST_F(TLSClientSocketTestWithEmbeddedTestServer, ServerCertError) { - net::EmbeddedTestServer server(net::EmbeddedTestServer::TYPE_HTTPS); - server.SetSSLConfig(net::EmbeddedTestServer::CERT_MISMATCHED_NAME); - ASSERT_TRUE(server.Start()); +TEST_F(TLSClientSocketDirectTestWithEmbeddedTestServer, + SSLInfoServerCertError) { + ASSERT_TRUE(StartTestServer(net::EmbeddedTestServer::CERT_MISMATCHED_NAME)); + mojom::TLSClientSocketOptionsPtr options = + mojom::TLSClientSocketOptions::New(); + options->send_ssl_info = true; + // Requesting SSLInfo should not bypass cert verification. + ASSERT_EQ(net::ERR_CERT_COMMON_NAME_INVALID, + CreateTLSSocketWithOptions(std::move(options))); - mojom::TCPConnectedSocketPtr client_socket; - net::IPEndPoint server_addr(net::IPAddress::IPv4Localhost(), server.port()); - EXPECT_EQ(net::OK, CreateTCPConnectedSocketSync( - mojo::MakeRequest(&client_socket), server_addr)); + // No SSLInfo should be provided on error. + EXPECT_FALSE(ssl_info()); - pre_tls_recv_handle()->reset(); - pre_tls_send_handle()->reset(); - net::TestCompletionCallback callback; - mojom::TLSClientSocketPtr tls_socket; - UpgradeToTLS(client_socket.get(), server.host_port_pair(), - mojo::MakeRequest(&tls_socket), callback.callback()); - ASSERT_EQ(net::ERR_CERT_COMMON_NAME_INVALID, callback.WaitForResult()); + // Handles should be invalid. + EXPECT_FALSE(post_tls_recv_handle()->is_valid()); + EXPECT_FALSE(post_tls_send_handle()->is_valid()); } +// Check skipping cert verification always received SSLInfo, even with valid +// certs. +TEST_F(TLSClientSocketDirectTestWithEmbeddedTestServer, + UnsafelySkipCertVerification) { + ASSERT_TRUE(StartTestServer(net::EmbeddedTestServer::CERT_OK)); + mojom::TLSClientSocketOptionsPtr options = + mojom::TLSClientSocketOptions::New(); + options->unsafely_skip_cert_verification = true; + ASSERT_EQ(net::OK, CreateTLSSocketWithOptions(std::move(options))); + + ASSERT_TRUE(ssl_info()); + EXPECT_TRUE(ssl_info()->is_valid()); + EXPECT_FALSE(ssl_info()->is_fatal_cert_error); + + TestTlsSocket(); +} + +TEST_F(TLSClientSocketDirectTestWithEmbeddedTestServer, + UnsafelySkipCertVerificationServerCertError) { + ASSERT_TRUE(StartTestServer(net::EmbeddedTestServer::CERT_MISMATCHED_NAME)); + mojom::TLSClientSocketOptionsPtr options = + mojom::TLSClientSocketOptions::New(); + options->unsafely_skip_cert_verification = true; + ASSERT_EQ(net::OK, CreateTLSSocketWithOptions(std::move(options))); + + ASSERT_TRUE(ssl_info()); + EXPECT_TRUE(ssl_info()->is_valid()); + EXPECT_FALSE(ssl_info()->is_fatal_cert_error); + + TestTlsSocket(); +} + +} // namespace + } // namespace network diff --git a/chromium/services/network/tls_socket_factory.cc b/chromium/services/network/tls_socket_factory.cc new file mode 100644 index 00000000000..4934404f0ea --- /dev/null +++ b/chromium/services/network/tls_socket_factory.cc @@ -0,0 +1,153 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "services/network/tls_socket_factory.h" + +#include <string> +#include <utility> + +#include "base/optional.h" +#include "mojo/public/cpp/bindings/binding.h" +#include "mojo/public/cpp/bindings/type_converter.h" +#include "net/base/completion_once_callback.h" +#include "net/base/net_errors.h" +#include "net/cert/cert_verifier.h" +#include "net/cert/ct_policy_enforcer.h" +#include "net/cert/multi_log_ct_verifier.h" +#include "net/socket/client_socket_factory.h" +#include "net/socket/client_socket_handle.h" +#include "net/ssl/ssl_config.h" +#include "net/ssl/ssl_config_service.h" +#include "net/url_request/url_request_context.h" +#include "services/network/ssl_config_type_converter.h" +#include "services/network/tls_client_socket.h" + +namespace network { +namespace { +// Cert verifier which blindly accepts all certificates, regardless of validity. +class FakeCertVerifier : public net::CertVerifier { + public: + FakeCertVerifier() {} + ~FakeCertVerifier() override {} + + int Verify(const RequestParams& params, + net::CertVerifyResult* verify_result, + net::CompletionOnceCallback, + std::unique_ptr<Request>*, + const net::NetLogWithSource&) override { + verify_result->Reset(); + verify_result->verified_cert = params.certificate(); + return net::OK; + } + void SetConfig(const Config& config) override {} +}; +} // namespace + +TLSSocketFactory::TLSSocketFactory( + net::URLRequestContext* url_request_context, + const net::HttpNetworkSession::Context* http_context) + : ssl_client_socket_context_( + url_request_context->cert_verifier(), + nullptr, /* TODO(rkn): ChannelIDService is not thread safe. */ + url_request_context->transport_security_state(), + url_request_context->cert_transparency_verifier(), + url_request_context->ct_policy_enforcer(), + std::string() /* TODO(rsleevi): Ensure a proper unique shard. */), + client_socket_factory_(nullptr), + ssl_config_service_(url_request_context->ssl_config_service()) { + if (http_context) { + client_socket_factory_ = http_context->client_socket_factory; + } + + if (!client_socket_factory_ && + url_request_context->GetNetworkSessionContext()) { + client_socket_factory_ = + url_request_context->GetNetworkSessionContext()->client_socket_factory; + } + if (!client_socket_factory_) + client_socket_factory_ = net::ClientSocketFactory::GetDefaultFactory(); +} + +TLSSocketFactory::~TLSSocketFactory() {} + +void TLSSocketFactory::UpgradeToTLS( + Delegate* socket_delegate, + const net::HostPortPair& host_port_pair, + mojom::TLSClientSocketOptionsPtr socket_options, + const net::MutableNetworkTrafficAnnotationTag& traffic_annotation, + mojom::TLSClientSocketRequest request, + mojom::SocketObserverPtr observer, + UpgradeToTLSCallback callback) { + const net::StreamSocket* socket = socket_delegate->BorrowSocket(); + if (!socket || !socket->IsConnected()) { + std::move(callback).Run( + net::ERR_SOCKET_NOT_CONNECTED, mojo::ScopedDataPipeConsumerHandle(), + mojo::ScopedDataPipeProducerHandle(), base::nullopt); + return; + } + auto socket_handle = std::make_unique<net::ClientSocketHandle>(); + socket_handle->SetSocket(socket_delegate->TakeSocket()); + CreateTLSClientSocket( + host_port_pair, std::move(socket_options), std::move(request), + std::move(socket_handle), std::move(observer), + static_cast<net::NetworkTrafficAnnotationTag>(traffic_annotation), + std::move(callback)); +} + +void TLSSocketFactory::CreateTLSClientSocket( + const net::HostPortPair& host_port_pair, + mojom::TLSClientSocketOptionsPtr socket_options, + mojom::TLSClientSocketRequest request, + std::unique_ptr<net::ClientSocketHandle> underlying_socket, + mojom::SocketObserverPtr observer, + const net::NetworkTrafficAnnotationTag& traffic_annotation, + mojom::TCPConnectedSocket::UpgradeToTLSCallback callback) { + auto socket = std::make_unique<TLSClientSocket>( + std::move(request), std::move(observer), + static_cast<net::NetworkTrafficAnnotationTag>(traffic_annotation)); + TLSClientSocket* socket_raw = socket.get(); + tls_socket_bindings_.AddBinding(std::move(socket), std::move(request)); + + net::SSLConfig ssl_config; + ssl_config_service_->GetSSLConfig(&ssl_config); + net::SSLClientSocketContext& ssl_client_socket_context = + ssl_client_socket_context_; + + bool send_ssl_info = false; + if (socket_options) { + ssl_config.version_min = + mojo::MojoSSLVersionToNetSSLVersion(socket_options->version_min); + ssl_config.version_max = + mojo::MojoSSLVersionToNetSSLVersion(socket_options->version_max); + + send_ssl_info = socket_options->send_ssl_info; + + if (socket_options->unsafely_skip_cert_verification) { + if (!no_verification_cert_verifier_) { + no_verification_cert_verifier_ = base::WrapUnique(new FakeCertVerifier); + no_verification_transport_security_state_.reset( + new net::TransportSecurityState); + no_verification_cert_transparency_verifier_.reset( + new net::MultiLogCTVerifier()); + no_verification_ct_policy_enforcer_.reset( + new net::DefaultCTPolicyEnforcer()); + no_verification_ssl_client_socket_context_.cert_verifier = + no_verification_cert_verifier_.get(); + no_verification_ssl_client_socket_context_.transport_security_state = + no_verification_transport_security_state_.get(); + no_verification_ssl_client_socket_context_.cert_transparency_verifier = + no_verification_cert_transparency_verifier_.get(); + no_verification_ssl_client_socket_context_.ct_policy_enforcer = + no_verification_ct_policy_enforcer_.get(); + } + ssl_client_socket_context = no_verification_ssl_client_socket_context_; + send_ssl_info = true; + } + } + socket_raw->Connect(host_port_pair, ssl_config, std::move(underlying_socket), + ssl_client_socket_context, client_socket_factory_, + std::move(callback), send_ssl_info); +} + +} // namespace network diff --git a/chromium/services/network/tls_socket_factory.h b/chromium/services/network/tls_socket_factory.h new file mode 100644 index 00000000000..76509df2833 --- /dev/null +++ b/chromium/services/network/tls_socket_factory.h @@ -0,0 +1,96 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef SERVICES_NETWORK_TLS_SOCKET_FACTORY_H_ +#define SERVICES_NETWORK_TLS_SOCKET_FACTORY_H_ + +#include <memory> +#include <vector> + +#include "base/component_export.h" +#include "base/macros.h" +#include "mojo/public/cpp/bindings/strong_binding_set.h" +#include "net/http/http_network_session.h" +#include "net/socket/ssl_client_socket.h" +#include "net/traffic_annotation/network_traffic_annotation.h" +#include "services/network/public/mojom/network_service.mojom.h" +#include "services/network/public/mojom/tls_socket.mojom.h" + +namespace net { +class ClientSocketHandle; +class ClientSocketFactory; +class SSLConfigService; +} // namespace net + +namespace network { + +// Helper class that handles TLS socket requests. +class COMPONENT_EXPORT(NETWORK_SERVICE) TLSSocketFactory { + public: + class Delegate { + public: + virtual const net::StreamSocket* BorrowSocket() = 0; + virtual std::unique_ptr<net::StreamSocket> TakeSocket() = 0; + }; + + // See documentation of UpgradeToTLS in tcp_socket.mojom for + // the semantics of the results. + using UpgradeToTLSCallback = + base::OnceCallback<void(int32_t net_error, + mojo::ScopedDataPipeConsumerHandle receive_stream, + mojo::ScopedDataPipeProducerHandle send_stream, + const base::Optional<net::SSLInfo>& ssl_info)>; + + // Constructs a TLSSocketFactory. If |net_log| is non-null, it is used to + // log NetLog events when logging is enabled. |net_log| used to must outlive + // |this|. Sockets will be created using, the earliest available from: + // 1) A ClientSocketFactory set on a non-null |http_context|. + // 2) A ClientSocketFactory set on |url_request_context|'s + // HttpNetworkSession::Context + // 3) The default ClientSocketFactory. + TLSSocketFactory(net::URLRequestContext* url_request_context, + const net::HttpNetworkSession::Context* http_context); + virtual ~TLSSocketFactory(); + + // Upgrades an existing socket to TLS. The previous pipes and data pump + // must already have been destroyed before the call to this method. + void UpgradeToTLS( + Delegate* socket_delegate, + const net::HostPortPair& host_port_pair, + mojom::TLSClientSocketOptionsPtr socket_options, + const net::MutableNetworkTrafficAnnotationTag& traffic_annotation, + mojom::TLSClientSocketRequest request, + mojom::SocketObserverPtr observer, + UpgradeToTLSCallback callback); + + private: + void CreateTLSClientSocket( + const net::HostPortPair& host_port_pair, + mojom::TLSClientSocketOptionsPtr socket_options, + mojom::TLSClientSocketRequest request, + std::unique_ptr<net::ClientSocketHandle> socket, + mojom::SocketObserverPtr observer, + const net::NetworkTrafficAnnotationTag& traffic_annotation, + mojom::TCPConnectedSocket::UpgradeToTLSCallback callback); + + // The following are used when |unsafely_skip_cert_verification| is specified + // in upgrade options. + net::SSLClientSocketContext no_verification_ssl_client_socket_context_; + std::unique_ptr<net::CertVerifier> no_verification_cert_verifier_; + std::unique_ptr<net::TransportSecurityState> + no_verification_transport_security_state_; + std::unique_ptr<net::CTVerifier> no_verification_cert_transparency_verifier_; + std::unique_ptr<net::CTPolicyEnforcer> no_verification_ct_policy_enforcer_; + + net::SSLClientSocketContext ssl_client_socket_context_; + net::ClientSocketFactory* client_socket_factory_; + net::SSLConfigService* const ssl_config_service_; + mojo::StrongBindingSet<mojom::TLSClientSocket> tls_socket_bindings_; + + DISALLOW_COPY_AND_ASSIGN(TLSSocketFactory); +}; + +} // namespace network + +#endif // SERVICES_NETWORK_SOCKET_FACTORY_H_ diff --git a/chromium/services/network/udp_socket.cc b/chromium/services/network/udp_socket.cc index 6eea7f10167..c2f6e707fd4 100644 --- a/chromium/services/network/udp_socket.cc +++ b/chromium/services/network/udp_socket.cc @@ -26,7 +26,7 @@ const uint32_t kMaxReadSize = 64 * 1024; // IPv6. const uint32_t kMaxPacketSize = kMaxReadSize - 1; -int ClampBufferSize(int requested_buffer_size) { +int ClampUDPBufferSize(int requested_buffer_size) { constexpr int kMinBufferSize = 0; constexpr int kMaxBufferSize = 128 * 1024; return base::ClampToRange(requested_buffer_size, kMinBufferSize, @@ -83,10 +83,11 @@ class SocketWrapperImpl : public UDPSocket::SocketWrapper { return socket_.SetBroadcast(broadcast); } int SetSendBufferSize(int send_buffer_size) override { - return socket_.SetSendBufferSize(ClampBufferSize(send_buffer_size)); + return socket_.SetSendBufferSize(ClampUDPBufferSize(send_buffer_size)); } int SetReceiveBufferSize(int receive_buffer_size) override { - return socket_.SetReceiveBufferSize(ClampBufferSize(receive_buffer_size)); + return socket_.SetReceiveBufferSize( + ClampUDPBufferSize(receive_buffer_size)); } int JoinGroup(const net::IPAddress& group_address) override { return socket_.JoinGroup(group_address); @@ -129,11 +130,11 @@ class SocketWrapperImpl : public UDPSocket::SocketWrapper { } if (result == net::OK && options->receive_buffer_size != 0) { result = socket_.SetReceiveBufferSize( - ClampBufferSize(options->receive_buffer_size)); + ClampUDPBufferSize(options->receive_buffer_size)); } if (result == net::OK && options->send_buffer_size != 0) { - result = - socket_.SetSendBufferSize(ClampBufferSize(options->send_buffer_size)); + result = socket_.SetSendBufferSize( + ClampUDPBufferSize(options->send_buffer_size)); } return result; } @@ -332,7 +333,8 @@ void UDPSocket::DoRecvFrom(uint32_t buffer_size) { DCHECK_GT(remaining_recv_slots_, 0u); DCHECK_GE(kMaxReadSize, buffer_size); - recvfrom_buffer_ = new net::IOBuffer(static_cast<size_t>(buffer_size)); + recvfrom_buffer_ = + base::MakeRefCounted<net::IOBuffer>(static_cast<size_t>(buffer_size)); // base::Unretained(this) is safe because socket is owned by |this|. int net_result = wrapped_socket_->RecvFrom( @@ -360,9 +362,8 @@ void UDPSocket::DoSendToOrWrite( // |data| points to a range of bytes in the received message and will be // freed when this method returns, so copy out the bytes now. - scoped_refptr<net::IOBufferWithSize> buffer = - new net::IOBufferWithSize(data.size()); - memcpy(buffer.get()->data(), data.begin(), data.size()); + auto buffer = base::MakeRefCounted<net::IOBufferWithSize>(data.size()); + memcpy(buffer.get()->data(), data.data(), data.size()); if (send_buffer_.get()) { auto request = std::make_unique<PendingSendRequest>(); diff --git a/chromium/services/network/url_loader.cc b/chromium/services/network/url_loader.cc index 1594e7ebc19..335f445e3bc 100644 --- a/chromium/services/network/url_loader.cc +++ b/chromium/services/network/url_loader.cc @@ -68,6 +68,7 @@ void PopulateResourceResponse(net::URLRequest* request, response->head.socket_address = response_info.socket_address; response->head.was_fetched_via_cache = request->was_cached(); response->head.was_fetched_via_proxy = request->was_fetched_via_proxy(); + response->head.proxy_server = request->proxy_server(); response->head.network_accessed = response_info.network_accessed; response->head.async_revalidation_requested = response_info.async_revalidation_requested; @@ -320,6 +321,9 @@ URLLoader::URLLoader( keepalive_statistics_recorder_(std::move(keepalive_statistics_recorder)), network_usage_accumulator_(std::move(network_usage_accumulator)), first_auth_attempt_(true), + custom_proxy_pre_cache_headers_(request.custom_proxy_pre_cache_headers), + custom_proxy_post_cache_headers_(request.custom_proxy_post_cache_headers), + fetch_window_id_(request.fetch_window_id), weak_ptr_factory_(this) { DCHECK(delete_callback_); if (!base::FeatureList::IsEnabled(features::kNetworkService)) { @@ -346,6 +350,11 @@ URLLoader::URLLoader( url_request_->SetReferrer(ComputeReferrer(request.referrer)); url_request_->set_referrer_policy(request.referrer_policy); url_request_->SetExtraRequestHeaders(request.headers); + if (!request.requested_with.empty()) { + // X-Requested-With header must be set here to avoid breaking CORS checks. + url_request_->SetExtraRequestHeaderByName("X-Requested-With", + request.requested_with, true); + } url_request_->set_upgrade_if_insecure(request.upgrade_if_insecure); url_request_->SetUserData(kUserDataKey, @@ -631,10 +640,19 @@ void URLLoader::OnCertificateRequested(net::URLRequest* unused, return; } - network_service_client_->OnCertificateRequested( - factory_params_->process_id, render_frame_id_, request_id_, cert_info, - base::BindOnce(&URLLoader::OnCertificateRequestedResponse, - weak_ptr_factory_.GetWeakPtr())); + if (fetch_window_id_) { + network_service_client_->OnCertificateRequested( + fetch_window_id_, -1 /* process_id */, -1 /* routing_id */, request_id_, + cert_info, + base::BindOnce(&URLLoader::OnCertificateRequestedResponse, + weak_ptr_factory_.GetWeakPtr())); + } else { + network_service_client_->OnCertificateRequested( + base::nullopt /* window_id */, factory_params_->process_id, + render_frame_id_, request_id_, cert_info, + base::BindOnce(&URLLoader::OnCertificateRequestedResponse, + weak_ptr_factory_.GetWeakPtr())); + } } void URLLoader::OnSSLCertificateError(net::URLRequest* request, @@ -665,6 +683,20 @@ void URLLoader::OnResponseStarted(net::URLRequest* url_request, int net_error) { return; } + MojoCreateDataPipeOptions options; + options.struct_size = sizeof(MojoCreateDataPipeOptions); + options.flags = MOJO_CREATE_DATA_PIPE_FLAG_NONE; + options.element_num_bytes = 1; + options.capacity_num_bytes = kDefaultAllocationSize; + MojoResult result = + mojo::CreateDataPipe(&options, &response_body_stream_, &consumer_handle_); + if (result != MOJO_RESULT_OK) { + NotifyCompleted(net::ERR_INSUFFICIENT_RESOURCES); + return; + } + DCHECK(response_body_stream_.is_valid()); + DCHECK(consumer_handle_.is_valid()); + // Do not account header bytes when reporting received body bytes to client. reported_total_encoded_bytes_ = url_request_->GetTotalReceivedBytes(); @@ -684,9 +716,6 @@ void URLLoader::OnResponseStarted(net::URLRequest* url_request, int net_error) { raw_response_headers_ = nullptr; } - mojo::DataPipe data_pipe(kDefaultAllocationSize); - response_body_stream_ = std::move(data_pipe.producer_handle); - consumer_handle_ = std::move(data_pipe.consumer_handle); peer_closed_handle_watcher_.Watch( response_body_stream_.get(), MOJO_HANDLE_SIGNAL_PEER_CLOSED, base::Bind(&URLLoader::OnResponseBodyStreamConsumerClosed, diff --git a/chromium/services/network/url_loader.h b/chromium/services/network/url_loader.h index 1b14cd70a56..e30c8041c96 100644 --- a/chromium/services/network/url_loader.h +++ b/chromium/services/network/url_loader.h @@ -15,6 +15,7 @@ #include "base/memory/ref_counted.h" #include "base/memory/weak_ptr.h" #include "base/optional.h" +#include "base/unguessable_token.h" #include "mojo/public/cpp/bindings/binding.h" #include "mojo/public/cpp/system/data_pipe.h" #include "mojo/public/cpp/system/simple_watcher.h" @@ -103,6 +104,14 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) URLLoader uint32_t GetRenderFrameId() const; uint32_t GetProcessId() const; + const net::HttpRequestHeaders& custom_proxy_pre_cache_headers() const { + return custom_proxy_pre_cache_headers_; + } + + const net::HttpRequestHeaders& custom_proxy_post_cache_headers() const { + return custom_proxy_post_cache_headers_; + } + // Gets the URLLoader associated with this request. static URLLoader* ForRequest(const net::URLRequest& request); @@ -255,6 +264,13 @@ class COMPONENT_EXPORT(NETWORK_SERVICE) URLLoader std::unique_ptr<ScopedThrottlingToken> throttling_token_; + net::HttpRequestHeaders custom_proxy_pre_cache_headers_; + net::HttpRequestHeaders custom_proxy_post_cache_headers_; + + // Indicates the originating frame of the request, see + // network::ResourceRequest::fetch_window_id for details. + base::Optional<base::UnguessableToken> fetch_window_id_; + base::WeakPtrFactory<URLLoader> weak_ptr_factory_; DISALLOW_COPY_AND_ASSIGN(URLLoader); diff --git a/chromium/services/network/url_loader_unittest.cc b/chromium/services/network/url_loader_unittest.cc index 162d6fe4fcb..0211387904c 100644 --- a/chromium/services/network/url_loader_unittest.cc +++ b/chromium/services/network/url_loader_unittest.cc @@ -30,6 +30,7 @@ #include "base/time/time.h" #include "build/build_config.h" #include "mojo/public/c/system/data_pipe.h" +#include "mojo/public/cpp/bindings/strong_binding.h" #include "mojo/public/cpp/system/data_pipe_utils.h" #include "mojo/public/cpp/system/wait.h" #include "net/base/io_buffer.h" @@ -37,6 +38,7 @@ #include "net/base/mime_sniffer.h" #include "net/base/net_errors.h" #include "net/http/http_response_info.h" +#include "net/ssl/client_cert_identity_test_util.h" #include "net/test/cert_test_util.h" #include "net/test/embedded_test_server/controllable_http_response.h" #include "net/test/embedded_test_server/embedded_test_server.h" @@ -328,28 +330,6 @@ class SimulatedCacheInterceptor : public net::URLRequestInterceptor { DISALLOW_COPY_AND_ASSIGN(SimulatedCacheInterceptor); }; -class RequestInterceptor : public net::URLRequestInterceptor { - public: - using InterceptCallback = base::Callback<void(net::URLRequest*)>; - - explicit RequestInterceptor(const InterceptCallback& callback) - : callback_(callback) {} - ~RequestInterceptor() override {} - - // URLRequestInterceptor implementation: - net::URLRequestJob* MaybeInterceptRequest( - net::URLRequest* request, - net::NetworkDelegate* network_delegate) const override { - callback_.Run(request); - return nullptr; - } - - private: - InterceptCallback callback_; - - DISALLOW_COPY_AND_ASSIGN(RequestInterceptor); -}; - // Returns whether monitoring was successfully set up. If yes, // StopMonitorBodyReadFromNetBeforePausedHistogram() needs to be called later to // stop monitoring. @@ -539,15 +519,6 @@ class URLLoaderTest : public testing::Test { EXPECT_EQ(actual_body, expected_body); } - // Adds an interceptor that can examine the URLRequest object. - void AddRequestObserver( - const GURL& url, - const RequestInterceptor::InterceptCallback& callback) { - net::URLRequestFilter::GetInstance()->AddUrlInterceptor( - url, std::unique_ptr<net::URLRequestInterceptor>( - new RequestInterceptor(callback))); - } - net::EmbeddedTestServer* test_server() { return &test_server_; } net::URLRequestContext* context() { return context_.get(); } TestURLLoaderClient* client() { return &client_; } @@ -689,6 +660,8 @@ class URLLoaderTest : public testing::Test { return sent_request_; } + void RunUntilIdle() { scoped_task_environment_.RunUntilIdle(); } + static constexpr int kProcessId = 4; static constexpr int kRouteId = 8; @@ -2094,12 +2067,43 @@ TEST_F(URLLoaderTest, EnterSuspendDiskCacheWriteQueued) { EXPECT_EQ('a', simulated_cache_dest->data()[0]); } -// A mock NetworkServiceClient that responds auth challenges with previously -// set credentials. -class TestAuthNetworkServiceClient : public mojom::NetworkServiceClient { +class FakeSSLPrivateKeyImpl : public network::mojom::SSLPrivateKey { public: - TestAuthNetworkServiceClient() = default; - ~TestAuthNetworkServiceClient() override = default; + explicit FakeSSLPrivateKeyImpl( + scoped_refptr<net::SSLPrivateKey> ssl_private_key) + : ssl_private_key_(std::move(ssl_private_key)) {} + ~FakeSSLPrivateKeyImpl() override {} + + // network::mojom::SSLPrivateKey: + void Sign(uint16_t algorithm, + const std::vector<uint8_t>& input, + network::mojom::SSLPrivateKey::SignCallback callback) override { + base::span<const uint8_t> input_span(input); + ssl_private_key_->Sign( + algorithm, input_span, + base::BindOnce(&FakeSSLPrivateKeyImpl::Callback, base::Unretained(this), + std::move(callback))); + } + + private: + void Callback(network::mojom::SSLPrivateKey::SignCallback callback, + net::Error net_error, + const std::vector<uint8_t>& signature) { + std::move(callback).Run(static_cast<int32_t>(net_error), signature); + } + + scoped_refptr<net::SSLPrivateKey> ssl_private_key_; + + DISALLOW_COPY_AND_ASSIGN(FakeSSLPrivateKeyImpl); +}; + +// A mock NetworkServiceClient that does the following: +// 1. Responds auth challenges with previously set credentials. +// 2. Responds certificate request with previously set responses. +class MockNetworkServiceClient : public mojom::NetworkServiceClient { + public: + MockNetworkServiceClient() = default; + ~MockNetworkServiceClient() override = default; enum class CredentialsResponse { NO_CREDENTIALS, @@ -2107,6 +2111,15 @@ class TestAuthNetworkServiceClient : public mojom::NetworkServiceClient { INCORRECT_CREDENTIALS_THEN_CORRECT_ONES, }; + enum class CertificateResponse { + INVALID = -1, + URL_LOADER_REQUEST_CANCELLED, + CANCEL_CERTIFICATE_SELECTION, + NULL_CERTIFICATE, + VALID_CERTIFICATE_SIGNATURE, + INVALID_CERTIFICATE_SIGNATURE, + }; + // mojom::NetworkServiceClient: void OnAuthRequired( uint32_t process_id, @@ -2139,13 +2152,37 @@ class TestAuthNetworkServiceClient : public mojom::NetworkServiceClient { } void OnCertificateRequested( + const base::Optional<base::UnguessableToken>& window_id, uint32_t process_id, uint32_t routing_id, uint32_t request_id, const scoped_refptr<net::SSLCertRequestInfo>& cert_info, mojom::NetworkServiceClient::OnCertificateRequestedCallback callback) override { - NOTREACHED(); + switch (certificate_response_) { + case CertificateResponse::INVALID: + NOTREACHED(); + break; + case CertificateResponse::URL_LOADER_REQUEST_CANCELLED: + ASSERT_TRUE(url_loader_ptr_); + url_loader_ptr_->reset(); + break; + case CertificateResponse::CANCEL_CERTIFICATE_SELECTION: + std::move(callback).Run(nullptr, std::vector<uint16_t>(), nullptr, + true /* cancel_certificate_selection */); + break; + case CertificateResponse::NULL_CERTIFICATE: + std::move(callback).Run(nullptr, std::vector<uint16_t>(), nullptr, + false /* cancel_certificate_selection */); + break; + case CertificateResponse::VALID_CERTIFICATE_SIGNATURE: + case CertificateResponse::INVALID_CERTIFICATE_SIGNATURE: + std::move(callback).Run(std::move(certificate_), algorithm_preferences_, + std::move(ssl_private_key_ptr_), + false /* cancel_certificate_selection */); + break; + } + ++on_certificate_requested_counter_; } void OnSSLCertificateError(uint32_t process_id, @@ -2206,19 +2243,51 @@ class TestAuthNetworkServiceClient : public mojom::NetworkServiceClient { return last_seen_response_headers_.get(); } + void set_certificate_response(CertificateResponse certificate_response) { + certificate_response_ = certificate_response; + } + + void set_url_loader_ptr(mojom::URLLoaderPtr* url_loader_ptr) { + url_loader_ptr_ = url_loader_ptr; + } + + void set_private_key(scoped_refptr<net::SSLPrivateKey> ssl_private_key) { + ssl_private_key_ = std::move(ssl_private_key); + algorithm_preferences_ = ssl_private_key_->GetAlgorithmPreferences(); + auto ssl_private_key_request = mojo::MakeRequest(&ssl_private_key_ptr_); + mojo::MakeStrongBinding( + std::make_unique<FakeSSLPrivateKeyImpl>(std::move(ssl_private_key_)), + std::move(ssl_private_key_request)); + } + + void set_certificate(scoped_refptr<net::X509Certificate> certificate) { + certificate_ = std::move(certificate); + } + + int on_certificate_requested_counter() { + return on_certificate_requested_counter_; + } + private: CredentialsResponse credentials_response_; base::Optional<net::AuthCredentials> auth_credentials_; int on_auth_required_call_counter_ = 0; scoped_refptr<net::HttpResponseHeaders> last_seen_response_headers_; - - DISALLOW_COPY_AND_ASSIGN(TestAuthNetworkServiceClient); + CertificateResponse certificate_response_ = CertificateResponse::INVALID; + mojom::URLLoaderPtr* url_loader_ptr_ = nullptr; + scoped_refptr<net::SSLPrivateKey> ssl_private_key_; + scoped_refptr<net::X509Certificate> certificate_; + network::mojom::SSLPrivateKeyPtr ssl_private_key_ptr_; + std::vector<uint16_t> algorithm_preferences_; + int on_certificate_requested_counter_ = 0; + + DISALLOW_COPY_AND_ASSIGN(MockNetworkServiceClient); }; TEST_F(URLLoaderTest, SetAuth) { - TestAuthNetworkServiceClient network_service_client; + MockNetworkServiceClient network_service_client; network_service_client.set_credentials_response( - TestAuthNetworkServiceClient::CredentialsResponse::CORRECT_CREDENTIALS); + MockNetworkServiceClient::CredentialsResponse::CORRECT_CREDENTIALS); ResourceRequest request = CreateResourceRequest("GET", test_server()->GetURL(kTestAuthURL)); @@ -2257,9 +2326,9 @@ TEST_F(URLLoaderTest, SetAuth) { } TEST_F(URLLoaderTest, CancelAuth) { - TestAuthNetworkServiceClient network_service_client; + MockNetworkServiceClient network_service_client; network_service_client.set_credentials_response( - TestAuthNetworkServiceClient::CredentialsResponse::NO_CREDENTIALS); + MockNetworkServiceClient::CredentialsResponse::NO_CREDENTIALS); ResourceRequest request = CreateResourceRequest("GET", test_server()->GetURL(kTestAuthURL)); @@ -2298,9 +2367,9 @@ TEST_F(URLLoaderTest, CancelAuth) { } TEST_F(URLLoaderTest, TwoChallenges) { - TestAuthNetworkServiceClient network_service_client; + MockNetworkServiceClient network_service_client; network_service_client.set_credentials_response( - TestAuthNetworkServiceClient::CredentialsResponse:: + MockNetworkServiceClient::CredentialsResponse:: INCORRECT_CREDENTIALS_THEN_CORRECT_ONES); ResourceRequest request = @@ -2342,9 +2411,9 @@ TEST_F(URLLoaderTest, TwoChallenges) { TEST_F(URLLoaderTest, NoAuthRequiredForFavicon) { constexpr char kFaviconTestPage[] = "/has_favicon.html"; - TestAuthNetworkServiceClient network_service_client; + MockNetworkServiceClient network_service_client; network_service_client.set_credentials_response( - TestAuthNetworkServiceClient::CredentialsResponse::CORRECT_CREDENTIALS); + MockNetworkServiceClient::CredentialsResponse::CORRECT_CREDENTIALS); ResourceRequest request = CreateResourceRequest("GET", test_server()->GetURL(kFaviconTestPage)); @@ -2384,9 +2453,9 @@ TEST_F(URLLoaderTest, NoAuthRequiredForFavicon) { } TEST_F(URLLoaderTest, HttpAuthResponseHeadersAvailable) { - TestAuthNetworkServiceClient network_service_client; + MockNetworkServiceClient network_service_client; network_service_client.set_credentials_response( - TestAuthNetworkServiceClient::CredentialsResponse::CORRECT_CREDENTIALS); + MockNetworkServiceClient::CredentialsResponse::CORRECT_CREDENTIALS); ResourceRequest request = CreateResourceRequest("GET", test_server()->GetURL(kTestAuthURL)); @@ -2564,4 +2633,276 @@ TEST_F(URLLoaderTest, FollowRedirectTwice) { delete_run_loop.Run(); } +class TestSSLPrivateKey : public net::SSLPrivateKey { + public: + explicit TestSSLPrivateKey(scoped_refptr<net::SSLPrivateKey> key) + : key_(std::move(key)) {} + + void set_fail_signing(bool fail_signing) { fail_signing_ = fail_signing; } + int sign_count() const { return sign_count_; } + + std::vector<uint16_t> GetAlgorithmPreferences() override { + return key_->GetAlgorithmPreferences(); + } + void Sign(uint16_t algorithm, + base::span<const uint8_t> input, + SignCallback callback) override { + sign_count_++; + if (fail_signing_) { + base::ThreadTaskRunnerHandle::Get()->PostTask( + FROM_HERE, base::BindOnce(std::move(callback), + net::ERR_SSL_CLIENT_AUTH_SIGNATURE_FAILED, + std::vector<uint8_t>())); + } else { + key_->Sign(algorithm, input, std::move(callback)); + } + } + + private: + ~TestSSLPrivateKey() override = default; + + scoped_refptr<net::SSLPrivateKey> key_; + bool fail_signing_ = false; + int sign_count_ = 0; + + DISALLOW_COPY_AND_ASSIGN(TestSSLPrivateKey); +}; + +#if !defined(OS_IOS) +TEST_F(URLLoaderTest, ClientAuthCancelConnection) { + net::EmbeddedTestServer test_server(net::EmbeddedTestServer::TYPE_HTTPS); + net::SSLServerConfig ssl_config; + ssl_config.client_cert_type = + net::SSLServerConfig::ClientCertType::REQUIRE_CLIENT_CERT; + test_server.SetSSLConfig(net::EmbeddedTestServer::CERT_OK, ssl_config); + test_server.AddDefaultHandlers( + base::FilePath(FILE_PATH_LITERAL("services/test/data"))); + ASSERT_TRUE(test_server.Start()); + + MockNetworkServiceClient network_service_client; + network_service_client.set_certificate_response( + MockNetworkServiceClient::CertificateResponse:: + URL_LOADER_REQUEST_CANCELLED); + + ResourceRequest request = + CreateResourceRequest("GET", test_server.GetURL("/defaultresponse")); + base::RunLoop delete_run_loop; + mojom::URLLoaderPtr loader; + mojom::URLLoaderFactoryParams params; + params.process_id = kProcessId; + params.is_corb_enabled = false; + std::unique_ptr<URLLoader> url_loader = std::make_unique<URLLoader>( + context(), &network_service_client, + DeleteLoaderCallback(&delete_run_loop, &url_loader), + mojo::MakeRequest(&loader), 0, request, false, + client()->CreateInterfacePtr(), TRAFFIC_ANNOTATION_FOR_TESTS, ¶ms, + 0 /* request_id */, resource_scheduler_client(), nullptr, + nullptr /* network_usage_accumulator */); + network_service_client.set_url_loader_ptr(&loader); + + RunUntilIdle(); + ASSERT_TRUE(url_loader); + + client()->RunUntilComplete(); + + EXPECT_EQ(net::ERR_FAILED, client()->completion_status().error_code); +} + +TEST_F(URLLoaderTest, ClientAuthCancelCertificateSelection) { + net::EmbeddedTestServer test_server(net::EmbeddedTestServer::TYPE_HTTPS); + net::SSLServerConfig ssl_config; + ssl_config.client_cert_type = + net::SSLServerConfig::ClientCertType::REQUIRE_CLIENT_CERT; + test_server.SetSSLConfig(net::EmbeddedTestServer::CERT_OK, ssl_config); + test_server.AddDefaultHandlers( + base::FilePath(FILE_PATH_LITERAL("services/test/data"))); + ASSERT_TRUE(test_server.Start()); + + MockNetworkServiceClient network_service_client; + network_service_client.set_certificate_response( + MockNetworkServiceClient::CertificateResponse:: + CANCEL_CERTIFICATE_SELECTION); + + ResourceRequest request = + CreateResourceRequest("GET", test_server.GetURL("/defaultresponse")); + base::RunLoop delete_run_loop; + mojom::URLLoaderPtr loader; + mojom::URLLoaderFactoryParams params; + params.process_id = kProcessId; + params.is_corb_enabled = false; + std::unique_ptr<URLLoader> url_loader = std::make_unique<URLLoader>( + context(), &network_service_client, + DeleteLoaderCallback(&delete_run_loop, &url_loader), + mojo::MakeRequest(&loader), 0, request, false, + client()->CreateInterfacePtr(), TRAFFIC_ANNOTATION_FOR_TESTS, ¶ms, + 0 /* request_id */, resource_scheduler_client(), nullptr, + nullptr /* network_usage_accumulator */); + + RunUntilIdle(); + ASSERT_TRUE(url_loader); + + EXPECT_EQ(0, network_service_client.on_certificate_requested_counter()); + + client()->RunUntilComplete(); + + EXPECT_EQ(1, network_service_client.on_certificate_requested_counter()); + EXPECT_EQ(net::ERR_SSL_CLIENT_AUTH_CERT_NEEDED, + client()->completion_status().error_code); +} + +TEST_F(URLLoaderTest, ClientAuthNoCertificate) { + net::EmbeddedTestServer test_server(net::EmbeddedTestServer::TYPE_HTTPS); + net::SSLServerConfig ssl_config; + ssl_config.client_cert_type = + net::SSLServerConfig::ClientCertType::REQUIRE_CLIENT_CERT; + test_server.SetSSLConfig(net::EmbeddedTestServer::CERT_OK, ssl_config); + test_server.AddDefaultHandlers( + base::FilePath(FILE_PATH_LITERAL("services/test/data"))); + ASSERT_TRUE(test_server.Start()); + + MockNetworkServiceClient network_service_client; + network_service_client.set_certificate_response( + MockNetworkServiceClient::CertificateResponse::NULL_CERTIFICATE); + + ResourceRequest request = + CreateResourceRequest("GET", test_server.GetURL("/defaultresponse")); + base::RunLoop delete_run_loop; + mojom::URLLoaderPtr loader; + mojom::URLLoaderFactoryParams params; + params.process_id = kProcessId; + params.is_corb_enabled = false; + std::unique_ptr<URLLoader> url_loader = std::make_unique<URLLoader>( + context(), &network_service_client, + DeleteLoaderCallback(&delete_run_loop, &url_loader), + mojo::MakeRequest(&loader), 0, request, false, + client()->CreateInterfacePtr(), TRAFFIC_ANNOTATION_FOR_TESTS, ¶ms, + 0 /* request_id */, resource_scheduler_client(), nullptr, + nullptr /* network_usage_accumulator */); + + RunUntilIdle(); + ASSERT_TRUE(url_loader); + + EXPECT_EQ(0, network_service_client.on_certificate_requested_counter()); + + client()->RunUntilComplete(); + + EXPECT_EQ(1, network_service_client.on_certificate_requested_counter()); + EXPECT_EQ(net::ERR_BAD_SSL_CLIENT_AUTH_CERT, + client()->completion_status().error_code); +} + +TEST_F(URLLoaderTest, ClientAuthCertificateWithValidSignature) { + std::unique_ptr<net::FakeClientCertIdentity> identity = + net::FakeClientCertIdentity::CreateFromCertAndKeyFiles( + net::GetTestCertsDirectory(), "client_1.pem", "client_1.pk8"); + ASSERT_TRUE(identity); + scoped_refptr<TestSSLPrivateKey> private_key = + base::MakeRefCounted<TestSSLPrivateKey>(identity->ssl_private_key()); + TestSSLPrivateKey* private_key_ptr = private_key.get(); + + net::EmbeddedTestServer test_server(net::EmbeddedTestServer::TYPE_HTTPS); + net::SSLServerConfig ssl_config; + ssl_config.client_cert_type = + net::SSLServerConfig::ClientCertType::REQUIRE_CLIENT_CERT; + test_server.SetSSLConfig(net::EmbeddedTestServer::CERT_OK, ssl_config); + test_server.AddDefaultHandlers( + base::FilePath(FILE_PATH_LITERAL("services/test/data"))); + ASSERT_TRUE(test_server.Start()); + + MockNetworkServiceClient network_service_client; + network_service_client.set_certificate_response( + MockNetworkServiceClient::CertificateResponse:: + VALID_CERTIFICATE_SIGNATURE); + network_service_client.set_private_key(std::move(private_key)); + scoped_refptr<net::X509Certificate> certificate = + test_server.GetCertificate(); + network_service_client.set_certificate(std::move(certificate)); + + ResourceRequest request = + CreateResourceRequest("GET", test_server.GetURL("/defaultresponse")); + base::RunLoop delete_run_loop; + mojom::URLLoaderPtr loader; + mojom::URLLoaderFactoryParams params; + params.process_id = kProcessId; + params.is_corb_enabled = false; + std::unique_ptr<URLLoader> url_loader = std::make_unique<URLLoader>( + context(), &network_service_client, + DeleteLoaderCallback(&delete_run_loop, &url_loader), + mojo::MakeRequest(&loader), 0, request, false, + client()->CreateInterfacePtr(), TRAFFIC_ANNOTATION_FOR_TESTS, ¶ms, + 0 /* request_id */, resource_scheduler_client(), nullptr, + nullptr /* network_usage_accumulator */); + + RunUntilIdle(); + ASSERT_TRUE(url_loader); + + EXPECT_EQ(0, network_service_client.on_certificate_requested_counter()); + EXPECT_EQ(0, private_key_ptr->sign_count()); + + client()->RunUntilComplete(); + + EXPECT_EQ(1, network_service_client.on_certificate_requested_counter()); + // The private key should have been used. + EXPECT_EQ(1, private_key_ptr->sign_count()); +} + +TEST_F(URLLoaderTest, ClientAuthCertificateWithInvalidSignature) { + std::unique_ptr<net::FakeClientCertIdentity> identity = + net::FakeClientCertIdentity::CreateFromCertAndKeyFiles( + net::GetTestCertsDirectory(), "client_1.pem", "client_1.pk8"); + ASSERT_TRUE(identity); + scoped_refptr<TestSSLPrivateKey> private_key = + base::MakeRefCounted<TestSSLPrivateKey>(identity->ssl_private_key()); + private_key->set_fail_signing(true); + TestSSLPrivateKey* private_key_ptr = private_key.get(); + + net::EmbeddedTestServer test_server(net::EmbeddedTestServer::TYPE_HTTPS); + net::SSLServerConfig ssl_config; + ssl_config.client_cert_type = + net::SSLServerConfig::ClientCertType::REQUIRE_CLIENT_CERT; + test_server.SetSSLConfig(net::EmbeddedTestServer::CERT_OK, ssl_config); + test_server.AddDefaultHandlers( + base::FilePath(FILE_PATH_LITERAL("services/test/data"))); + ASSERT_TRUE(test_server.Start()); + + MockNetworkServiceClient network_service_client; + network_service_client.set_certificate_response( + MockNetworkServiceClient::CertificateResponse:: + VALID_CERTIFICATE_SIGNATURE); + network_service_client.set_private_key(std::move(private_key)); + scoped_refptr<net::X509Certificate> certificate = + test_server.GetCertificate(); + network_service_client.set_certificate(std::move(certificate)); + + ResourceRequest request = + CreateResourceRequest("GET", test_server.GetURL("/defaultresponse")); + base::RunLoop delete_run_loop; + mojom::URLLoaderPtr loader; + mojom::URLLoaderFactoryParams params; + params.process_id = kProcessId; + params.is_corb_enabled = false; + std::unique_ptr<URLLoader> url_loader = std::make_unique<URLLoader>( + context(), &network_service_client, + DeleteLoaderCallback(&delete_run_loop, &url_loader), + mojo::MakeRequest(&loader), 0, request, false, + client()->CreateInterfacePtr(), TRAFFIC_ANNOTATION_FOR_TESTS, ¶ms, + 0 /* request_id */, resource_scheduler_client(), nullptr, + nullptr /* network_usage_accumulator */); + + RunUntilIdle(); + ASSERT_TRUE(url_loader); + + EXPECT_EQ(0, network_service_client.on_certificate_requested_counter()); + EXPECT_EQ(0, private_key_ptr->sign_count()); + + client()->RunUntilComplete(); + + EXPECT_EQ(1, network_service_client.on_certificate_requested_counter()); + // The private key should have been used. + EXPECT_EQ(1, private_key_ptr->sign_count()); + EXPECT_EQ(net::ERR_SSL_CLIENT_AUTH_SIGNATURE_FAILED, + client()->completion_status().error_code); +} +#endif // !defined(OS_IOS) + } // namespace network diff --git a/chromium/services/network/url_request_context_builder_mojo.cc b/chromium/services/network/url_request_context_builder_mojo.cc index 0013f022860..5a9f95fca8a 100644 --- a/chromium/services/network/url_request_context_builder_mojo.cc +++ b/chromium/services/network/url_request_context_builder_mojo.cc @@ -10,10 +10,7 @@ #include "net/proxy_resolution/proxy_config_service.h" #include "services/network/network_context.h" #include "services/network/public/cpp/features.h" - -#if !defined(OS_IOS) #include "services/network/proxy_service_mojo.h" -#endif namespace network { @@ -43,7 +40,6 @@ URLRequestContextBuilderMojo::CreateProxyResolutionService( DCHECK(url_request_context); DCHECK(host_resolver); -#if !defined(OS_IOS) if (mojo_proxy_resolver_factory_) { std::unique_ptr<net::DhcpPacFileFetcher> dhcp_pac_file_fetcher = dhcp_fetcher_factory_->Create(url_request_context); @@ -63,7 +59,6 @@ URLRequestContextBuilderMojo::CreateProxyResolutionService( std::move(dhcp_pac_file_fetcher), host_resolver, net_log, network_delegate); } -#endif return net::URLRequestContextBuilder::CreateProxyResolutionService( std::move(proxy_config_service), url_request_context, host_resolver, diff --git a/chromium/services/network/websocket.cc b/chromium/services/network/websocket.cc index ed70e6ce12e..55e853a4d27 100644 --- a/chromium/services/network/websocket.cc +++ b/chromium/services/network/websocket.cc @@ -385,7 +385,7 @@ void WebSocket::SendFrame(bool fin, DCHECK(IsKnownEnumValue(type)); // TODO(darin): Avoid this copy. - scoped_refptr<net::IOBuffer> data_to_pass(new net::IOBuffer(data.size())); + auto data_to_pass = base::MakeRefCounted<net::IOBuffer>(data.size()); std::copy(data.begin(), data.end(), data_to_pass->data()); channel_->SendFrame(fin, MessageTypeToOpCode(type), std::move(data_to_pass), diff --git a/chromium/services/preferences/BUILD.gn b/chromium/services/preferences/BUILD.gn index 9e6abf3953f..d11e48492e7 100644 --- a/chromium/services/preferences/BUILD.gn +++ b/chromium/services/preferences/BUILD.gn @@ -21,6 +21,7 @@ source_set("preferences") { ":*", "//services/preferences/public/cpp:service_main", ] + configs += [ "//build/config/compiler:wexit_time_destructors" ] deps = [ "//components/prefs", "//services/preferences/public/cpp", diff --git a/chromium/services/preferences/public/cpp/BUILD.gn b/chromium/services/preferences/public/cpp/BUILD.gn index 307dd8fde62..7643b134e54 100644 --- a/chromium/services/preferences/public/cpp/BUILD.gn +++ b/chromium/services/preferences/public/cpp/BUILD.gn @@ -20,6 +20,8 @@ source_set("cpp") { "scoped_pref_update.h", ] + configs += [ "//build/config/compiler:wexit_time_destructors" ] + public_deps = [ "//base", "//components/prefs", @@ -47,4 +49,5 @@ source_set("service_main") { "pref_service_main.cc", "pref_service_main.h", ] + configs += [ "//build/config/compiler:wexit_time_destructors" ] } diff --git a/chromium/services/preferences/tracked/OWNERS b/chromium/services/preferences/tracked/OWNERS index 4939fa68ce2..f1e684933f5 100644 --- a/chromium/services/preferences/tracked/OWNERS +++ b/chromium/services/preferences/tracked/OWNERS @@ -1,3 +1,2 @@ -bauerb@chromium.org gab@chromium.org proberge@chromium.org diff --git a/chromium/services/preferences/tracked/pref_hash_store_impl.cc b/chromium/services/preferences/tracked/pref_hash_store_impl.cc index 477996997e8..429a068a68e 100644 --- a/chromium/services/preferences/tracked/pref_hash_store_impl.cc +++ b/chromium/services/preferences/tracked/pref_hash_store_impl.cc @@ -10,6 +10,7 @@ #include "base/logging.h" #include "base/macros.h" #include "base/metrics/histogram_macros.h" +#include "base/no_destructor.h" #include "services/preferences/tracked/device_id.h" #include "services/preferences/tracked/hash_store_contents.h" @@ -20,9 +21,9 @@ using ValueState = // Returns a deterministic ID for this machine. std::string GenerateDeviceId() { - static std::string cached_device_id; - if (!cached_device_id.empty()) - return cached_device_id; + static base::NoDestructor<std::string> cached_device_id; + if (!cached_device_id->empty()) + return *cached_device_id; std::string device_id; MachineIdStatus status = GetDeterministicMachineSpecificId(&device_id); @@ -30,7 +31,7 @@ std::string GenerateDeviceId() { status == MachineIdStatus::SUCCESS); if (status == MachineIdStatus::SUCCESS) { - cached_device_id = device_id; + *cached_device_id = device_id; return device_id; } diff --git a/chromium/services/proxy_resolver/BUILD.gn b/chromium/services/proxy_resolver/BUILD.gn index bb6c6127528..20660e40a03 100644 --- a/chromium/services/proxy_resolver/BUILD.gn +++ b/chromium/services/proxy_resolver/BUILD.gn @@ -15,6 +15,8 @@ source_set("lib") { "proxy_resolver_service.h", ] + configs += [ "//build/config/compiler:wexit_time_destructors" ] + deps = [ "//base", "//mojo/public/cpp/bindings", diff --git a/chromium/services/resource_coordinator/BUILD.gn b/chromium/services/resource_coordinator/BUILD.gn index 4f2152e68d6..a12720fdccd 100644 --- a/chromium/services/resource_coordinator/BUILD.gn +++ b/chromium/services/resource_coordinator/BUILD.gn @@ -55,8 +55,12 @@ source_set("lib") { "resource_coordinator_clock.h", "resource_coordinator_service.cc", "resource_coordinator_service.h", + "webui_graph_dump_impl.cc", + "webui_graph_dump_impl.h", ] + configs += [ "//build/config/compiler:wexit_time_destructors" ] + public_deps = [ "//base", "//mojo/public/cpp/bindings", @@ -109,6 +113,7 @@ source_set("tests") { "public/cpp/memory_instrumentation/memory_instrumentation_mojom_traits_unittest.cc", "public/cpp/memory_instrumentation/os_metrics_unittest.cc", "public/cpp/memory_instrumentation/tracing_integration_unittest.cc", + "webui_graph_dump_impl_unittest.cc", ] if (!is_android) { diff --git a/chromium/services/resource_coordinator/coordination_unit/coordination_unit_graph.cc b/chromium/services/resource_coordinator/coordination_unit/coordination_unit_graph.cc index 5c1771d4f3a..80505505e65 100644 --- a/chromium/services/resource_coordinator/coordination_unit/coordination_unit_graph.cc +++ b/chromium/services/resource_coordinator/coordination_unit/coordination_unit_graph.cc @@ -127,25 +127,19 @@ CoordinationUnitGraph::GetProcessCoordinationUnitByPid(base::ProcessId pid) { return ProcessCoordinationUnitImpl::FromCoordinationUnitBase(it->second); } -std::vector<CoordinationUnitBase*> -CoordinationUnitGraph::GetCoordinationUnitsOfType(CoordinationUnitType type) { - std::vector<CoordinationUnitBase*> results; - for (const auto& el : coordination_units_) { - if (el.first.type == type) - results.push_back(el.second.get()); - } - return results; -} - std::vector<ProcessCoordinationUnitImpl*> CoordinationUnitGraph::GetAllProcessCoordinationUnits() { - auto cus = GetCoordinationUnitsOfType(CoordinationUnitType::kProcess); - std::vector<ProcessCoordinationUnitImpl*> process_cus; - for (auto* process_cu : cus) { - process_cus.push_back( - ProcessCoordinationUnitImpl::FromCoordinationUnitBase(process_cu)); - } - return process_cus; + return GetAllCoordinationUnitsOfType<ProcessCoordinationUnitImpl>(); +} + +std::vector<FrameCoordinationUnitImpl*> +CoordinationUnitGraph::GetAllFrameCoordinationUnits() { + return GetAllCoordinationUnitsOfType<FrameCoordinationUnitImpl>(); +} + +std::vector<PageCoordinationUnitImpl*> +CoordinationUnitGraph::GetAllPageCoordinationUnits() { + return GetAllCoordinationUnitsOfType<PageCoordinationUnitImpl>(); } CoordinationUnitBase* CoordinationUnitGraph::AddNewCoordinationUnit( @@ -183,4 +177,15 @@ void CoordinationUnitGraph::BeforeProcessPidChange( processes_by_pid_[new_pid] = process; } +template <typename CUType> +std::vector<CUType*> CoordinationUnitGraph::GetAllCoordinationUnitsOfType() { + const auto type = CUType::Type(); + std::vector<CUType*> ret; + for (const auto& el : coordination_units_) { + if (el.first.type == type) + ret.push_back(CUType::FromCoordinationUnitBase(el.second.get())); + } + return ret; +} + } // namespace resource_coordinator diff --git a/chromium/services/resource_coordinator/coordination_unit/coordination_unit_graph.h b/chromium/services/resource_coordinator/coordination_unit/coordination_unit_graph.h index 649cfbc5b9c..1d3df94a53c 100644 --- a/chromium/services/resource_coordinator/coordination_unit/coordination_unit_graph.h +++ b/chromium/services/resource_coordinator/coordination_unit/coordination_unit_graph.h @@ -71,11 +71,10 @@ class CoordinationUnitGraph { SystemCoordinationUnitImpl* FindOrCreateSystemCoordinationUnit( std::unique_ptr<service_manager::ServiceContextRef> service_ref); - // Search functions for type and ID queries. - std::vector<CoordinationUnitBase*> GetCoordinationUnitsOfType( - CoordinationUnitType type); - std::vector<ProcessCoordinationUnitImpl*> GetAllProcessCoordinationUnits(); + std::vector<FrameCoordinationUnitImpl*> GetAllFrameCoordinationUnits(); + std::vector<PageCoordinationUnitImpl*> GetAllPageCoordinationUnits(); + // Retrieves the process CU with PID |pid|, if any. ProcessCoordinationUnitImpl* GetProcessCoordinationUnitByPid( base::ProcessId pid); @@ -103,6 +102,9 @@ class CoordinationUnitGraph { void BeforeProcessPidChange(ProcessCoordinationUnitImpl* process, base::ProcessId new_pid); + template <typename CUType> + std::vector<CUType*> GetAllCoordinationUnitsOfType(); + CoordinationUnitID system_coordination_unit_id_; CUIDMap coordination_units_; ProcessByPidMap processes_by_pid_; diff --git a/chromium/services/resource_coordinator/coordination_unit/coordination_unit_graph_unittest.cc b/chromium/services/resource_coordinator/coordination_unit/coordination_unit_graph_unittest.cc index 16d0e00583c..84929d351ea 100644 --- a/chromium/services/resource_coordinator/coordination_unit/coordination_unit_graph_unittest.cc +++ b/chromium/services/resource_coordinator/coordination_unit/coordination_unit_graph_unittest.cc @@ -3,6 +3,9 @@ // found in the LICENSE file. #include "services/resource_coordinator/coordination_unit/coordination_unit_graph.h" + +#include "services/resource_coordinator/coordination_unit/frame_coordination_unit_impl.h" +#include "services/resource_coordinator/coordination_unit/mock_coordination_unit_graphs.h" #include "services/resource_coordinator/coordination_unit/process_coordination_unit_impl.h" #include "services/resource_coordinator/coordination_unit/system_coordination_unit_impl.h" #include "testing/gtest/include/gtest/gtest.h" @@ -92,4 +95,26 @@ TEST(CoordinationUnitGraphTest, PIDReuse) { EXPECT_EQ(process2, graph.GetProcessCoordinationUnitByPid(kPid)); } +TEST(CoordinationUnitGraphTest, GetAllCUsByType) { + CoordinationUnitGraph graph; + MockMultiplePagesInSingleProcessCoordinationUnitGraph cu_graph(&graph); + + std::vector<ProcessCoordinationUnitImpl*> processes = + graph.GetAllProcessCoordinationUnits(); + ASSERT_EQ(1u, processes.size()); + EXPECT_NE(nullptr, processes[0]); + + std::vector<FrameCoordinationUnitImpl*> frames = + graph.GetAllFrameCoordinationUnits(); + ASSERT_EQ(2u, frames.size()); + EXPECT_NE(nullptr, frames[0]); + EXPECT_NE(nullptr, frames[1]); + + std::vector<PageCoordinationUnitImpl*> pages = + graph.GetAllPageCoordinationUnits(); + ASSERT_EQ(2u, pages.size()); + EXPECT_NE(nullptr, pages[0]); + EXPECT_NE(nullptr, pages[1]); +} + } // namespace resource_coordinator diff --git a/chromium/services/resource_coordinator/coordination_unit/coordination_unit_introspector_impl.cc b/chromium/services/resource_coordinator/coordination_unit/coordination_unit_introspector_impl.cc index bfb041254a2..7e834158e27 100644 --- a/chromium/services/resource_coordinator/coordination_unit/coordination_unit_introspector_impl.cc +++ b/chromium/services/resource_coordinator/coordination_unit/coordination_unit_introspector_impl.cc @@ -35,12 +35,7 @@ void CoordinationUnitIntrospectorImpl::GetProcessToURLMap( mojom::ProcessInfoPtr process_info(mojom::ProcessInfo::New()); process_info->pid = base::checked_cast<base::ProcessId>(pid); DCHECK_NE(base::kNullProcessId, process_info->pid); - - int64_t launch_time; - if (process_cu->GetProperty(mojom::PropertyType::kLaunchTime, - &launch_time)) { - process_info->launch_time = base::Time::FromTimeT(launch_time); - } + process_info->launch_time = process_cu->launch_time(); std::set<PageCoordinationUnitImpl*> page_cus = process_cu->GetAssociatedPageCoordinationUnits(); diff --git a/chromium/services/resource_coordinator/coordination_unit/frame_coordination_unit_impl.cc b/chromium/services/resource_coordinator/coordination_unit/frame_coordination_unit_impl.cc index 62182a23ea0..ee129155aa4 100644 --- a/chromium/services/resource_coordinator/coordination_unit/frame_coordination_unit_impl.cc +++ b/chromium/services/resource_coordinator/coordination_unit/frame_coordination_unit_impl.cc @@ -70,13 +70,22 @@ void FrameCoordinationUnitImpl::SetNetworkAlmostIdle(bool idle) { } void FrameCoordinationUnitImpl::SetLifecycleState(mojom::LifecycleState state) { - SetProperty(mojom::PropertyType::kLifecycleState, - static_cast<int64_t>(state)); - // The page will have the same lifecycle state as the main frame. - if (IsMainFrame() && GetPageCoordinationUnit()) { - GetPageCoordinationUnit()->SetProperty(mojom::PropertyType::kLifecycleState, - static_cast<int64_t>(state)); - } + if (state == lifecycle_state_) + return; + + mojom::LifecycleState old_state = lifecycle_state_; + lifecycle_state_ = state; + + // Notify parents of this change. + if (process_coordination_unit_) + process_coordination_unit_->OnFrameLifecycleStateChanged(this, old_state); + if (page_coordination_unit_) + page_coordination_unit_->OnFrameLifecycleStateChanged(this, old_state); +} + +void FrameCoordinationUnitImpl::SetHasNonEmptyBeforeUnload( + bool has_nonempty_beforeunload) { + has_nonempty_beforeunload_ = has_nonempty_beforeunload; } void FrameCoordinationUnitImpl::OnAlertFired() { diff --git a/chromium/services/resource_coordinator/coordination_unit/frame_coordination_unit_impl.h b/chromium/services/resource_coordinator/coordination_unit/frame_coordination_unit_impl.h index b37d8acecac..b3f43c6b65b 100644 --- a/chromium/services/resource_coordinator/coordination_unit/frame_coordination_unit_impl.h +++ b/chromium/services/resource_coordinator/coordination_unit/frame_coordination_unit_impl.h @@ -35,6 +35,7 @@ class FrameCoordinationUnitImpl void SetAudibility(bool audible) override; void SetNetworkAlmostIdle(bool idle) override; void SetLifecycleState(mojom::LifecycleState state) override; + void SetHasNonEmptyBeforeUnload(bool has_nonempty_beforeunload) override; void OnAlertFired() override; void OnNonPersistentNotificationCreated() override; @@ -43,7 +44,9 @@ class FrameCoordinationUnitImpl ProcessCoordinationUnitImpl* GetProcessCoordinationUnit() const; bool IsMainFrame() const; + mojom::LifecycleState lifecycle_state() const { return lifecycle_state_; } base::TimeTicks last_audible_time() const { return last_audible_time_; } + bool has_nonempty_beforeunload() const { return has_nonempty_beforeunload_; } const std::set<FrameCoordinationUnitImpl*>& child_frame_coordination_units_for_testing() const { @@ -81,6 +84,8 @@ class FrameCoordinationUnitImpl ProcessCoordinationUnitImpl* process_coordination_unit_; std::set<FrameCoordinationUnitImpl*> child_frame_coordination_units_; + mojom::LifecycleState lifecycle_state_ = mojom::LifecycleState::kRunning; + bool has_nonempty_beforeunload_ = false; base::TimeTicks last_audible_time_; DISALLOW_COPY_AND_ASSIGN(FrameCoordinationUnitImpl); diff --git a/chromium/services/resource_coordinator/coordination_unit/frame_coordination_unit_impl_unittest.cc b/chromium/services/resource_coordinator/coordination_unit/frame_coordination_unit_impl_unittest.cc index cb5a69de019..e03341d13ed 100644 --- a/chromium/services/resource_coordinator/coordination_unit/frame_coordination_unit_impl_unittest.cc +++ b/chromium/services/resource_coordinator/coordination_unit/frame_coordination_unit_impl_unittest.cc @@ -125,7 +125,7 @@ TEST_F(FrameCoordinationUnitImplTest, LastAudibleTime) { cu_graph.frame->last_audible_time()); } -int64_t GetLifecycleState(resource_coordinator::CoordinationUnitBase* cu) { +int64_t GetLifecycleState(PageCoordinationUnitImpl* cu) { int64_t value; if (cu->GetProperty(mojom::PropertyType::kLifecycleState, &value)) return value; @@ -155,31 +155,43 @@ TEST_F(FrameCoordinationUnitImplTest, LifecycleStatesTransitions) { // Freezing a child frame should not affect the page state. cu_graph.child_frame->SetLifecycleState(mojom::LifecycleState::kFrozen); - // Verify that the frame is frozen. - EXPECT_FROZEN(cu_graph.child_frame); - // But all pages remain active. EXPECT_RUNNING(cu_graph.page); EXPECT_RUNNING(cu_graph.other_page); - // Freezing a page main frame should freeze that page. + // Freezing the only frame in a page should freeze that page. cu_graph.frame->SetLifecycleState(mojom::LifecycleState::kFrozen); - EXPECT_FROZEN(cu_graph.frame); EXPECT_FROZEN(cu_graph.page); + EXPECT_RUNNING(cu_graph.other_page); + + // Unfreeze the child frame in the other page. + cu_graph.child_frame->SetLifecycleState(mojom::LifecycleState::kRunning); + EXPECT_FROZEN(cu_graph.page); + EXPECT_RUNNING(cu_graph.other_page); - // Freezing the other page main frame. + // Freezing the main frame in the other page should not alter that pages + // state, as there is still a child frame that is running. cu_graph.other_frame->SetLifecycleState(mojom::LifecycleState::kFrozen); - EXPECT_FROZEN(cu_graph.other_frame); + EXPECT_FROZEN(cu_graph.page); + EXPECT_RUNNING(cu_graph.other_page); + + // Refreezing the child frame should freeze the page. + cu_graph.child_frame->SetLifecycleState(mojom::LifecycleState::kFrozen); + EXPECT_FROZEN(cu_graph.page); EXPECT_FROZEN(cu_graph.other_page); - // Unfreezing subframe should have no effect. - cu_graph.child_frame->SetLifecycleState(mojom::LifecycleState::kRunning); - // Verify that the frame is unfrozen. - EXPECT_RUNNING(cu_graph.child_frame); - // But the page is still frozen + // Unfreezing a main frame should unfreeze the associated page. + cu_graph.frame->SetLifecycleState(mojom::LifecycleState::kRunning); + EXPECT_RUNNING(cu_graph.page); EXPECT_FROZEN(cu_graph.other_page); - // Unfreezing the main frame should unfreeze the page. + // Unfreezing the child frame should unfreeze the associated page. + cu_graph.child_frame->SetLifecycleState(mojom::LifecycleState::kRunning); + EXPECT_RUNNING(cu_graph.page); + EXPECT_RUNNING(cu_graph.other_page); + + // Unfreezing the main frame shouldn't change anything. cu_graph.other_frame->SetLifecycleState(mojom::LifecycleState::kRunning); + EXPECT_RUNNING(cu_graph.page); EXPECT_RUNNING(cu_graph.other_page); } diff --git a/chromium/services/resource_coordinator/coordination_unit/page_coordination_unit_impl.cc b/chromium/services/resource_coordinator/coordination_unit/page_coordination_unit_impl.cc index e8fd4c0ae32..268b9afe1f2 100644 --- a/chromium/services/resource_coordinator/coordination_unit/page_coordination_unit_impl.cc +++ b/chromium/services/resource_coordinator/coordination_unit/page_coordination_unit_impl.cc @@ -65,8 +65,10 @@ void PageCoordinationUnitImpl::OnTitleUpdated() { } void PageCoordinationUnitImpl::OnMainFrameNavigationCommitted( + base::TimeTicks navigation_committed_time, int64_t navigation_id, const std::string& url) { + navigation_committed_time_ = navigation_committed_time; main_frame_url_ = url; navigation_id_ = navigation_id; SendEvent(mojom::Event::kNavigationCommitted); @@ -143,10 +145,22 @@ PageCoordinationUnitImpl::GetMainFrameCoordinationUnit() const { return nullptr; } +void PageCoordinationUnitImpl::OnFrameLifecycleStateChanged( + FrameCoordinationUnitImpl* frame_cu, + mojom::LifecycleState old_state) { + DCHECK(base::ContainsKey(frame_coordination_units_, frame_cu)); + DCHECK_NE(old_state, frame_cu->lifecycle_state()); + + int delta = 0; + if (old_state == mojom::LifecycleState::kFrozen) + delta = -1; + else if (frame_cu->lifecycle_state() == mojom::LifecycleState::kFrozen) + delta = 1; + if (delta != 0) + OnNumFrozenFramesStateChange(delta); +} + void PageCoordinationUnitImpl::OnEventReceived(mojom::Event event) { - if (event == mojom::Event::kNavigationCommitted) { - navigation_committed_time_ = ResourceCoordinatorClock::NowTicks(); - } for (auto& observer : observers()) observer.OnPageEventReceived(this, event); } @@ -161,14 +175,62 @@ void PageCoordinationUnitImpl::OnPropertyChanged( } bool PageCoordinationUnitImpl::AddFrame(FrameCoordinationUnitImpl* frame_cu) { - return frame_coordination_units_.count(frame_cu) - ? false - : frame_coordination_units_.insert(frame_cu).second; + const bool inserted = frame_coordination_units_.insert(frame_cu).second; + if (inserted) { + OnNumFrozenFramesStateChange( + frame_cu->lifecycle_state() == mojom::LifecycleState::kFrozen ? 1 : 0); + } + return inserted; } bool PageCoordinationUnitImpl::RemoveFrame( FrameCoordinationUnitImpl* frame_cu) { - return frame_coordination_units_.erase(frame_cu) > 0; + bool removed = frame_coordination_units_.erase(frame_cu) > 0; + if (removed) { + OnNumFrozenFramesStateChange( + frame_cu->lifecycle_state() == mojom::LifecycleState::kFrozen ? -1 : 0); + } + return removed; +} + +void PageCoordinationUnitImpl::OnNumFrozenFramesStateChange( + int num_frozen_frames_delta) { + num_frozen_frames_ += num_frozen_frames_delta; + DCHECK_GE(num_frozen_frames_, 0u); + DCHECK_LE(num_frozen_frames_, frame_coordination_units_.size()); + + const int64_t kRunning = + static_cast<int64_t>(mojom::LifecycleState::kRunning); + const int64_t kFrozen = static_cast<int64_t>(mojom::LifecycleState::kFrozen); + + // We are interested in knowing when we have transitioned to or from + // "fully frozen". A page with no frames is considered to be running by + // default. + bool was_fully_frozen = + GetPropertyOrDefault(mojom::PropertyType::kLifecycleState, kRunning) == + kFrozen; + bool is_fully_frozen = frame_coordination_units_.size() > 0 && + num_frozen_frames_ == frame_coordination_units_.size(); + if (was_fully_frozen == is_fully_frozen) + return; + + if (is_fully_frozen) { + // Aggregate the beforeunload handler information from the entire frame + // tree. + bool has_nonempty_beforeunload = false; + for (auto* frame : frame_coordination_units_) { + if (frame->has_nonempty_beforeunload()) { + has_nonempty_beforeunload = true; + break; + } + } + set_has_nonempty_beforeunload(has_nonempty_beforeunload); + } + + // TODO(fdoray): Store the lifecycle state as a member on the + // PageCoordinationUnit rather than as a non-typed property. + SetProperty(mojom::PropertyType::kLifecycleState, + is_fully_frozen ? kFrozen : kRunning); } } // namespace resource_coordinator diff --git a/chromium/services/resource_coordinator/coordination_unit/page_coordination_unit_impl.h b/chromium/services/resource_coordinator/coordination_unit/page_coordination_unit_impl.h index b5de73ae457..db05abeff1a 100644 --- a/chromium/services/resource_coordinator/coordination_unit/page_coordination_unit_impl.h +++ b/chromium/services/resource_coordinator/coordination_unit/page_coordination_unit_impl.h @@ -35,7 +35,8 @@ class PageCoordinationUnitImpl void SetUKMSourceId(int64_t ukm_source_id) override; void OnFaviconUpdated() override; void OnTitleUpdated() override; - void OnMainFrameNavigationCommitted(int64_t navigation_id, + void OnMainFrameNavigationCommitted(base::TimeTicks navigation_committed_time, + int64_t navigation_id, const std::string& url) override; // There is no direct relationship between processes and pages. However, @@ -85,10 +86,17 @@ class PageCoordinationUnitImpl uint64_t private_footprint_kb_estimate) { private_footprint_kb_estimate_ = private_footprint_kb_estimate; } + void set_has_nonempty_beforeunload(bool has_nonempty_beforeunload) { + has_nonempty_beforeunload_ = has_nonempty_beforeunload; + } const std::string& main_frame_url() const { return main_frame_url_; } int64_t navigation_id() const { return navigation_id_; } + // Invoked when the state of a frame in this page changes. + void OnFrameLifecycleStateChanged(FrameCoordinationUnitImpl* frame_cu, + mojom::LifecycleState old_state); + private: friend class FrameCoordinationUnitImpl; @@ -100,6 +108,13 @@ class PageCoordinationUnitImpl bool AddFrame(FrameCoordinationUnitImpl* frame_cu); bool RemoveFrame(FrameCoordinationUnitImpl* frame_cu); + // This is called whenever |num_frozen_frames_| changes, or whenever + // |frame_coordination_units_.size()| changes. It is used to synthesize the + // value of |has_nonempty_beforeunload| and to update the LifecycleState of + // the page. Calling this with |num_frozen_frames_delta == 0| implies that the + // number of frames itself has changed. + void OnNumFrozenFramesStateChange(int num_frozen_frames_delta); + std::set<FrameCoordinationUnitImpl*> frame_coordination_units_; base::TimeTicks visibility_change_time_; @@ -119,6 +134,15 @@ class PageCoordinationUnitImpl // The most current memory footprint estimate. uint64_t private_footprint_kb_estimate_ = 0; + // Counts the number of frames in a page that are frozen. + size_t num_frozen_frames_ = 0; + + // Indicates whether or not this page has a non-empty beforeunload handler. + // This is an aggregation of the same value on each frame in the page's frame + // tree. The aggregation is made at the moment all frames associated with a + // page have transition to frozen. + bool has_nonempty_beforeunload_ = false; + // The URL the main frame last committed a navigation to and the unique ID of // the associated navigation handle. std::string main_frame_url_; diff --git a/chromium/services/resource_coordinator/coordination_unit/page_coordination_unit_impl_unittest.cc b/chromium/services/resource_coordinator/coordination_unit/page_coordination_unit_impl_unittest.cc index cf4075dc2ab..140d7a2c9a3 100644 --- a/chromium/services/resource_coordinator/coordination_unit/page_coordination_unit_impl_unittest.cc +++ b/chromium/services/resource_coordinator/coordination_unit/page_coordination_unit_impl_unittest.cc @@ -173,7 +173,8 @@ TEST_F(PageCoordinationUnitImplTest, TimeSinceLastNavigation) { EXPECT_TRUE(cu_graph.page->TimeSinceLastNavigation().is_zero()); // 1st navigation. - cu_graph.page->OnMainFrameNavigationCommitted(10u, "http://www.example.org"); + cu_graph.page->OnMainFrameNavigationCommitted( + ResourceCoordinatorClock::NowTicks(), 10u, "http://www.example.org"); EXPECT_EQ("http://www.example.org", cu_graph.page->main_frame_url()); EXPECT_EQ(10u, cu_graph.page->navigation_id()); AdvanceClock(base::TimeDelta::FromSeconds(11)); @@ -182,7 +183,8 @@ TEST_F(PageCoordinationUnitImplTest, TimeSinceLastNavigation) { // 2nd navigation. cu_graph.page->OnMainFrameNavigationCommitted( - 20u, "http://www.example.org/bobcat"); + ResourceCoordinatorClock::NowTicks(), 20u, + "http://www.example.org/bobcat"); EXPECT_EQ("http://www.example.org/bobcat", cu_graph.page->main_frame_url()); EXPECT_EQ(20u, cu_graph.page->navigation_id()); AdvanceClock(base::TimeDelta::FromSeconds(17)); @@ -215,4 +217,36 @@ TEST_F(PageCoordinationUnitImplTest, IsLoading) { EXPECT_EQ(0u, loading); } +TEST_F(PageCoordinationUnitImplTest, OnAllFramesInPageFrozen) { + const int64_t kRunning = + static_cast<int64_t>(mojom::LifecycleState::kRunning); + const int64_t kFrozen = static_cast<int64_t>(mojom::LifecycleState::kFrozen); + + MockSinglePageWithMultipleProcessesCoordinationUnitGraph cu_graph( + coordination_unit_graph()); + + EXPECT_EQ(kRunning, cu_graph.page->GetPropertyOrDefault( + mojom::PropertyType::kLifecycleState, kRunning)); + + // 1/2 frames in the page is frozen. Expect the page to still be running. + cu_graph.frame->SetLifecycleState(mojom::LifecycleState::kFrozen); + EXPECT_EQ(kRunning, cu_graph.page->GetPropertyOrDefault( + mojom::PropertyType::kLifecycleState, kRunning)); + + // 2/2 frames in the process are frozen. We expect the page to be frozen. + cu_graph.child_frame->SetLifecycleState(mojom::LifecycleState::kFrozen); + EXPECT_EQ(kFrozen, cu_graph.page->GetPropertyOrDefault( + mojom::PropertyType::kLifecycleState, kRunning)); + + // Unfreeze a frame and expect the page to be running again. + cu_graph.frame->SetLifecycleState(mojom::LifecycleState::kRunning); + EXPECT_EQ(kRunning, cu_graph.page->GetPropertyOrDefault( + mojom::PropertyType::kLifecycleState, kRunning)); + + // Refreeze that frame and expect the page to be frozen again. + cu_graph.frame->SetLifecycleState(mojom::LifecycleState::kFrozen); + EXPECT_EQ(kFrozen, cu_graph.page->GetPropertyOrDefault( + mojom::PropertyType::kLifecycleState, kRunning)); +} + } // namespace resource_coordinator diff --git a/chromium/services/resource_coordinator/coordination_unit/process_coordination_unit_impl.cc b/chromium/services/resource_coordinator/coordination_unit/process_coordination_unit_impl.cc index 18604283154..feb95bad870 100644 --- a/chromium/services/resource_coordinator/coordination_unit/process_coordination_unit_impl.cc +++ b/chromium/services/resource_coordinator/coordination_unit/process_coordination_unit_impl.cc @@ -30,22 +30,8 @@ void ProcessCoordinationUnitImpl::AddFrame(const CoordinationUnitID& cu_id) { DCHECK(cu_id.type == CoordinationUnitType::kFrame); auto* frame_cu = FrameCoordinationUnitImpl::GetCoordinationUnitByID(graph_, cu_id); - if (!frame_cu) - return; - if (AddFrame(frame_cu)) { - frame_cu->AddProcessCoordinationUnit(this); - } -} - -void ProcessCoordinationUnitImpl::RemoveFrame(const CoordinationUnitID& cu_id) { - DCHECK(cu_id != id()); - FrameCoordinationUnitImpl* frame_cu = - FrameCoordinationUnitImpl::GetCoordinationUnitByID(graph_, cu_id); - if (!frame_cu) - return; - if (RemoveFrame(frame_cu)) { - frame_cu->RemoveProcessCoordinationUnit(this); - } + if (frame_cu) + AddFrameImpl(frame_cu); } void ProcessCoordinationUnitImpl::SetCPUUsage(double cpu_usage) { @@ -59,7 +45,8 @@ void ProcessCoordinationUnitImpl::SetExpectedTaskQueueingDuration( } void ProcessCoordinationUnitImpl::SetLaunchTime(base::Time launch_time) { - SetProperty(mojom::PropertyType::kLaunchTime, launch_time.ToTimeT()); + DCHECK(launch_time_.is_null()); + launch_time_ = launch_time; } void ProcessCoordinationUnitImpl::SetMainThreadTaskLoadIsLow( @@ -100,6 +87,18 @@ ProcessCoordinationUnitImpl::GetAssociatedPageCoordinationUnits() const { return page_cus; } +void ProcessCoordinationUnitImpl::OnFrameLifecycleStateChanged( + FrameCoordinationUnitImpl* frame_cu, + mojom::LifecycleState old_state) { + DCHECK(base::ContainsKey(frame_coordination_units_, frame_cu)); + DCHECK_NE(old_state, frame_cu->lifecycle_state()); + + if (old_state == mojom::LifecycleState::kFrozen) + DecrementNumFrozenFrames(); + else if (frame_cu->lifecycle_state() == mojom::LifecycleState::kFrozen) + IncrementNumFrozenFrames(); +} + void ProcessCoordinationUnitImpl::OnEventReceived(mojom::Event event) { for (auto& observer : observers()) observer.OnProcessEventReceived(this, event); @@ -112,17 +111,40 @@ void ProcessCoordinationUnitImpl::OnPropertyChanged( observer.OnProcessPropertyChanged(this, property_type, value); } -bool ProcessCoordinationUnitImpl::AddFrame( +void ProcessCoordinationUnitImpl::AddFrameImpl( FrameCoordinationUnitImpl* frame_cu) { - bool success = frame_coordination_units_.count(frame_cu) - ? false - : frame_coordination_units_.insert(frame_cu).second; - return success; + const bool inserted = frame_coordination_units_.insert(frame_cu).second; + if (inserted) { + frame_cu->AddProcessCoordinationUnit(this); + if (frame_cu->lifecycle_state() == mojom::LifecycleState::kFrozen) + IncrementNumFrozenFrames(); + } } -bool ProcessCoordinationUnitImpl::RemoveFrame( +void ProcessCoordinationUnitImpl::RemoveFrame( FrameCoordinationUnitImpl* frame_cu) { - return frame_coordination_units_.erase(frame_cu) > 0; + DCHECK(base::ContainsKey(frame_coordination_units_, frame_cu)); + frame_coordination_units_.erase(frame_cu); + + if (frame_cu->lifecycle_state() == mojom::LifecycleState::kFrozen) + DecrementNumFrozenFrames(); +} + +void ProcessCoordinationUnitImpl::DecrementNumFrozenFrames() { + --num_frozen_frames_; + DCHECK_GE(num_frozen_frames_, 0); +} + +void ProcessCoordinationUnitImpl::IncrementNumFrozenFrames() { + ++num_frozen_frames_; + DCHECK_LE(num_frozen_frames_, + static_cast<int>(frame_coordination_units_.size())); + + if (num_frozen_frames_ == + static_cast<int>(frame_coordination_units_.size())) { + for (auto& observer : observers()) + observer.OnAllFramesInProcessFrozen(this); + } } } // namespace resource_coordinator diff --git a/chromium/services/resource_coordinator/coordination_unit/process_coordination_unit_impl.h b/chromium/services/resource_coordinator/coordination_unit/process_coordination_unit_impl.h index 97dcef6ab7c..d01b13a77bc 100644 --- a/chromium/services/resource_coordinator/coordination_unit/process_coordination_unit_impl.h +++ b/chromium/services/resource_coordinator/coordination_unit/process_coordination_unit_impl.h @@ -13,7 +13,6 @@ namespace resource_coordinator { class FrameCoordinationUnitImpl; -class ProcessCoordinationUnitImpl; class ProcessCoordinationUnitImpl : public CoordinationUnitInterface<ProcessCoordinationUnitImpl, @@ -30,7 +29,6 @@ class ProcessCoordinationUnitImpl // mojom::ProcessCoordinationUnit implementation. void AddFrame(const CoordinationUnitID& cu_id) override; - void RemoveFrame(const CoordinationUnitID& cu_id) override; void SetCPUUsage(double cpu_usage) override; void SetExpectedTaskQueueingDuration(base::TimeDelta duration) override; void SetLaunchTime(base::Time launch_time) override; @@ -53,25 +51,38 @@ class ProcessCoordinationUnitImpl const; base::ProcessId process_id() const { return process_id_; } + base::Time launch_time() const { return launch_time_; } - private: - friend class FrameCoordinationUnitImpl; + // Removes |frame_cu| from the set of frames hosted by this process. Invoked + // from the destructor of FrameCoordinationUnitImpl. + void RemoveFrame(FrameCoordinationUnitImpl* frame_cu); + + // Invoked when the state of a frame hosted by this process changes. + void OnFrameLifecycleStateChanged(FrameCoordinationUnitImpl* frame_cu, + mojom::LifecycleState old_state); + private: // CoordinationUnitInterface implementation. void OnEventReceived(mojom::Event event) override; void OnPropertyChanged(mojom::PropertyType property_type, int64_t value) override; - bool AddFrame(FrameCoordinationUnitImpl* frame_cu); - bool RemoveFrame(FrameCoordinationUnitImpl* frame_cu); + void AddFrameImpl(FrameCoordinationUnitImpl* frame_cu); + + void DecrementNumFrozenFrames(); + void IncrementNumFrozenFrames(); base::TimeDelta cumulative_cpu_usage_; uint64_t private_footprint_kb_ = 0u; base::ProcessId process_id_ = base::kNullProcessId; + base::Time launch_time_; std::set<FrameCoordinationUnitImpl*> frame_coordination_units_; + // The number of frames hosted by this process that are frozen. + int num_frozen_frames_ = 0; + DISALLOW_COPY_AND_ASSIGN(ProcessCoordinationUnitImpl); }; diff --git a/chromium/services/resource_coordinator/coordination_unit/process_coordination_unit_impl_unittest.cc b/chromium/services/resource_coordinator/coordination_unit/process_coordination_unit_impl_unittest.cc index 6524cb5db66..b2ceb510698 100644 --- a/chromium/services/resource_coordinator/coordination_unit/process_coordination_unit_impl_unittest.cc +++ b/chromium/services/resource_coordinator/coordination_unit/process_coordination_unit_impl_unittest.cc @@ -3,7 +3,11 @@ // found in the LICENSE file. #include "services/resource_coordinator/coordination_unit/process_coordination_unit_impl.h" + #include "services/resource_coordinator/coordination_unit/coordination_unit_test_harness.h" +#include "services/resource_coordinator/coordination_unit/frame_coordination_unit_impl.h" +#include "services/resource_coordinator/coordination_unit/mock_coordination_unit_graphs.h" +#include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h" namespace resource_coordinator { @@ -12,6 +16,22 @@ namespace { class ProcessCoordinationUnitImplTest : public CoordinationUnitTestHarness {}; +class MockCoordinationUnitGraphObserver : public CoordinationUnitGraphObserver { + public: + MockCoordinationUnitGraphObserver() = default; + virtual ~MockCoordinationUnitGraphObserver() = default; + + bool ShouldObserve(const CoordinationUnitBase* coordination_unit) override { + return true; + } + + MOCK_METHOD1(OnAllFramesInProcessFrozen, + void(const ProcessCoordinationUnitImpl*)); + + private: + DISALLOW_COPY_AND_ASSIGN(MockCoordinationUnitGraphObserver); +}; + } // namespace TEST_F(ProcessCoordinationUnitImplTest, MeasureCPUUsage) { @@ -23,4 +43,32 @@ TEST_F(ProcessCoordinationUnitImplTest, MeasureCPUUsage) { EXPECT_EQ(1, cpu_usage / 1000.0); } +TEST_F(ProcessCoordinationUnitImplTest, OnAllFramesInProcessFrozen) { + auto owned_observer = std::make_unique< + testing::StrictMock<MockCoordinationUnitGraphObserver>>(); + auto* observer = owned_observer.get(); + coordination_unit_graph()->RegisterObserver(std::move(owned_observer)); + MockMultiplePagesInSingleProcessCoordinationUnitGraph cu_graph( + coordination_unit_graph()); + + // 1/2 frame in the process is frozen. + // No call to OnAllFramesInProcessFrozen() is expected. + cu_graph.frame->SetLifecycleState(mojom::LifecycleState::kFrozen); + + // 2/2 frames in the process are frozen. + EXPECT_CALL(*observer, OnAllFramesInProcessFrozen(cu_graph.process.get())); + cu_graph.other_frame->SetLifecycleState(mojom::LifecycleState::kFrozen); + testing::Mock::VerifyAndClear(observer); + + // A frame is unfrozen and frozen. + cu_graph.frame->SetLifecycleState(mojom::LifecycleState::kRunning); + EXPECT_CALL(*observer, OnAllFramesInProcessFrozen(cu_graph.process.get())); + cu_graph.frame->SetLifecycleState(mojom::LifecycleState::kFrozen); + testing::Mock::VerifyAndClear(observer); + + // A frozen frame is frozen again. + // No call to OnAllFramesInProcessFrozen() is expected. + cu_graph.frame->SetLifecycleState(mojom::LifecycleState::kFrozen); +} + } // namespace resource_coordinator diff --git a/chromium/services/resource_coordinator/manifest.json b/chromium/services/resource_coordinator/manifest.json index 75be968e1e4..435040977ee 100644 --- a/chromium/services/resource_coordinator/manifest.json +++ b/chromium/services/resource_coordinator/manifest.json @@ -16,6 +16,7 @@ "coordination_unit": [ "resource_coordinator.mojom.CoordinationUnitProvider" ], "heap_profiler_helper": [ "memory_instrumentation.mojom.HeapProfilerHelper" ], "page_signal": [ "resource_coordinator.mojom.PageSignalGenerator" ], + "webui_graph_dump": [ "resource_coordinator.mojom.WebUIGraphDump" ], "tests": [ "*" ] }, "requires": { diff --git a/chromium/services/resource_coordinator/memory_instrumentation/process_map.cc b/chromium/services/resource_coordinator/memory_instrumentation/process_map.cc index 0fa7f97c0d8..a6d32ea6201 100644 --- a/chromium/services/resource_coordinator/memory_instrumentation/process_map.cc +++ b/chromium/services/resource_coordinator/memory_instrumentation/process_map.cc @@ -5,6 +5,7 @@ #include "services/resource_coordinator/memory_instrumentation/process_map.h" #include "base/process/process_handle.h" +#include "base/stl_util.h" #include "mojo/public/cpp/bindings/binding.h" #include "services/resource_coordinator/public/mojom/memory_instrumentation/memory_instrumentation.mojom.h" #include "services/service_manager/public/cpp/connector.h" @@ -33,7 +34,15 @@ void ProcessMap::OnInit(std::vector<RunningServiceInfoPtr> instances) { for (const RunningServiceInfoPtr& instance : instances) { if (instance->pid == base::kNullProcessId) continue; + const service_manager::Identity& identity = instance->identity; + + // TODO(https://crbug.com/818593): The listener interface is racy, so the + // map may contain spurious entries. If so, remove the existing entry before + // adding a new one. + if (base::ContainsKey(instances_, identity)) + OnServiceStopped(identity); + auto it_and_inserted = instances_.emplace(identity, instance->pid); DCHECK(it_and_inserted.second); } diff --git a/chromium/services/resource_coordinator/observers/coordination_unit_graph_observer.h b/chromium/services/resource_coordinator/observers/coordination_unit_graph_observer.h index 6d5efaae6a1..078c33f4a0b 100644 --- a/chromium/services/resource_coordinator/observers/coordination_unit_graph_observer.h +++ b/chromium/services/resource_coordinator/observers/coordination_unit_graph_observer.h @@ -91,6 +91,10 @@ class CoordinationUnitGraphObserver { const SystemCoordinationUnitImpl* system_cu, const mojom::Event event) {} + // Called when all the frames in a process become frozen. + virtual void OnAllFramesInProcessFrozen( + const ProcessCoordinationUnitImpl* process_cu) {} + void set_coordination_unit_graph( CoordinationUnitGraph* coordination_unit_graph) { coordination_unit_graph_ = coordination_unit_graph; diff --git a/chromium/services/resource_coordinator/observers/ipc_volume_reporter_unittest.cc b/chromium/services/resource_coordinator/observers/ipc_volume_reporter_unittest.cc index 8d355678bff..9cf745b6697 100644 --- a/chromium/services/resource_coordinator/observers/ipc_volume_reporter_unittest.cc +++ b/chromium/services/resource_coordinator/observers/ipc_volume_reporter_unittest.cc @@ -55,7 +55,8 @@ TEST_F(IPCVolumeReporterTest, Basic) { cu_graph.page->SetUKMSourceId(1); cu_graph.page->OnFaviconUpdated(); cu_graph.page->OnTitleUpdated(); - cu_graph.page->OnMainFrameNavigationCommitted(1u, "http://example.org"); + cu_graph.page->OnMainFrameNavigationCommitted( + ResourceCoordinatorClock::NowTicks(), 1u, "http://example.org"); cu_graph.process->SetCPUUsage(1.0); cu_graph.process->SetExpectedTaskQueueingDuration( @@ -76,7 +77,7 @@ TEST_F(IPCVolumeReporterTest, Basic) { histogram_tester_.ExpectTotalCount("ResourceCoordinator.IPCPerMinute.Process", 1); histogram_tester_.ExpectUniqueSample( - "ResourceCoordinator.IPCPerMinute.Process", 5, 1); + "ResourceCoordinator.IPCPerMinute.Process", 4, 1); EXPECT_TRUE(reporter_->mock_timer()->IsRunning()); }; diff --git a/chromium/services/resource_coordinator/observers/metrics_collector_unittest.cc b/chromium/services/resource_coordinator/observers/metrics_collector_unittest.cc index 861d14c8d6d..38eae02eec1 100644 --- a/chromium/services/resource_coordinator/observers/metrics_collector_unittest.cc +++ b/chromium/services/resource_coordinator/observers/metrics_collector_unittest.cc @@ -67,7 +67,8 @@ TEST_F(MAYBE_MetricsCollectorTest, FromBackgroundedToFirstAudioStartsUMA) { auto frame_cu = CreateCoordinationUnit<FrameCoordinationUnitImpl>(); page_cu->AddFrame(frame_cu->id()); - page_cu->OnMainFrameNavigationCommitted(kDummyID, kDummyUrl); + page_cu->OnMainFrameNavigationCommitted(ResourceCoordinatorClock::NowTicks(), + kDummyID, kDummyUrl); AdvanceClock(kTestMetricsReportDelayTimeout); page_cu->SetVisibility(true); @@ -119,7 +120,8 @@ TEST_F(MAYBE_MetricsCollectorTest, page_cu->AddFrame(frame_cu->id()); page_cu->SetVisibility(false); - page_cu->OnMainFrameNavigationCommitted(kDummyID, kDummyUrl); + page_cu->OnMainFrameNavigationCommitted(ResourceCoordinatorClock::NowTicks(), + kDummyID, kDummyUrl); frame_cu->SetAudibility(true); // The page is within 5 minutes after main frame navigation was committed, // thus no metrics recorded. @@ -135,7 +137,8 @@ TEST_F(MAYBE_MetricsCollectorTest, TEST_F(MAYBE_MetricsCollectorTest, FromBackgroundedToFirstTitleUpdatedUMA) { auto page_cu = CreateCoordinationUnit<PageCoordinationUnitImpl>(); - page_cu->OnMainFrameNavigationCommitted(kDummyID, kDummyUrl); + page_cu->OnMainFrameNavigationCommitted(ResourceCoordinatorClock::NowTicks(), + kDummyID, kDummyUrl); AdvanceClock(kTestMetricsReportDelayTimeout); page_cu->SetVisibility(true); @@ -167,7 +170,8 @@ TEST_F(MAYBE_MetricsCollectorTest, FromBackgroundedToFirstTitleUpdatedUMA5MinutesTimeout) { auto page_cu = CreateCoordinationUnit<PageCoordinationUnitImpl>(); - page_cu->OnMainFrameNavigationCommitted(kDummyID, kDummyUrl); + page_cu->OnMainFrameNavigationCommitted(ResourceCoordinatorClock::NowTicks(), + kDummyID, kDummyUrl); page_cu->SetVisibility(false); page_cu->OnTitleUpdated(); // The page is within 5 minutes after main frame navigation was committed, @@ -185,7 +189,8 @@ TEST_F(MAYBE_MetricsCollectorTest, FromBackgroundedToFirstAlertFiredUMA) { auto frame_cu = CreateCoordinationUnit<FrameCoordinationUnitImpl>(); page_cu->AddFrame(frame_cu->id()); - page_cu->OnMainFrameNavigationCommitted(kDummyID, kDummyUrl); + page_cu->OnMainFrameNavigationCommitted(ResourceCoordinatorClock::NowTicks(), + kDummyID, kDummyUrl); AdvanceClock(kTestMetricsReportDelayTimeout); page_cu->SetVisibility(true); @@ -219,7 +224,8 @@ TEST_F(MAYBE_MetricsCollectorTest, auto frame_cu = CreateCoordinationUnit<FrameCoordinationUnitImpl>(); page_cu->AddFrame(frame_cu->id()); - page_cu->OnMainFrameNavigationCommitted(kDummyID, kDummyUrl); + page_cu->OnMainFrameNavigationCommitted(ResourceCoordinatorClock::NowTicks(), + kDummyID, kDummyUrl); page_cu->SetVisibility(false); frame_cu->OnAlertFired(); // The page is within 5 minutes after main frame navigation was committed, @@ -238,7 +244,8 @@ TEST_F(MAYBE_MetricsCollectorTest, auto frame_cu = CreateCoordinationUnit<FrameCoordinationUnitImpl>(); page_cu->AddFrame(frame_cu->id()); - page_cu->OnMainFrameNavigationCommitted(kDummyID, kDummyUrl); + page_cu->OnMainFrameNavigationCommitted(ResourceCoordinatorClock::NowTicks(), + kDummyID, kDummyUrl); AdvanceClock(kTestMetricsReportDelayTimeout); page_cu->SetVisibility(true); @@ -273,7 +280,8 @@ TEST_F( auto frame_cu = CreateCoordinationUnit<FrameCoordinationUnitImpl>(); page_cu->AddFrame(frame_cu->id()); - page_cu->OnMainFrameNavigationCommitted(kDummyID, kDummyUrl); + page_cu->OnMainFrameNavigationCommitted(ResourceCoordinatorClock::NowTicks(), + kDummyID, kDummyUrl); page_cu->SetVisibility(false); frame_cu->OnNonPersistentNotificationCreated(); // The page is within 5 minutes after main frame navigation was committed, @@ -289,7 +297,8 @@ TEST_F( TEST_F(MAYBE_MetricsCollectorTest, FromBackgroundedToFirstFaviconUpdatedUMA) { auto page_cu = CreateCoordinationUnit<PageCoordinationUnitImpl>(); - page_cu->OnMainFrameNavigationCommitted(kDummyID, kDummyUrl); + page_cu->OnMainFrameNavigationCommitted(ResourceCoordinatorClock::NowTicks(), + kDummyID, kDummyUrl); AdvanceClock(kTestMetricsReportDelayTimeout); page_cu->SetVisibility(true); @@ -321,7 +330,8 @@ TEST_F(MAYBE_MetricsCollectorTest, FromBackgroundedToFirstFaviconUpdatedUMA5MinutesTimeout) { auto page_cu = CreateCoordinationUnit<PageCoordinationUnitImpl>(); - page_cu->OnMainFrameNavigationCommitted(kDummyID, kDummyUrl); + page_cu->OnMainFrameNavigationCommitted(ResourceCoordinatorClock::NowTicks(), + kDummyID, kDummyUrl); page_cu->SetVisibility(false); page_cu->OnFaviconUpdated(); // The page is within 5 minutes after main frame navigation was committed, @@ -350,7 +360,8 @@ TEST_F(MAYBE_MetricsCollectorTest, ResponsivenessMetric) { GURL url = GURL("https://google.com/foobar"); ukm_recorder.UpdateSourceURL(id, url); page_cu->SetUKMSourceId(id); - page_cu->OnMainFrameNavigationCommitted(kDummyID, kDummyUrl); + page_cu->OnMainFrameNavigationCommitted(ResourceCoordinatorClock::NowTicks(), + kDummyID, kDummyUrl); for (int count = 1; count < kDefaultFrequencyUkmEQTReported; ++count) { process_cu->SetExpectedTaskQueueingDuration( diff --git a/chromium/services/resource_coordinator/observers/page_signal_generator_impl.cc b/chromium/services/resource_coordinator/observers/page_signal_generator_impl.cc index ea9c4bdb717..e6d19a2bf9f 100644 --- a/chromium/services/resource_coordinator/observers/page_signal_generator_impl.cc +++ b/chromium/services/resource_coordinator/observers/page_signal_generator_impl.cc @@ -242,6 +242,7 @@ void PageSignalGeneratorImpl::OnSystemEventReceived( data->last_state_change < measurement_start) { DispatchPageSignal( page, &mojom::PageSignalReceiver::OnLoadTimePerformanceEstimate, + page->TimeSinceLastNavigation(), page->cumulative_cpu_usage_estimate(), page->private_footprint_kb_estimate()); data->performance_estimate_issued = true; diff --git a/chromium/services/resource_coordinator/observers/page_signal_generator_impl_unittest.cc b/chromium/services/resource_coordinator/observers/page_signal_generator_impl_unittest.cc index d11af0ca626..1ba999fe308 100644 --- a/chromium/services/resource_coordinator/observers/page_signal_generator_impl_unittest.cc +++ b/chromium/services/resource_coordinator/observers/page_signal_generator_impl_unittest.cc @@ -63,8 +63,9 @@ class MockPageSignalReceiverImpl : public mojom::PageSignalReceiver { void(const PageNavigationIdentity& page_navigation_id)); MOCK_METHOD1(NotifyRendererIsBloated, void(const PageNavigationIdentity& page_navigation_id)); - MOCK_METHOD3(OnLoadTimePerformanceEstimate, + MOCK_METHOD4(OnLoadTimePerformanceEstimate, void(const PageNavigationIdentity& page_navigation_id, + base::TimeDelta load_duration, base::TimeDelta cpu_usage_estimate, uint64_t private_footprint_kb_estimate)); @@ -308,7 +309,8 @@ void PageSignalGeneratorImplTest::TestPageAlmostIdleTransitions(bool timeout) { EXPECT_FALSE(page_data->idling_timer.IsRunning()); // Post a navigation. The state should reset. - page_cu->OnMainFrameNavigationCommitted(1, "https://www.example.org"); + page_cu->OnMainFrameNavigationCommitted(ResourceCoordinatorClock::NowTicks(), + 1, "https://www.example.org"); EXPECT_EQ(LIS::kLoadingNotStarted, page_data->GetLoadIdleState()); EXPECT_FALSE(page_data->idling_timer.IsRunning()); } @@ -426,7 +428,10 @@ TEST_F(PageSignalGeneratorImplTest, OnLoadTimePerformanceEstimate) { PageSignalGeneratorImpl::PageData* page_data = psg->GetPageData(page_cu); page_data->idling_timer.SetTaskRunner(task_env().GetMainThreadTaskRunner()); - page_cu->OnMainFrameNavigationCommitted(1, "https://www.google.com/"); + base::TimeTicks navigation_committed_time = + ResourceCoordinatorClock::NowTicks(); + page_cu->OnMainFrameNavigationCommitted(navigation_committed_time, 1, + "https://www.google.com/"); DrivePageToLoadedAndIdle(&cu_graph); base::TimeTicks event_time = ResourceCoordinatorClock::NowTicks(); @@ -449,6 +454,7 @@ TEST_F(PageSignalGeneratorImplTest, OnLoadTimePerformanceEstimate) { EXPECT_CALL(mock_receiver, OnLoadTimePerformanceEstimate( IdentityMatches(cu_graph.page->id(), 1u, "https://www.google.com/"), + event_time - navigation_committed_time, base::TimeDelta::FromMicroseconds(15), 150)) .WillOnce( ::testing::InvokeWithoutArgs(&run_loop, &base::RunLoop::Quit)); @@ -459,7 +465,9 @@ TEST_F(PageSignalGeneratorImplTest, OnLoadTimePerformanceEstimate) { ::testing::Mock::VerifyAndClear(&mock_receiver); // Make sure a second run around the state machine generates a second event. - page_cu->OnMainFrameNavigationCommitted(2, "https://example.org/bobcat"); + navigation_committed_time = ResourceCoordinatorClock::NowTicks(); + page_cu->OnMainFrameNavigationCommitted(navigation_committed_time, 2, + "https://example.org/bobcat"); task_env().FastForwardUntilNoTasksRemain(); EXPECT_NE(LIS::kLoadedAndIdle, page_data->GetLoadIdleState()); @@ -477,6 +485,7 @@ TEST_F(PageSignalGeneratorImplTest, OnLoadTimePerformanceEstimate) { OnLoadTimePerformanceEstimate( IdentityMatches(cu_graph.page->id(), 2u, "https://example.org/bobcat"), + event_time - navigation_committed_time, base::TimeDelta::FromMicroseconds(25), 250)) .WillOnce( ::testing::InvokeWithoutArgs(&run_loop, &base::RunLoop::Quit)); diff --git a/chromium/services/resource_coordinator/public/cpp/BUILD.gn b/chromium/services/resource_coordinator/public/cpp/BUILD.gn index 5535fec2e65..3fec0f1625a 100644 --- a/chromium/services/resource_coordinator/public/cpp/BUILD.gn +++ b/chromium/services/resource_coordinator/public/cpp/BUILD.gn @@ -9,7 +9,9 @@ component("resource_coordinator_cpp_base") { "coordination_unit_types.h", ] - defines = [ "SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_BASE_IMPLEMENTATION" ] + configs += [ "//build/config/compiler:wexit_time_destructors" ] + + defines = [ "IS_SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_BASE_IMPL" ] deps = [ "//base", @@ -17,56 +19,42 @@ component("resource_coordinator_cpp_base") { ] } +component("resource_coordinator_cpp_features") { + sources = [ + "resource_coordinator_features.cc", + "resource_coordinator_features.h", + ] + + defines = [ "IS_SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_FEATURES_IMPL" ] + + deps = [ + "//base", + ] +} + component("resource_coordinator_cpp") { sources = [ "frame_resource_coordinator.cc", "frame_resource_coordinator.h", - "memory_instrumentation/client_process_impl.cc", - "memory_instrumentation/client_process_impl.h", - "memory_instrumentation/coordinator.h", - "memory_instrumentation/global_memory_dump.cc", - "memory_instrumentation/global_memory_dump.h", - "memory_instrumentation/memory_instrumentation.cc", - "memory_instrumentation/memory_instrumentation.h", - "memory_instrumentation/os_metrics.cc", - "memory_instrumentation/os_metrics.h", - "memory_instrumentation/os_metrics_linux.cc", - "memory_instrumentation/os_metrics_mac.cc", - "memory_instrumentation/os_metrics_win.cc", - "memory_instrumentation/tracing_observer.cc", - "memory_instrumentation/tracing_observer.h", "page_resource_coordinator.cc", "page_resource_coordinator.h", "process_resource_coordinator.cc", "process_resource_coordinator.h", - "resource_coordinator_features.cc", - "resource_coordinator_features.h", "resource_coordinator_interface.h", "system_resource_coordinator.cc", "system_resource_coordinator.h", ] - if (is_android) { - set_sources_assignment_filter([]) - sources += [ "memory_instrumentation/os_metrics_linux.cc" ] - set_sources_assignment_filter(sources_assignment_filter) - } - - if (is_fuchsia) { - sources += [ "memory_instrumentation/os_metrics_fuchsia.cc" ] - } - - defines = [ "SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_IMPLEMENTATION" ] + configs += [ "//build/config/compiler:wexit_time_destructors" ] - deps = [] - if (is_win) { - deps += [ "//base/win:pe_image" ] - } + defines = [ "IS_SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_IMPL" ] public_deps = [ ":resource_coordinator_cpp_base", + ":resource_coordinator_cpp_features", "//base", "//mojo/public/cpp/bindings", + "//services/resource_coordinator/public/cpp/memory_instrumentation", "//services/resource_coordinator/public/mojom", "//services/service_manager/public/cpp", ] diff --git a/chromium/services/resource_coordinator/public/cpp/coordination_unit_id.h b/chromium/services/resource_coordinator/public/cpp/coordination_unit_id.h index e36f3935bd0..a94a2a1531d 100644 --- a/chromium/services/resource_coordinator/public/cpp/coordination_unit_id.h +++ b/chromium/services/resource_coordinator/public/cpp/coordination_unit_id.h @@ -7,8 +7,8 @@ #include <string> #include <tuple> +#include "base/component_export.h" #include "services/resource_coordinator/public/cpp/coordination_unit_types.h" -#include "services/resource_coordinator/public/cpp/resource_coordinator_base_export.h" namespace resource_coordinator { @@ -17,7 +17,8 @@ namespace resource_coordinator { // would like to move it to base/ as easily as possible at that point. // TODO(oysteine): Rename to CoordinationUnitGUID to better differentiate the // class from the internal id -struct SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_BASE_EXPORT CoordinationUnitID { +struct COMPONENT_EXPORT(SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_BASE) + CoordinationUnitID { typedef uint64_t CoordinationUnitTypeId; enum RandomID { RANDOM_ID }; diff --git a/chromium/services/resource_coordinator/public/cpp/frame_resource_coordinator.h b/chromium/services/resource_coordinator/public/cpp/frame_resource_coordinator.h index 0921fafff6c..c452abdbf80 100644 --- a/chromium/services/resource_coordinator/public/cpp/frame_resource_coordinator.h +++ b/chromium/services/resource_coordinator/public/cpp/frame_resource_coordinator.h @@ -12,7 +12,8 @@ namespace resource_coordinator { -class SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_EXPORT FrameResourceCoordinator +class COMPONENT_EXPORT(SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP) + FrameResourceCoordinator : public ResourceCoordinatorInterface<mojom::FrameCoordinationUnitPtr, mojom::FrameCoordinationUnitRequest> { public: diff --git a/chromium/services/resource_coordinator/public/cpp/memory_instrumentation/BUILD.gn b/chromium/services/resource_coordinator/public/cpp/memory_instrumentation/BUILD.gn new file mode 100644 index 00000000000..1a655864a1c --- /dev/null +++ b/chromium/services/resource_coordinator/public/cpp/memory_instrumentation/BUILD.gn @@ -0,0 +1,47 @@ +# Copyright 2018 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +component("memory_instrumentation") { + sources = [ + "client_process_impl.cc", + "client_process_impl.h", + "coordinator.h", + "global_memory_dump.cc", + "global_memory_dump.h", + "memory_instrumentation.cc", + "memory_instrumentation.h", + "os_metrics.cc", + "os_metrics.h", + "os_metrics_linux.cc", + "os_metrics_mac.cc", + "os_metrics_win.cc", + "tracing_observer.cc", + "tracing_observer.h", + ] + + if (is_android) { + # Disable the rule that excludes _linux.cc files from Android builds. + set_sources_assignment_filter([]) + sources += [ "os_metrics_linux.cc" ] + set_sources_assignment_filter(sources_assignment_filter) + } + + if (is_fuchsia) { + sources += [ "os_metrics_fuchsia.cc" ] + } + + defines = [ "IS_RESOURCE_COORDINATOR_PUBLIC_MEMORY_INSTRUMENTATION_IMPL" ] + + deps = [] + if (is_win) { + deps += [ "//base/win:pe_image" ] + } + + public_deps = [ + "//base", + "//mojo/public/cpp/bindings", + "//services/resource_coordinator/public/mojom", + "//services/service_manager/public/cpp", + ] +} diff --git a/chromium/services/resource_coordinator/public/cpp/memory_instrumentation/client_process_impl.h b/chromium/services/resource_coordinator/public/cpp/memory_instrumentation/client_process_impl.h index 4a311711e70..d672811a45a 100644 --- a/chromium/services/resource_coordinator/public/cpp/memory_instrumentation/client_process_impl.h +++ b/chromium/services/resource_coordinator/public/cpp/memory_instrumentation/client_process_impl.h @@ -6,13 +6,13 @@ #define SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_MEMORY_INSTRUMENTATION_CLIENT_PROCESS_IMPL_H_ #include "base/compiler_specific.h" +#include "base/component_export.h" #include "base/single_thread_task_runner.h" #include "base/synchronization/lock.h" #include "base/trace_event/memory_dump_manager.h" #include "base/trace_event/memory_dump_request_args.h" #include "mojo/public/cpp/bindings/binding.h" #include "services/resource_coordinator/public/cpp/memory_instrumentation/coordinator.h" -#include "services/resource_coordinator/public/cpp/resource_coordinator_export.h" #include "services/resource_coordinator/public/mojom/memory_instrumentation/memory_instrumentation.mojom.h" #include "services/service_manager/public/cpp/connector.h" @@ -28,10 +28,11 @@ class TracingObserver; // no Coordinator service in child processes. So, in a child process, the // local dump manager remotely connects to the Coordinator service. In the // browser process, it locally connects to the Coordinator service. -class SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_EXPORT ClientProcessImpl - : public mojom::ClientProcess { +class COMPONENT_EXPORT(RESOURCE_COORDINATOR_PUBLIC_MEMORY_INSTRUMENTATION) + ClientProcessImpl : public mojom::ClientProcess { public: - struct SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_EXPORT Config { + struct COMPONENT_EXPORT( + RESOURCE_COORDINATOR_PUBLIC_MEMORY_INSTRUMENTATION) Config { public: Config(service_manager::Connector* connector, const std::string& service_name, diff --git a/chromium/services/resource_coordinator/public/cpp/memory_instrumentation/global_memory_dump.h b/chromium/services/resource_coordinator/public/cpp/memory_instrumentation/global_memory_dump.h index ad996fd17a5..6f5cfdc6ab1 100644 --- a/chromium/services/resource_coordinator/public/cpp/memory_instrumentation/global_memory_dump.h +++ b/chromium/services/resource_coordinator/public/cpp/memory_instrumentation/global_memory_dump.h @@ -5,17 +5,19 @@ #ifndef SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_MEMORY_INSTRUMENTATION_GLOBAL_MEMORY_DUMP_H_ #define SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_MEMORY_INSTRUMENTATION_GLOBAL_MEMORY_DUMP_H_ +#include "base/component_export.h" #include "base/optional.h" -#include "services/resource_coordinator/public/cpp/resource_coordinator_export.h" #include "services/resource_coordinator/public/mojom/memory_instrumentation/memory_instrumentation.mojom.h" namespace memory_instrumentation { // The returned data structure to consumers of the memory_instrumentation // service containing dumps for each process. -class SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_EXPORT GlobalMemoryDump { +class COMPONENT_EXPORT(RESOURCE_COORDINATOR_PUBLIC_MEMORY_INSTRUMENTATION) + GlobalMemoryDump { public: - class SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_EXPORT ProcessDump { + class COMPONENT_EXPORT(RESOURCE_COORDINATOR_PUBLIC_MEMORY_INSTRUMENTATION) + ProcessDump { public: ProcessDump(mojom::ProcessMemoryDumpPtr process_memory_dump); ~ProcessDump(); diff --git a/chromium/services/resource_coordinator/public/cpp/memory_instrumentation/memory_instrumentation.cc b/chromium/services/resource_coordinator/public/cpp/memory_instrumentation/memory_instrumentation.cc index c5f114edde9..8ba1e3b2f7e 100644 --- a/chromium/services/resource_coordinator/public/cpp/memory_instrumentation/memory_instrumentation.cc +++ b/chromium/services/resource_coordinator/public/cpp/memory_instrumentation/memory_instrumentation.cc @@ -20,7 +20,7 @@ void WrapGlobalMemoryDump( MemoryInstrumentation::RequestGlobalDumpCallback callback, bool success, mojom::GlobalMemoryDumpPtr dump) { - callback.Run(success, GlobalMemoryDump::MoveFrom(std::move(dump))); + std::move(callback).Run(success, GlobalMemoryDump::MoveFrom(std::move(dump))); } } // namespace @@ -58,7 +58,7 @@ void MemoryInstrumentation::RequestGlobalDump( coordinator->RequestGlobalMemoryDump( MemoryDumpType::SUMMARY_ONLY, MemoryDumpLevelOfDetail::BACKGROUND, allocator_dump_names, - base::BindRepeating(&WrapGlobalMemoryDump, callback)); + base::BindOnce(&WrapGlobalMemoryDump, std::move(callback))); } void MemoryInstrumentation::RequestPrivateMemoryFootprint( @@ -66,7 +66,7 @@ void MemoryInstrumentation::RequestPrivateMemoryFootprint( RequestGlobalDumpCallback callback) { const auto& coordinator = GetCoordinatorBindingForCurrentThread(); coordinator->RequestPrivateMemoryFootprint( - pid, base::BindRepeating(&WrapGlobalMemoryDump, callback)); + pid, base::BindOnce(&WrapGlobalMemoryDump, std::move(callback))); } void MemoryInstrumentation::RequestGlobalDumpForPid( @@ -76,7 +76,7 @@ void MemoryInstrumentation::RequestGlobalDumpForPid( const auto& coordinator = GetCoordinatorBindingForCurrentThread(); coordinator->RequestGlobalMemoryDumpForPid( pid, allocator_dump_names, - base::BindRepeating(&WrapGlobalMemoryDump, callback)); + base::BindOnce(&WrapGlobalMemoryDump, std::move(callback))); } void MemoryInstrumentation::RequestGlobalDumpAndAppendToTrace( @@ -85,7 +85,7 @@ void MemoryInstrumentation::RequestGlobalDumpAndAppendToTrace( RequestGlobalMemoryDumpAndAppendToTraceCallback callback) { const auto& coordinator = GetCoordinatorBindingForCurrentThread(); coordinator->RequestGlobalMemoryDumpAndAppendToTrace( - dump_type, level_of_detail, callback); + dump_type, level_of_detail, std::move(callback)); } const mojom::CoordinatorPtr& diff --git a/chromium/services/resource_coordinator/public/cpp/memory_instrumentation/memory_instrumentation.h b/chromium/services/resource_coordinator/public/cpp/memory_instrumentation/memory_instrumentation.h index 5ba364bd1db..d449685cb55 100644 --- a/chromium/services/resource_coordinator/public/cpp/memory_instrumentation/memory_instrumentation.h +++ b/chromium/services/resource_coordinator/public/cpp/memory_instrumentation/memory_instrumentation.h @@ -6,12 +6,12 @@ #define SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_MEMORY_INSTRUMENTATION_MEMORY_INSTRUMENTATION_H_ #include "base/callback_forward.h" +#include "base/component_export.h" #include "base/memory/ref_counted.h" #include "base/threading/thread_local_storage.h" #include "base/trace_event/memory_dump_request_args.h" #include "services/resource_coordinator/public/cpp/memory_instrumentation/coordinator.h" #include "services/resource_coordinator/public/cpp/memory_instrumentation/global_memory_dump.h" -#include "services/resource_coordinator/public/cpp/resource_coordinator_export.h" #include "services/resource_coordinator/public/mojom/memory_instrumentation/memory_instrumentation.mojom.h" #include "services/service_manager/public/cpp/connector.h" @@ -26,15 +26,16 @@ namespace memory_instrumentation { // memory_instrumentation service and hides away the complexity associated with // having to deal with it (e.g., maintaining service connections, bindings, // handling timeouts). -class SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_EXPORT MemoryInstrumentation { +class COMPONENT_EXPORT(RESOURCE_COORDINATOR_PUBLIC_MEMORY_INSTRUMENTATION) + MemoryInstrumentation { public: using MemoryDumpType = base::trace_event::MemoryDumpType; using MemoryDumpLevelOfDetail = base::trace_event::MemoryDumpLevelOfDetail; using RequestGlobalDumpCallback = - base::RepeatingCallback<void(bool success, - std::unique_ptr<GlobalMemoryDump> dump)>; + base::OnceCallback<void(bool success, + std::unique_ptr<GlobalMemoryDump> dump)>; using RequestGlobalMemoryDumpAndAppendToTraceCallback = - base::RepeatingCallback<void(bool success, uint64_t dump_id)>; + base::OnceCallback<void(bool success, uint64_t dump_id)>; static void CreateInstance(service_manager::Connector*, const std::string& service_name); diff --git a/chromium/services/resource_coordinator/public/cpp/memory_instrumentation/os_metrics.h b/chromium/services/resource_coordinator/public/cpp/memory_instrumentation/os_metrics.h index 7eaec00f28a..53496d5e986 100644 --- a/chromium/services/resource_coordinator/public/cpp/memory_instrumentation/os_metrics.h +++ b/chromium/services/resource_coordinator/public/cpp/memory_instrumentation/os_metrics.h @@ -4,11 +4,11 @@ #ifndef SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_MEMORY_INSTRUMENTATION_OS_METRICS_H_ #define SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_MEMORY_INSTRUMENTATION_OS_METRICS_H_ +#include "base/component_export.h" #include "base/gtest_prod_util.h" #include "base/process/process_handle.h" #include "base/trace_event/process_memory_dump.h" #include "build/build_config.h" -#include "services/resource_coordinator/public/cpp/resource_coordinator_export.h" #include "services/resource_coordinator/public/mojom/memory_instrumentation/memory_instrumentation.mojom.h" namespace heap_profiling { @@ -17,7 +17,8 @@ FORWARD_DECLARE_TEST(ProfilingJsonExporterTest, MemoryMaps); namespace memory_instrumentation { -class SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_EXPORT OSMetrics { +class COMPONENT_EXPORT( + RESOURCE_COORDINATOR_PUBLIC_MEMORY_INSTRUMENTATION) OSMetrics { public: static bool FillOSMemoryDump(base::ProcessId pid, mojom::RawOSMemDump* dump); static bool FillProcessMemoryMaps(base::ProcessId, diff --git a/chromium/services/resource_coordinator/public/cpp/memory_instrumentation/tracing_observer.h b/chromium/services/resource_coordinator/public/cpp/memory_instrumentation/tracing_observer.h index 6950341837b..dca47ff417e 100644 --- a/chromium/services/resource_coordinator/public/cpp/memory_instrumentation/tracing_observer.h +++ b/chromium/services/resource_coordinator/public/cpp/memory_instrumentation/tracing_observer.h @@ -5,10 +5,10 @@ #ifndef SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_MEMORY_INSTRUMENTATION_TRACING_OBSERVER_H #define SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_MEMORY_INSTRUMENTATION_TRACING_OBSERVER_H +#include "base/component_export.h" #include "base/macros.h" #include "base/trace_event/memory_dump_manager.h" #include "base/trace_event/trace_event.h" -#include "services/resource_coordinator/public/cpp/resource_coordinator_export.h" #include "services/resource_coordinator/public/mojom/memory_instrumentation/memory_instrumentation.mojom.h" namespace memory_instrumentation { @@ -16,8 +16,8 @@ namespace memory_instrumentation { // Observes TraceLog for Enable/Disable events and when they occur Enables and // Disables the MemoryDumpManager with the correct state based on reading the // trace log. Also provides a method for adding a dump to the trace. -class SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_EXPORT TracingObserver - : public base::trace_event::TraceLog::EnabledStateObserver { +class COMPONENT_EXPORT(RESOURCE_COORDINATOR_PUBLIC_MEMORY_INSTRUMENTATION) + TracingObserver : public base::trace_event::TraceLog::EnabledStateObserver { public: TracingObserver(base::trace_event::TraceLog*, base::trace_event::MemoryDumpManager*); diff --git a/chromium/services/resource_coordinator/public/cpp/page_resource_coordinator.cc b/chromium/services/resource_coordinator/public/cpp/page_resource_coordinator.cc index 911be476ede..ddf98d43329 100644 --- a/chromium/services/resource_coordinator/public/cpp/page_resource_coordinator.cc +++ b/chromium/services/resource_coordinator/public/cpp/page_resource_coordinator.cc @@ -47,11 +47,13 @@ void PageResourceCoordinator::OnTitleUpdated() { } void PageResourceCoordinator::OnMainFrameNavigationCommitted( + base::TimeTicks navigation_committed_time, uint64_t navigation_id, const std::string& url) { if (!service_) return; - service_->OnMainFrameNavigationCommitted(navigation_id, url); + service_->OnMainFrameNavigationCommitted(navigation_committed_time, + navigation_id, url); } void PageResourceCoordinator::AddFrame(const FrameResourceCoordinator& frame) { diff --git a/chromium/services/resource_coordinator/public/cpp/page_resource_coordinator.h b/chromium/services/resource_coordinator/public/cpp/page_resource_coordinator.h index fad88314f21..d4bc02942ed 100644 --- a/chromium/services/resource_coordinator/public/cpp/page_resource_coordinator.h +++ b/chromium/services/resource_coordinator/public/cpp/page_resource_coordinator.h @@ -7,13 +7,15 @@ #include "base/memory/weak_ptr.h" #include "base/threading/thread_checker.h" +#include "base/time/time.h" #include "services/resource_coordinator/public/cpp/frame_resource_coordinator.h" #include "services/resource_coordinator/public/cpp/resource_coordinator_interface.h" #include "services/resource_coordinator/public/mojom/coordination_unit.mojom.h" namespace resource_coordinator { -class SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_EXPORT PageResourceCoordinator +class COMPONENT_EXPORT(SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP) + PageResourceCoordinator : public ResourceCoordinatorInterface<mojom::PageCoordinationUnitPtr, mojom::PageCoordinationUnitRequest> { public: @@ -25,7 +27,8 @@ class SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_EXPORT PageResourceCoordinator void SetUKMSourceId(int64_t ukm_source_id); void OnFaviconUpdated(); void OnTitleUpdated(); - void OnMainFrameNavigationCommitted(uint64_t navigation_id, + void OnMainFrameNavigationCommitted(base::TimeTicks navigation_committed_time, + uint64_t navigation_id, const std::string& url); void AddFrame(const FrameResourceCoordinator& frame); diff --git a/chromium/services/resource_coordinator/public/cpp/process_resource_coordinator.cc b/chromium/services/resource_coordinator/public/cpp/process_resource_coordinator.cc index 139a2c763d0..725568174cb 100644 --- a/chromium/services/resource_coordinator/public/cpp/process_resource_coordinator.cc +++ b/chromium/services/resource_coordinator/public/cpp/process_resource_coordinator.cc @@ -46,16 +46,6 @@ void ProcessResourceCoordinator::AddFrame( weak_ptr_factory_.GetWeakPtr())); } -void ProcessResourceCoordinator::RemoveFrame( - const FrameResourceCoordinator& frame) { - DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); - if (!service_) - return; - frame.service()->GetID( - base::BindOnce(&ProcessResourceCoordinator::RemoveFrameByID, - weak_ptr_factory_.GetWeakPtr())); -} - void ProcessResourceCoordinator::ConnectToService( mojom::CoordinationUnitProviderPtr& provider, const CoordinationUnitID& cu_id) { @@ -67,10 +57,4 @@ void ProcessResourceCoordinator::AddFrameByID(const CoordinationUnitID& cu_id) { service_->AddFrame(cu_id); } -void ProcessResourceCoordinator::RemoveFrameByID( - const CoordinationUnitID& cu_id) { - DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); - service_->RemoveFrame(cu_id); -} - } // namespace resource_coordinator diff --git a/chromium/services/resource_coordinator/public/cpp/process_resource_coordinator.h b/chromium/services/resource_coordinator/public/cpp/process_resource_coordinator.h index 8bd32642109..383176339a4 100644 --- a/chromium/services/resource_coordinator/public/cpp/process_resource_coordinator.h +++ b/chromium/services/resource_coordinator/public/cpp/process_resource_coordinator.h @@ -14,10 +14,10 @@ namespace resource_coordinator { -class SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_EXPORT ProcessResourceCoordinator - : public ResourceCoordinatorInterface< - mojom::ProcessCoordinationUnitPtr, - mojom::ProcessCoordinationUnitRequest> { +class COMPONENT_EXPORT(SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP) + ProcessResourceCoordinator : public ResourceCoordinatorInterface< + mojom::ProcessCoordinationUnitPtr, + mojom::ProcessCoordinationUnitRequest> { public: ProcessResourceCoordinator(service_manager::Connector* connector); ~ProcessResourceCoordinator() override; @@ -27,14 +27,12 @@ class SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_EXPORT ProcessResourceCoordinator void SetPID(base::ProcessId pid); void AddFrame(const FrameResourceCoordinator& frame); - void RemoveFrame(const FrameResourceCoordinator& frame); private: void ConnectToService(mojom::CoordinationUnitProviderPtr& provider, const CoordinationUnitID& cu_id) override; void AddFrameByID(const CoordinationUnitID& cu_id); - void RemoveFrameByID(const CoordinationUnitID& cu_id); THREAD_CHECKER(thread_checker_); diff --git a/chromium/services/resource_coordinator/public/cpp/resource_coordinator_base_export.h b/chromium/services/resource_coordinator/public/cpp/resource_coordinator_base_export.h deleted file mode 100644 index 1dd7f9052e2..00000000000 --- a/chromium/services/resource_coordinator/public/cpp/resource_coordinator_base_export.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_RESOURCE_COORDINATOR_BASE_EXPORT_H_ -#define SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_RESOURCE_COORDINATOR_BASE_EXPORT_H_ - -#if defined(COMPONENT_BUILD) -#if defined(WIN32) - -#if defined(SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_BASE_IMPLEMENTATION) -#define SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_BASE_EXPORT \ - __declspec(dllexport) -#else -#define SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_BASE_EXPORT \ - __declspec(dllimport) -#endif // defined(SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_BASE_IMPLEMENTATION) - -#else // defined(WIN32) -#if defined(SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_BASE_IMPLEMENTATION) -#define SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_BASE_EXPORT \ - __attribute__((visibility("default"))) -#else -#define SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_BASE_EXPORT -#endif -#endif - -#else // defined(COMPONENT_BUILD) -#define SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_BASE_EXPORT -#endif - -#endif // SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_RESOURCE_COORDINATOR_BASE_EXPORT_H_ diff --git a/chromium/services/resource_coordinator/public/cpp/resource_coordinator_export.h b/chromium/services/resource_coordinator/public/cpp/resource_coordinator_export.h deleted file mode 100644 index c271f9788f8..00000000000 --- a/chromium/services/resource_coordinator/public/cpp/resource_coordinator_export.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2017 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_RESOURCE_COORDINATOR_EXPORT_H_ -#define SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_RESOURCE_COORDINATOR_EXPORT_H_ - -#if defined(COMPONENT_BUILD) -#if defined(WIN32) - -#if defined(SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_IMPLEMENTATION) -#define SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_EXPORT __declspec(dllexport) -#else -#define SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_EXPORT __declspec(dllimport) -#endif // defined(SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_IMPLEMENTATION) - -#else // defined(WIN32) -#if defined(SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_IMPLEMENTATION) -#define SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_EXPORT \ - __attribute__((visibility("default"))) -#else -#define SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_EXPORT -#endif -#endif - -#else // defined(COMPONENT_BUILD) -#define SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_EXPORT -#endif - -#endif // SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_RESOURCE_COORDINATOR_EXPORT_H_ diff --git a/chromium/services/resource_coordinator/public/cpp/resource_coordinator_features.cc b/chromium/services/resource_coordinator/public/cpp/resource_coordinator_features.cc index 9b0ffa0317c..ab457541797 100644 --- a/chromium/services/resource_coordinator/public/cpp/resource_coordinator_features.cc +++ b/chromium/services/resource_coordinator/public/cpp/resource_coordinator_features.cc @@ -18,10 +18,6 @@ constexpr char kMainThreadTaskLoadLowThresholdParameterName[] = namespace features { -// Globally enable the GRC. -const base::Feature kGlobalResourceCoordinator{ - "GlobalResourceCoordinator", base::FEATURE_ENABLED_BY_DEFAULT}; - const base::Feature kPageAlmostIdle{"PageAlmostIdle", base::FEATURE_DISABLED_BY_DEFAULT}; @@ -33,10 +29,6 @@ const base::Feature kPerformanceMeasurement{"PerformanceMeasurement", namespace resource_coordinator { -bool IsResourceCoordinatorEnabled() { - return base::FeatureList::IsEnabled(features::kGlobalResourceCoordinator); -} - bool IsPageAlmostIdleSignalEnabled() { return base::FeatureList::IsEnabled(features::kPageAlmostIdle); } diff --git a/chromium/services/resource_coordinator/public/cpp/resource_coordinator_features.h b/chromium/services/resource_coordinator/public/cpp/resource_coordinator_features.h index 14f1fba6e5a..297546408d1 100644 --- a/chromium/services/resource_coordinator/public/cpp/resource_coordinator_features.h +++ b/chromium/services/resource_coordinator/public/cpp/resource_coordinator_features.h @@ -8,32 +8,29 @@ #ifndef SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_RESOURCE_COORDINATOR_FEATURES_H_ #define SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_RESOURCE_COORDINATOR_FEATURES_H_ +#include "base/component_export.h" #include "base/feature_list.h" -#include "services/resource_coordinator/public/cpp/resource_coordinator_export.h" namespace features { // The features should be documented alongside the definition of their values // in the .cc file. -extern const SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_EXPORT base::Feature - kGlobalResourceCoordinator; -extern const SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_EXPORT base::Feature - kPageAlmostIdle; -extern const SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_EXPORT base::Feature - kPerformanceMeasurement; +extern const COMPONENT_EXPORT(SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_FEATURES) + base::Feature kGlobalResourceCoordinator; +extern const COMPONENT_EXPORT(SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_FEATURES) + base::Feature kPageAlmostIdle; +extern const COMPONENT_EXPORT(SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_FEATURES) + base::Feature kPerformanceMeasurement; } // namespace features namespace resource_coordinator { -bool SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_EXPORT -IsResourceCoordinatorEnabled(); +bool COMPONENT_EXPORT(SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_FEATURES) + IsPageAlmostIdleSignalEnabled(); -bool SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_EXPORT -IsPageAlmostIdleSignalEnabled(); - -int SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_EXPORT -GetMainThreadTaskLoadLowThreshold(); +int COMPONENT_EXPORT(SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_FEATURES) + GetMainThreadTaskLoadLowThreshold(); } // namespace resource_coordinator diff --git a/chromium/services/resource_coordinator/public/cpp/resource_coordinator_interface.h b/chromium/services/resource_coordinator/public/cpp/resource_coordinator_interface.h index 3c7a7ec0e4f..c949abee4f1 100644 --- a/chromium/services/resource_coordinator/public/cpp/resource_coordinator_interface.h +++ b/chromium/services/resource_coordinator/public/cpp/resource_coordinator_interface.h @@ -7,9 +7,9 @@ #include <stdint.h> +#include "base/component_export.h" #include "base/macros.h" #include "services/resource_coordinator/public/cpp/coordination_unit_id.h" -#include "services/resource_coordinator/public/cpp/resource_coordinator_export.h" #include "services/resource_coordinator/public/mojom/coordination_unit_provider.mojom.h" #include "services/resource_coordinator/public/mojom/service_constants.mojom.h" #include "services/service_manager/public/cpp/connector.h" diff --git a/chromium/services/resource_coordinator/public/cpp/system_resource_coordinator.h b/chromium/services/resource_coordinator/public/cpp/system_resource_coordinator.h index 3f205ffed8e..2ce3258f0dd 100644 --- a/chromium/services/resource_coordinator/public/cpp/system_resource_coordinator.h +++ b/chromium/services/resource_coordinator/public/cpp/system_resource_coordinator.h @@ -10,10 +10,10 @@ namespace resource_coordinator { -class SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP_EXPORT SystemResourceCoordinator - : public ResourceCoordinatorInterface< - mojom::SystemCoordinationUnitPtr, - mojom::SystemCoordinationUnitRequest> { +class COMPONENT_EXPORT(SERVICES_RESOURCE_COORDINATOR_PUBLIC_CPP) + SystemResourceCoordinator : public ResourceCoordinatorInterface< + mojom::SystemCoordinationUnitPtr, + mojom::SystemCoordinationUnitRequest> { public: SystemResourceCoordinator(service_manager::Connector* connector); ~SystemResourceCoordinator() override; diff --git a/chromium/services/resource_coordinator/public/mojom/BUILD.gn b/chromium/services/resource_coordinator/public/mojom/BUILD.gn index 929c52ff6ba..8211d8d8e2f 100644 --- a/chromium/services/resource_coordinator/public/mojom/BUILD.gn +++ b/chromium/services/resource_coordinator/public/mojom/BUILD.gn @@ -18,6 +18,7 @@ mojom_component("mojom") { "page_signal.mojom", "service_constants.mojom", "signals.mojom", + "webui_graph_dump.mojom", ] public_deps = [ diff --git a/chromium/services/resource_coordinator/public/mojom/coordination_unit.mojom b/chromium/services/resource_coordinator/public/mojom/coordination_unit.mojom index a07ee41bb47..e43d364155a 100644 --- a/chromium/services/resource_coordinator/public/mojom/coordination_unit.mojom +++ b/chromium/services/resource_coordinator/public/mojom/coordination_unit.mojom @@ -63,6 +63,7 @@ interface FrameCoordinationUnit { SetAudibility(bool audible); SetNetworkAlmostIdle(bool idle); SetLifecycleState(LifecycleState state); + SetHasNonEmptyBeforeUnload(bool has_nonempty_beforeunload); // Event signals. OnAlertFired(); @@ -89,9 +90,12 @@ interface PageCoordinationUnit { OnFaviconUpdated(); OnTitleUpdated(); + // |navigation_committed_time| is the time when the commit occurred. // |navigation_id| is the unique ID of the navigation handle that was // committed. - OnMainFrameNavigationCommitted(int64 navigation_id, string url); + OnMainFrameNavigationCommitted( + mojo_base.mojom.TimeTicks navigation_committed_time, + int64 navigation_id, string url); }; interface ProcessCoordinationUnit { @@ -103,7 +107,6 @@ interface ProcessCoordinationUnit { // Add a new binding to an existing ProcessCoordinationUnit. AddBinding(ProcessCoordinationUnit& request); AddFrame(CoordinationUnitID cu_id); - RemoveFrame(CoordinationUnitID cu_id); // Property signals. SetCPUUsage(double cpu_usage); diff --git a/chromium/services/resource_coordinator/public/mojom/page_signal.mojom b/chromium/services/resource_coordinator/public/mojom/page_signal.mojom index 584dfc126c5..ecb2c2c29b7 100644 --- a/chromium/services/resource_coordinator/public/mojom/page_signal.mojom +++ b/chromium/services/resource_coordinator/public/mojom/page_signal.mojom @@ -53,7 +53,17 @@ interface PageSignalReceiver { // This notification needs the url that was loaded, as by the time the // notification comes back around, the WebContents may have navigated to // another site altogether. + // |load_duration| is the wall-clock duration from navigation commit, until + // the page is considered loaded (currently almost idle). + // |cpu_usage_estimate| is an estimate of how much CPU time was consumed by + // this page load across all the the processes involved. This is approximate + // primarily because it's impossible to accurately approportion the cost + // of shared processes to individual pages. Other considerations involve + // the timing of measurement, as well as the fact that there's no accounting + // for processes that contributed to a page in the past, but no longer do + // so at the time of measurement. OnLoadTimePerformanceEstimate(PageNavigationIdentity page_navigation_id, + mojo_base.mojom.TimeDelta load_duration, mojo_base.mojom.TimeDelta cpu_usage_estimate, uint64 private_footprint_kb_estimate); }; diff --git a/chromium/services/resource_coordinator/public/mojom/signals.mojom b/chromium/services/resource_coordinator/public/mojom/signals.mojom index f7a1baa8234..13f7196a30f 100644 --- a/chromium/services/resource_coordinator/public/mojom/signals.mojom +++ b/chromium/services/resource_coordinator/public/mojom/signals.mojom @@ -28,7 +28,6 @@ enum PropertyType { kAudible, kCPUUsage, kExpectedTaskQueueingDuration, - kLaunchTime, kMainThreadTaskLoadIsLow, // Network is considered almost idle when there's no more than 2 network // connections. diff --git a/chromium/services/resource_coordinator/public/mojom/webui_graph_dump.mojom b/chromium/services/resource_coordinator/public/mojom/webui_graph_dump.mojom new file mode 100644 index 00000000000..0a320743c8b --- /dev/null +++ b/chromium/services/resource_coordinator/public/mojom/webui_graph_dump.mojom @@ -0,0 +1,53 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// This file exposes an interface to the chrome://discards Web UI to allow +// viewing and exploring the CU graph. + +module resource_coordinator.mojom; + +import "mojo/public/mojom/base/process_id.mojom"; +import "mojo/public/mojom/base/time.mojom"; + +// Represents the momentary state of a Page CU. +struct WebUIPageInfo { + int64 id; + + int64 main_frame_id; + + string main_frame_url; + + // TODO(siggi): Estimate data. +}; + +// Represents the momentary state of a Frame CU. +struct WebUIFrameInfo { + int64 id; + + int64 parent_frame_id; + int64 process_id; +}; + +// Represents the momentary state of a Process CU. +struct WebUIProcessInfo { + int64 id; + + mojo_base.mojom.ProcessId pid; + mojo_base.mojom.TimeDelta cumulative_cpu_usage; + uint64 private_footprint_kb; +}; + +// Represents the momentary state of an entire RC graph. +struct WebUIGraph { + array<WebUIPageInfo> pages; + array<WebUIFrameInfo> frames; + array<WebUIProcessInfo> processes; +}; + +// This interface allows grabbing the momentary state of the RC graph for +// visualization or inspection. This is exposed on the RC service, and used +// from the chrome://discards WebUI graph view tab. +interface WebUIGraphDump { + GetCurrentGraph() => (WebUIGraph graph); +}; diff --git a/chromium/services/resource_coordinator/resource_coordinator_service.cc b/chromium/services/resource_coordinator/resource_coordinator_service.cc index a9557ae252c..57fc53adeee 100644 --- a/chromium/services/resource_coordinator/resource_coordinator_service.cc +++ b/chromium/services/resource_coordinator/resource_coordinator_service.cc @@ -69,6 +69,8 @@ void ResourceCoordinatorService::OnStart() { registry_.AddInterface(base::BindRepeating( &memory_instrumentation::CoordinatorImpl::BindHeapProfilerHelperRequest, base::Unretained(memory_instrumentation_coordinator_.get()))); + registry_.AddInterface(base::BindRepeating( + &ResourceCoordinatorService::BindWebUIGraphDump, base::Unretained(this))); } void ResourceCoordinatorService::OnBindInterface( @@ -79,4 +81,31 @@ void ResourceCoordinatorService::OnBindInterface( source_info); } +void ResourceCoordinatorService::BindWebUIGraphDump( + mojom::WebUIGraphDumpRequest request, + const service_manager::BindSourceInfo& source_info) { + std::unique_ptr<WebUIGraphDumpImpl> graph_dump = + std::make_unique<WebUIGraphDumpImpl>(&coordination_unit_graph_); + + auto error_callback = + base::BindOnce(&ResourceCoordinatorService::OnGraphDumpConnectionError, + base::Unretained(this), graph_dump.get()); + graph_dump->Bind(std::move(request), std::move(error_callback)); + + graph_dumps_.push_back(std::move(graph_dump)); +} + +void ResourceCoordinatorService::OnGraphDumpConnectionError( + WebUIGraphDumpImpl* graph_dump) { + const auto it = std::find_if( + graph_dumps_.begin(), graph_dumps_.end(), + [graph_dump](const std::unique_ptr<WebUIGraphDumpImpl>& graph_dump_ptr) { + return graph_dump_ptr.get() == graph_dump; + }); + + DCHECK(it != graph_dumps_.end()); + + graph_dumps_.erase(it); +} + } // namespace resource_coordinator diff --git a/chromium/services/resource_coordinator/resource_coordinator_service.h b/chromium/services/resource_coordinator/resource_coordinator_service.h index 2e207825e16..e02532f4594 100644 --- a/chromium/services/resource_coordinator/resource_coordinator_service.h +++ b/chromium/services/resource_coordinator/resource_coordinator_service.h @@ -7,6 +7,7 @@ #include <memory> #include <string> +#include <vector> #include "base/callback.h" #include "base/macros.h" @@ -15,6 +16,7 @@ #include "services/resource_coordinator/coordination_unit/coordination_unit_graph.h" #include "services/resource_coordinator/coordination_unit/coordination_unit_introspector_impl.h" #include "services/resource_coordinator/memory_instrumentation/coordinator_impl.h" +#include "services/resource_coordinator/webui_graph_dump_impl.h" #include "services/service_manager/public/cpp/binder_registry.h" #include "services/service_manager/public/cpp/service.h" #include "services/service_manager/public/cpp/service_context_ref.h" @@ -45,6 +47,10 @@ class ResourceCoordinatorService : public service_manager::Service { } private: + void BindWebUIGraphDump(mojom::WebUIGraphDumpRequest request, + const service_manager::BindSourceInfo& source_info); + void OnGraphDumpConnectionError(WebUIGraphDumpImpl* graph_dump); + service_manager::BinderRegistryWithArgs< const service_manager::BindSourceInfo&> registry_; @@ -55,6 +61,9 @@ class ResourceCoordinatorService : public service_manager::Service { memory_instrumentation_coordinator_; std::unique_ptr<service_manager::ServiceContextRefFactory> ref_factory_; + // Current graph dump instances. + std::vector<std::unique_ptr<WebUIGraphDumpImpl>> graph_dumps_; + // WeakPtrFactory members should always come last so WeakPtrs are destructed // before other members. base::WeakPtrFactory<ResourceCoordinatorService> weak_factory_; diff --git a/chromium/services/resource_coordinator/webui_graph_dump_impl.cc b/chromium/services/resource_coordinator/webui_graph_dump_impl.cc new file mode 100644 index 00000000000..42ef9c9d560 --- /dev/null +++ b/chromium/services/resource_coordinator/webui_graph_dump_impl.cc @@ -0,0 +1,82 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "services/resource_coordinator/webui_graph_dump_impl.h" + +#include "base/macros.h" +#include "services/resource_coordinator/coordination_unit/coordination_unit_graph.h" +#include "services/resource_coordinator/coordination_unit/frame_coordination_unit_impl.h" +#include "services/resource_coordinator/coordination_unit/page_coordination_unit_impl.h" +#include "services/resource_coordinator/coordination_unit/process_coordination_unit_impl.h" + +namespace resource_coordinator { + +WebUIGraphDumpImpl::WebUIGraphDumpImpl(CoordinationUnitGraph* graph) + : graph_(graph), binding_(this) { + DCHECK(graph); +} + +WebUIGraphDumpImpl::~WebUIGraphDumpImpl() {} + +void WebUIGraphDumpImpl::Bind(mojom::WebUIGraphDumpRequest request, + base::OnceClosure error_handler) { + binding_.Bind(std::move(request)); + binding_.set_connection_error_handler(std::move(error_handler)); +} + +void WebUIGraphDumpImpl::GetCurrentGraph(GetCurrentGraphCallback callback) { + mojom::WebUIGraphPtr graph = mojom::WebUIGraph::New(); + + { + auto processes = graph_->GetAllProcessCoordinationUnits(); + graph->processes.reserve(processes.size()); + for (auto* process : processes) { + mojom::WebUIProcessInfoPtr process_info = mojom::WebUIProcessInfo::New(); + + process_info->id = process->id().id; + process_info->pid = process->process_id(); + process_info->cumulative_cpu_usage = process->cumulative_cpu_usage(); + process_info->private_footprint_kb = process->private_footprint_kb(); + + graph->processes.push_back(std::move(process_info)); + } + } + + { + auto frames = graph_->GetAllFrameCoordinationUnits(); + graph->frames.reserve(frames.size()); + for (auto* frame : frames) { + mojom::WebUIFrameInfoPtr frame_info = mojom::WebUIFrameInfo::New(); + + frame_info->id = frame->id().id; + + auto* parent_frame = frame->GetParentFrameCoordinationUnit(); + frame_info->parent_frame_id = parent_frame ? parent_frame->id().id : 0; + + auto* process = frame->GetProcessCoordinationUnit(); + frame_info->process_id = process ? process->id().id : 0; + + graph->frames.push_back(std::move(frame_info)); + } + } + + { + auto pages = graph_->GetAllPageCoordinationUnits(); + graph->pages.reserve(pages.size()); + for (auto* page : pages) { + mojom::WebUIPageInfoPtr page_info = mojom::WebUIPageInfo::New(); + + page_info->id = page->id().id; + page_info->main_frame_url = page->main_frame_url(); + + auto* main_frame = page->GetMainFrameCoordinationUnit(); + page_info->main_frame_id = main_frame ? main_frame->id().id : 0; + + graph->pages.push_back(std::move(page_info)); + } + } + std::move(callback).Run(std::move(graph)); +} + +} // namespace resource_coordinator diff --git a/chromium/services/resource_coordinator/webui_graph_dump_impl.h b/chromium/services/resource_coordinator/webui_graph_dump_impl.h new file mode 100644 index 00000000000..9b75d0f8905 --- /dev/null +++ b/chromium/services/resource_coordinator/webui_graph_dump_impl.h @@ -0,0 +1,36 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef SERVICES_RESOURCE_COORDINATOR_WEBUI_GRAPH_DUMP_IMPL_H_ +#define SERVICES_RESOURCE_COORDINATOR_WEBUI_GRAPH_DUMP_IMPL_H_ + +#include "mojo/public/cpp/bindings/binding.h" +#include "services/resource_coordinator/public/mojom/webui_graph_dump.mojom.h" + +namespace resource_coordinator { + +class CoordinationUnitGraph; + +class WebUIGraphDumpImpl : public mojom::WebUIGraphDump { + public: + explicit WebUIGraphDumpImpl(CoordinationUnitGraph* graph); + ~WebUIGraphDumpImpl() override; + + // WebUIGraphDump implementation. + void GetCurrentGraph(GetCurrentGraphCallback callback) override; + + // Bind this instance to |request| with the |error_handler|. + void Bind(mojom::WebUIGraphDumpRequest request, + base::OnceClosure error_handler); + + private: + CoordinationUnitGraph* graph_; + mojo::Binding<mojom::WebUIGraphDump> binding_; + + DISALLOW_COPY_AND_ASSIGN(WebUIGraphDumpImpl); +}; + +} // namespace resource_coordinator + +#endif // SERVICES_RESOURCE_COORDINATOR_WEBUI_GRAPH_DUMP_IMPL_H_ diff --git a/chromium/services/resource_coordinator/webui_graph_dump_impl_unittest.cc b/chromium/services/resource_coordinator/webui_graph_dump_impl_unittest.cc new file mode 100644 index 00000000000..be205936776 --- /dev/null +++ b/chromium/services/resource_coordinator/webui_graph_dump_impl_unittest.cc @@ -0,0 +1,67 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "services/resource_coordinator/webui_graph_dump_impl.h" + +#include "base/test/bind_test_util.h" +#include "base/time/time.h" +#include "services/resource_coordinator/coordination_unit/coordination_unit_test_harness.h" +#include "services/resource_coordinator/coordination_unit/mock_coordination_unit_graphs.h" +#include "services/resource_coordinator/coordination_unit/page_coordination_unit_impl.h" +#include "services/resource_coordinator/resource_coordinator_clock.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace resource_coordinator { + +class WebUIGraphDumpImplTest : public CoordinationUnitTestHarness {}; + +TEST_F(WebUIGraphDumpImplTest, Create) { + CoordinationUnitGraph graph; + MockMultiplePagesWithMultipleProcessesCoordinationUnitGraph cu_graph(&graph); + + base::TimeTicks now = ResourceCoordinatorClock::NowTicks(); + + constexpr char kExampleUrl[] = "http://www.example.org"; + cu_graph.page->OnMainFrameNavigationCommitted(now, 1, kExampleUrl); + cu_graph.other_page->OnMainFrameNavigationCommitted(now, 2, kExampleUrl); + + WebUIGraphDumpImpl impl(&graph); + + mojom::WebUIGraphPtr returned_graph; + WebUIGraphDumpImpl::GetCurrentGraphCallback callback = + base::BindLambdaForTesting([&returned_graph](mojom::WebUIGraphPtr graph) { + returned_graph = std::move(graph); + }); + impl.GetCurrentGraph(std::move(callback)); + + task_env().RunUntilIdle(); + + ASSERT_NE(nullptr, returned_graph.get()); + EXPECT_EQ(2u, returned_graph->pages.size()); + for (const auto& page : returned_graph->pages) { + EXPECT_NE(0u, page->id); + EXPECT_NE(0u, page->main_frame_id); + } + + EXPECT_EQ(3u, returned_graph->frames.size()); + // Count the top-level frames as we go. + size_t top_level_frames = 0; + for (const auto& frame : returned_graph->frames) { + if (frame->parent_frame_id == 0) + ++top_level_frames; + EXPECT_NE(0u, frame->id); + EXPECT_NE(0u, frame->process_id); + } + // Make sure we have one top-level frame per page. + EXPECT_EQ(returned_graph->pages.size(), top_level_frames); + + EXPECT_EQ(2u, returned_graph->processes.size()); + for (const auto& page : returned_graph->pages) { + EXPECT_NE(0u, page->id); + EXPECT_NE(0u, page->main_frame_id); + EXPECT_EQ(kExampleUrl, page->main_frame_url); + } +} + +} // namespace resource_coordinator diff --git a/chromium/services/service_manager/BUILD.gn b/chromium/services/service_manager/BUILD.gn index 55cdc1339c4..d07adc739c7 100644 --- a/chromium/services/service_manager/BUILD.gn +++ b/chromium/services/service_manager/BUILD.gn @@ -37,6 +37,8 @@ source_set("service_manager") { "switches.h", ] + configs += [ "//build/config/compiler:wexit_time_destructors" ] + deps = [ "//base/third_party/dynamic_annotations", ] @@ -52,8 +54,3 @@ source_set("service_manager") { "//services/service_manager/sandbox", ] } - -service_manifest("manifest") { - name = "service_manager" - source = "manifest.json" -} diff --git a/chromium/services/service_manager/manifest.json b/chromium/services/service_manager/manifest.json deleted file mode 100644 index 368a0a690b0..00000000000 --- a/chromium/services/service_manager/manifest.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "manifest_version": 1, - "name": "service_manager", - "display_name": "Service Manager", - "interface_provider_specs": { - "service_manager:connector": { - // NOTE: This manifest is for documentation purposes only. Relevant - // capability spec is defined inline in the ServiceManager implementation. - // - // TODO(rockot): Fix this. We can bake this file into ServiceManager at - // build time or something. Same with service:catalog. - "provides": { - // Clients requesting this class are able to connect to other clients as - // specific users other than their own. - "service_manager:user_id": [ ], - // Clients requesting this class are allowed to register clients for - // processes they launch themselves. - "service_manager:client_process": [ ], - // Clients requesting this class are allowed to connect to other clients - // in specific process instance groups. - "service_manager:instance_name": [ ], - "service_manager:block_wildcard": [ ], - - "service_manager:service_manager": [ - "service_manager.mojom.ServiceManager" - ] - }, - "requires": { - "*": [ "service_manager:service_factory" ], - "tracing": [ "app" ] - } - } - } -} diff --git a/chromium/services/service_manager/public/cpp/BUILD.gn b/chromium/services/service_manager/public/cpp/BUILD.gn index 40125676ebf..311f3afaa58 100644 --- a/chromium/services/service_manager/public/cpp/BUILD.gn +++ b/chromium/services/service_manager/public/cpp/BUILD.gn @@ -17,6 +17,8 @@ component("cpp") { "local_interface_provider.h", "service.cc", "service.h", + "service_binding.cc", + "service_binding.h", "service_context.cc", "service_context.h", "service_context_ref.cc", @@ -27,6 +29,8 @@ component("cpp") { "service_runner.h", ] + configs += [ "//build/config/compiler:wexit_time_destructors" ] + public_deps = [ ":cpp_types", "//base", @@ -37,7 +41,12 @@ component("cpp") { "//url", ] - defines = [ "SERVICE_MANAGER_PUBLIC_CPP_IMPL" ] + defines = [ + "IS_SERVICE_MANAGER_CPP_IMPL", + + # TODO: Use COMPONENT_EXPORT everywhere here and remove this. + "SERVICE_MANAGER_PUBLIC_CPP_IMPL", + ] } # A component for types which the public interfaces depend on for typemapping. @@ -56,6 +65,8 @@ component("cpp_types") { "types_export.h", ] + configs += [ "//build/config/compiler:wexit_time_destructors" ] + deps = [ "//services/service_manager/public/mojom:constants", ] diff --git a/chromium/services/service_manager/public/cpp/service.cc b/chromium/services/service_manager/public/cpp/service.cc index 2030d0364a2..11e6996b89e 100644 --- a/chromium/services/service_manager/public/cpp/service.cc +++ b/chromium/services/service_manager/public/cpp/service.cc @@ -19,6 +19,8 @@ void Service::OnBindInterface(const BindSourceInfo& source, const std::string& interface_name, mojo::ScopedMessagePipeHandle interface_pipe) {} +void Service::OnDisconnected() {} + bool Service::OnServiceManagerConnectionLost() { return true; } diff --git a/chromium/services/service_manager/public/cpp/service.h b/chromium/services/service_manager/public/cpp/service.h index b0242905342..cf17566f965 100644 --- a/chromium/services/service_manager/public/cpp/service.h +++ b/chromium/services/service_manager/public/cpp/service.h @@ -7,9 +7,9 @@ #include <string> +#include "base/component_export.h" #include "base/macros.h" #include "mojo/public/cpp/system/message_pipe.h" -#include "services/service_manager/public/cpp/export.h" namespace service_manager { @@ -18,7 +18,7 @@ struct BindSourceInfo; // The primary contract between a Service and the Service Manager, receiving // lifecycle notifications and connection requests. -class SERVICE_MANAGER_PUBLIC_CPP_EXPORT Service { +class COMPONENT_EXPORT(SERVICE_MANAGER_CPP) Service { public: Service(); virtual ~Service(); @@ -37,6 +37,14 @@ class SERVICE_MANAGER_PUBLIC_CPP_EXPORT Service { const std::string& interface_name, mojo::ScopedMessagePipeHandle interface_pipe); + // Called when the Service Manager has stopped tracking this instance. Once + // invoked, no further Service interface methods will be called on this + // Service, and no further communication with the Service Manager is possible. + // + // The Service may continue to operate and service existing client connections + // as it deems appropriate. + virtual void OnDisconnected(); + // Called when the Service Manager has stopped tracking this instance. The // service should use this as a signal to shut down, and in fact its process // may be reaped shortly afterward if applicable. @@ -49,6 +57,9 @@ class SERVICE_MANAGER_PUBLIC_CPP_EXPORT Service { // // NOTE: This may be called at any time, and once it's been called, none of // the other public Service methods will be invoked by the ServiceContext. + // + // This is ONLY invoked when using a ServiceContext and is therefore + // deprecated. virtual bool OnServiceManagerConnectionLost(); protected: @@ -72,7 +83,7 @@ class SERVICE_MANAGER_PUBLIC_CPP_EXPORT Service { // TODO(rockot): Remove this. It's here to satisfy a few remaining use cases // where a Service impl is owned by something other than its ServiceContext. -class SERVICE_MANAGER_PUBLIC_CPP_EXPORT ForwardingService : public Service { +class COMPONENT_EXPORT(SERVICE_MANAGER_CPP) ForwardingService : public Service { public: // |target| must outlive this object. explicit ForwardingService(Service* target); diff --git a/chromium/services/service_manager/public/cpp/service_binding.cc b/chromium/services/service_manager/public/cpp/service_binding.cc new file mode 100644 index 00000000000..cac9eef3706 --- /dev/null +++ b/chromium/services/service_manager/public/cpp/service_binding.cc @@ -0,0 +1,94 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "services/service_manager/public/cpp/service_binding.h" + +#include <utility> + +#include "base/bind.h" +#include "services/service_manager/public/cpp/service.h" + +#include "base/debug/stack_trace.h" + +namespace service_manager { + +ServiceBinding::ServiceBinding(service_manager::Service* service) + : service_(service), binding_(this) { + DCHECK(service_); +} + +ServiceBinding::ServiceBinding(service_manager::Service* service, + mojom::ServiceRequest request) + : ServiceBinding(service) { + if (request.is_pending()) + Bind(std::move(request)); +} + +ServiceBinding::~ServiceBinding() = default; + +Connector* ServiceBinding::GetConnector() { + if (!connector_) + connector_ = Connector::Create(&pending_connector_request_); + return connector_.get(); +} + +void ServiceBinding::Bind(mojom::ServiceRequest request) { + DCHECK(!is_bound()); + binding_.Bind(std::move(request)); + binding_.set_connection_error_handler(base::BindOnce( + &ServiceBinding::OnConnectionError, base::Unretained(this))); +} + +void ServiceBinding::RequestClose() { + DCHECK(is_bound()); + if (service_control_.is_bound()) { + service_control_->RequestQuit(); + } else { + // It's possible that the service may request closure before receiving the + // initial |OnStart()| event, in which case there is not yet a control + // interface on which to request closure. In that case we defer until + // |OnStart()| is received. + request_closure_on_start_ = true; + } +} + +void ServiceBinding::Close() { + DCHECK(is_bound()); + binding_.Close(); + service_control_.reset(); + connector_.reset(); +} + +void ServiceBinding::OnConnectionError() { + service_->OnDisconnected(); +} + +void ServiceBinding::OnStart(const Identity& identity, + OnStartCallback callback) { + identity_ = identity; + service_->OnStart(); + + if (!pending_connector_request_.is_pending()) + connector_ = Connector::Create(&pending_connector_request_); + std::move(callback).Run(std::move(pending_connector_request_), + mojo::MakeRequest(&service_control_)); + + // Execute any prior |RequestClose()| request on the service's behalf. + if (request_closure_on_start_) + service_control_->RequestQuit(); +} + +void ServiceBinding::OnBindInterface( + const BindSourceInfo& source_info, + const std::string& interface_name, + mojo::ScopedMessagePipeHandle interface_pipe, + OnBindInterfaceCallback callback) { + // Acknowledge this request. + std::move(callback).Run(); + + service_->OnBindInterface(source_info, interface_name, + std::move(interface_pipe)); +} + +} // namespace service_manager diff --git a/chromium/services/service_manager/public/cpp/service_binding.h b/chromium/services/service_manager/public/cpp/service_binding.h new file mode 100644 index 00000000000..6ed95480f9d --- /dev/null +++ b/chromium/services/service_manager/public/cpp/service_binding.h @@ -0,0 +1,139 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef SERVICES_SERVICE_MANAGER_PUBLIC_CPP_SERVICE_BINDING_H_ +#define SERVICES_SERVICE_MANAGER_PUBLIC_CPP_SERVICE_BINDING_H_ + +#include <memory> + +#include "base/callback.h" +#include "base/component_export.h" +#include "base/macros.h" +#include "mojo/public/cpp/bindings/binding.h" +#include "services/service_manager/public/cpp/connector.h" +#include "services/service_manager/public/mojom/connector.mojom.h" +#include "services/service_manager/public/mojom/service.mojom.h" +#include "services/service_manager/public/mojom/service_control.mojom.h" + +namespace service_manager { + +class Service; + +// Encapsulates service-side bindings to Service Manager interfaces. Namely, +// this helps receive and dispatch Service interface events to a service +// implementation, while also exposing a working Connector interface the service +// can use to make outgoing interface requests. +// +// A ServiceBinding is considered to be "bound" after |Bind()| is invoked with a +// valid ServiceRequest (or the equivalent constructor is used -- see below). +// Upon connection error or an explicit call to |Close()|, the ServiceBinding +// will be considered "unbound" until another call to |Bind()| is made. +// +// NOTE: A well-behaved service should aim to always close its ServiceBinding +// gracefully by calling |RequestClose()|. Closing a ServiceBinding abruptly +// (by either destroying it or explicitly calling |Close()|) introduces inherent +// flakiness into the system unless the Service's |OnDisconnected()| has already +// been invoked, because otherwise the Service Manager may have in-flight +// interface requests directed at your service instance and these will be +// dropped to the dismay of the service instance which issued them. Exceptions +// can reasonably be made for system-wide shutdown situations where even the +// Service Manager itself will be imminently torn down. +class COMPONENT_EXPORT(SERVICE_MANAGER_CPP) ServiceBinding + : public mojom::Service { + public: + // Creates a new ServiceBinding bound to |service|. The service will not + // receive any Service interface calls until |Bind()| is called, but its + // |connector()| is usable immediately upon construction. + // + // |service| is not owned and must outlive this ServiceBinding. + explicit ServiceBinding(service_manager::Service* service); + + // Same as above, but behaves as if |Bind(request)| is also called immediately + // after construction. See below. + ServiceBinding(service_manager::Service* service, + mojom::ServiceRequest request); + + ~ServiceBinding() override; + + bool is_bound() const { return binding_.is_bound(); } + + Identity identity() const { return identity_; } + + // Returns a usable Connector which can make outgoing interface requests + // identifying as the service to which this ServiceBinding is bound. + Connector* GetConnector(); + + // Binds this ServiceBinding to a new ServiceRequest. Once a ServiceBinding + // is bound, its target Service will begin receiving Service events. The + // order of events received is: + // + // - OnStart() exactly once + // - OnIdentityKnown() exactly once + // - OnBindInterface() zero or more times + // + // The target Service will be able to receive these events until this + // ServiceBinding is either unbound or destroyed. + // + // If |request| is invalid, this call does nothing. + // + // Must only be called on an unbound ServiceBinding. + void Bind(mojom::ServiceRequest request); + + // Asks the Service Manager nicely if it's OK for this service instance to + // disappear now. If the Service Manager thinks it's OK, it will sever the + // binding's connection, ultimately triggering an |OnDisconnected()| call on + // the bound Service object. + // + // Must only be called on a bound ServiceBinding. + void RequestClose(); + + // Immediately severs the connection to the Service Manager. No further + // incoming interface requests will be received until this ServiceBinding is + // bound again. Always prefer |RequestClose()| under normal circumstances, + // unless |OnDisconnected()| has already been invoked on the Service. See the + // note in the class documentation above regarding graceful binding closure. + // + // Must only be called on a bound ServiceBinding. + void Close(); + + private: + void OnConnectionError(); + + // mojom::Service: + void OnStart(const Identity& identity, OnStartCallback callback) override; + void OnBindInterface(const BindSourceInfo& source_info, + const std::string& interface_name, + mojo::ScopedMessagePipeHandle interface_pipe, + OnBindInterfaceCallback callback) override; + + // The Service instance to which all incoming events from the Service Manager + // should be directed. Typically this is the object which owns this + // ServiceBinding. + service_manager::Service* const service_; + + // A pending Connector request which will eventually be passed to the Service + // Manager. Created preemptively by every unbound ServiceBinding so that + // |connector()| may begin pipelining outgoing requests even before the + // ServiceBinding is bound to a ServiceRequest. + mojom::ConnectorRequest pending_connector_request_; + + mojo::Binding<mojom::Service> binding_; + Identity identity_; + std::unique_ptr<Connector> connector_; + + // This instance's control interface to the service manager. Note that this + // is unbound and therefore invalid until OnStart() is called. + mojom::ServiceControlAssociatedPtr service_control_; + + // Tracks whether |RequestClose()| has been called at least once prior to + // receiving |OnStart()| on a bound ServiceBinding. This ensures that the + // closure request is actually issued once |OnStart()| is invoked. + bool request_closure_on_start_ = false; + + DISALLOW_COPY_AND_ASSIGN(ServiceBinding); +}; + +} // namespace service_manager + +#endif // SERVICES_SERVICE_MANAGER_PUBLIC_CPP_SERVICE_CONTEXT_H_ diff --git a/chromium/services/service_manager/public/cpp/service_keepalive.cc b/chromium/services/service_manager/public/cpp/service_keepalive.cc index 4eb0d60b143..d8f30aa3072 100644 --- a/chromium/services/service_manager/public/cpp/service_keepalive.cc +++ b/chromium/services/service_manager/public/cpp/service_keepalive.cc @@ -11,7 +11,7 @@ namespace service_manager { ServiceKeepalive::ServiceKeepalive(ServiceContext* context, - base::TimeDelta idle_timeout, + base::Optional<base::TimeDelta> idle_timeout, TimeoutObserver* timeout_observer) : context_(context), idle_timeout_(idle_timeout), @@ -40,7 +40,9 @@ void ServiceKeepalive::OnRefAdded() { } void ServiceKeepalive::OnRefCountZero() { - idle_timer_.Start(FROM_HERE, idle_timeout_, + if (!idle_timeout_.has_value()) + return; + idle_timer_.Start(FROM_HERE, idle_timeout_.value(), base::BindRepeating(&ServiceKeepalive::OnTimerExpired, weak_ptr_factory_.GetWeakPtr())); } diff --git a/chromium/services/service_manager/public/cpp/service_keepalive.h b/chromium/services/service_manager/public/cpp/service_keepalive.h index af5db0c7cbe..9fb0b9df746 100644 --- a/chromium/services/service_manager/public/cpp/service_keepalive.h +++ b/chromium/services/service_manager/public/cpp/service_keepalive.h @@ -9,6 +9,7 @@ #include "base/macros.h" #include "base/memory/weak_ptr.h" +#include "base/optional.h" #include "base/timer/timer.h" #include "services/service_manager/public/cpp/service_context_ref.h" @@ -39,10 +40,12 @@ class SERVICE_MANAGER_PUBLIC_CPP_EXPORT ServiceKeepalive { }; // Creates a keepalive which allows the service to be idle for |idle_timeout| - // before requesting termination. Both |context| and |timeout_observer| are - // not owned and must outlive the ServiceKeepalive instance. + // before requesting termination. If |idle_timeout| is not given, the + // ServiceKeepalive will never request termination, i.e. the service will + // stay alive indefinitely. Both |context| and |timeout_observer| are not + // owned and must outlive the ServiceKeepalive instance. ServiceKeepalive(ServiceContext* context, - base::TimeDelta idle_timeout, + base::Optional<base::TimeDelta> idle_timeout, TimeoutObserver* timeout_observer = nullptr); ~ServiceKeepalive(); @@ -55,7 +58,7 @@ class SERVICE_MANAGER_PUBLIC_CPP_EXPORT ServiceKeepalive { void OnTimerExpired(); ServiceContext* const context_; - const base::TimeDelta idle_timeout_; + const base::Optional<base::TimeDelta> idle_timeout_; TimeoutObserver* const timeout_observer_; base::OneShotTimer idle_timer_; ServiceContextRefFactory ref_factory_; diff --git a/chromium/services/service_manager/public/java/src/org/chromium/services/service_manager/Connector.java b/chromium/services/service_manager/public/java/src/org/chromium/services/service_manager/Connector.java new file mode 100644 index 00000000000..ccf55823735 --- /dev/null +++ b/chromium/services/service_manager/public/java/src/org/chromium/services/service_manager/Connector.java @@ -0,0 +1,55 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.services.service_manager; + +import org.chromium.mojo.bindings.ConnectionErrorHandler; +import org.chromium.mojo.bindings.Interface; +import org.chromium.mojo.bindings.InterfaceRequest; +import org.chromium.mojo.system.MessagePipeHandle; +import org.chromium.mojo.system.MojoException; +import org.chromium.service_manager.mojom.ConstantsConstants; +import org.chromium.service_manager.mojom.Identity; + +/** + * This class exposes the ability to bind interfaces from other services in the system. + */ +public class Connector implements ConnectionErrorHandler { + private org.chromium.service_manager.mojom.Connector.Proxy mConnector; + + private static class ConnectorBindInterfaceResponseImpl + implements org.chromium.service_manager.mojom.Connector.BindInterfaceResponse { + @Override + public void call(Integer result, Identity userId) {} + } + + public Connector(MessagePipeHandle handle) { + mConnector = org.chromium.service_manager.mojom.Connector.MANAGER.attachProxy(handle, 0); + mConnector.getProxyHandler().setErrorHandler(this); + } + + /** + * Asks a service to bind an interface request. + * + * @param serviceName The name of the service. + * @param interfaceName The name of interface I. + * @param request The request for the interface I. + */ + public <I extends Interface, P extends Interface.Proxy> void bindInterface( + String serviceName, String interfaceName, InterfaceRequest<I> request) { + Identity target = new Identity(); + target.name = serviceName; + target.userId = ConstantsConstants.INHERIT_USER_ID; + target.instance = ""; + + org.chromium.service_manager.mojom.Connector.BindInterfaceResponse callback = + new ConnectorBindInterfaceResponseImpl(); + mConnector.bindInterface(target, interfaceName, request.passHandle(), callback); + } + + @Override + public void onConnectionError(MojoException e) { + mConnector.close(); + } +} diff --git a/chromium/services/service_manager/public/java/src/org/chromium/services/service_manager/InterfaceFactory.java b/chromium/services/service_manager/public/java/src/org/chromium/services/service_manager/InterfaceFactory.java new file mode 100644 index 00000000000..f5b919ab039 --- /dev/null +++ b/chromium/services/service_manager/public/java/src/org/chromium/services/service_manager/InterfaceFactory.java @@ -0,0 +1,19 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.services.service_manager; + +import org.chromium.mojo.bindings.Interface; + +/** + * A factory that creates implementations of a mojo interface. + * + * @param <I> the mojo interface + */ +public interface InterfaceFactory<I extends Interface> { + /** + * Returns an implementation of the mojo interface. + */ + I createImpl(); +} diff --git a/chromium/services/service_manager/public/java/src/org/chromium/services/service_manager/InterfaceProvider.java b/chromium/services/service_manager/public/java/src/org/chromium/services/service_manager/InterfaceProvider.java new file mode 100644 index 00000000000..649d3c43c24 --- /dev/null +++ b/chromium/services/service_manager/public/java/src/org/chromium/services/service_manager/InterfaceProvider.java @@ -0,0 +1,57 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.services.service_manager; + +import org.chromium.mojo.bindings.ConnectionErrorHandler; +import org.chromium.mojo.bindings.Interface; +import org.chromium.mojo.bindings.InterfaceRequest; +import org.chromium.mojo.system.Core; +import org.chromium.mojo.system.MessagePipeHandle; +import org.chromium.mojo.system.MojoException; +import org.chromium.mojo.system.Pair; + +/** + * Provides access to interfaces exposed by an InterfaceProvider mojo interface. + */ +public class InterfaceProvider implements ConnectionErrorHandler { + private Core mCore; + private org.chromium.service_manager.mojom.InterfaceProvider.Proxy mInterfaceProvider; + + public InterfaceProvider(MessagePipeHandle pipe) { + mCore = pipe.getCore(); + mInterfaceProvider = + org.chromium.service_manager.mojom.InterfaceProvider.MANAGER.attachProxy(pipe, 0); + mInterfaceProvider.getProxyHandler().setErrorHandler(this); + } + + /** + * Binds |request| to an implementation of I in the remote application. + * + * @param manager The Manager for interface I. + * @param request The request for the interface I. + */ + public <I extends Interface> void getInterface( + Interface.Manager<I, ? extends Interface.Proxy> manager, InterfaceRequest<I> request) { + mInterfaceProvider.getInterface(manager.getName(), request.passHandle()); + } + + /** + * Binds and returns a proxy to an implementation of I in the remote application. + * + * @param manager The Manager for interface I. + * @return A bound Proxy for interface I. + */ + public <I extends Interface, P extends Interface.Proxy> P getInterface( + Interface.Manager<I, P> manager) { + Pair<P, InterfaceRequest<I>> result = manager.getInterfaceRequest(mCore); + getInterface(manager, result.second); + return result.first; + } + + @Override + public void onConnectionError(MojoException e) { + mInterfaceProvider.close(); + } +} diff --git a/chromium/services/service_manager/public/java/src/org/chromium/services/service_manager/InterfaceRegistry.java b/chromium/services/service_manager/public/java/src/org/chromium/services/service_manager/InterfaceRegistry.java new file mode 100644 index 00000000000..72f5091ec9a --- /dev/null +++ b/chromium/services/service_manager/public/java/src/org/chromium/services/service_manager/InterfaceRegistry.java @@ -0,0 +1,78 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.services.service_manager; + +import org.chromium.mojo.bindings.Interface; +import org.chromium.mojo.system.MessagePipeHandle; +import org.chromium.mojo.system.MojoException; +import org.chromium.service_manager.mojom.InterfaceProvider; + +import java.util.HashMap; +import java.util.Map; + +/** + * A registry where interfaces may be registered to be exposed to another application. + * + * To use, define a class that implements your specific interface. Then + * implement an InterfaceFactory that creates instances of your implementation + * and register that on the registry with a Manager for the interface like this: + * + * registry.addInterface(InterfaceType.MANAGER, factory); + */ +public class InterfaceRegistry implements InterfaceProvider { + private final Map<String, InterfaceBinder> mBinders = new HashMap<String, InterfaceBinder>(); + + public <I extends Interface> void addInterface( + Interface.Manager<I, ? extends Interface.Proxy> manager, InterfaceFactory<I> factory) { + mBinders.put(manager.getName(), new InterfaceBinder<I>(manager, factory)); + } + + public static InterfaceRegistry create(MessagePipeHandle pipe) { + InterfaceRegistry registry = new InterfaceRegistry(); + InterfaceProvider.MANAGER.bind(registry, pipe); + return registry; + } + + @Override + public void getInterface(String name, MessagePipeHandle pipe) { + InterfaceBinder binder = mBinders.get(name); + if (binder == null) { + return; + } + binder.bindToMessagePipe(pipe); + } + + @Override + public void close() { + mBinders.clear(); + } + + @Override + public void onConnectionError(MojoException e) { + close(); + } + + InterfaceRegistry() {} + + private static class InterfaceBinder<I extends Interface> { + private Interface.Manager<I, ? extends Interface.Proxy> mManager; + private InterfaceFactory<I> mFactory; + + public InterfaceBinder(Interface.Manager<I, ? extends Interface.Proxy> manager, + InterfaceFactory<I> factory) { + mManager = manager; + mFactory = factory; + } + + public void bindToMessagePipe(MessagePipeHandle pipe) { + I impl = mFactory.createImpl(); + if (impl == null) { + pipe.close(); + return; + } + mManager.bind(impl, pipe); + } + } +} diff --git a/chromium/services/service_manager/public/service_manifest.gni b/chromium/services/service_manager/public/service_manifest.gni index 1a55cde7692..1bf94631690 100644 --- a/chromium/services/service_manager/public/service_manifest.gni +++ b/chromium/services/service_manager/public/service_manifest.gni @@ -45,6 +45,8 @@ import("//build/config/dcheck_always_on.gni") # within this output manifest, specifically within a toplevel "services" # list. # +# testonly (optional) +# # Outputs: # # An instantiation of this template produces a meta manifest from the source @@ -61,6 +63,8 @@ template("service_manifest") { "Only one of \"source\" or \"source_manifest\" must be defined for the $target_name target") action(target_name) { + testonly = defined(invoker.testonly) && invoker.testonly + script = "//services/service_manager/public/tools/manifest/manifest_collator.py" diff --git a/chromium/services/service_manager/sandbox/features.cc b/chromium/services/service_manager/sandbox/features.cc index 85e43faefaf..22b4e13ebf2 100644 --- a/chromium/services/service_manager/sandbox/features.cc +++ b/chromium/services/service_manager/sandbox/features.cc @@ -24,6 +24,9 @@ const base::Feature kNetworkServiceWindowsSandbox{ // sandbox::MITIGATION_EXTENSION_POINT_DISABLE. const base::Feature kWinSboxDisableExtensionPoints{ "WinSboxDisableExtensionPoint", base::FEATURE_ENABLED_BY_DEFAULT}; + +// Controls whether the isolated XR service is sandboxed. +const base::Feature kXRSandbox{"XRSandbox", base::FEATURE_DISABLED_BY_DEFAULT}; #endif // defined(OS_WIN) } // namespace features diff --git a/chromium/services/service_manager/sandbox/features.h b/chromium/services/service_manager/sandbox/features.h index 919806cfa1b..d67ffa47132 100644 --- a/chromium/services/service_manager/sandbox/features.h +++ b/chromium/services/service_manager/sandbox/features.h @@ -23,6 +23,8 @@ SERVICE_MANAGER_SANDBOX_EXPORT extern const base::Feature SERVICE_MANAGER_SANDBOX_EXPORT extern const base::Feature kWinSboxDisableExtensionPoints; + +SERVICE_MANAGER_SANDBOX_EXPORT extern const base::Feature kXRSandbox; #endif // defined(OS_WIN) } // namespace features diff --git a/chromium/services/service_manager/sandbox/linux/bpf_audio_policy_linux.cc b/chromium/services/service_manager/sandbox/linux/bpf_audio_policy_linux.cc index b69da9725a8..f9c155ce377 100644 --- a/chromium/services/service_manager/sandbox/linux/bpf_audio_policy_linux.cc +++ b/chromium/services/service_manager/sandbox/linux/bpf_audio_policy_linux.cc @@ -36,6 +36,9 @@ ResultExpr AudioProcessPolicy::EvaluateSyscall(int system_call_number) const { #if defined(__NR_ftruncate) case __NR_ftruncate: #endif +#if defined(__NR_ftruncate64) + case __NR_ftruncate64: +#endif #if defined(__NR_getdents) case __NR_getdents: #endif diff --git a/chromium/services/service_manager/sandbox/linux/sandbox_linux.cc b/chromium/services/service_manager/sandbox/linux/sandbox_linux.cc index 2e0a3ea6b0d..09dbba50467 100644 --- a/chromium/services/service_manager/sandbox/linux/sandbox_linux.cc +++ b/chromium/services/service_manager/sandbox/linux/sandbox_linux.cc @@ -192,26 +192,11 @@ void SandboxLinux::PreinitializeSandbox() { } void SandboxLinux::EngageNamespaceSandbox(bool from_zygote) { - CHECK(pre_initialized_); - if (from_zygote) { - // Check being in a new PID namespace created by the namespace sandbox and - // being the init process. - CHECK(sandbox::NamespaceSandbox::InNewPidNamespace()); - const pid_t pid = getpid(); - CHECK_EQ(1, pid); - } - - CHECK(sandbox::Credentials::MoveToNewUserNS()); - - // Note: this requires SealSandbox() to be called later in this process to be - // safe, as this class is keeping a file descriptor to /proc/. - CHECK(sandbox::Credentials::DropFileSystemAccess(proc_fd_)); + CHECK(EngageNamespaceSandboxInternal(from_zygote)); +} - // We do not drop CAP_SYS_ADMIN because we need it to place each child process - // in its own PID namespace later on. - std::vector<sandbox::Credentials::Capability> caps; - caps.push_back(sandbox::Credentials::Capability::SYS_ADMIN); - CHECK(sandbox::Credentials::SetCapabilities(proc_fd_, caps)); +bool SandboxLinux::EngageNamespaceSandboxIfPossible() { + return EngageNamespaceSandboxInternal(false /* from_zygote */); } std::vector<int> SandboxLinux::GetFileDescriptorsToClose() { @@ -497,4 +482,35 @@ void SandboxLinux::StopThreadAndEnsureNotCounted(base::Thread* thread) const { sandbox::ThreadHelpers::StopThreadAndWatchProcFS(proc_fd.get(), thread)); } +bool SandboxLinux::EngageNamespaceSandboxInternal(bool from_zygote) { + CHECK(pre_initialized_); + if (from_zygote) { + // Check being in a new PID namespace created by the namespace sandbox and + // being the init process. + CHECK(sandbox::NamespaceSandbox::InNewPidNamespace()); + const pid_t pid = getpid(); + CHECK_EQ(1, pid); + } + + // After we successfully move to a new user ns, we don't allow this function + // to fail. + if (!sandbox::Credentials::MoveToNewUserNS()) { + return false; + } + + // Note: this requires SealSandbox() to be called later in this process to be + // safe, as this class is keeping a file descriptor to /proc/. + CHECK(sandbox::Credentials::DropFileSystemAccess(proc_fd_)); + + // Now we drop all capabilities that we can. In the zygote process, we need + // to keep CAP_SYS_ADMIN, to place each child in its own PID namespace + // later on. + std::vector<sandbox::Credentials::Capability> caps; + if (from_zygote) { + caps.push_back(sandbox::Credentials::Capability::SYS_ADMIN); + } + CHECK(sandbox::Credentials::SetCapabilities(proc_fd_, caps)); + return true; +} + } // namespace service_manager diff --git a/chromium/services/service_manager/sandbox/linux/sandbox_linux.h b/chromium/services/service_manager/sandbox/linux/sandbox_linux.h index dcf4eeeee18..21817971a45 100644 --- a/chromium/services/service_manager/sandbox/linux/sandbox_linux.h +++ b/chromium/services/service_manager/sandbox/linux/sandbox_linux.h @@ -126,8 +126,17 @@ class SERVICE_MANAGER_SANDBOX_EXPORT SandboxLinux { // a new unprivileged namespace. This is a layer-1 sandbox. // In order for this sandbox to be effective, it must be "sealed" by calling // InitializeSandbox(). + // Terminates the process in case the sandboxing operations cannot complete + // successfully. void EngageNamespaceSandbox(bool from_zygote); + // Performs the same actions as EngageNamespaceSandbox, but is allowed to + // to fail. This is useful when sandboxed non-renderer processes could + // benefit from extra sandboxing but is not strictly required on systems that + // don't support unprivileged user namespaces. + // Zygote should use EngageNamespaceSandbox instead. + bool EngageNamespaceSandboxIfPossible(); + // Return a list of file descriptors to close if PreinitializeSandbox() ran // but InitializeSandbox() won't. Avoid using. // TODO(jln): get rid of this hack. @@ -242,6 +251,12 @@ class SERVICE_MANAGER_SANDBOX_EXPORT SandboxLinux { // anymore. void StopThreadAndEnsureNotCounted(base::Thread* thread) const; + // Engages the namespace sandbox as described for EngageNamespaceSandbox. + // Returns false if it fails to transition to a new user namespace, but + // after transitioning to a new user namespace we don't allow this function + // to fail. + bool EngageNamespaceSandboxInternal(bool from_zygote); + // A file descriptor to /proc. It's dangerous to have it around as it could // allow for sandbox bypasses. It needs to be closed before we consider // ourselves sandboxed. diff --git a/chromium/services/service_manager/sandbox/mac/cdm.sb b/chromium/services/service_manager/sandbox/mac/cdm.sb index a7ba200db8b..dbec5d1b779 100644 --- a/chromium/services/service_manager/sandbox/mac/cdm.sb +++ b/chromium/services/service_manager/sandbox/mac/cdm.sb @@ -7,9 +7,5 @@ ; Allow preloading of the CDM using seatbelt extension. (allow file-read* (extension "com.apple.app-sandbox.read")) -; Allow to read framework and CDM resources files for CDM host verification -(define bundle-version-path "BUNDLE_VERSION_PATH") -(allow file-read* (subpath (param bundle-version-path))) - ; mach IPC (allow mach-lookup (global-name "com.apple.windowserver.active")) diff --git a/chromium/services/service_manager/sandbox/mac/common.sb b/chromium/services/service_manager/sandbox/mac/common.sb index 0e90c9ab2f6..51f6a4b910f 100644 --- a/chromium/services/service_manager/sandbox/mac/common.sb +++ b/chromium/services/service_manager/sandbox/mac/common.sb @@ -14,6 +14,7 @@ (define (param-defined? str) (string? (param str))) ; Define constants for all of the parameter strings passed in. +(define bundle-version-path "BUNDLE_VERSION_PATH") (define disable-sandbox-denial-logging "DISABLE_SANDBOX_DENIAL_LOGGING") (define enable-logging "ENABLE_LOGGING") (define homedir-as-literal "USER_HOMEDIR_AS_LITERAL") diff --git a/chromium/services/service_manager/sandbox/mac/gpu.sb b/chromium/services/service_manager/sandbox/mac/gpu.sb index 576976f0daf..619e630a384 100644 --- a/chromium/services/service_manager/sandbox/mac/gpu.sb +++ b/chromium/services/service_manager/sandbox/mac/gpu.sb @@ -28,4 +28,7 @@ (allow file-read* (subpath "/System/Library/Extensions"))) ; Needed for VideoToolbox usage - https://crbug.com/767037 -(allow mach-lookup (global-name "com.apple.coremedia.videodecoder"))
\ No newline at end of file +(allow mach-lookup (global-name "com.apple.coremedia.videodecoder")) + +; Needed for GPU process to fallback to SwiftShader - https://crbug.com/897914 +(allow file-read-data file-read-metadata (subpath (param bundle-version-path))) diff --git a/chromium/services/service_manager/sandbox/mac/ppapi_v2.sb b/chromium/services/service_manager/sandbox/mac/ppapi_v2.sb index 1a5b8e27b7c..341656f0bda 100644 --- a/chromium/services/service_manager/sandbox/mac/ppapi_v2.sb +++ b/chromium/services/service_manager/sandbox/mac/ppapi_v2.sb @@ -14,6 +14,10 @@ ; Needed for Fonts. (allow-font-access) +; Mach lookups. +(allow mach-lookup + (global-name "com.apple.windowserver.active")) + ; IOKit (allow iokit-open (iokit-registry-entry-class "IOSurfaceRootUserClient")) diff --git a/chromium/services/service_manager/sandbox/mac/renderer_v2.sb b/chromium/services/service_manager/sandbox/mac/renderer_v2.sb index 35010447a6e..4211576f11a 100644 --- a/chromium/services/service_manager/sandbox/mac/renderer_v2.sb +++ b/chromium/services/service_manager/sandbox/mac/renderer_v2.sb @@ -69,5 +69,4 @@ ; https://crbug.com/850021 (global-name "com.apple.cvmsServ") ; crbug.com/792217 - (global-name "com.apple.system.notification_center") - (global-name "com.apple.windowserver.active")) + (global-name "com.apple.system.notification_center")) diff --git a/chromium/services/service_manager/sandbox/mac/sandbox_mac.mm b/chromium/services/service_manager/sandbox/mac/sandbox_mac.mm index 36b90f62e13..eb01a0abddb 100644 --- a/chromium/services/service_manager/sandbox/mac/sandbox_mac.mm +++ b/chromium/services/service_manager/sandbox/mac/sandbox_mac.mm @@ -242,9 +242,9 @@ bool SandboxMac::Enable(SandboxType sandbox_type) { if (!compiler.InsertBooleanParam(kSandboxMacOS1013, macos_1013)) return false; - if (sandbox_type == service_manager::SANDBOX_TYPE_CDM) { - base::FilePath bundle_path = SandboxMac::GetCanonicalPath( - base::mac::FrameworkBundlePath().DirName()); + if (sandbox_type == service_manager::SANDBOX_TYPE_GPU) { + base::FilePath bundle_path = + SandboxMac::GetCanonicalPath(base::mac::FrameworkBundlePath()); if (!compiler.InsertStringParam(kSandboxBundleVersionPath, bundle_path.value())) return false; diff --git a/chromium/services/service_manager/sandbox/sandbox_type.cc b/chromium/services/service_manager/sandbox/sandbox_type.cc index 1c471e51799..da6fc047d41 100644 --- a/chromium/services/service_manager/sandbox/sandbox_type.cc +++ b/chromium/services/service_manager/sandbox/sandbox_type.cc @@ -19,6 +19,10 @@ bool IsUnsandboxedSandboxType(SandboxType sandbox_type) { #if defined(OS_WIN) case SANDBOX_TYPE_NO_SANDBOX_AND_ELEVATED_PRIVILEGES: return true; + + case SANDBOX_TYPE_XRCOMPOSITING: + return !base::FeatureList::IsEnabled( + service_manager::features::kXRSandbox); #endif case SANDBOX_TYPE_AUDIO: #if defined(OS_WIN) || defined(OS_MACOSX) || defined(OS_LINUX) @@ -34,6 +38,7 @@ bool IsUnsandboxedSandboxType(SandboxType sandbox_type) { #else return true; #endif + default: return false; } @@ -73,6 +78,9 @@ void SetCommandLineFlagsForSandboxType(base::CommandLine* command_line, case SANDBOX_TYPE_CDM: case SANDBOX_TYPE_PDF_COMPOSITOR: case SANDBOX_TYPE_PROFILING: +#if defined(OS_WIN) + case SANDBOX_TYPE_XRCOMPOSITING: +#endif case SANDBOX_TYPE_AUDIO: DCHECK(command_line->GetSwitchValueASCII(switches::kProcessType) == switches::kUtilityProcess); @@ -118,6 +126,11 @@ SandboxType SandboxTypeFromCommandLine(const base::CommandLine& command_line) { if (process_type == switches::kPpapiPluginProcess) return SANDBOX_TYPE_PPAPI; +#if defined(OS_MACOSX) + if (process_type == switches::kNaClLoaderProcess) + return SANDBOX_TYPE_NACL_LOADER; +#endif + // This is a process which we don't know about. return SANDBOX_TYPE_INVALID; } @@ -138,6 +151,10 @@ std::string StringFromUtilitySandboxType(SandboxType sandbox_type) { return switches::kProfilingSandbox; case SANDBOX_TYPE_UTILITY: return switches::kUtilitySandbox; +#if defined(OS_WIN) + case SANDBOX_TYPE_XRCOMPOSITING: + return switches::kXrCompositingSandbox; +#endif case SANDBOX_TYPE_AUDIO: return switches::kAudioSandbox; default: @@ -166,6 +183,10 @@ SandboxType UtilitySandboxTypeFromString(const std::string& sandbox_string) { return SANDBOX_TYPE_PDF_COMPOSITOR; if (sandbox_string == switches::kProfilingSandbox) return SANDBOX_TYPE_PROFILING; +#if defined(OS_WIN) + if (sandbox_string == switches::kXrCompositingSandbox) + return SANDBOX_TYPE_XRCOMPOSITING; +#endif if (sandbox_string == switches::kAudioSandbox) return SANDBOX_TYPE_AUDIO; return SANDBOX_TYPE_UTILITY; diff --git a/chromium/services/service_manager/sandbox/sandbox_type.h b/chromium/services/service_manager/sandbox/sandbox_type.h index 1c5832b1502..f0c04da8034 100644 --- a/chromium/services/service_manager/sandbox/sandbox_type.h +++ b/chromium/services/service_manager/sandbox/sandbox_type.h @@ -26,6 +26,9 @@ enum SandboxType { #if defined(OS_WIN) // Do not apply any sandboxing and elevate the privileges of the process. SANDBOX_TYPE_NO_SANDBOX_AND_ELEVATED_PRIVILEGES, + + // The XR Compositing process. + SANDBOX_TYPE_XRCOMPOSITING, #endif // Renderer or worker process. Most common case. diff --git a/chromium/services/service_manager/sandbox/switches.cc b/chromium/services/service_manager/sandbox/switches.cc index 7eecefdc8ee..91e5096b542 100644 --- a/chromium/services/service_manager/sandbox/switches.cc +++ b/chromium/services/service_manager/sandbox/switches.cc @@ -26,6 +26,7 @@ const char kNetworkSandbox[] = "network"; const char kPpapiSandbox[] = "ppapi"; const char kUtilitySandbox[] = "utility"; const char kCdmSandbox[] = "cdm"; +const char kXrCompositingSandbox[] = "xr_compositing"; const char kPdfCompositorSandbox[] = "pdf_compositor"; const char kProfilingSandbox[] = "profiling"; const char kAudioSandbox[] = "audio"; @@ -92,6 +93,10 @@ const char kEnableGpuAppContainer[] = "enable-gpu-appcontainer"; // Disables the sandbox and gives the process elevated privileges. const char kNoSandboxAndElevatedPrivileges[] = "no-sandbox-and-elevated"; + +// Add additional capabilities to the AppContainer sandbox used for XR +// compositing. +const char kAddXrAppContainerCaps[] = "add-xr-appcontainer-caps"; #endif #if defined(OS_MACOSX) @@ -102,6 +107,7 @@ const char kEnableSandboxLogging[] = "enable-sandbox-logging"; // Flags spied upon from other layers. const char kGpuProcess[] = "gpu-process"; +const char kNaClLoaderProcess[] = "nacl-loader"; const char kPpapiBrokerProcess[] = "ppapi-broker"; const char kPpapiPluginProcess[] = "ppapi"; const char kRendererProcess[] = "renderer"; diff --git a/chromium/services/service_manager/sandbox/switches.h b/chromium/services/service_manager/sandbox/switches.h index c5436d29757..aa7fe563c51 100644 --- a/chromium/services/service_manager/sandbox/switches.h +++ b/chromium/services/service_manager/sandbox/switches.h @@ -25,6 +25,7 @@ SERVICE_MANAGER_SANDBOX_EXPORT extern const char kNetworkSandbox[]; SERVICE_MANAGER_SANDBOX_EXPORT extern const char kPpapiSandbox[]; SERVICE_MANAGER_SANDBOX_EXPORT extern const char kUtilitySandbox[]; SERVICE_MANAGER_SANDBOX_EXPORT extern const char kCdmSandbox[]; +SERVICE_MANAGER_SANDBOX_EXPORT extern const char kXrCompositingSandbox[]; SERVICE_MANAGER_SANDBOX_EXPORT extern const char kPdfCompositorSandbox[]; SERVICE_MANAGER_SANDBOX_EXPORT extern const char kProfilingSandbox[]; SERVICE_MANAGER_SANDBOX_EXPORT extern const char kAudioSandbox[]; @@ -50,6 +51,7 @@ SERVICE_MANAGER_SANDBOX_EXPORT extern const char kDisableGpuLpac[]; SERVICE_MANAGER_SANDBOX_EXPORT extern const char kEnableGpuAppContainer[]; SERVICE_MANAGER_SANDBOX_EXPORT extern const char kNoSandboxAndElevatedPrivileges[]; +SERVICE_MANAGER_SANDBOX_EXPORT extern const char kAddXrAppContainerCaps[]; #endif #if defined(OS_MACOSX) SERVICE_MANAGER_SANDBOX_EXPORT extern const char kEnableSandboxLogging[]; @@ -57,6 +59,7 @@ SERVICE_MANAGER_SANDBOX_EXPORT extern const char kEnableSandboxLogging[]; // Flags spied upon from other layers. SERVICE_MANAGER_SANDBOX_EXPORT extern const char kGpuProcess[]; +SERVICE_MANAGER_SANDBOX_EXPORT extern const char kNaClLoaderProcess[]; SERVICE_MANAGER_SANDBOX_EXPORT extern const char kPpapiBrokerProcess[]; SERVICE_MANAGER_SANDBOX_EXPORT extern const char kPpapiPluginProcess[]; SERVICE_MANAGER_SANDBOX_EXPORT extern const char kRendererProcess[]; diff --git a/chromium/services/service_manager/sandbox/win/sandbox_win.cc b/chromium/services/service_manager/sandbox/win/sandbox_win.cc index d330dd221cb..95cffa1b4ee 100644 --- a/chromium/services/service_manager/sandbox/win/sandbox_win.cc +++ b/chromium/services/service_manager/sandbox/win/sandbox_win.cc @@ -594,10 +594,19 @@ sandbox::ResultCode SetJobMemoryLimit(const base::CommandLine& cmd_line, base::string16 GetAppContainerProfileName( const std::string& appcontainer_id, service_manager::SandboxType sandbox_type) { - DCHECK(sandbox_type == service_manager::SANDBOX_TYPE_GPU); + DCHECK(sandbox_type == service_manager::SANDBOX_TYPE_GPU || + sandbox_type == service_manager::SANDBOX_TYPE_XRCOMPOSITING); auto sha1 = base::SHA1HashString(appcontainer_id); - auto profile_name = base::StrCat( - {"chrome.sandbox.gpu", base::HexEncode(sha1.data(), sha1.size())}); + std::string sandbox_base_name = + (sandbox_type == service_manager::SANDBOX_TYPE_XRCOMPOSITING) + ? std::string("chrome.sandbox.xrdevice") + : std::string("chrome.sandbox.gpu"); + std::string profile_name = base::StrCat( + {sandbox_base_name, base::HexEncode(sha1.data(), sha1.size())}); + // CreateAppContainerProfile requires that the profile name is at most 64 + // characters. The size of sha1 is a constant 40, so validate that the base + // names are sufficiently short that the total length is valid. + DCHECK(profile_name.length() <= 64); return base::UTF8ToWide(profile_name); } @@ -605,23 +614,41 @@ sandbox::ResultCode SetupAppContainerProfile( sandbox::AppContainerProfile* profile, const base::CommandLine& command_line, service_manager::SandboxType sandbox_type) { - if (sandbox_type != service_manager::SANDBOX_TYPE_GPU) + if (sandbox_type != service_manager::SANDBOX_TYPE_GPU && + sandbox_type != service_manager::SANDBOX_TYPE_XRCOMPOSITING) return sandbox::SBOX_ERROR_UNSUPPORTED; - if (!profile->AddImpersonationCapability(L"chromeInstallFiles")) { + if (sandbox_type == service_manager::SANDBOX_TYPE_GPU && + !profile->AddImpersonationCapability(L"chromeInstallFiles")) { DLOG(ERROR) << "AppContainerProfile::AddImpersonationCapability() failed"; return sandbox::SBOX_ERROR_CREATE_APPCONTAINER_PROFILE_CAPABILITY; } + if (sandbox_type == service_manager::SANDBOX_TYPE_XRCOMPOSITING && + !profile->AddCapability(L"chromeInstallFiles")) { + DLOG(ERROR) << "AppContainerProfile::AddCapability() failed"; + return sandbox::SBOX_ERROR_CREATE_APPCONTAINER_PROFILE_CAPABILITY; + } + std::vector<base::string16> base_caps = { L"lpacChromeInstallFiles", L"registryRead", }; - auto cmdline_caps = - base::SplitString(command_line.GetSwitchValueNative( - service_manager::switches::kAddGpuAppContainerCaps), - L",", base::TRIM_WHITESPACE, base::SPLIT_WANT_NONEMPTY); - base_caps.insert(base_caps.end(), cmdline_caps.begin(), cmdline_caps.end()); + if (sandbox_type == service_manager::SANDBOX_TYPE_GPU) { + auto cmdline_caps = base::SplitString( + command_line.GetSwitchValueNative( + service_manager::switches::kAddGpuAppContainerCaps), + L",", base::TRIM_WHITESPACE, base::SPLIT_WANT_NONEMPTY); + base_caps.insert(base_caps.end(), cmdline_caps.begin(), cmdline_caps.end()); + } + + if (sandbox_type == service_manager::SANDBOX_TYPE_XRCOMPOSITING) { + auto cmdline_caps = base::SplitString( + command_line.GetSwitchValueNative( + service_manager::switches::kAddXrAppContainerCaps), + L",", base::TRIM_WHITESPACE, base::SPLIT_WANT_NONEMPTY); + base_caps.insert(base_caps.end(), cmdline_caps.begin(), cmdline_caps.end()); + } for (const auto& cap : base_caps) { if (!profile->AddCapability(cap.c_str())) { @@ -630,7 +657,9 @@ sandbox::ResultCode SetupAppContainerProfile( } } - if (!command_line.HasSwitch(service_manager::switches::kDisableGpuLpac)) { + // Enable LPAC for GPU process, but not for XRCompositor service. + if (sandbox_type == service_manager::SANDBOX_TYPE_GPU && + !command_line.HasSwitch(service_manager::switches::kDisableGpuLpac)) { profile->SetEnableLowPrivilegeAppContainer(true); } @@ -938,9 +967,10 @@ sandbox::ResultCode SandboxWin::StartSandboxedProcess( return result; } - // Allow the renderer and gpu processes to access the log file. + // Allow the renderer, gpu and utility processes to access the log file. if (process_type == service_manager::switches::kRendererProcess || - process_type == service_manager::switches::kGpuProcess) { + process_type == service_manager::switches::kGpuProcess || + process_type == service_manager::switches::kUtilityProcess) { if (logging::IsLoggingToFileEnabled()) { DCHECK(base::FilePath(logging::GetLogFileFullPath()).IsAbsolute()); result = policy->AddRule(sandbox::TargetPolicy::SUBSYS_FILES, diff --git a/chromium/services/service_manager/service_manager.cc b/chromium/services/service_manager/service_manager.cc index 31e8182bb55..710d95dad2b 100644 --- a/chromium/services/service_manager/service_manager.cc +++ b/chromium/services/service_manager/service_manager.cc @@ -47,9 +47,6 @@ namespace service_manager { namespace { -const char kCapability_UserID[] = "service_manager:user_id"; -const char kCapability_ClientProcess[] = "service_manager:client_process"; -const char kCapability_InstanceName[] = "service_manager:instance_name"; const char kCapability_ServiceManager[] = "service_manager:service_manager"; bool Succeeded(mojom::ConnectResult result) { @@ -577,12 +574,11 @@ class ServiceManager::Instance const Identity& target) { if (service && pid_receiver_request && (service->is_bound() || pid_receiver_request->is_pending())) { - if (!HasCapability(GetConnectionSpec(), kCapability_ClientProcess)) { + if (!options_.can_create_other_service_instances) { LOG(ERROR) << "Instance: " << identity_.name() << " attempting " << "to register an instance for a process it created for " << "target: " << target.name() << " without the " - << "service_manager{client_process} capability " - << "class."; + << "'can_create_other_service_instances' option."; return mojom::ConnectResult::ACCESS_DENIED; } @@ -614,25 +610,24 @@ class ServiceManager::Instance options_.instance_sharing == catalog::ServiceOptions::InstanceSharingType:: SHARED_INSTANCE_ACROSS_USERS || - HasCapability(connection_spec, kCapability_UserID); + options_.can_connect_to_other_services_as_any_user; if (!skip_user_check && target.user_id() != identity_.user_id() && target.user_id() != mojom::kRootUserID) { LOG(ERROR) << "Instance: " << identity_.name() << " running as: " << identity_.user_id() << " attempting to connect to: " << target.name() - << " as: " << target.user_id() << " without " - << " the service:service_manager{user_id} capability."; + << " as: " << target.user_id() << " without" + << " the 'can_connect_to_other_services_as_any_user' option."; return mojom::ConnectResult::ACCESS_DENIED; } - if (!target.instance().empty() && - target.instance() != target.name() && - !HasCapability(connection_spec, kCapability_InstanceName)) { - LOG(ERROR) << "Instance: " << identity_.name() << " attempting to " - << "connect to " << target.name() - << " using Instance name: " << target.instance() - << " without the " - << "service_manager{instance_name} capability."; + if (!target.instance().empty() && target.instance() != target.name() && + !options_.can_connect_to_other_services_with_any_instance_name) { + LOG(ERROR) + << "Instance: " << identity_.name() << " attempting to" + << " connect to " << target.name() + << " using Instance name: " << target.instance() << " without the" + << " 'can_connect_to_other_services_with_any_instance_name' option."; return mojom::ConnectResult::ACCESS_DENIED; } @@ -950,12 +945,13 @@ void ServiceManager::Connect(std::unique_ptr<ConnectParams> params) { catalog::ServiceOptions::InstanceSharingType::SINGLETON; const Identity original_target(params->target()); - // Services that request "all_users" class from the Service Manager are + // Services that have "shared_instance_across_users" value of + // "instance_sharing" option are // allowed to field connection requests from any user. They also run with a // synthetic user id generated here. The user id provided via Connect() is - // ignored. Additionally services with the "all_users" class are not tied to - // the lifetime of the service that started them, instead they are owned by - // the Service Manager. + // ignored. Additionally services with the "shared_instance_across_users" + // value are not tied to the lifetime of the service that started them, + // instead they are owned by the Service Manager. Identity source_identity_for_creation; InstanceType instance_type; diff --git a/chromium/services/service_manager/service_manifests.md b/chromium/services/service_manager/service_manifests.md new file mode 100644 index 00000000000..097ae3cd807 --- /dev/null +++ b/chromium/services/service_manager/service_manifests.md @@ -0,0 +1,299 @@ +# Service Manifests + +[TOC] + +## Overview + +Manifest files are used to configure security properties and +permissions for services, such as listing allowed sets of interfaces or +specifying a sandbox type. The files use JSON format and are usually +placed in the same directory as the service source files, but the path +is configurable in the BUILD.gn file for the service +(see [README.md](README.md#build-targets) for an example). + +## Terminology + +The Service Manager is responsible for starting new service instances on-demand, +and a given service may have any number of concurrently running instances. +The Service Manager disambiguates service instances by their unique identity. +A service's **identity** is represented by the 3-tuple of its service name, +user ID, and instance qualifier: + +### Service name + +A free-form -- typically short -- string identifying the the specific service +being run in the instance. + +### User ID + +A GUID string representing the identity of a user of the Service Manager. +Every running service instance is associated with a specific user ID. +This user ID is not related to any OS user ID or Chrome profile name. + +There is a special `kInheritUserID` id which causes the target identity of a +connection request to inherit the source identity. In Chrome, each +`BrowserContext` is modeled as a separate unique user ID so that renderer +instances running on behalf of different `BrowserContext`s run as different +"users". + +### Instance name + +An arbitrary free-form string used to disambiguate multiple instances of a +service for the same user. + +### Connections + +Every service instance has a Connector API it can use to issue requests to the +Service Manager. One such request is `BindInterface`, which allows the service +to bind an interface within another service instance. + +When binding an interface, the **source identity** refers to the service +initiating the bind request, and the **target identity** refers to the +destination service instance. Based on both the source and target identities, +the Service Manager may choose to start a new service instance, reuse an +existing instance as the destination for the bind request or deny the request. + +### Interface provider + +InterfaceProvider is a Mojo +[interface](https://cs.chromium.org/chromium/src/services/service_manager/public/mojom/interface_provider.mojom) +for providing an implementation of an interface by name. It is implemented by +different hosts (for frames and workers) in the browser, and the manifest +allows to specify different sets of capabilities exposed by these hosts. + +## File structure + +### name (string) + +A unique identifier that is used to refer to the service. + +### display\_name (string) + +A human-readable name which can have any descriptive value. Not user-visible. + +### sandbox\_type (string) + +An optional field that provides access to several types of sandboxes. +Inside a sandbox, by default the service is essentially restricted to CPU and +memory access only. Common values are: + +* `utility` (default) - also allows full access to one configurable directory; +* `none` - no sandboxing is applied; +* `none_and_elevated` - under Windows, no sandboxing is applied and privileges +are elevated. + +If the service cannot run with a sandbox type of utility, elevated, or none, +please reach out to the security team. + +### options (dictionary) + +This field can be used to specify values for any of the following options: + +#### instance\_sharing (string) + +Determines which parameters result in the creation of a new service +instance on an incoming service start/interface bind request. + +Possible values: + +##### none (default) + +When one service is connecting to another, checks are performed to either find +an existing instance that matches the target identity or create a new one if +no match is found. +By default, all three identity components (service name, user id and instance +name) are used to find a match. + +See +[advice](https://chromium.googlesource.com/chromium/src/+/master/docs/servicification.md#is-your-service-global-or-per_browsercontext) +on using this option. + +Example: +[identity](https://cs.chromium.org/chromium/src/services/identity/manifest.json) + +##### shared\_instance\_across\_users + +In this case, the user id parameter is ignored when looking for a matching +target instance, so when connecting with different user IDs (but the same +service name and instance name), an existing instance (if any) will be reused. + +Example: +[data_decoder](https://cs.chromium.org/chromium/src/services/data_decoder/manifest.json) + +##### singleton + +In this case, both user id and instance name parameters are ignored. +Only one service instance is created and used for all connections to this +service. + +Example: +[download_manager](https://cs.chromium.org/chromium/src/chrome/browser/android/download/manifest.json) + +#### can\_connect\_to\_other\_services\_as\_any\_user (bool) + +This option allows a service to make outgoing requests with a user id +other than the one it was created with. + +**Note:** this privilege must only be granted to services that are trusted +at the same level as the browser process itself. + +Example: +[content_browser](https://cs.chromium.org/chromium/src/content/public/app/mojo/content_browser_manifest.json) + +The browser process manages all `BrowserContexts`, so it needs this privilege +to act on behalf of different users. + +#### can\_connect\_to\_other\_services\_with\_any\_instance\_name (bool) + +This option allows a service to specify an instance name that is +different from the service name when connecting. + +**Note:** this privilege must only be granted to services that are trusted +at the same level as the browser process itself. + +Example: +[chrome_browser](https://cs.chromium.org/chromium/src/chrome/app/chrome_manifest.json) + +Code in chrome_browser calls an XML parsing library function, which generates a +random instance name to +[isolate unrelated decode operations](https://cs.chromium.org/chromium/src/services/data_decoder/public/cpp/safe_xml_parser.cc?l=50). + +#### can\_create\_other\_service\_instances (bool) + +This option allows a service to register arbitrary new service instances it +creates on its own. + +**Note:** this privilege must only be granted to services that are trusted +at least at the same level as the Service Manager itself. + +Example: +[content_browser](https://cs.chromium.org/chromium/src/content/public/app/mojo/content_browser_manifest.json) + +The browser manages render processes, and thus needs this privilege to manage +the content_renderer instances on behalf of the service manager. + +### interface\_provider\_specs (dictionary) + +The interface provider spec is a dictionary keyed by interface provider +name, with each value representing the capability spec for that +provider. +Each capability spec defines an optional "provides" key and an optional +"requires" key. + +Every interface provider spec (often exclusively) contains one standard +capability spec named “service_manager:connector”. This is the +capability spec enforced when inter-service connections are made from a +service's `Connector` interface. + +Some other examples of capability specs are things like "navigation:frame", +which enforces capability specs for interfaces retrieved through a +frame's `InterfaceProvider`. + +See [README.md](README.md#service-manifests) for some examples. + +**Note:** Since multiple interface provider support makes the manifest files +harder to understand, there is a plan to simplify this section +(see https://crbug.com/718652 for more details). + +#### provides (dictionary) + +This optional section specifies a set of capabilities provided by the service. +A capability is a set of accessible interfaces. + +For example, suppose we have the following capabilities: + +* useful_capability + * useful\_interface\_1 + * useful\_interface\_2 +* better\_capability + * better\_interface + +The `provides` section would be: +``` json + "provides": { + "useful_capability": [ + "useful_interface_1", + "useful_interface_2" ], + "better_capability": [ + "better_interface" ], + } +``` + +#### requires (dictionary) + +This optional section is also a dictionary, keyed by remote service +names (the service name must match the "name" value in the remote service's +manifest). Each value is a list of capabilities required by this service +from the listed remote service. This section does not name interfaces, +only capabilities. + +For example, if our capability requires service "some_capability" from +service "another_service", the "requires" section will look like this: + +``` json +"requires": { + "another_service": [ "some_capability" ] +``` + +An asterisk is a wildcard which means that any listed capabilities are +required of any service that provides them. For example: + +``` json +"requires": { + "*": [ "some_capability" ] +``` + +In the above example, this service can access any interface provided as part +of a capability named "some_capability" in any service on the system. + +While generally discouraged, there are use cases for wildcards. +Consider building something like a UI shell with a launcher that wants to +tell any service "please launch with some default UI". The services providing +a "launch" capability would be expected to include access to an +"`app_launcher.mojom.Launcher`" interface as part of that capability, with an +implementation that does something useful like open some default UI for the +service: + +``` json +"provides": { + "launch": [ "app_launcher.mojom.Launcher" ] +} +``` + +Then our app launcher service would expect to be able to bind +`app_launcher.mojom.Launcher` in any service that provides that capability: + + +``` json +"requires": { + "*" : [ "launch" ] +} +``` + +### required\_files (dictionary) + +Allows the (sandboxed) service to specify a list of platform-specific files it +needs to access from disk while running. Each file is keyed by an arbitrary +name chosen by the service, and references a file path relative to the Service +Manager embedder (e.g. the Chrome binary) at runtime. + +Files specified here will be opened by the Service Manager before launching a +new instance of this service, and their opened file descriptors will be passed +to the new sandboxed process. The file descriptors may be accessed via +`base::FileDescriptorStore` using the corresponding key string from the +manifest. + +**Note:** This is currently only supported on Android and Linux-based desktop +builds. + +#### path (string) + +Path to the file relative to the executable location at runtime. + +#### platform (string) + +The platform this file is required on. +Possible values: + +* `linux` +* `android` diff --git a/chromium/services/service_manager/tests/connect/connect_test.mojom b/chromium/services/service_manager/tests/connect/connect_test.mojom index 24fa8b8073b..666d1d095f8 100644 --- a/chromium/services/service_manager/tests/connect/connect_test.mojom +++ b/chromium/services/service_manager/tests/connect/connect_test.mojom @@ -32,11 +32,11 @@ interface StandaloneApp { ConnectToClassInterface() => (string class_interface_response, string title); }; -interface UserIdTest { - // Attempts to connect to mojo:connect_test_class_app as |user_id|. +interface IdentityTest { + // Attempts to connect to mojo:connect_test_class_app as |target|. // The callback takes the connection response result, and the identity - // mojo:connect_test_class_app was run as, which should match |user_id|. - ConnectToClassAppAsDifferentUser(service_manager.mojom.Identity target) => + // mojo:connect_test_class_app was run as, which should match |target|. + ConnectToClassAppWithIdentity(service_manager.mojom.Identity target) => (int32 connection_result, service_manager.mojom.Identity target); }; diff --git a/chromium/services/shape_detection/BUILD.gn b/chromium/services/shape_detection/BUILD.gn index d7722279b6e..5a84a11d115 100644 --- a/chromium/services/shape_detection/BUILD.gn +++ b/chromium/services/shape_detection/BUILD.gn @@ -56,6 +56,8 @@ source_set("lib") { ] } + configs += [ "//build/config/compiler:wexit_time_destructors" ] + deps = [ "//mojo/public/cpp/bindings", "//ui/gfx", diff --git a/chromium/services/shape_detection/android/java/src/org/chromium/shape_detection/BarcodeDetectionImpl.java b/chromium/services/shape_detection/android/java/src/org/chromium/shape_detection/BarcodeDetectionImpl.java new file mode 100644 index 00000000000..2d34df7340b --- /dev/null +++ b/chromium/services/shape_detection/android/java/src/org/chromium/shape_detection/BarcodeDetectionImpl.java @@ -0,0 +1,92 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.shape_detection; + +import android.graphics.Point; +import android.graphics.Rect; +import android.util.SparseArray; + +import com.google.android.gms.vision.Frame; +import com.google.android.gms.vision.barcode.Barcode; +import com.google.android.gms.vision.barcode.BarcodeDetector; + +import org.chromium.base.ContextUtils; +import org.chromium.base.Log; +import org.chromium.gfx.mojom.PointF; +import org.chromium.gfx.mojom.RectF; +import org.chromium.mojo.system.MojoException; +import org.chromium.shape_detection.mojom.BarcodeDetection; +import org.chromium.shape_detection.mojom.BarcodeDetectionResult; +import org.chromium.shape_detection.mojom.BarcodeDetectorOptions; + +/** + * Implementation of mojo BarcodeDetection, using Google Play Services vision package. + */ +public class BarcodeDetectionImpl implements BarcodeDetection { + private static final String TAG = "BarcodeDetectionImpl"; + + private BarcodeDetector mBarcodeDetector; + + public BarcodeDetectionImpl(BarcodeDetectorOptions options) { + // TODO(mcasas): extract the barcode formats to hunt for out of + // |options| and use them for building |mBarcodeDetector|. + // https://crbug.com/582266. + mBarcodeDetector = + new BarcodeDetector.Builder(ContextUtils.getApplicationContext()).build(); + } + + @Override + public void detect(org.chromium.skia.mojom.Bitmap bitmapData, DetectResponse callback) { + // The vision library will be downloaded the first time the API is used + // on the device; this happens "fast", but it might have not completed, + // bail in this case. Also, the API was disabled between and v.9.0 and + // v.9.2, see https://developers.google.com/android/guides/releases. + if (!mBarcodeDetector.isOperational()) { + Log.e(TAG, "BarcodeDetector is not operational"); + callback.call(new BarcodeDetectionResult[0]); + return; + } + + Frame frame = BitmapUtils.convertToFrame(bitmapData); + if (frame == null) { + Log.e(TAG, "Error converting Mojom Bitmap to Frame"); + callback.call(new BarcodeDetectionResult[0]); + return; + } + + final SparseArray<Barcode> barcodes = mBarcodeDetector.detect(frame); + + BarcodeDetectionResult[] barcodeArray = new BarcodeDetectionResult[barcodes.size()]; + for (int i = 0; i < barcodes.size(); i++) { + barcodeArray[i] = new BarcodeDetectionResult(); + final Barcode barcode = barcodes.valueAt(i); + barcodeArray[i].rawValue = barcode.rawValue; + final Rect rect = barcode.getBoundingBox(); + barcodeArray[i].boundingBox = new RectF(); + barcodeArray[i].boundingBox.x = rect.left; + barcodeArray[i].boundingBox.y = rect.top; + barcodeArray[i].boundingBox.width = rect.width(); + barcodeArray[i].boundingBox.height = rect.height(); + final Point[] corners = barcode.cornerPoints; + barcodeArray[i].cornerPoints = new PointF[corners.length]; + for (int j = 0; j < corners.length; j++) { + barcodeArray[i].cornerPoints[j] = new PointF(); + barcodeArray[i].cornerPoints[j].x = corners[j].x; + barcodeArray[i].cornerPoints[j].y = corners[j].y; + } + } + callback.call(barcodeArray); + } + + @Override + public void close() { + mBarcodeDetector.release(); + } + + @Override + public void onConnectionError(MojoException e) { + close(); + } +} diff --git a/chromium/services/shape_detection/android/java/src/org/chromium/shape_detection/BarcodeDetectionProviderImpl.java b/chromium/services/shape_detection/android/java/src/org/chromium/shape_detection/BarcodeDetectionProviderImpl.java new file mode 100644 index 00000000000..f4410a88f3e --- /dev/null +++ b/chromium/services/shape_detection/android/java/src/org/chromium/shape_detection/BarcodeDetectionProviderImpl.java @@ -0,0 +1,65 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.shape_detection; + +import com.google.android.gms.common.ConnectionResult; +import com.google.android.gms.common.GoogleApiAvailability; +import com.google.android.gms.vision.barcode.Barcode; + +import org.chromium.base.ContextUtils; +import org.chromium.base.Log; +import org.chromium.mojo.bindings.InterfaceRequest; +import org.chromium.mojo.system.MojoException; +import org.chromium.services.service_manager.InterfaceFactory; +import org.chromium.shape_detection.mojom.BarcodeDetection; +import org.chromium.shape_detection.mojom.BarcodeDetectionProvider; +import org.chromium.shape_detection.mojom.BarcodeDetectorOptions; + +/** + * Service provider to create BarcodeDetection services + */ +public class BarcodeDetectionProviderImpl implements BarcodeDetectionProvider { + private static final String TAG = "BarcodeProviderImpl"; + + public BarcodeDetectionProviderImpl() {} + + @Override + public void createBarcodeDetection( + InterfaceRequest<BarcodeDetection> request, BarcodeDetectorOptions options) { + BarcodeDetection.MANAGER.bind(new BarcodeDetectionImpl(options), request); + } + + @Override + public void enumerateSupportedFormats(EnumerateSupportedFormatsResponse callback) { + int[] supportedFormats = {Barcode.AZTEC, Barcode.CODE_128, Barcode.CODE_39, Barcode.CODE_93, + Barcode.CODABAR, Barcode.DATA_MATRIX, Barcode.EAN_13, Barcode.EAN_8, Barcode.ITF, + Barcode.PDF417, Barcode.QR_CODE, Barcode.UPC_A, Barcode.UPC_E}; + callback.call(supportedFormats); + } + + @Override + public void close() {} + + @Override + public void onConnectionError(MojoException e) {} + + /** + * A factory class to register BarcodeDetectionProvider interface. + */ + public static class Factory implements InterfaceFactory<BarcodeDetectionProvider> { + public Factory() {} + + @Override + public BarcodeDetectionProvider createImpl() { + if (GoogleApiAvailability.getInstance().isGooglePlayServicesAvailable( + ContextUtils.getApplicationContext()) + != ConnectionResult.SUCCESS) { + Log.e(TAG, "Google Play Services not available"); + return null; + } + return new BarcodeDetectionProviderImpl(); + } + } +} diff --git a/chromium/services/shape_detection/android/java/src/org/chromium/shape_detection/BitmapUtils.java b/chromium/services/shape_detection/android/java/src/org/chromium/shape_detection/BitmapUtils.java new file mode 100644 index 00000000000..492f44def81 --- /dev/null +++ b/chromium/services/shape_detection/android/java/src/org/chromium/shape_detection/BitmapUtils.java @@ -0,0 +1,57 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.shape_detection; + +import android.graphics.Bitmap; + +import com.google.android.gms.vision.Frame; + +import org.chromium.mojo_base.BigBufferUtil; +import org.chromium.skia.mojom.ColorType; + +import java.nio.ByteBuffer; + +/** + * Utility class to convert a Bitmap to a GMS core YUV Frame. + */ +public class BitmapUtils { + public static Bitmap convertToBitmap(org.chromium.skia.mojom.Bitmap bitmapData) { + if (bitmapData.imageInfo == null) return null; + int width = bitmapData.imageInfo.width; + int height = bitmapData.imageInfo.height; + final long numPixels = (long) width * height; + // TODO(mcasas): https://crbug.com/670028 homogeneize overflow checking. + if (bitmapData.pixelData == null || width <= 0 || height <= 0 + || numPixels > (Long.MAX_VALUE / 4)) { + return null; + } + + if (bitmapData.imageInfo.colorType != ColorType.RGBA_8888 + && bitmapData.imageInfo.colorType != ColorType.BGRA_8888) { + return null; + } + + ByteBuffer imageBuffer = + ByteBuffer.wrap(BigBufferUtil.getBytesFromBigBuffer(bitmapData.pixelData)); + if (imageBuffer.capacity() <= 0) { + return null; + } + + Bitmap bitmap = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888); + bitmap.copyPixelsFromBuffer(imageBuffer); + + return bitmap; + } + + public static Frame convertToFrame(org.chromium.skia.mojom.Bitmap bitmapData) { + Bitmap bitmap = convertToBitmap(bitmapData); + if (bitmap == null) { + return null; + } + + // This constructor implies a pixel format conversion to YUV. + return new Frame.Builder().setBitmap(bitmap).build(); + } +} diff --git a/chromium/services/shape_detection/android/java/src/org/chromium/shape_detection/FaceDetectionImpl.java b/chromium/services/shape_detection/android/java/src/org/chromium/shape_detection/FaceDetectionImpl.java new file mode 100644 index 00000000000..6cbbd6591cd --- /dev/null +++ b/chromium/services/shape_detection/android/java/src/org/chromium/shape_detection/FaceDetectionImpl.java @@ -0,0 +1,113 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.shape_detection; + +import android.graphics.Bitmap; +import android.graphics.Canvas; +import android.graphics.PointF; +import android.media.FaceDetector; +import android.media.FaceDetector.Face; + +import org.chromium.base.Log; +import org.chromium.base.task.AsyncTask; +import org.chromium.gfx.mojom.RectF; +import org.chromium.mojo.system.MojoException; +import org.chromium.shape_detection.mojom.FaceDetection; +import org.chromium.shape_detection.mojom.FaceDetectionResult; +import org.chromium.shape_detection.mojom.FaceDetectorOptions; +import org.chromium.shape_detection.mojom.Landmark; + +/** + * Android implementation of the FaceDetection service defined in + * services/shape_detection/public/mojom/facedetection.mojom + */ +public class FaceDetectionImpl implements FaceDetection { + private static final String TAG = "FaceDetectionImpl"; + private static final int MAX_FACES = 32; + private final boolean mFastMode; + private final int mMaxFaces; + + FaceDetectionImpl(FaceDetectorOptions options) { + mFastMode = options.fastMode; + mMaxFaces = Math.min(options.maxDetectedFaces, MAX_FACES); + } + + @Override + public void detect(org.chromium.skia.mojom.Bitmap bitmapData, final DetectResponse callback) { + Bitmap bitmap = BitmapUtils.convertToBitmap(bitmapData); + if (bitmap == null) { + Log.e(TAG, "Error converting Mojom Bitmap to Android Bitmap"); + callback.call(new FaceDetectionResult[0]); + return; + } + + // FaceDetector requires an even width, so pad the image if the width is odd. + // https://developer.android.com/reference/android/media/FaceDetector.html#FaceDetector(int, int, int) + final int width = bitmapData.imageInfo.width + (bitmapData.imageInfo.width % 2); + final int height = bitmapData.imageInfo.height; + if (width != bitmapData.imageInfo.width) { + Bitmap paddedBitmap = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888); + Canvas canvas = new Canvas(paddedBitmap); + canvas.drawBitmap(bitmap, 0, 0, null); + bitmap = paddedBitmap; + } + + // A Bitmap must be in 565 format for findFaces() to work. See + // http://androidxref.com/7.0.0_r1/xref/frameworks/base/media/java/android/media/FaceDetector.java#124 + // + // It turns out that FaceDetector is not able to detect correctly if + // simply using pixmap.setConfig(). The reason might be that findFaces() + // needs non-premultiplied ARGB arrangement, while the alpha type in the + // original image is premultiplied. We can use getPixels() which does + // the unmultiplication while copying to a new array. See + // http://androidxref.com/7.0.0_r1/xref/frameworks/base/graphics/java/android/graphics/Bitmap.java#538 + int[] pixels = new int[width * height]; + bitmap.getPixels(pixels, 0, width, 0, 0, width, height); + final Bitmap unPremultipliedBitmap = + Bitmap.createBitmap(pixels, width, height, Bitmap.Config.RGB_565); + + // FaceDetector creation and findFaces() might take a long time and trigger a + // "StrictMode policy violation": they should happen in a background thread. + AsyncTask.THREAD_POOL_EXECUTOR.execute(new Runnable() { + @Override + public void run() { + final FaceDetector detector = new FaceDetector(width, height, mMaxFaces); + Face[] detectedFaces = new Face[mMaxFaces]; + // findFaces() will stop at |mMaxFaces|. + final int numberOfFaces = detector.findFaces(unPremultipliedBitmap, detectedFaces); + + FaceDetectionResult[] faceArray = new FaceDetectionResult[numberOfFaces]; + + for (int i = 0; i < numberOfFaces; i++) { + faceArray[i] = new FaceDetectionResult(); + + final Face face = detectedFaces[i]; + final PointF midPoint = new PointF(); + face.getMidPoint(midPoint); + final float eyesDistance = face.eyesDistance(); + + faceArray[i].boundingBox = new RectF(); + faceArray[i].boundingBox.x = midPoint.x - eyesDistance; + faceArray[i].boundingBox.y = midPoint.y - eyesDistance; + faceArray[i].boundingBox.width = 2 * eyesDistance; + faceArray[i].boundingBox.height = 2 * eyesDistance; + // TODO(xianglu): Consider adding Face.confidence and Face.pose. + + faceArray[i].landmarks = new Landmark[0]; + } + + callback.call(faceArray); + } + }); + } + + @Override + public void close() {} + + @Override + public void onConnectionError(MojoException e) { + close(); + } +} diff --git a/chromium/services/shape_detection/android/java/src/org/chromium/shape_detection/FaceDetectionImplGmsCore.java b/chromium/services/shape_detection/android/java/src/org/chromium/shape_detection/FaceDetectionImplGmsCore.java new file mode 100644 index 00000000000..d29cb056fd3 --- /dev/null +++ b/chromium/services/shape_detection/android/java/src/org/chromium/shape_detection/FaceDetectionImplGmsCore.java @@ -0,0 +1,172 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.shape_detection; + +import android.graphics.PointF; +import android.util.SparseArray; + +import com.google.android.gms.vision.Frame; +import com.google.android.gms.vision.face.Face; +import com.google.android.gms.vision.face.FaceDetector; +import com.google.android.gms.vision.face.Landmark; + +import org.chromium.base.ContextUtils; +import org.chromium.base.Log; +import org.chromium.gfx.mojom.RectF; +import org.chromium.mojo.system.MojoException; +import org.chromium.shape_detection.mojom.FaceDetection; +import org.chromium.shape_detection.mojom.FaceDetectionResult; +import org.chromium.shape_detection.mojom.FaceDetectorOptions; +import org.chromium.shape_detection.mojom.LandmarkType; + +import java.util.ArrayList; +import java.util.List; + +/** + * Google Play services implementation of the FaceDetection service defined in + * services/shape_detection/public/mojom/facedetection.mojom + */ +public class FaceDetectionImplGmsCore implements FaceDetection { + private static final String TAG = "FaceDetectionImpl"; + private static final int MAX_FACES = 32; + // Maximum rotation around the z-axis allowed when computing a tighter bounding box for the + // detected face. + private static final int MAX_EULER_Z = 15; + private final int mMaxFaces; + private final boolean mFastMode; + private final FaceDetector mFaceDetector; + + FaceDetectionImplGmsCore(FaceDetectorOptions options) { + FaceDetector.Builder builder = + new FaceDetector.Builder(ContextUtils.getApplicationContext()); + mMaxFaces = Math.min(options.maxDetectedFaces, MAX_FACES); + mFastMode = options.fastMode; + + try { + builder.setMode(mFastMode ? FaceDetector.FAST_MODE : FaceDetector.ACCURATE_MODE); + builder.setLandmarkType(FaceDetector.ALL_LANDMARKS); + if (mMaxFaces == 1) { + builder.setProminentFaceOnly(true); + } + } catch (IllegalArgumentException e) { + Log.e(TAG, "Unexpected exception " + e); + assert false; + } + + mFaceDetector = builder.build(); + } + + @Override + public void detect(org.chromium.skia.mojom.Bitmap bitmapData, DetectResponse callback) { + // The vision library will be downloaded the first time the API is used + // on the device; this happens "fast", but it might have not completed, + // bail in this case. + if (!mFaceDetector.isOperational()) { + Log.e(TAG, "FaceDetector is not operational"); + + // Fallback to Android's FaceDetectionImpl. + FaceDetectorOptions options = new FaceDetectorOptions(); + options.fastMode = mFastMode; + options.maxDetectedFaces = mMaxFaces; + FaceDetectionImpl detector = new FaceDetectionImpl(options); + detector.detect(bitmapData, callback); + return; + } + + Frame frame = BitmapUtils.convertToFrame(bitmapData); + if (frame == null) { + Log.e(TAG, "Error converting Mojom Bitmap to Frame"); + callback.call(new FaceDetectionResult[0]); + return; + } + + final SparseArray<Face> faces = mFaceDetector.detect(frame); + + FaceDetectionResult[] faceArray = new FaceDetectionResult[faces.size()]; + for (int i = 0; i < faces.size(); i++) { + faceArray[i] = new FaceDetectionResult(); + final Face face = faces.valueAt(i); + + final List<Landmark> landmarks = face.getLandmarks(); + ArrayList<org.chromium.shape_detection.mojom.Landmark> mojoLandmarks = + new ArrayList<org.chromium.shape_detection.mojom.Landmark>(landmarks.size()); + + int leftEyeIndex = -1; + int rightEyeIndex = -1; + int bottomMouthIndex = -1; + for (int j = 0; j < landmarks.size(); j++) { + final Landmark landmark = landmarks.get(j); + final int landmarkType = landmark.getType(); + if (landmarkType != Landmark.LEFT_EYE && landmarkType != Landmark.RIGHT_EYE + && landmarkType != Landmark.BOTTOM_MOUTH + && landmarkType != Landmark.NOSE_BASE) { + continue; + } + + org.chromium.shape_detection.mojom.Landmark mojoLandmark = + new org.chromium.shape_detection.mojom.Landmark(); + mojoLandmark.locations = new org.chromium.gfx.mojom.PointF[1]; + mojoLandmark.locations[0] = new org.chromium.gfx.mojom.PointF(); + mojoLandmark.locations[0].x = landmark.getPosition().x; + mojoLandmark.locations[0].y = landmark.getPosition().y; + + if (landmarkType == Landmark.LEFT_EYE) { + mojoLandmark.type = LandmarkType.EYE; + leftEyeIndex = j; + } else if (landmarkType == Landmark.RIGHT_EYE) { + mojoLandmark.type = LandmarkType.EYE; + rightEyeIndex = j; + } else if (landmarkType == Landmark.BOTTOM_MOUTH) { + mojoLandmark.type = LandmarkType.MOUTH; + bottomMouthIndex = j; + } else { + assert landmarkType == Landmark.NOSE_BASE; + mojoLandmark.type = LandmarkType.NOSE; + } + mojoLandmarks.add(mojoLandmark); + } + faceArray[i].landmarks = mojoLandmarks.toArray( + new org.chromium.shape_detection.mojom.Landmark[mojoLandmarks.size()]); + + final PointF corner = face.getPosition(); + faceArray[i].boundingBox = new RectF(); + if (leftEyeIndex != -1 && rightEyeIndex != -1 + && Math.abs(face.getEulerZ()) < MAX_EULER_Z) { + // Tighter calculation of the bounding box because the GMScore + // and Android Face APIs give different results. + final PointF leftEyePoint = landmarks.get(leftEyeIndex).getPosition(); + final PointF rightEyePoint = landmarks.get(rightEyeIndex).getPosition(); + final float eyesDistance = leftEyePoint.x - rightEyePoint.x; + final float eyeMouthDistance = bottomMouthIndex != -1 + ? landmarks.get(bottomMouthIndex).getPosition().y - leftEyePoint.y + : -1; + final PointF midEyePoint = + new PointF(corner.x + face.getWidth() / 2, leftEyePoint.y); + faceArray[i].boundingBox.x = 2 * rightEyePoint.x - midEyePoint.x; + faceArray[i].boundingBox.y = midEyePoint.y - eyesDistance; + faceArray[i].boundingBox.width = 2 * eyesDistance; + faceArray[i].boundingBox.height = eyeMouthDistance > eyesDistance + ? eyeMouthDistance + eyesDistance + : 2 * eyesDistance; + } else { + faceArray[i].boundingBox.x = corner.x; + faceArray[i].boundingBox.y = corner.y; + faceArray[i].boundingBox.width = face.getWidth(); + faceArray[i].boundingBox.height = face.getHeight(); + } + } + callback.call(faceArray); + } + + @Override + public void close() { + mFaceDetector.release(); + } + + @Override + public void onConnectionError(MojoException e) { + close(); + } +} diff --git a/chromium/services/shape_detection/android/java/src/org/chromium/shape_detection/FaceDetectionProviderImpl.java b/chromium/services/shape_detection/android/java/src/org/chromium/shape_detection/FaceDetectionProviderImpl.java new file mode 100644 index 00000000000..2195fb1fe02 --- /dev/null +++ b/chromium/services/shape_detection/android/java/src/org/chromium/shape_detection/FaceDetectionProviderImpl.java @@ -0,0 +1,56 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.shape_detection; + +import com.google.android.gms.common.ConnectionResult; +import com.google.android.gms.common.GoogleApiAvailability; + +import org.chromium.base.ContextUtils; +import org.chromium.mojo.bindings.InterfaceRequest; +import org.chromium.mojo.system.MojoException; +import org.chromium.services.service_manager.InterfaceFactory; +import org.chromium.shape_detection.mojom.FaceDetection; +import org.chromium.shape_detection.mojom.FaceDetectionProvider; +import org.chromium.shape_detection.mojom.FaceDetectorOptions; + +/** + * Service provider to create FaceDetection services + */ +public class FaceDetectionProviderImpl implements FaceDetectionProvider { + public FaceDetectionProviderImpl() {} + + @Override + public void createFaceDetection( + InterfaceRequest<FaceDetection> request, FaceDetectorOptions options) { + final boolean isGmsCoreSupported = + GoogleApiAvailability.getInstance().isGooglePlayServicesAvailable( + ContextUtils.getApplicationContext()) + == ConnectionResult.SUCCESS; + + if (isGmsCoreSupported) { + FaceDetection.MANAGER.bind(new FaceDetectionImplGmsCore(options), request); + } else { + FaceDetection.MANAGER.bind(new FaceDetectionImpl(options), request); + } + } + + @Override + public void close() {} + + @Override + public void onConnectionError(MojoException e) {} + + /** + * A factory class to register FaceDetectionProvider interface. + */ + public static class Factory implements InterfaceFactory<FaceDetectionProvider> { + public Factory() {} + + @Override + public FaceDetectionProvider createImpl() { + return new FaceDetectionProviderImpl(); + } + } +} diff --git a/chromium/services/shape_detection/android/java/src/org/chromium/shape_detection/InterfaceRegistrar.java b/chromium/services/shape_detection/android/java/src/org/chromium/shape_detection/InterfaceRegistrar.java new file mode 100644 index 00000000000..fd6b559e0cb --- /dev/null +++ b/chromium/services/shape_detection/android/java/src/org/chromium/shape_detection/InterfaceRegistrar.java @@ -0,0 +1,29 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.shape_detection; + +import org.chromium.base.annotations.CalledByNative; +import org.chromium.base.annotations.JNINamespace; +import org.chromium.mojo.system.impl.CoreImpl; +import org.chromium.services.service_manager.InterfaceRegistry; +import org.chromium.shape_detection.mojom.BarcodeDetectionProvider; +import org.chromium.shape_detection.mojom.FaceDetectionProvider; +import org.chromium.shape_detection.mojom.TextDetection; + +@JNINamespace("shape_detection") +class InterfaceRegistrar { + @CalledByNative + static void createInterfaceRegistryForContext(int nativeHandle) { + // Note: The bindings code manages the lifetime of this object, so it + // is not necessary to hold on to a reference to it explicitly. + InterfaceRegistry registry = InterfaceRegistry.create( + CoreImpl.getInstance().acquireNativeHandle(nativeHandle).toMessagePipeHandle()); + registry.addInterface( + BarcodeDetectionProvider.MANAGER, new BarcodeDetectionProviderImpl.Factory()); + registry.addInterface( + FaceDetectionProvider.MANAGER, new FaceDetectionProviderImpl.Factory()); + registry.addInterface(TextDetection.MANAGER, new TextDetectionImpl.Factory()); + } +} diff --git a/chromium/services/shape_detection/android/java/src/org/chromium/shape_detection/TextDetectionImpl.java b/chromium/services/shape_detection/android/java/src/org/chromium/shape_detection/TextDetectionImpl.java new file mode 100644 index 00000000000..ec20ea42f19 --- /dev/null +++ b/chromium/services/shape_detection/android/java/src/org/chromium/shape_detection/TextDetectionImpl.java @@ -0,0 +1,109 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.shape_detection; + +import android.graphics.Point; +import android.graphics.Rect; +import android.util.SparseArray; + +import com.google.android.gms.common.ConnectionResult; +import com.google.android.gms.common.GoogleApiAvailability; +import com.google.android.gms.vision.Frame; +import com.google.android.gms.vision.text.TextBlock; +import com.google.android.gms.vision.text.TextRecognizer; + +import org.chromium.base.ContextUtils; +import org.chromium.base.Log; +import org.chromium.gfx.mojom.PointF; +import org.chromium.gfx.mojom.RectF; +import org.chromium.mojo.system.MojoException; +import org.chromium.services.service_manager.InterfaceFactory; +import org.chromium.shape_detection.mojom.TextDetection; +import org.chromium.shape_detection.mojom.TextDetectionResult; + + +/** + * Implementation of mojo TextDetection, using Google Play Services vision package. + */ +public class TextDetectionImpl implements TextDetection { + private static final String TAG = "TextDetectionImpl"; + + private TextRecognizer mTextRecognizer; + + public TextDetectionImpl() { + mTextRecognizer = new TextRecognizer.Builder(ContextUtils.getApplicationContext()).build(); + } + + @Override + public void detect(org.chromium.skia.mojom.Bitmap bitmapData, DetectResponse callback) { + // The vision library will be downloaded the first time the API is used + // on the device; this happens "fast", but it might have not completed, + // bail in this case. Also, the API was disabled between and v.9.0 and + // v.9.2, see https://developers.google.com/android/guides/releases. + if (!mTextRecognizer.isOperational()) { + Log.e(TAG, "TextDetector is not operational"); + callback.call(new TextDetectionResult[0]); + return; + } + + Frame frame = BitmapUtils.convertToFrame(bitmapData); + if (frame == null) { + Log.e(TAG, "Error converting Mojom Bitmap to Frame"); + callback.call(new TextDetectionResult[0]); + return; + } + + final SparseArray<TextBlock> textBlocks = mTextRecognizer.detect(frame); + + TextDetectionResult[] detectedTextArray = new TextDetectionResult[textBlocks.size()]; + for (int i = 0; i < textBlocks.size(); i++) { + detectedTextArray[i] = new TextDetectionResult(); + final TextBlock textBlock = textBlocks.valueAt(i); + detectedTextArray[i].rawValue = textBlock.getValue(); + final Rect rect = textBlock.getBoundingBox(); + detectedTextArray[i].boundingBox = new RectF(); + detectedTextArray[i].boundingBox.x = rect.left; + detectedTextArray[i].boundingBox.y = rect.top; + detectedTextArray[i].boundingBox.width = rect.width(); + detectedTextArray[i].boundingBox.height = rect.height(); + final Point[] corners = textBlock.getCornerPoints(); + detectedTextArray[i].cornerPoints = new PointF[corners.length]; + for (int j = 0; j < corners.length; j++) { + detectedTextArray[i].cornerPoints[j] = new PointF(); + detectedTextArray[i].cornerPoints[j].x = corners[j].x; + detectedTextArray[i].cornerPoints[j].y = corners[j].y; + } + } + callback.call(detectedTextArray); + } + + @Override + public void close() { + mTextRecognizer.release(); + } + + @Override + public void onConnectionError(MojoException e) { + close(); + } + + /** + * A factory class to register TextDetection interface. + */ + public static class Factory implements InterfaceFactory<TextDetection> { + public Factory() {} + + @Override + public TextDetection createImpl() { + if (GoogleApiAvailability.getInstance().isGooglePlayServicesAvailable( + ContextUtils.getApplicationContext()) + != ConnectionResult.SUCCESS) { + Log.e(TAG, "Google Play Services not available"); + return null; + } + return new TextDetectionImpl(); + } + } +} diff --git a/chromium/services/shape_detection/android/javatests/src/org/chromium/shape_detection/BarcodeDetectionImplTest.java b/chromium/services/shape_detection/android/javatests/src/org/chromium/shape_detection/BarcodeDetectionImplTest.java new file mode 100644 index 00000000000..cfdc07f6987 --- /dev/null +++ b/chromium/services/shape_detection/android/javatests/src/org/chromium/shape_detection/BarcodeDetectionImplTest.java @@ -0,0 +1,66 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.shape_detection; + +import android.support.test.filters.SmallTest; + +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.chromium.base.test.BaseJUnit4ClassRunner; +import org.chromium.base.test.util.Feature; +import org.chromium.shape_detection.mojom.BarcodeDetection; +import org.chromium.shape_detection.mojom.BarcodeDetectionResult; +import org.chromium.shape_detection.mojom.BarcodeDetectorOptions; + +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.TimeUnit; + +/** + * Test suite for BarcodeDetectionImpl. + */ +@RunWith(BaseJUnit4ClassRunner.class) +public class BarcodeDetectionImplTest { + private static final org.chromium.skia.mojom.Bitmap QR_CODE_BITMAP = + TestUtils.mojoBitmapFromFile("qr_code.png"); + + private static BarcodeDetectionResult[] detect(org.chromium.skia.mojom.Bitmap mojoBitmap) { + BarcodeDetectorOptions options = new BarcodeDetectorOptions(); + BarcodeDetection detector = new BarcodeDetectionImpl(options); + + final ArrayBlockingQueue<BarcodeDetectionResult[]> queue = new ArrayBlockingQueue<>(1); + detector.detect(mojoBitmap, new BarcodeDetection.DetectResponse() { + @Override + public void call(BarcodeDetectionResult[] results) { + queue.add(results); + } + }); + BarcodeDetectionResult[] toReturn = null; + try { + toReturn = queue.poll(5L, TimeUnit.SECONDS); + } catch (InterruptedException e) { + Assert.fail("Could not get BarcodeDetectionResult: " + e.toString()); + } + Assert.assertNotNull(toReturn); + return toReturn; + } + + @Test + @SmallTest + @Feature({"ShapeDetection"}) + public void testDetectBase64ValidImageString() { + if (!TestUtils.IS_GMS_CORE_SUPPORTED) { + return; + } + BarcodeDetectionResult[] results = detect(QR_CODE_BITMAP); + Assert.assertEquals(1, results.length); + Assert.assertEquals("https://chromium.org", results[0].rawValue); + Assert.assertEquals(40.0, results[0].boundingBox.x, 0.0); + Assert.assertEquals(40.0, results[0].boundingBox.y, 0.0); + Assert.assertEquals(250.0, results[0].boundingBox.width, 0.0); + Assert.assertEquals(250.0, results[0].boundingBox.height, 0.0); + } +} diff --git a/chromium/services/shape_detection/android/javatests/src/org/chromium/shape_detection/FaceDetectionImplTest.java b/chromium/services/shape_detection/android/javatests/src/org/chromium/shape_detection/FaceDetectionImplTest.java new file mode 100644 index 00000000000..f57610a3646 --- /dev/null +++ b/chromium/services/shape_detection/android/javatests/src/org/chromium/shape_detection/FaceDetectionImplTest.java @@ -0,0 +1,204 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.shape_detection; + +import android.graphics.Bitmap; +import android.graphics.Canvas; +import android.graphics.Matrix; +import android.graphics.RectF; +import android.support.test.filters.SmallTest; + +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.chromium.base.test.BaseJUnit4ClassRunner; +import org.chromium.base.test.util.Feature; +import org.chromium.shape_detection.mojom.FaceDetection; +import org.chromium.shape_detection.mojom.FaceDetectionResult; +import org.chromium.shape_detection.mojom.FaceDetectorOptions; + +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.TimeUnit; + +/** + * Test suite for FaceDetectionImpl. + */ +@RunWith(BaseJUnit4ClassRunner.class) +public class FaceDetectionImplTest { + private static final org.chromium.skia.mojom.Bitmap MONA_LISA_BITMAP = + TestUtils.mojoBitmapFromFile("mona_lisa.jpg"); + private static final org.chromium.skia.mojom.Bitmap FACE_POSE_BITMAP = + TestUtils.mojoBitmapFromFile("face_pose.png"); + // Different versions of Android have different implementations of FaceDetector.findFaces(), so + // we have to use a large error threshold. + private static final double BOUNDING_BOX_POSITION_ERROR = 10.0; + private static final double BOUNDING_BOX_SIZE_ERROR = 5.0; + private static final float ACCURATE_MODE_SIZE = 2.0f; + private static enum DetectionProviderType { ANDROID, GMS_CORE } + + public FaceDetectionImplTest() {} + + private static FaceDetectionResult[] detect(org.chromium.skia.mojom.Bitmap mojoBitmap, + boolean fastMode, DetectionProviderType api) { + FaceDetectorOptions options = new FaceDetectorOptions(); + options.fastMode = fastMode; + options.maxDetectedFaces = 32; + FaceDetection detector = null; + if (api == DetectionProviderType.ANDROID) { + detector = new FaceDetectionImpl(options); + } else if (api == DetectionProviderType.GMS_CORE) { + detector = new FaceDetectionImplGmsCore(options); + } else { + assert false; + return null; + } + + final ArrayBlockingQueue<FaceDetectionResult[]> queue = new ArrayBlockingQueue<>(1); + detector.detect(mojoBitmap, new FaceDetection.DetectResponse() { + @Override + public void call(FaceDetectionResult[] results) { + queue.add(results); + } + }); + FaceDetectionResult[] toReturn = null; + try { + toReturn = queue.poll(5L, TimeUnit.SECONDS); + } catch (InterruptedException e) { + Assert.fail("Could not get FaceDetectionResult: " + e.toString()); + } + Assert.assertNotNull(toReturn); + return toReturn; + } + + private void detectSucceedsOnValidImage(DetectionProviderType api) { + FaceDetectionResult[] results = detect(MONA_LISA_BITMAP, true, api); + Assert.assertEquals(1, results.length); + Assert.assertEquals( + api == DetectionProviderType.GMS_CORE ? 4 : 0, results[0].landmarks.length); + Assert.assertEquals(40.0, results[0].boundingBox.width, BOUNDING_BOX_SIZE_ERROR); + Assert.assertEquals(40.0, results[0].boundingBox.height, BOUNDING_BOX_SIZE_ERROR); + Assert.assertEquals(24.0, results[0].boundingBox.x, BOUNDING_BOX_POSITION_ERROR); + Assert.assertEquals(20.0, results[0].boundingBox.y, BOUNDING_BOX_POSITION_ERROR); + } + + @Test + @SmallTest + @Feature({"ShapeDetection"}) + public void testDetectValidImageWithAndroidAPI() { + detectSucceedsOnValidImage(DetectionProviderType.ANDROID); + } + + @Test + @SmallTest + @Feature({"ShapeDetection"}) + public void testDetectValidImageWithGmsCore() { + if (TestUtils.IS_GMS_CORE_SUPPORTED) { + detectSucceedsOnValidImage(DetectionProviderType.GMS_CORE); + } + } + + @Test + @SmallTest + @Feature({"ShapeDetection"}) + public void testDetectHandlesOddWidthWithAndroidAPI() throws Exception { + // Pad the image so that the width is odd. + Bitmap paddedBitmap = Bitmap.createBitmap(MONA_LISA_BITMAP.imageInfo.width + 1, + MONA_LISA_BITMAP.imageInfo.height, Bitmap.Config.ARGB_8888); + Canvas canvas = new Canvas(paddedBitmap); + canvas.drawBitmap(BitmapUtils.convertToBitmap(MONA_LISA_BITMAP), 0, 0, null); + org.chromium.skia.mojom.Bitmap mojoBitmap = TestUtils.mojoBitmapFromBitmap(paddedBitmap); + Assert.assertEquals(1, mojoBitmap.imageInfo.width % 2); + + FaceDetectionResult[] results = detect(mojoBitmap, true, DetectionProviderType.ANDROID); + Assert.assertEquals(1, results.length); + Assert.assertEquals(40.0, results[0].boundingBox.width, BOUNDING_BOX_SIZE_ERROR); + Assert.assertEquals(40.0, results[0].boundingBox.height, BOUNDING_BOX_SIZE_ERROR); + Assert.assertEquals(24.0, results[0].boundingBox.x, BOUNDING_BOX_POSITION_ERROR); + Assert.assertEquals(20.0, results[0].boundingBox.y, BOUNDING_BOX_POSITION_ERROR); + } + + @Test + @SmallTest + @Feature({"ShapeDetection"}) + public void testDetectFacesInProfileWithGmsCore() { + if (!TestUtils.IS_GMS_CORE_SUPPORTED) { + return; + } + FaceDetectionResult[] fastModeResults = + detect(FACE_POSE_BITMAP, true, DetectionProviderType.GMS_CORE); + Assert.assertEquals(4, fastModeResults.length); + + FaceDetectionResult[] unorderedResults = + detect(FACE_POSE_BITMAP, false, DetectionProviderType.GMS_CORE); + FaceDetectionResult[] accurateModeResults = + new FaceDetectionResult[unorderedResults.length]; + for (int i = 0; i < accurateModeResults.length; i++) { + accurateModeResults[i] = new FaceDetectionResult(); + } + Assert.assertEquals(4, accurateModeResults.length); + // Order face results align with fast mode's order which is different from accurate mode. + accurateModeResults[0].boundingBox = unorderedResults[1].boundingBox; + accurateModeResults[1].boundingBox = unorderedResults[2].boundingBox; + accurateModeResults[2].boundingBox = unorderedResults[0].boundingBox; + accurateModeResults[3].boundingBox = unorderedResults[3].boundingBox; + + // The face bounding box of using ACCURATE_MODE is smaller than FAST_MODE + for (int i = 0; i < accurateModeResults.length; i++) { + RectF fastModeRect = new RectF(); + RectF accurateModeRect = new RectF(); + + fastModeRect.set(fastModeResults[i].boundingBox.x, fastModeResults[i].boundingBox.y, + fastModeResults[i].boundingBox.x + fastModeResults[i].boundingBox.width, + fastModeResults[i].boundingBox.y + fastModeResults[i].boundingBox.height); + + accurateModeRect.set(accurateModeResults[i].boundingBox.x + ACCURATE_MODE_SIZE, + accurateModeResults[i].boundingBox.y + ACCURATE_MODE_SIZE, + accurateModeResults[i].boundingBox.x + accurateModeResults[i].boundingBox.width + - ACCURATE_MODE_SIZE, + accurateModeResults[i].boundingBox.y + accurateModeResults[i].boundingBox.height + - ACCURATE_MODE_SIZE); + + Assert.assertEquals(true, fastModeRect.contains(accurateModeRect)); + } + } + + private void detectRotatedFace(Matrix matrix) { + // Get the bitmap of fourth face in face_pose.png + Bitmap fourthFace = Bitmap.createBitmap( + BitmapUtils.convertToBitmap(FACE_POSE_BITMAP), 508, 0, 182, 194); + int width = fourthFace.getWidth(); + int height = fourthFace.getHeight(); + + Bitmap rotationBitmap = Bitmap.createBitmap(fourthFace, 0, 0, width, height, matrix, true); + FaceDetectionResult[] results = detect(TestUtils.mojoBitmapFromBitmap(rotationBitmap), + false, DetectionProviderType.GMS_CORE); + Assert.assertEquals(1, results.length); + Assert.assertEquals(197.0, results[0].boundingBox.width, BOUNDING_BOX_SIZE_ERROR); + Assert.assertEquals(246.0, results[0].boundingBox.height, BOUNDING_BOX_SIZE_ERROR); + } + + @Test + @SmallTest + @Feature({"ShapeDetection"}) + public void testDetectRotatedFaceWithGmsCore() { + if (!TestUtils.IS_GMS_CORE_SUPPORTED) { + return; + } + Matrix matrix = new Matrix(); + + // Rotate the Bitmap. + matrix.postRotate(15); + detectRotatedFace(matrix); + + matrix.reset(); + matrix.postRotate(30); + detectRotatedFace(matrix); + + matrix.reset(); + matrix.postRotate(40); + detectRotatedFace(matrix); + } +} diff --git a/chromium/services/shape_detection/android/javatests/src/org/chromium/shape_detection/TestUtils.java b/chromium/services/shape_detection/android/javatests/src/org/chromium/shape_detection/TestUtils.java new file mode 100644 index 00000000000..fd8fea95a0d --- /dev/null +++ b/chromium/services/shape_detection/android/javatests/src/org/chromium/shape_detection/TestUtils.java @@ -0,0 +1,75 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.shape_detection; + +import android.graphics.Bitmap; +import android.graphics.BitmapFactory; +import android.graphics.Canvas; +import android.graphics.Color; +import android.graphics.Paint; + +import com.google.android.gms.common.ConnectionResult; +import com.google.android.gms.common.GoogleApiAvailability; + +import org.chromium.base.ContextUtils; +import org.chromium.base.test.util.UrlUtils; +import org.chromium.skia.mojom.ColorType; +import org.chromium.skia.mojom.ImageInfo; + +import java.nio.ByteBuffer; + +/** + * Utility class for ShapeDetection instrumentation tests, + * provides support for e.g. reading files and converting + * Bitmaps to mojom.Bitmaps. + */ +public class TestUtils { + public static final boolean IS_GMS_CORE_SUPPORTED = isGmsCoreSupported(); + + private static boolean isGmsCoreSupported() { + return GoogleApiAvailability.getInstance().isGooglePlayServicesAvailable( + ContextUtils.getApplicationContext()) + == ConnectionResult.SUCCESS; + } + + public static org.chromium.skia.mojom.Bitmap mojoBitmapFromBitmap(Bitmap bitmap) { + ByteBuffer buffer = ByteBuffer.allocate(bitmap.getByteCount()); + bitmap.copyPixelsToBuffer(buffer); + + org.chromium.skia.mojom.Bitmap mojoBitmap = new org.chromium.skia.mojom.Bitmap(); + mojoBitmap.imageInfo = new ImageInfo(); + mojoBitmap.imageInfo.width = bitmap.getWidth(); + mojoBitmap.imageInfo.height = bitmap.getHeight(); + mojoBitmap.imageInfo.colorType = ColorType.RGBA_8888; + mojoBitmap.pixelData = new org.chromium.mojo_base.mojom.BigBuffer(); + mojoBitmap.pixelData.setBytes(buffer.array()); + return mojoBitmap; + } + + public static org.chromium.skia.mojom.Bitmap mojoBitmapFromFile(String relPath) { + String path = UrlUtils.getIsolatedTestFilePath("services/test/data/" + relPath); + Bitmap bitmap = BitmapFactory.decodeFile(path); + return mojoBitmapFromBitmap(bitmap); + } + + public static org.chromium.skia.mojom.Bitmap mojoBitmapFromText(String[] texts) { + final int x = 10; + final int baseline = 100; + + Paint paint = new Paint(Paint.ANTI_ALIAS_FLAG); + paint.setTextSize(36.0f); + paint.setTextAlign(Paint.Align.LEFT); + + Bitmap bitmap = Bitmap.createBitmap(1080, 480, Bitmap.Config.ARGB_8888); + Canvas canvas = new Canvas(bitmap); + canvas.drawColor(Color.WHITE); + + for (int i = 0; i < texts.length; i++) { + canvas.drawText(texts[i], x, baseline * (i + 1), paint); + } + + return mojoBitmapFromBitmap(bitmap); + } +} diff --git a/chromium/services/shape_detection/android/javatests/src/org/chromium/shape_detection/TextDetectionImplTest.java b/chromium/services/shape_detection/android/javatests/src/org/chromium/shape_detection/TextDetectionImplTest.java new file mode 100644 index 00000000000..91b867cb532 --- /dev/null +++ b/chromium/services/shape_detection/android/javatests/src/org/chromium/shape_detection/TextDetectionImplTest.java @@ -0,0 +1,81 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.shape_detection; + +import android.support.test.filters.SmallTest; + +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; + +import org.chromium.base.test.BaseJUnit4ClassRunner; +import org.chromium.base.test.util.Feature; +import org.chromium.gfx.mojom.RectF; +import org.chromium.shape_detection.mojom.TextDetection; +import org.chromium.shape_detection.mojom.TextDetectionResult; + +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.TimeUnit; + +/** + * Test suite for TextDetectionImpl. + */ +@RunWith(BaseJUnit4ClassRunner.class) +public class TextDetectionImplTest { + private static final String[] DETECTION_EXPECTED_TEXT = { + "The quick brown fox jumped over the lazy dog.", "Helvetica Neue 36."}; + private static final float[][] TEXT_BOUNDING_BOX = { + {0.0f, 71.0f, 753.0f, 36.0f}, {4.0f, 173.0f, 307.0f, 28.0f}}; + private static final org.chromium.skia.mojom.Bitmap TEXT_DETECTION_BITMAP = + TestUtils.mojoBitmapFromText(DETECTION_EXPECTED_TEXT); + + private static TextDetectionResult[] detect(org.chromium.skia.mojom.Bitmap mojoBitmap) { + TextDetection detector = new TextDetectionImpl(); + + final ArrayBlockingQueue<TextDetectionResult[]> queue = new ArrayBlockingQueue<>(1); + detector.detect(mojoBitmap, new TextDetection.DetectResponse() { + @Override + public void call(TextDetectionResult[] results) { + queue.add(results); + } + }); + TextDetectionResult[] toReturn = null; + try { + toReturn = queue.poll(5L, TimeUnit.SECONDS); + } catch (InterruptedException e) { + Assert.fail("Could not get TextDetectionResult: " + e.toString()); + } + Assert.assertNotNull(toReturn); + return toReturn; + } + + @Test + @SmallTest + @Feature({"ShapeDetection"}) + public void testDetectSucceedsOnValidBitmap() { + if (!TestUtils.IS_GMS_CORE_SUPPORTED) { + return; + } + TextDetectionResult[] results = detect(TEXT_DETECTION_BITMAP); + Assert.assertEquals(DETECTION_EXPECTED_TEXT.length, results.length); + + for (int i = 0; i < DETECTION_EXPECTED_TEXT.length; i++) { + Assert.assertEquals(results[i].rawValue, DETECTION_EXPECTED_TEXT[i]); + Assert.assertEquals(TEXT_BOUNDING_BOX[i][0], results[i].boundingBox.x, 0.0); + Assert.assertEquals(TEXT_BOUNDING_BOX[i][1], results[i].boundingBox.y, 0.0); + Assert.assertEquals(TEXT_BOUNDING_BOX[i][2], results[i].boundingBox.width, 0.0); + Assert.assertEquals(TEXT_BOUNDING_BOX[i][3], results[i].boundingBox.height, 0.0); + + RectF cornerRectF = new RectF(); + cornerRectF.x = results[i].cornerPoints[0].x; + cornerRectF.y = results[i].cornerPoints[0].y; + cornerRectF.width = results[i].cornerPoints[1].x - cornerRectF.x; + cornerRectF.height = results[i].cornerPoints[2].y - cornerRectF.y; + Assert.assertEquals(results[i].boundingBox, cornerRectF); + Assert.assertEquals(results[i].cornerPoints[3].x, results[i].cornerPoints[1].x, 0.0); + Assert.assertEquals(results[i].cornerPoints[3].y, results[i].cornerPoints[2].y, 0.0); + } + } +} diff --git a/chromium/services/shape_detection/android/junit/src/org/chromium/shape_detection/BitmapUtilsTest.java b/chromium/services/shape_detection/android/junit/src/org/chromium/shape_detection/BitmapUtilsTest.java new file mode 100644 index 00000000000..19f96b32d2c --- /dev/null +++ b/chromium/services/shape_detection/android/junit/src/org/chromium/shape_detection/BitmapUtilsTest.java @@ -0,0 +1,86 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package org.chromium.shape_detection; + +import static org.junit.Assert.assertNull; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.MockitoAnnotations; +import org.robolectric.annotation.Config; +import org.robolectric.shadows.ShadowLog; + +import org.chromium.base.test.BaseRobolectricTestRunner; +import org.chromium.base.test.util.Feature; +import org.chromium.mojo_base.BigBufferUtil; +import org.chromium.skia.mojom.Bitmap; +import org.chromium.skia.mojom.ColorType; +import org.chromium.skia.mojom.ImageInfo; + +/** + * Test suite for conversion-to-Frame utils. + */ +@RunWith(BaseRobolectricTestRunner.class) +@Config(manifest = Config.NONE) +public class BitmapUtilsTest { + private static final int VALID_WIDTH = 1; + private static final int VALID_HEIGHT = 1; + private static final int INVALID_WIDTH = 0; + private static final long NUM_BYTES = VALID_WIDTH * VALID_HEIGHT * 4; + private static final byte[] EMPTY_DATA = new byte[0]; + + public BitmapUtilsTest() {} + + @Before + public void setUp() { + ShadowLog.stream = System.out; + MockitoAnnotations.initMocks(this); + } + + /** + * Verify conversion fails if the Bitmap is invalid. + */ + @Test + @Feature({"ShapeDetection"}) + public void testConversionFailsWithInvalidBitmap() { + Bitmap bitmap = new Bitmap(); + bitmap.pixelData = null; + bitmap.imageInfo = new ImageInfo(); + + assertNull(BitmapUtils.convertToFrame(bitmap)); + } + + /** + * Verify conversion fails if the sent dimensions are ugly. + */ + @Test + @Feature({"ShapeDetection"}) + public void testConversionFailsWithInvalidDimensions() { + Bitmap bitmap = new Bitmap(); + bitmap.imageInfo = new ImageInfo(); + bitmap.pixelData = BigBufferUtil.createBigBufferFromBytes(EMPTY_DATA); + bitmap.imageInfo.width = INVALID_WIDTH; + bitmap.imageInfo.height = VALID_HEIGHT; + + assertNull(BitmapUtils.convertToFrame(bitmap)); + } + + /** + * Verify conversion fails if Bitmap fails to wrap(). + */ + @Test + @Feature({"ShapeDetection"}) + public void testConversionFailsWithWronglyWrappedData() { + Bitmap bitmap = new Bitmap(); + bitmap.imageInfo = new ImageInfo(); + bitmap.pixelData = BigBufferUtil.createBigBufferFromBytes(EMPTY_DATA); + bitmap.imageInfo.width = VALID_WIDTH; + bitmap.imageInfo.height = VALID_HEIGHT; + bitmap.imageInfo.colorType = ColorType.RGBA_8888; + + assertNull(BitmapUtils.convertToFrame(bitmap)); + } +} diff --git a/chromium/services/tracing/BUILD.gn b/chromium/services/tracing/BUILD.gn index 323c756b719..6102956b6be 100644 --- a/chromium/services/tracing/BUILD.gn +++ b/chromium/services/tracing/BUILD.gn @@ -22,6 +22,8 @@ source_set("lib") { "tracing_service.h", ] + configs += [ "//build/config/compiler:wexit_time_destructors" ] + public_deps = [ "//base", "//mojo/public/cpp/bindings", @@ -58,6 +60,8 @@ service("tracing") { "service_main.cc", ] + configs = [ "//build/config/compiler:wexit_time_destructors" ] + deps = [ ":lib", "//mojo/public/cpp/system", @@ -105,7 +109,7 @@ source_set("tests") { deps += [ "//third_party/perfetto/include/perfetto/protozero:protozero", - "//third_party/perfetto/protos/perfetto/common", + "//third_party/perfetto/protos/perfetto/common:lite", "//third_party/perfetto/protos/perfetto/trace:lite", "//third_party/perfetto/protos/perfetto/trace/chrome:lite", ] diff --git a/chromium/services/tracing/perfetto/README.md b/chromium/services/tracing/perfetto/README.md index af3cc67870b..58d4a3c5e4c 100644 --- a/chromium/services/tracing/perfetto/README.md +++ b/chromium/services/tracing/perfetto/README.md @@ -49,8 +49,8 @@ To add a new data source: * Add a new string identifier in [perfetto_service.mojom](/services/tracing/public/mojom/perfetto_service.mojom). * Register the data source in [ProducerHost::OnConnect](/services/tracing/perfetto/producer_host.cc). -* Set up the data source in [ProducerClient::CreateDataSourceInstance](/services/tracing/public/cpp/perfetto/producer_client.cc). -* Tear down the data source in [ProducerClient::TearDownDataSourceInstance](/services/tracing/public/cpp/perfetto/producer_client.cc). +* Set up the data source in [ProducerClient::StartDataSource](/services/tracing/public/cpp/perfetto/producer_client.cc). +* Tear down the data source in [ProducerClient::StopDataSource](/services/tracing/public/cpp/perfetto/producer_client.cc). * For each thread that wants to log a proto, use a separate TraceWriter created using [ProducerClient::CreateTraceWriter](/services/tracing/public/cpp/perfetto/producer_client.cc). diff --git a/chromium/services/tracing/perfetto/json_trace_exporter.cc b/chromium/services/tracing/perfetto/json_trace_exporter.cc index bc19b9bbcd6..df3f0103cac 100644 --- a/chromium/services/tracing/perfetto/json_trace_exporter.cc +++ b/chromium/services/tracing/perfetto/json_trace_exporter.cc @@ -4,6 +4,7 @@ #include "services/tracing/perfetto/json_trace_exporter.h" +#include <unordered_map> #include <utility> #include "base/json/json_reader.h" @@ -23,17 +24,36 @@ using TraceEvent = base::trace_event::TraceEvent; namespace { +const char* GetStringFromStringTable( + const std::unordered_map<int, std::string>& string_table, + int index) { + auto it = string_table.find(index); + DCHECK(it != string_table.end()); + + return it->second.c_str(); +} void OutputJSONFromTraceEventProto( const perfetto::protos::ChromeTraceEvent& event, - std::string* out) { + std::string* out, + const std::unordered_map<int, std::string>& string_table) { char phase = static_cast<char>(event.phase()); + const char* name = + event.has_name_index() + ? GetStringFromStringTable(string_table, event.name_index()) + : event.name().c_str(); + const char* category_group_name = + event.has_category_group_name_index() + ? GetStringFromStringTable(string_table, + event.category_group_name_index()) + : event.category_group_name().c_str(); + base::StringAppendF(out, "{\"pid\":%i,\"tid\":%i,\"ts\":%" PRId64 ",\"ph\":\"%c\",\"cat\":\"%s\",\"name\":", event.process_id(), event.thread_id(), event.timestamp(), - phase, event.category_group_name().c_str()); - base::EscapeJSONString(event.name(), true, out); + phase, category_group_name); + base::EscapeJSONString(name, true, out); if (event.has_duration()) { base::StringAppendF(out, ",\"dur\":%" PRId64, event.duration()); @@ -126,7 +146,9 @@ void OutputJSONFromTraceEventProto( } *out += "\""; - *out += arg.name(); + *out += arg.has_name_index() + ? GetStringFromStringTable(string_table, arg.name_index()) + : arg.name(); *out += "\":"; TraceEvent::TraceValue value; @@ -244,19 +266,24 @@ void JSONTraceExporter::OnTraceData(std::vector<perfetto::TracePacket> packets, continue; } - const perfetto::protos::ChromeEventBundle& bundle = packet.chrome_events(); - for (const perfetto::protos::ChromeTraceEvent& event : - bundle.trace_events()) { + auto& bundle = packet.chrome_events(); + + std::unordered_map<int, std::string> string_table; + for (auto& string_table_entry : bundle.string_table()) { + string_table[string_table_entry.index()] = string_table_entry.value(); + } + + for (auto& event : bundle.trace_events()) { if (has_output_first_event_) { out += ","; } else { has_output_first_event_ = true; } - OutputJSONFromTraceEventProto(event, &out); + OutputJSONFromTraceEventProto(event, &out, string_table); } - for (const perfetto::protos::ChromeMetadata& metadata : bundle.metadata()) { + for (auto& metadata : bundle.metadata()) { if (metadata.has_string_value()) { metadata_->SetString(metadata.name(), metadata.string_value()); } else if (metadata.has_int_value()) { diff --git a/chromium/services/tracing/perfetto/json_trace_exporter_unittest.cc b/chromium/services/tracing/perfetto/json_trace_exporter_unittest.cc index 47923162619..f46987d18af 100644 --- a/chromium/services/tracing/perfetto/json_trace_exporter_unittest.cc +++ b/chromium/services/tracing/perfetto/json_trace_exporter_unittest.cc @@ -70,7 +70,6 @@ class MockConsumerEndpoint : public perfetto::TracingService::ConsumerEndpoint { mock_service_->OnTracingEnabled( config.data_sources()[0].config().chrome_config().trace_config()); } - void DisableTracing() override { mock_service_->OnTracingDisabled(); } void ReadBuffers() override {} void FreeBuffers() override {} @@ -78,6 +77,9 @@ class MockConsumerEndpoint : public perfetto::TracingService::ConsumerEndpoint { callback(true); } + // Unused in chrome, only meaningful when using TraceConfig.deferred_start. + void StartTracing() override {} + private: MockService* mock_service_; }; @@ -221,9 +223,10 @@ class JSONTraceExporterTest : public testing::Test { return trace_event; } - const trace_analyzer::TraceAnalyzer* trace_analyzer() const { + trace_analyzer::TraceAnalyzer* trace_analyzer() { return trace_analyzer_.get(); } + MockService* service() { return service_.get(); } const base::DictionaryValue* parsed_trace_data() const { return parsed_trace_data_.get(); @@ -307,6 +310,58 @@ TEST_F(JSONTraceExporterTest, TestBasicEvent) { ValidateAndGetBasicTestPacket(); } +TEST_F(JSONTraceExporterTest, TestStringTable) { + CreateJSONTraceExporter("foo"); + service()->WaitForTracingEnabled(); + StopAndFlush(); + + perfetto::protos::TracePacket trace_packet_proto; + auto* new_trace_event = + trace_packet_proto.mutable_chrome_events()->add_trace_events(); + + { + auto* string_table_entry = + trace_packet_proto.mutable_chrome_events()->add_string_table(); + string_table_entry->set_index(1); + string_table_entry->set_value("foo_name"); + } + + { + auto* string_table_entry = + trace_packet_proto.mutable_chrome_events()->add_string_table(); + string_table_entry->set_index(2); + string_table_entry->set_value("foo_cat"); + } + + { + auto* string_table_entry = + trace_packet_proto.mutable_chrome_events()->add_string_table(); + string_table_entry->set_index(3); + string_table_entry->set_value("foo_arg"); + } + + new_trace_event->set_name_index(1); + new_trace_event->set_category_group_name_index(2); + + auto* new_arg = new_trace_event->add_args(); + new_arg->set_name_index(3); + new_arg->set_bool_value(true); + + FinalizePacket(trace_packet_proto); + + service()->WaitForTracingDisabled(); + + auto* trace_event = trace_analyzer()->FindFirstOf( + trace_analyzer::Query(trace_analyzer::Query::EVENT_NAME) == + trace_analyzer::Query::String("foo_name")); + EXPECT_TRUE(trace_event); + + EXPECT_EQ("foo_name", trace_event->name); + EXPECT_EQ("foo_cat", trace_event->category); + + EXPECT_TRUE(trace_event->GetKnownArgAsBool("foo_arg")); +} + TEST_F(JSONTraceExporterTest, TestEventWithBoolArgs) { CreateJSONTraceExporter("foo"); service()->WaitForTracingEnabled(); diff --git a/chromium/services/tracing/perfetto/perfetto_integration_unittest.cc b/chromium/services/tracing/perfetto/perfetto_integration_unittest.cc index 102aa2f273e..94c6ce25ead 100644 --- a/chromium/services/tracing/perfetto/perfetto_integration_unittest.cc +++ b/chromium/services/tracing/perfetto/perfetto_integration_unittest.cc @@ -109,9 +109,8 @@ class MockProducerClient : public ProducerClient { size_t send_packet_count() const { return send_packet_count_; } - void CreateDataSourceInstance( - uint64_t id, - mojom::DataSourceConfigPtr data_source_config) override { + void StartDataSource(uint64_t id, + mojom::DataSourceConfigPtr data_source_config) override { enabled_data_source_ = std::make_unique<TestDataSource>( this, send_packet_count_, data_source_config->trace_config, data_source_config->target_buffer); @@ -121,9 +120,7 @@ class MockProducerClient : public ProducerClient { } } - void TearDownDataSourceInstance( - uint64_t id, - TearDownDataSourceInstanceCallback callback) override { + void StopDataSource(uint64_t id, StopDataSourceCallback callback) override { enabled_data_source_.reset(); if (client_disabled_callback_) { diff --git a/chromium/services/tracing/perfetto/producer_host.cc b/chromium/services/tracing/perfetto/producer_host.cc index 2fbe67a55b8..470fac649d9 100644 --- a/chromium/services/tracing/perfetto/producer_host.cc +++ b/chromium/services/tracing/perfetto/producer_host.cc @@ -71,9 +71,13 @@ void ProducerHost::OnTracingSetup() { producer_client_->OnTracingStart(std::move(shm)); } -void ProducerHost::CreateDataSourceInstance( - perfetto::DataSourceInstanceID id, - const perfetto::DataSourceConfig& config) { +void ProducerHost::SetupDataSource(perfetto::DataSourceInstanceID, + const perfetto::DataSourceConfig&) { + // TODO(primiano): plumb call through mojo. +} + +void ProducerHost::StartDataSource(perfetto::DataSourceInstanceID id, + const perfetto::DataSourceConfig& config) { // TODO(oysteine): Send full DataSourceConfig, not just the name/target_buffer // and Chrome Tracing string. auto data_source_config = mojom::DataSourceConfig::New(); @@ -81,13 +85,12 @@ void ProducerHost::CreateDataSourceInstance( data_source_config->target_buffer = config.target_buffer(); data_source_config->trace_config = config.chrome_config().trace_config(); - producer_client_->CreateDataSourceInstance(id, std::move(data_source_config)); + producer_client_->StartDataSource(id, std::move(data_source_config)); } -void ProducerHost::TearDownDataSourceInstance( - perfetto::DataSourceInstanceID id) { +void ProducerHost::StopDataSource(perfetto::DataSourceInstanceID id) { if (producer_client_) { - producer_client_->TearDownDataSourceInstance( + producer_client_->StopDataSource( id, base::BindOnce( [](ProducerHost* producer_host, perfetto::DataSourceInstanceID id) { diff --git a/chromium/services/tracing/perfetto/producer_host.h b/chromium/services/tracing/perfetto/producer_host.h index 00efaef8f7d..5a477c1c2af 100644 --- a/chromium/services/tracing/perfetto/producer_host.h +++ b/chromium/services/tracing/perfetto/producer_host.h @@ -60,11 +60,13 @@ class ProducerHost : public tracing::mojom::ProducerHost, void OnConnect() override; void OnDisconnect() override; - void CreateDataSourceInstance( - perfetto::DataSourceInstanceID id, - const perfetto::DataSourceConfig& config) override; + void SetupDataSource(perfetto::DataSourceInstanceID id, + const perfetto::DataSourceConfig& config) override; - void TearDownDataSourceInstance(perfetto::DataSourceInstanceID) override; + void StartDataSource(perfetto::DataSourceInstanceID id, + const perfetto::DataSourceConfig& config) override; + + void StopDataSource(perfetto::DataSourceInstanceID) override; void OnTracingSetup() override; void Flush(perfetto::FlushRequestID, const perfetto::DataSourceInstanceID* raw_data_source_ids, diff --git a/chromium/services/tracing/public/cpp/BUILD.gn b/chromium/services/tracing/public/cpp/BUILD.gn index 09226fd9030..950482aa37d 100644 --- a/chromium/services/tracing/public/cpp/BUILD.gn +++ b/chromium/services/tracing/public/cpp/BUILD.gn @@ -15,6 +15,8 @@ component("cpp") { defines = [ "IS_TRACING_CPP_IMPL" ] output_name = "tracing_cpp" + configs += [ "//build/config/compiler:wexit_time_destructors" ] + public_deps = [ "//base", "//mojo/public/cpp/bindings", diff --git a/chromium/services/tracing/public/cpp/perfetto/producer_client.cc b/chromium/services/tracing/public/cpp/perfetto/producer_client.cc index 2e2fc068b25..c759b781874 100644 --- a/chromium/services/tracing/public/cpp/perfetto/producer_client.cc +++ b/chromium/services/tracing/public/cpp/perfetto/producer_client.cc @@ -138,7 +138,7 @@ void ProducerClient::OnTracingStart( } } -void ProducerClient::CreateDataSourceInstance( +void ProducerClient::StartDataSource( uint64_t id, mojom::DataSourceConfigPtr data_source_config) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); @@ -153,9 +153,8 @@ void ProducerClient::CreateDataSourceInstance( } } -void ProducerClient::TearDownDataSourceInstance( - uint64_t id, - TearDownDataSourceInstanceCallback callback) { +void ProducerClient::StopDataSource(uint64_t id, + StopDataSourceCallback callback) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); for (auto* data_source : data_sources_) { diff --git a/chromium/services/tracing/public/cpp/perfetto/producer_client.h b/chromium/services/tracing/public/cpp/perfetto/producer_client.h index ca7757cfd2c..92a050de64e 100644 --- a/chromium/services/tracing/public/cpp/perfetto/producer_client.h +++ b/chromium/services/tracing/public/cpp/perfetto/producer_client.h @@ -39,7 +39,7 @@ class MojoSharedMemory; // * Add a new data source name in perfetto_service.mojom. // * Register the data source with Perfetto in ProducerHost::OnConnect. // * Construct the new implementation when requested to -// in ProducerClient::CreateDataSourceInstance. +// in ProducerClient::StartDataSource. class COMPONENT_EXPORT(TRACING_CPP) ProducerClient : public mojom::ProducerClient, public perfetto::TracingService::ProducerEndpoint { @@ -96,13 +96,10 @@ class COMPONENT_EXPORT(TRACING_CPP) ProducerClient // Called through Mojo by the ProducerHost on the service-side to control // tracing and toggle specific DataSources. void OnTracingStart(mojo::ScopedSharedBufferHandle shared_memory) override; - void CreateDataSourceInstance( - uint64_t id, - mojom::DataSourceConfigPtr data_source_config) override; + void StartDataSource(uint64_t id, + mojom::DataSourceConfigPtr data_source_config) override; - void TearDownDataSourceInstance( - uint64_t id, - TearDownDataSourceInstanceCallback callback) override; + void StopDataSource(uint64_t id, StopDataSourceCallback callback) override; void Flush(uint64_t flush_request_id, const std::vector<uint64_t>& data_source_ids) override; diff --git a/chromium/services/tracing/public/cpp/perfetto/trace_event_data_source.cc b/chromium/services/tracing/public/cpp/perfetto/trace_event_data_source.cc index 1a288cea7e0..f7bec193abf 100644 --- a/chromium/services/tracing/public/cpp/perfetto/trace_event_data_source.cc +++ b/chromium/services/tracing/public/cpp/perfetto/trace_event_data_source.cc @@ -4,6 +4,7 @@ #include "services/tracing/public/cpp/perfetto/trace_event_data_source.h" +#include <map> #include <utility> #include "base/json/json_writer.h" @@ -23,6 +24,9 @@ using TraceConfig = base::trace_event::TraceConfig; namespace tracing { +using ChromeEventBundleHandle = + protozero::MessageHandle<perfetto::protos::pbzero::ChromeEventBundle>; + TraceEventMetadataSource::TraceEventMetadataSource() : DataSourceBase(mojom::kMetaDataSourceName), origin_task_runner_(base::SequencedTaskRunnerHandle::Get()) {} @@ -41,8 +45,7 @@ void TraceEventMetadataSource::GenerateMetadata( DCHECK(origin_task_runner_->RunsTasksInCurrentSequence()); auto trace_packet = trace_writer->NewTracePacket(); - protozero::MessageHandle<perfetto::protos::pbzero::ChromeEventBundle> - event_bundle(trace_packet->set_chrome_events()); + ChromeEventBundleHandle event_bundle(trace_packet->set_chrome_events()); base::AutoLock lock(lock_); for (auto& generator : generator_functions_) { @@ -102,6 +105,11 @@ class TraceEventDataSource::ThreadLocalEventSink { : trace_writer_(std::move(trace_writer)) {} ~ThreadLocalEventSink() { + // Finalize the current message before posting the |trace_writer_| for + // destruction, to avoid data races. + event_bundle_ = ChromeEventBundleHandle(); + trace_packet_handle_ = perfetto::TraceWriter::TracePacketHandle(); + // Delete the TraceWriter on the sequence that Perfetto runs on, needed // as the ThreadLocalEventSink gets deleted on thread // shutdown and we can't safely call TaskRunnerHandle::Get() at that point @@ -111,6 +119,46 @@ class TraceEventDataSource::ThreadLocalEventSink { std::move(trace_writer_)); } + void EnsureValidHandles() { + if (trace_packet_handle_) { + return; + } + + trace_packet_handle_ = trace_writer_->NewTracePacket(); + event_bundle_ = + ChromeEventBundleHandle(trace_packet_handle_->set_chrome_events()); + string_table_.clear(); + next_string_table_index_ = 0; + } + + int GetStringTableIndexForString(const char* str_value) { + EnsureValidHandles(); + + auto it = string_table_.find(reinterpret_cast<intptr_t>(str_value)); + if (it != string_table_.end()) { + CHECK_EQ(std::string(reinterpret_cast<const char*>(it->first)), + std::string(str_value)); + + return it->second; + } + + int string_table_index = ++next_string_table_index_; + string_table_[reinterpret_cast<intptr_t>(str_value)] = string_table_index; + + auto* new_string_table_entry = event_bundle_->add_string_table(); + new_string_table_entry->set_value(str_value); + new_string_table_entry->set_index(string_table_index); + + return string_table_index; + } + + void AddConvertableToTraceFormat( + const base::trace_event::ConvertableToTraceFormat* value, + perfetto::protos::pbzero::ChromeTraceEvent_Arg* arg) { + std::string json = value->ToString(); + arg->set_json_value(json.c_str()); + } + void AddTraceEvent(const TraceEvent& trace_event) { // TODO(oysteine): Adding trace events to Perfetto will // stall in some situations, specifically when we overflow @@ -126,14 +174,46 @@ class TraceEventDataSource::ThreadLocalEventSink { return; } - // TODO(oysteine): Consider batching several trace events per trace packet, - // and only add repeated data once per batch. - auto trace_packet = trace_writer_->NewTracePacket(); - protozero::MessageHandle<perfetto::protos::pbzero::ChromeEventBundle> - event_bundle(trace_packet->set_chrome_events()); + EnsureValidHandles(); + + int name_index = 0; + int category_name_index = 0; + int arg_name_indices[base::trace_event::kTraceMaxNumArgs] = {0}; + + // Populate any new string table parts first; has to be done before + // the add_trace_events() call (as the string table is part of the outer + // proto message). + // If the TRACE_EVENT_FLAG_COPY flag is set, the char* pointers aren't + // necessarily valid after the TRACE_EVENT* call, and so we need to store + // the string every time. + bool string_table_enabled = !(trace_event.flags() & TRACE_EVENT_FLAG_COPY); + if (string_table_enabled) { + name_index = GetStringTableIndexForString(trace_event.name()); + category_name_index = GetStringTableIndexForString( + TraceLog::GetCategoryGroupName(trace_event.category_group_enabled())); + + for (int i = 0; + i < base::trace_event::kTraceMaxNumArgs && trace_event.arg_name(i); + ++i) { + arg_name_indices[i] = + GetStringTableIndexForString(trace_event.arg_name(i)); + } + } + + auto* new_trace_event = event_bundle_->add_trace_events(); - auto* new_trace_event = event_bundle->add_trace_events(); - new_trace_event->set_name(trace_event.name()); + if (name_index) { + new_trace_event->set_name_index(name_index); + } else { + new_trace_event->set_name(trace_event.name()); + } + + if (category_name_index) { + new_trace_event->set_category_group_name_index(category_name_index); + } else { + new_trace_event->set_category_group_name( + TraceLog::GetCategoryGroupName(trace_event.category_group_enabled())); + } new_trace_event->set_timestamp( trace_event.timestamp().since_origin().InMicroseconds()); @@ -155,9 +235,6 @@ class TraceEventDataSource::ThreadLocalEventSink { new_trace_event->set_process_id(process_id); new_trace_event->set_thread_id(thread_id); - new_trace_event->set_category_group_name( - TraceLog::GetCategoryGroupName(trace_event.category_group_enabled())); - char phase = trace_event.phase(); new_trace_event->set_phase(phase); @@ -166,11 +243,16 @@ class TraceEventDataSource::ThreadLocalEventSink { ++i) { auto type = trace_event.arg_type(i); auto* new_arg = new_trace_event->add_args(); - new_arg->set_name(trace_event.arg_name(i)); + + if (arg_name_indices[i]) { + new_arg->set_name_index(arg_name_indices[i]); + } else { + new_arg->set_name(trace_event.arg_name(i)); + } if (type == TRACE_VALUE_TYPE_CONVERTABLE) { - std::string json = trace_event.arg_convertible_value(i)->ToString(); - new_arg->set_json_value(json.c_str()); + AddConvertableToTraceFormat(trace_event.arg_convertible_value(i), + new_arg); continue; } @@ -243,10 +325,20 @@ class TraceEventDataSource::ThreadLocalEventSink { } } - void Flush() { trace_writer_->Flush(); } + void Flush() { + event_bundle_ = ChromeEventBundleHandle(); + trace_packet_handle_ = perfetto::TraceWriter::TracePacketHandle(); + trace_writer_->Flush(); + token_ = ""; + } private: std::unique_ptr<perfetto::TraceWriter> trace_writer_; + ChromeEventBundleHandle event_bundle_; + perfetto::TraceWriter::TracePacketHandle trace_packet_handle_; + std::map<intptr_t, int> string_table_; + int next_string_table_index_ = 0; + std::string token_; }; namespace { diff --git a/chromium/services/tracing/public/cpp/perfetto/trace_event_data_source_unittest.cc b/chromium/services/tracing/public/cpp/perfetto/trace_event_data_source_unittest.cc index a43ea2be629..77b678f93e4 100644 --- a/chromium/services/tracing/public/cpp/perfetto/trace_event_data_source_unittest.cc +++ b/chromium/services/tracing/public/cpp/perfetto/trace_event_data_source_unittest.cc @@ -4,6 +4,7 @@ #include "services/tracing/public/cpp/perfetto/trace_event_data_source.h" +#include <map> #include <utility> #include <vector> @@ -31,12 +32,10 @@ const char kCategoryGroup[] = "foo"; class MockProducerClient : public ProducerClient { public: explicit MockProducerClient( - scoped_refptr<base::SequencedTaskRunner> main_thread_task_runner, - const char* wanted_event_category) + scoped_refptr<base::SequencedTaskRunner> main_thread_task_runner) : delegate_(perfetto::base::kPageSize), stream_(&delegate_), - main_thread_task_runner_(std::move(main_thread_task_runner)), - wanted_event_category_(wanted_event_category) { + main_thread_task_runner_(std::move(main_thread_task_runner)) { trace_packet_.Reset(&stream_); } @@ -56,9 +55,7 @@ class MockProducerClient : public ProducerClient { auto proto = std::make_unique<perfetto::protos::TracePacket>(); EXPECT_TRUE(proto->ParseFromArray(buffer.begin, message_size)); if (proto->has_chrome_events() && - proto->chrome_events().trace_events().size() > 0 && - proto->chrome_events().trace_events()[0].category_group_name() == - wanted_event_category_) { + proto->chrome_events().trace_events().size() > 0) { finalized_packets_.push_back(std::move(proto)); } else if (proto->has_chrome_events() && proto->chrome_events().metadata().size() > 0) { @@ -87,6 +84,7 @@ class MockProducerClient : public ProducerClient { EXPECT_GT(finalized_packets_.size(), packet_index); auto event_bundle = finalized_packets_[packet_index]->chrome_events(); + PopulateStringTable(event_bundle); return event_bundle.trace_events(); } @@ -99,7 +97,25 @@ class MockProducerClient : public ProducerClient { return event_bundle.metadata(); } + void PopulateStringTable( + const perfetto::protos::ChromeEventBundle& event_bundle) { + string_table_.clear(); + + auto& string_table = event_bundle.string_table(); + for (int i = 0; i < string_table.size(); ++i) { + string_table_[string_table[i].index()] = string_table[i].value(); + } + } + + std::string GetStringTableEntry(int index) { + auto it = string_table_.find(index); + CHECK(it != string_table_.end()); + + return it->second; + } + private: + std::map<int, std::string> string_table_; std::vector<std::unique_ptr<perfetto::protos::TracePacket>> finalized_packets_; std::vector<std::unique_ptr<perfetto::protos::TracePacket>> metadata_packets_; @@ -107,7 +123,6 @@ class MockProducerClient : public ProducerClient { protozero::ScatteredStreamWriterNullDelegate delegate_; protozero::ScatteredStreamWriter stream_; scoped_refptr<base::SequencedTaskRunner> main_thread_task_runner_; - const char* wanted_event_category_; }; // For sequences/threads other than our own, we just want to ignore @@ -196,11 +211,9 @@ class TraceEventDataSourceTest : public testing::Test { producer_client_.reset(); } - void CreateTraceEventDataSource( - const char* wanted_event_category = kCategoryGroup) { + void CreateTraceEventDataSource() { producer_client_ = std::make_unique<MockProducerClient>( - scoped_task_environment_.GetMainThreadTaskRunner(), - wanted_event_category); + scoped_task_environment_.GetMainThreadTaskRunner()); auto data_source_config = mojom::DataSourceConfig::New(); TraceEventDataSource::GetInstance()->StartTracing(producer_client(), @@ -292,7 +305,7 @@ TEST_F(TraceEventDataSourceTest, MetadataSourceBasicTypes) { } TEST_F(TraceEventDataSourceTest, TraceLogMetadataEvents) { - CreateTraceEventDataSource("__metadata"); + CreateTraceEventDataSource(); base::RunLoop wait_for_flush; TraceEventDataSource::GetInstance()->StopTracing( @@ -303,7 +316,8 @@ TEST_F(TraceEventDataSourceTest, TraceLogMetadataEvents) { for (size_t i = 0; i < producer_client()->GetFinalizedPacketCount(); ++i) { auto trace_events = producer_client()->GetChromeTraceEvents(i); for (auto& event : trace_events) { - if (event.name() == "process_uptime_seconds") { + if (producer_client()->GetStringTableEntry(event.name_index()) == + "process_uptime_seconds") { has_process_uptime_event = true; break; } @@ -322,8 +336,10 @@ TEST_F(TraceEventDataSourceTest, BasicTraceEvent) { EXPECT_EQ(trace_events.size(), 1); auto trace_event = trace_events[0]; - EXPECT_EQ("bar", trace_event.name()); - EXPECT_EQ(kCategoryGroup, trace_event.category_group_name()); + EXPECT_EQ("bar", + producer_client()->GetStringTableEntry(trace_event.name_index())); + EXPECT_EQ(kCategoryGroup, producer_client()->GetStringTableEntry( + trace_event.category_group_name_index())); EXPECT_EQ(TRACE_EVENT_PHASE_BEGIN, trace_event.phase()); } @@ -338,8 +354,10 @@ TEST_F(TraceEventDataSourceTest, TimestampedTraceEvent) { EXPECT_EQ(trace_events.size(), 1); auto trace_event = trace_events[0]; - EXPECT_EQ("bar", trace_event.name()); - EXPECT_EQ(kCategoryGroup, trace_event.category_group_name()); + EXPECT_EQ("bar", + producer_client()->GetStringTableEntry(trace_event.name_index())); + EXPECT_EQ(kCategoryGroup, producer_client()->GetStringTableEntry( + trace_event.category_group_name_index())); EXPECT_EQ(42u, trace_event.id()); EXPECT_EQ(4242, trace_event.thread_id()); EXPECT_EQ(424242, trace_event.timestamp()); @@ -355,8 +373,10 @@ TEST_F(TraceEventDataSourceTest, InstantTraceEvent) { EXPECT_EQ(trace_events.size(), 1); auto trace_event = trace_events[0]; - EXPECT_EQ("bar", trace_event.name()); - EXPECT_EQ(kCategoryGroup, trace_event.category_group_name()); + EXPECT_EQ("bar", + producer_client()->GetStringTableEntry(trace_event.name_index())); + EXPECT_EQ(kCategoryGroup, producer_client()->GetStringTableEntry( + trace_event.category_group_name_index())); EXPECT_EQ(TRACE_EVENT_SCOPE_THREAD, trace_event.flags()); EXPECT_EQ(TRACE_EVENT_PHASE_INSTANT, trace_event.phase()); } @@ -373,12 +393,34 @@ TEST_F(TraceEventDataSourceTest, EventWithStringArgs) { auto trace_args = trace_events[0].args(); EXPECT_EQ(trace_args.size(), 2); - EXPECT_EQ("arg1_name", trace_args[0].name()); + EXPECT_EQ("arg1_name", + producer_client()->GetStringTableEntry(trace_args[0].name_index())); EXPECT_EQ("arg1_val", trace_args[0].string_value()); - EXPECT_EQ("arg2_name", trace_args[1].name()); + EXPECT_EQ("arg2_name", + producer_client()->GetStringTableEntry(trace_args[1].name_index())); EXPECT_EQ("arg2_val", trace_args[1].string_value()); } +TEST_F(TraceEventDataSourceTest, NoStringTableTest) { + CreateTraceEventDataSource(); + + TRACE_EVENT_INSTANT2(kCategoryGroup, "bar", + TRACE_EVENT_SCOPE_THREAD | TRACE_EVENT_FLAG_COPY, + "arg1_name", "arg1_val", "arg2_name", "arg2_val"); + + auto trace_events = producer_client()->GetChromeTraceEvents(); + EXPECT_EQ(trace_events.size(), 1); + + EXPECT_EQ("bar", trace_events[0].name()); + EXPECT_EQ(kCategoryGroup, trace_events[0].category_group_name()); + + auto trace_args = trace_events[0].args(); + EXPECT_EQ(trace_args.size(), 2); + + EXPECT_EQ("arg1_name", trace_args[0].name()); + EXPECT_EQ("arg2_name", trace_args[1].name()); +} + TEST_F(TraceEventDataSourceTest, EventWithUIntArgs) { CreateTraceEventDataSource(); @@ -532,19 +574,14 @@ TEST_F(TraceEventDataSourceTest, CompleteTraceEventsIntoSeparateBeginAndEnd) { // TRACE_EVENT_PHASE_COMPLETE events should internally emit a // TRACE_EVENT_PHASE_BEGIN event first, and then a TRACE_EVENT_PHASE_END event // when the duration is attempted set on the first event. - EXPECT_EQ(2u, producer_client()->GetFinalizedPacketCount()); - - auto events_from_first_packet = producer_client()->GetChromeTraceEvents(0); - EXPECT_EQ(events_from_first_packet.size(), 1); + auto events = producer_client()->GetChromeTraceEvents(0); + EXPECT_EQ(events.size(), 2); - auto begin_trace_event = events_from_first_packet[0]; + auto begin_trace_event = events[0]; EXPECT_EQ(TRACE_EVENT_PHASE_BEGIN, begin_trace_event.phase()); EXPECT_EQ(10, begin_trace_event.timestamp()); - auto events_from_second_packet = producer_client()->GetChromeTraceEvents(1); - EXPECT_EQ(events_from_second_packet.size(), 1); - - auto end_trace_event = events_from_second_packet[0]; + auto end_trace_event = events[1]; EXPECT_EQ(TRACE_EVENT_PHASE_END, end_trace_event.phase()); EXPECT_EQ(20, end_trace_event.timestamp()); EXPECT_EQ(50, end_trace_event.thread_timestamp()); diff --git a/chromium/services/tracing/public/mojom/perfetto_service.mojom b/chromium/services/tracing/public/mojom/perfetto_service.mojom index 056a94e39a0..e76198e7f8e 100644 --- a/chromium/services/tracing/public/mojom/perfetto_service.mojom +++ b/chromium/services/tracing/public/mojom/perfetto_service.mojom @@ -7,14 +7,13 @@ module tracing.mojom; const string kTraceEventDataSourceName = "org.chromium.trace_event"; const string kMetaDataSourceName = "org.chromium.trace_metadata"; -// Brief description of the flow: There's a per-process ProducerClient -// which connects to the central PerfettoService and establishes a two-way -// connection with a ProducerHost. The latter will then pass a -// SharedMemoryBuffer to the ProducerClient and tell it to start logging -// events of a given type into it. As chunks of the buffer gets filled -// up, the client will communicate this to the ProducerHost which will -// tell Perfetto to copy the finished chunks into its central storage -// and pass to any consumers. +// Brief description of the flow: There's a per-process ProducerClient which +// connects to the central PerfettoService and establishes a two-way connection +// with a ProducerHost. The latter will then pass a SharedMemoryBuffer to the +// ProducerClient and tell it to start logging events of a given type into it. +// As chunks of the buffer get filled up, the client will communicate this to +// the ProducerHost, which will tell Perfetto to copy the finished chunks into +// its central storage and pass them on to its consumers. // For a more complete explanation of a Perfetto tracing session, see // https://android.googlesource.com/platform/external/perfetto/+/master/docs/life-of-a-tracing-session.md @@ -22,22 +21,22 @@ const string kMetaDataSourceName = "org.chromium.trace_metadata"; // See https://android.googlesource.com/platform/external/perfetto/ // for the full documentation of Perfetto concepts and its shared memory ABI. -// Used by the CommitDataRequest() method (client process->service) to signal when -// a chunk is (segment/page of the shared memory buffer which is -// owned by a per-thread TraceWriter) the central Perfetto service that it's -// ready for consumption (flushed or fully written). +// Passed by the client process as part of CommitDataRequest() to the central +// Perfetto service. Signals that a chunk (segment/page of the shared memory +// buffer which is owned by a per-thread TraceWriter) is ready for consumption +// (flushed or fully written). struct ChunksToMove { // The page index within the producer:service shared memory buffer. uint32 page; // The chunk index within the given page. uint32 chunk; - // The target ring-buffer in the service where the chunk should be copied into. + // The target ring-buffer in the service that the chunk should be copied into. uint32 target_buffer; }; -// Used by the CommitDataRequest method (client process -> service) to -// patch previously written chunks (to fill in size fields when protos -// span multiple chunks, for example). +// Passed by the client process as part of CommitDataRequest() to the central +// Perfetto service. Used to patch previously written chunks (for example, to +// fill in size fields when protos span multiple chunks). struct ChunkPatch { // Offset relative to the chunk defined in ChunksToPatch. uint32 offset; @@ -45,8 +44,9 @@ struct ChunkPatch { }; struct ChunksToPatch { - // The triplet {target_buffer, writer_id, chunk_id} uniquely identified a chunk that has - // been copied over into the main, non-shared, trace buffer owned by the service. + // The triplet {target_buffer, writer_id, chunk_id} uniquely identifies a + // chunk that has been copied over into the main, non-shared, trace buffer + // owned by the service. uint32 target_buffer; uint32 writer_id; uint32 chunk_id; @@ -73,7 +73,7 @@ struct DataSourceRegistration { }; interface ProducerHost { - // Called by a ProducerClient to asks the service to: + // Called by a ProducerClient to ask the service to: // 1) Move data from the shared memory buffer into the final tracing buffer // owned by the service (through the |chunks_to_move|). // 2) Patch data (i.e. apply diff) that has been previously copied into the @@ -82,8 +82,8 @@ interface ProducerHost { // requests. CommitData(CommitDataRequest data_request); - // Called by a ProducerClient to let the Host know it can provide a - // specific datasource. + // Called by a ProducerClient to let the Host know it can provide a specific + // datasource. RegisterDataSource(DataSourceRegistration registration_info); // Called to let the Service know that a flush is complete. @@ -96,13 +96,14 @@ interface ProducerClient { // TODO(oysteine): Make a TypeTrait for sending the full DataSourceConfig. // Called by Perfetto (via ProducerHost) to request a data source to start // logging. - CreateDataSourceInstance(uint64 id, DataSourceConfig data_source_config); + StartDataSource(uint64 id, DataSourceConfig data_source_config); // Requesting a data source to stop logging again, with the id previously - // sent in the CreateDataSourceInstance call. - TearDownDataSourceInstance(uint64 id) => (); + // sent in the StartDataSource call. + StopDataSource(uint64 id) => (); Flush(uint64 flush_request_id, array<uint64> data_source_ids); }; interface PerfettoService { - ConnectToProducerHost(ProducerClient producer_client, ProducerHost& producer_host); + ConnectToProducerHost(ProducerClient producer_client, + ProducerHost& producer_host); }; diff --git a/chromium/services/video_capture/BUILD.gn b/chromium/services/video_capture/BUILD.gn index 84c11a86fd8..836a1a685a0 100644 --- a/chromium/services/video_capture/BUILD.gn +++ b/chromium/services/video_capture/BUILD.gn @@ -49,6 +49,8 @@ source_set("lib") { "virtual_device_enabled_device_factory.h", ] + configs += [ "//build/config/compiler:wexit_time_destructors" ] + public_deps = [ "//base", "//media", @@ -80,17 +82,11 @@ source_set("tests") { "test/fake_device_test.cc", "test/fake_device_test.h", "test/fake_device_unittest.cc", - "test/mock_device.cc", - "test/mock_device.h", - "test/mock_device_factory.cc", - "test/mock_device_factory.h", "test/mock_device_test.cc", "test/mock_device_test.h", "test/mock_device_unittest.cc", - "test/mock_producer.cc", - "test/mock_producer.h", - "test/mock_receiver.cc", - "test/mock_receiver.h", + "test/mock_devices_changed_observer.cc", + "test/mock_devices_changed_observer.h", "test/virtual_device_unittest.cc", "texture_virtual_device_mojo_adapter_unittest.cc", ] @@ -99,10 +95,12 @@ source_set("tests") { ":lib", ":video_capture", "//base/test:test_support", + "//media/capture:test_support", "//media/capture/mojom:video_capture", "//services/service_manager/public/cpp", "//services/service_manager/public/cpp:service_test_support", "//services/service_manager/public/cpp/test:test_support", + "//services/video_capture/public/cpp:mocks", "//testing/gmock", "//testing/gtest", "//ui/gfx:test_support", diff --git a/chromium/services/video_capture/device_factory_media_to_mojo_adapter.cc b/chromium/services/video_capture/device_factory_media_to_mojo_adapter.cc index 1cd5fc8e834..63bca2db8cd 100644 --- a/chromium/services/video_capture/device_factory_media_to_mojo_adapter.cc +++ b/chromium/services/video_capture/device_factory_media_to_mojo_adapter.cc @@ -77,12 +77,10 @@ DeviceFactoryMediaToMojoAdapter::ActiveDeviceEntry::operator=( DeviceFactoryMediaToMojoAdapter::ActiveDeviceEntry&& other) = default; DeviceFactoryMediaToMojoAdapter::DeviceFactoryMediaToMojoAdapter( - std::unique_ptr<service_manager::ServiceContextRef> service_ref, std::unique_ptr<media::VideoCaptureSystem> capture_system, media::MojoJpegDecodeAcceleratorFactoryCB jpeg_decoder_factory_callback, scoped_refptr<base::SequencedTaskRunner> jpeg_decoder_task_runner) - : service_ref_(std::move(service_ref)), - capture_system_(std::move(capture_system)), + : capture_system_(std::move(capture_system)), jpeg_decoder_factory_callback_(std::move(jpeg_decoder_factory_callback)), jpeg_decoder_task_runner_(std::move(jpeg_decoder_task_runner)), has_called_get_device_infos_(false), @@ -90,6 +88,11 @@ DeviceFactoryMediaToMojoAdapter::DeviceFactoryMediaToMojoAdapter( DeviceFactoryMediaToMojoAdapter::~DeviceFactoryMediaToMojoAdapter() = default; +void DeviceFactoryMediaToMojoAdapter::SetServiceRef( + std::unique_ptr<service_manager::ServiceContextRef> service_ref) { + service_ref_ = std::move(service_ref); +} + void DeviceFactoryMediaToMojoAdapter::GetDeviceInfos( GetDeviceInfosCallback callback) { capture_system_->GetDeviceInfosAsync( @@ -129,6 +132,7 @@ void DeviceFactoryMediaToMojoAdapter::CreateDevice( capture_system_->GetDeviceInfosAsync( base::Bind(&DiscardDeviceInfosAndCallContinuation, base::Passed(&create_and_add_new_device_cb))); + has_called_get_device_infos_ = true; } void DeviceFactoryMediaToMojoAdapter::AddSharedMemoryVirtualDevice( @@ -145,10 +149,16 @@ void DeviceFactoryMediaToMojoAdapter::AddTextureVirtualDevice( NOTIMPLEMENTED(); } +void DeviceFactoryMediaToMojoAdapter::RegisterVirtualDevicesChangedObserver( + mojom::DevicesChangedObserverPtr observer) { + NOTIMPLEMENTED(); +} + void DeviceFactoryMediaToMojoAdapter::CreateAndAddNewDevice( const std::string& device_id, mojom::DeviceRequest device_request, CreateDeviceCallback callback) { + DCHECK(service_ref_); std::unique_ptr<media::VideoCaptureDevice> media_device = capture_system_->CreateDevice(device_id); if (media_device == nullptr) { diff --git a/chromium/services/video_capture/device_factory_media_to_mojo_adapter.h b/chromium/services/video_capture/device_factory_media_to_mojo_adapter.h index 65cde0d34e3..1f09720e916 100644 --- a/chromium/services/video_capture/device_factory_media_to_mojo_adapter.h +++ b/chromium/services/video_capture/device_factory_media_to_mojo_adapter.h @@ -24,12 +24,14 @@ class DeviceMediaToMojoAdapter; class DeviceFactoryMediaToMojoAdapter : public mojom::DeviceFactory { public: DeviceFactoryMediaToMojoAdapter( - std::unique_ptr<service_manager::ServiceContextRef> service_ref, std::unique_ptr<media::VideoCaptureSystem> capture_system, media::MojoJpegDecodeAcceleratorFactoryCB jpeg_decoder_factory_callback, scoped_refptr<base::SequencedTaskRunner> jpeg_decoder_task_runner); ~DeviceFactoryMediaToMojoAdapter() override; + void SetServiceRef( + std::unique_ptr<service_manager::ServiceContextRef> service_ref); + // mojom::DeviceFactory implementation. void GetDeviceInfos(GetDeviceInfosCallback callback) override; void CreateDevice(const std::string& device_id, @@ -43,6 +45,8 @@ class DeviceFactoryMediaToMojoAdapter : public mojom::DeviceFactory { void AddTextureVirtualDevice( const media::VideoCaptureDeviceInfo& device_info, mojom::TextureVirtualDeviceRequest virtual_device) override; + void RegisterVirtualDevicesChangedObserver( + mojom::DevicesChangedObserverPtr observer) override; private: struct ActiveDeviceEntry { @@ -63,7 +67,7 @@ class DeviceFactoryMediaToMojoAdapter : public mojom::DeviceFactory { CreateDeviceCallback callback); void OnClientConnectionErrorOrClose(const std::string& device_id); - const std::unique_ptr<service_manager::ServiceContextRef> service_ref_; + std::unique_ptr<service_manager::ServiceContextRef> service_ref_; const std::unique_ptr<media::VideoCaptureSystem> capture_system_; const media::MojoJpegDecodeAcceleratorFactoryCB jpeg_decoder_factory_callback_; diff --git a/chromium/services/video_capture/device_factory_provider_impl.cc b/chromium/services/video_capture/device_factory_provider_impl.cc index b62f196e75b..1e0176b30a0 100644 --- a/chromium/services/video_capture/device_factory_provider_impl.cc +++ b/chromium/services/video_capture/device_factory_provider_impl.cc @@ -67,9 +67,13 @@ class DeviceFactoryProviderImpl::GpuDependenciesContext { base::WeakPtrFactory<GpuDependenciesContext> weak_factory_for_gpu_io_thread_; }; -DeviceFactoryProviderImpl::DeviceFactoryProviderImpl( - std::unique_ptr<service_manager::ServiceContextRef> service_ref) - : service_ref_(std::move(service_ref)) {} +DeviceFactoryProviderImpl::DeviceFactoryProviderImpl() { + // Unretained |this| is safe because |factory_bindings_| is owned by + // |this|. + factory_bindings_.set_connection_error_handler(base::BindRepeating( + &DeviceFactoryProviderImpl::OnFactoryClientDisconnected, + base::Unretained(this))); +} DeviceFactoryProviderImpl::~DeviceFactoryProviderImpl() { factory_bindings_.CloseAllBindings(); @@ -80,6 +84,11 @@ DeviceFactoryProviderImpl::~DeviceFactoryProviderImpl() { } } +void DeviceFactoryProviderImpl::SetServiceRef( + std::unique_ptr<service_manager::ServiceContextRef> service_ref) { + service_ref_ = std::move(service_ref); +} + void DeviceFactoryProviderImpl::InjectGpuDependencies( mojom::AcceleratorFactoryPtr accelerator_factory) { LazyInitializeGpuDependenciesContext(); @@ -91,7 +100,10 @@ void DeviceFactoryProviderImpl::InjectGpuDependencies( void DeviceFactoryProviderImpl::ConnectToDeviceFactory( mojom::DeviceFactoryRequest request) { + DCHECK(service_ref_); LazyInitializeDeviceFactory(); + if (factory_bindings_.empty()) + device_factory_->SetServiceRef(service_ref_->Clone()); factory_bindings_.AddBinding(device_factory_.get(), std::move(request)); } @@ -119,13 +131,22 @@ void DeviceFactoryProviderImpl::LazyInitializeDeviceFactory() { std::move(media_device_factory)); device_factory_ = std::make_unique<VirtualDeviceEnabledDeviceFactory>( - service_ref_->Clone(), std::make_unique<DeviceFactoryMediaToMojoAdapter>( - service_ref_->Clone(), std::move(video_capture_system), + std::move(video_capture_system), base::BindRepeating( &GpuDependenciesContext::CreateJpegDecodeAccelerator, gpu_dependencies_context_->GetWeakPtr()), gpu_dependencies_context_->GetTaskRunner())); } +void DeviceFactoryProviderImpl::OnFactoryClientDisconnected() { + // If last client has disconnected, release service ref so that service + // shutdown timeout starts if no other references are still alive. + // We keep the |device_factory_| instance alive in order to avoid + // losing state that would be expensive to reinitialize, e.g. having + // already enumerated the available devices. + if (factory_bindings_.empty()) + device_factory_->SetServiceRef(nullptr); +} + } // namespace video_capture diff --git a/chromium/services/video_capture/device_factory_provider_impl.h b/chromium/services/video_capture/device_factory_provider_impl.h index ee9119c64bc..b0609b4cf22 100644 --- a/chromium/services/video_capture/device_factory_provider_impl.h +++ b/chromium/services/video_capture/device_factory_provider_impl.h @@ -17,12 +17,16 @@ namespace video_capture { +class VirtualDeviceEnabledDeviceFactory; + class DeviceFactoryProviderImpl : public mojom::DeviceFactoryProvider { public: - DeviceFactoryProviderImpl( - std::unique_ptr<service_manager::ServiceContextRef> service_ref); + DeviceFactoryProviderImpl(); ~DeviceFactoryProviderImpl() override; + void SetServiceRef( + std::unique_ptr<service_manager::ServiceContextRef> service_ref); + // mojom::DeviceFactoryProvider implementation. void InjectGpuDependencies( mojom::AcceleratorFactoryPtr accelerator_factory) override; @@ -33,10 +37,11 @@ class DeviceFactoryProviderImpl : public mojom::DeviceFactoryProvider { void LazyInitializeGpuDependenciesContext(); void LazyInitializeDeviceFactory(); + void OnFactoryClientDisconnected(); mojo::BindingSet<mojom::DeviceFactory> factory_bindings_; - std::unique_ptr<mojom::DeviceFactory> device_factory_; - const std::unique_ptr<service_manager::ServiceContextRef> service_ref_; + std::unique_ptr<VirtualDeviceEnabledDeviceFactory> device_factory_; + std::unique_ptr<service_manager::ServiceContextRef> service_ref_; std::unique_ptr<GpuDependenciesContext> gpu_dependencies_context_; DISALLOW_COPY_AND_ASSIGN(DeviceFactoryProviderImpl); diff --git a/chromium/services/video_capture/device_media_to_mojo_adapter_unittest.cc b/chromium/services/video_capture/device_media_to_mojo_adapter_unittest.cc index 939fadbc36a..879667d6b87 100644 --- a/chromium/services/video_capture/device_media_to_mojo_adapter_unittest.cc +++ b/chromium/services/video_capture/device_media_to_mojo_adapter_unittest.cc @@ -6,8 +6,8 @@ #include "base/run_loop.h" #include "base/test/scoped_task_environment.h" -#include "services/video_capture/test/mock_device.h" -#include "services/video_capture/test/mock_receiver.h" +#include "media/capture/video/mock_device.h" +#include "services/video_capture/public/cpp/mock_receiver.h" #include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h" @@ -24,7 +24,7 @@ class DeviceMediaToMojoAdapterTest : public ::testing::Test { void SetUp() override { mock_receiver_ = std::make_unique<MockReceiver>(mojo::MakeRequest(&receiver_)); - auto mock_device = std::make_unique<MockDevice>(); + auto mock_device = std::make_unique<media::MockDevice>(); mock_device_ptr_ = mock_device.get(); adapter_ = std::make_unique<DeviceMediaToMojoAdapter>( std::unique_ptr<service_manager::ServiceContextRef>(), @@ -40,7 +40,7 @@ class DeviceMediaToMojoAdapterTest : public ::testing::Test { } protected: - MockDevice* mock_device_ptr_; + media::MockDevice* mock_device_ptr_; std::unique_ptr<DeviceMediaToMojoAdapter> adapter_; std::unique_ptr<MockReceiver> mock_receiver_; mojom::ReceiverPtr receiver_; diff --git a/chromium/services/video_capture/public/cpp/BUILD.gn b/chromium/services/video_capture/public/cpp/BUILD.gn index 7871eb604ba..d590a1e4b22 100644 --- a/chromium/services/video_capture/public/cpp/BUILD.gn +++ b/chromium/services/video_capture/public/cpp/BUILD.gn @@ -10,6 +10,8 @@ source_set("cpp") { "receiver_media_to_mojo_adapter.h", ] + configs += [ "//build/config/compiler:wexit_time_destructors" ] + public_deps = [ "//base", "//media", @@ -22,3 +24,23 @@ source_set("cpp") { "//mojo/public/cpp/bindings:bindings", ] } + +source_set("mocks") { + testonly = true + + sources = [ + "mock_device_factory.cc", + "mock_device_factory.h", + "mock_device_factory_provider.cc", + "mock_device_factory_provider.h", + "mock_producer.cc", + "mock_producer.h", + "mock_receiver.cc", + "mock_receiver.h", + ] + + public_deps = [ + "//services/video_capture/public/mojom", + "//testing/gmock", + ] +} diff --git a/chromium/services/video_capture/public/cpp/mock_device_factory.cc b/chromium/services/video_capture/public/cpp/mock_device_factory.cc new file mode 100644 index 00000000000..37ef651b7b4 --- /dev/null +++ b/chromium/services/video_capture/public/cpp/mock_device_factory.cc @@ -0,0 +1,38 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "services/video_capture/public/cpp/mock_device_factory.h" + +namespace video_capture { + +MockDeviceFactory::MockDeviceFactory() = default; + +MockDeviceFactory::~MockDeviceFactory() = default; + +void MockDeviceFactory::GetDeviceInfos(GetDeviceInfosCallback callback) { + DoGetDeviceInfos(callback); +} + +void MockDeviceFactory::CreateDevice( + const std::string& device_id, + video_capture::mojom::DeviceRequest device_request, + CreateDeviceCallback callback) { + DoCreateDevice(device_id, &device_request, callback); +} + +void MockDeviceFactory::AddSharedMemoryVirtualDevice( + const media::VideoCaptureDeviceInfo& device_info, + video_capture::mojom::ProducerPtr producer, + bool send_buffer_handles_to_producer_as_raw_file_descriptors, + video_capture::mojom::SharedMemoryVirtualDeviceRequest virtual_device) { + DoAddVirtualDevice(device_info, producer.get(), &virtual_device); +} + +void MockDeviceFactory::AddTextureVirtualDevice( + const media::VideoCaptureDeviceInfo& device_info, + video_capture::mojom::TextureVirtualDeviceRequest virtual_device) { + DoAddTextureVirtualDevice(device_info, &virtual_device); +} + +} // namespace video_capture diff --git a/chromium/services/video_capture/public/cpp/mock_device_factory.h b/chromium/services/video_capture/public/cpp/mock_device_factory.h new file mode 100644 index 00000000000..dc8a79d910c --- /dev/null +++ b/chromium/services/video_capture/public/cpp/mock_device_factory.h @@ -0,0 +1,54 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef SERVICES_VIDEO_CAPTURE_PUBLIC_CPP_MOCK_DEVICE_FACTORY_H_ +#define SERVICES_VIDEO_CAPTURE_PUBLIC_CPP_MOCK_DEVICE_FACTORY_H_ + +#include "services/video_capture/public/mojom/device_factory.mojom.h" +#include "testing/gmock/include/gmock/gmock.h" + +namespace video_capture { + +class MockDeviceFactory : public video_capture::mojom::DeviceFactory { + public: + MockDeviceFactory(); + ~MockDeviceFactory() override; + + void GetDeviceInfos(GetDeviceInfosCallback callback) override; + void CreateDevice(const std::string& device_id, + video_capture::mojom::DeviceRequest device_request, + CreateDeviceCallback callback) override; + void AddSharedMemoryVirtualDevice( + const media::VideoCaptureDeviceInfo& device_info, + video_capture::mojom::ProducerPtr producer, + bool send_buffer_handles_to_producer_as_raw_file_descriptors, + video_capture::mojom::SharedMemoryVirtualDeviceRequest virtual_device) + override; + void AddTextureVirtualDevice(const media::VideoCaptureDeviceInfo& device_info, + video_capture::mojom::TextureVirtualDeviceRequest + virtual_device) override; + void RegisterVirtualDevicesChangedObserver( + video_capture::mojom::DevicesChangedObserverPtr observer) override { + NOTIMPLEMENTED(); + } + + MOCK_METHOD1(DoGetDeviceInfos, void(GetDeviceInfosCallback& callback)); + MOCK_METHOD3(DoCreateDevice, + void(const std::string& device_id, + video_capture::mojom::DeviceRequest* device_request, + CreateDeviceCallback& callback)); + MOCK_METHOD3(DoAddVirtualDevice, + void(const media::VideoCaptureDeviceInfo& device_info, + video_capture::mojom::ProducerProxy* producer, + video_capture::mojom::SharedMemoryVirtualDeviceRequest* + virtual_device_request)); + MOCK_METHOD2( + DoAddTextureVirtualDevice, + void(const media::VideoCaptureDeviceInfo& device_info, + video_capture::mojom::TextureVirtualDeviceRequest* virtual_device)); +}; + +} // namespace video_capture + +#endif // SERVICES_VIDEO_CAPTURE_PUBLIC_CPP_MOCK_DEVICE_FACTORY_H_ diff --git a/chromium/services/video_capture/public/cpp/mock_device_factory_provider.cc b/chromium/services/video_capture/public/cpp/mock_device_factory_provider.cc new file mode 100644 index 00000000000..187124c980c --- /dev/null +++ b/chromium/services/video_capture/public/cpp/mock_device_factory_provider.cc @@ -0,0 +1,23 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "services/video_capture/public/cpp/mock_device_factory_provider.h" + +namespace video_capture { + +MockDeviceFactoryProvider::MockDeviceFactoryProvider() {} + +MockDeviceFactoryProvider::~MockDeviceFactoryProvider() = default; + +void MockDeviceFactoryProvider::ConnectToDeviceFactory( + video_capture::mojom::DeviceFactoryRequest request) { + DoConnectToDeviceFactory(request); +} + +void MockDeviceFactoryProvider::InjectGpuDependencies( + video_capture::mojom::AcceleratorFactoryPtr accelerator_factory) { + DoInjectGpuDependencies(accelerator_factory); +} + +} // namespace video_capture diff --git a/chromium/services/video_capture/public/cpp/mock_device_factory_provider.h b/chromium/services/video_capture/public/cpp/mock_device_factory_provider.h new file mode 100644 index 00000000000..f812e5b2507 --- /dev/null +++ b/chromium/services/video_capture/public/cpp/mock_device_factory_provider.h @@ -0,0 +1,35 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef SERVICES_VIDEO_CAPTURE_PUBLIC_CPP_MOCK_DEVICE_FACTORY_PROVIDER_H_ +#define SERVICES_VIDEO_CAPTURE_PUBLIC_CPP_MOCK_DEVICE_FACTORY_PROVIDER_H_ + +#include "services/video_capture/public/mojom/device_factory_provider.mojom.h" +#include "testing/gmock/include/gmock/gmock.h" + +namespace video_capture { + +class MockDeviceFactoryProvider + : public video_capture::mojom::DeviceFactoryProvider { + public: + MockDeviceFactoryProvider(); + ~MockDeviceFactoryProvider() override; + + void ConnectToDeviceFactory( + video_capture::mojom::DeviceFactoryRequest request) override; + + void InjectGpuDependencies( + video_capture::mojom::AcceleratorFactoryPtr accelerator_factory) override; + + MOCK_METHOD1( + DoInjectGpuDependencies, + void(video_capture::mojom::AcceleratorFactoryPtr& accelerator_factory)); + MOCK_METHOD1(SetShutdownDelayInSeconds, void(float seconds)); + MOCK_METHOD1(DoConnectToDeviceFactory, + void(video_capture::mojom::DeviceFactoryRequest& request)); +}; + +} // namespace video_capture + +#endif // SERVICES_VIDEO_CAPTURE_PUBLIC_CPP_MOCK_DEVICE_FACTORY_PROVIDER_H_ diff --git a/chromium/services/video_capture/public/cpp/mock_producer.cc b/chromium/services/video_capture/public/cpp/mock_producer.cc new file mode 100644 index 00000000000..8d47eb47bf5 --- /dev/null +++ b/chromium/services/video_capture/public/cpp/mock_producer.cc @@ -0,0 +1,20 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "services/video_capture/public/cpp/mock_producer.h" + +namespace video_capture { + +MockProducer::MockProducer(mojom::ProducerRequest request) + : binding_(this, std::move(request)) {} + +MockProducer::~MockProducer() = default; + +void MockProducer::OnNewBuffer(int32_t buffer_id, + media::mojom::VideoBufferHandlePtr buffer_handle, + OnNewBufferCallback callback) { + DoOnNewBuffer(buffer_id, &buffer_handle, callback); +} + +} // namespace video_capture diff --git a/chromium/services/video_capture/public/cpp/mock_producer.h b/chromium/services/video_capture/public/cpp/mock_producer.h new file mode 100644 index 00000000000..084168cbfed --- /dev/null +++ b/chromium/services/video_capture/public/cpp/mock_producer.h @@ -0,0 +1,37 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef SERVICES_VIDEO_CAPTURE_PUBLIC_CPP_MOCK_PRODUCER_H_ +#define SERVICES_VIDEO_CAPTURE_PUBLIC_CPP_MOCK_PRODUCER_H_ + +#include "media/mojo/interfaces/media_types.mojom.h" +#include "mojo/public/cpp/bindings/binding.h" +#include "services/video_capture/public/mojom/producer.mojom.h" +#include "testing/gmock/include/gmock/gmock.h" + +namespace video_capture { + +class MockProducer : public mojom::Producer { + public: + MockProducer(mojom::ProducerRequest request); + ~MockProducer() override; + + // Use forwarding method to work around gmock not supporting move-only types. + void OnNewBuffer(int32_t buffer_id, + media::mojom::VideoBufferHandlePtr buffer_handle, + OnNewBufferCallback callback) override; + + MOCK_METHOD3(DoOnNewBuffer, + void(int32_t, + media::mojom::VideoBufferHandlePtr*, + OnNewBufferCallback& callback)); + MOCK_METHOD1(OnBufferRetired, void(int32_t)); + + private: + const mojo::Binding<mojom::Producer> binding_; +}; + +} // namespace video_capture + +#endif // SERVICES_VIDEO_CAPTURE_PUBLIC_CPP_MOCK_PRODUCER_H_ diff --git a/chromium/services/video_capture/public/cpp/mock_receiver.cc b/chromium/services/video_capture/public/cpp/mock_receiver.cc new file mode 100644 index 00000000000..54551a134cf --- /dev/null +++ b/chromium/services/video_capture/public/cpp/mock_receiver.cc @@ -0,0 +1,31 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "services/video_capture/public/cpp/mock_receiver.h" + +namespace video_capture { + +MockReceiver::MockReceiver() : binding_(this) {} + +MockReceiver::MockReceiver(mojom::ReceiverRequest request) + : binding_(this, std::move(request)) {} + +MockReceiver::~MockReceiver() = default; + +void MockReceiver::OnNewBuffer( + int32_t buffer_id, + media::mojom::VideoBufferHandlePtr buffer_handle) { + DoOnNewBuffer(buffer_id, &buffer_handle); +} + +void MockReceiver::OnFrameReadyInBuffer( + int32_t buffer_id, + int32_t frame_feedback_id, + mojom::ScopedAccessPermissionPtr access_permission, + media::mojom::VideoFrameInfoPtr frame_info) { + DoOnFrameReadyInBuffer(buffer_id, frame_feedback_id, &access_permission, + &frame_info); +} + +} // namespace video_capture diff --git a/chromium/services/video_capture/public/cpp/mock_receiver.h b/chromium/services/video_capture/public/cpp/mock_receiver.h new file mode 100644 index 00000000000..a5a9c306590 --- /dev/null +++ b/chromium/services/video_capture/public/cpp/mock_receiver.h @@ -0,0 +1,50 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef SERVICES_VIDEO_CAPTURE_PUBLIC_CPP_MOCK_RECEIVER_H_ +#define SERVICES_VIDEO_CAPTURE_PUBLIC_CPP_MOCK_RECEIVER_H_ + +#include "media/mojo/interfaces/media_types.mojom.h" +#include "mojo/public/cpp/bindings/binding.h" +#include "services/video_capture/public/mojom/receiver.mojom.h" +#include "testing/gmock/include/gmock/gmock.h" + +namespace video_capture { + +class MockReceiver : public mojom::Receiver { + public: + MockReceiver(); + explicit MockReceiver(mojom::ReceiverRequest request); + ~MockReceiver() override; + + // Use forwarding method to work around gmock not supporting move-only types. + void OnNewBuffer(int32_t buffer_id, + media::mojom::VideoBufferHandlePtr buffer_handle) override; + void OnFrameReadyInBuffer( + int32_t buffer_id, + int32_t frame_feedback_id, + mojom::ScopedAccessPermissionPtr access_permission, + media::mojom::VideoFrameInfoPtr frame_info) override; + + MOCK_METHOD2(DoOnNewBuffer, + void(int32_t, media::mojom::VideoBufferHandlePtr*)); + MOCK_METHOD4(DoOnFrameReadyInBuffer, + void(int32_t buffer_id, + int32_t frame_feedback_id, + mojom::ScopedAccessPermissionPtr*, + media::mojom::VideoFrameInfoPtr*)); + MOCK_METHOD1(OnBufferRetired, void(int32_t)); + MOCK_METHOD1(OnError, void(media::VideoCaptureError)); + MOCK_METHOD1(OnFrameDropped, void(media::VideoCaptureFrameDropReason)); + MOCK_METHOD1(OnLog, void(const std::string&)); + MOCK_METHOD0(OnStarted, void()); + MOCK_METHOD0(OnStartedUsingGpuDecode, void()); + + private: + const mojo::Binding<mojom::Receiver> binding_; +}; + +} // namespace video_capture + +#endif // SERVICES_VIDEO_CAPTURE_PUBLIC_CPP_MOCK_RECEIVER_H_ diff --git a/chromium/services/video_capture/public/mojom/device_factory.mojom b/chromium/services/video_capture/public/mojom/device_factory.mojom index e79248befb0..558c8dd599e 100644 --- a/chromium/services/video_capture/public/mojom/device_factory.mojom +++ b/chromium/services/video_capture/public/mojom/device_factory.mojom @@ -15,6 +15,10 @@ enum DeviceAccessResultCode { ERROR_DEVICE_NOT_FOUND }; +interface DevicesChangedObserver { + OnDevicesChanged(); +}; + // Enables access to a set of video capture devices. // Typical operation is to first call GetDeviceInfos() to obtain // information about available devices. The |device_id| of the infos can @@ -63,4 +67,10 @@ interface DeviceFactory { AddTextureVirtualDevice( media.mojom.VideoCaptureDeviceInfo device_info, TextureVirtualDevice& virtual_device); + + // Registered observers will get notified whenever a virtual device is added + // or removed. Note: Changes to non-virtual devices are currently being + // monitored outside the video capture service, and therefore the service + // does not offer such monitoring. + RegisterVirtualDevicesChangedObserver(DevicesChangedObserver observer); }; diff --git a/chromium/services/video_capture/public/mojom/virtual_device.mojom b/chromium/services/video_capture/public/mojom/virtual_device.mojom index 18559a884fb..71898688350 100644 --- a/chromium/services/video_capture/public/mojom/virtual_device.mojom +++ b/chromium/services/video_capture/public/mojom/virtual_device.mojom @@ -35,7 +35,8 @@ interface SharedMemoryVirtualDevice { // |Producer.OnNewBufferHandle| and/or |Producer.OnBufferRetired| // will be invoked. RequestFrameBuffer(gfx.mojom.Size dimension, - media.mojom.VideoCapturePixelFormat pixel_format) + media.mojom.VideoCapturePixelFormat pixel_format, + media.mojom.PlaneStrides? strides) => (int32 buffer_id); // Called to indicate that a video frame is ready in the given buffer @@ -54,6 +55,9 @@ interface TextureVirtualDevice { int32 buffer_id, media.mojom.MailboxBufferHandleSet mailbox_handles); // The invoker must guarantee that the textures with |buffer_id| stay valid // until |access_permission| is released by the invocation target. + // In |frame_info|, |visible_rect| must be equivalent to the full |coded_size| + // of the frame, i.e. using |visible_rect| to crop to subregions of the frame + // is not supported. OnFrameReadyInBuffer(int32 buffer_id, ScopedAccessPermission access_permission, media.mojom.VideoFrameInfo frame_info); diff --git a/chromium/services/video_capture/service_impl.cc b/chromium/services/video_capture/service_impl.cc index 52d58ca6722..0316e85c6fd 100644 --- a/chromium/services/video_capture/service_impl.cc +++ b/chromium/services/video_capture/service_impl.cc @@ -4,6 +4,7 @@ #include "services/video_capture/service_impl.h" +#include "build/build_config.h" #include "mojo/public/cpp/bindings/strong_binding.h" #include "services/service_manager/public/cpp/service_context.h" #include "services/video_capture/device_factory_provider_impl.h" @@ -13,9 +14,8 @@ namespace video_capture { -ServiceImpl::ServiceImpl(float shutdown_delay_in_seconds) - : shutdown_delay_in_seconds_(shutdown_delay_in_seconds), - weak_factory_(this) {} +ServiceImpl::ServiceImpl(base::Optional<base::TimeDelta> shutdown_delay) + : shutdown_delay_(shutdown_delay), weak_factory_(this) {} ServiceImpl::~ServiceImpl() { DCHECK(thread_checker_.CalledOnValidThread()); @@ -25,7 +25,14 @@ ServiceImpl::~ServiceImpl() { // static std::unique_ptr<service_manager::Service> ServiceImpl::Create() { - return std::make_unique<ServiceImpl>(); +#if defined(OS_ANDROID) + // On Android, we do not use automatic service shutdown, because when shutting + // down the service, we lose caching of the supported formats, and re-querying + // these can take several seconds on certain Android devices. + return std::make_unique<ServiceImpl>(base::Optional<base::TimeDelta>()); +#else + return std::make_unique<ServiceImpl>(base::TimeDelta::FromSeconds(5)); +#endif } void ServiceImpl::SetDestructionObserver(base::OnceClosure observer_cb) { @@ -66,8 +73,7 @@ void ServiceImpl::OnStart() { // SetServiceContextRefProviderForTesting(). if (!ref_factory_) { ref_factory_ = std::make_unique<service_manager::ServiceKeepalive>( - context(), base::TimeDelta::FromSecondsD(shutdown_delay_in_seconds_), - this); + context(), shutdown_delay_, this); } registry_.AddInterface<mojom::DeviceFactoryProvider>( @@ -79,6 +85,8 @@ void ServiceImpl::OnStart() { base::Bind(&ServiceImpl::OnTestingControlsRequest, base::Unretained(this))); + // Unretained |this| is safe because |factory_provider_bindings_| is owned by + // |this|. factory_provider_bindings_.set_connection_error_handler(base::BindRepeating( &ServiceImpl::OnProviderClientDisconnected, base::Unretained(this))); } @@ -112,6 +120,8 @@ void ServiceImpl::OnDeviceFactoryProviderRequest( mojom::DeviceFactoryProviderRequest request) { DCHECK(thread_checker_.CalledOnValidThread()); LazyInitializeDeviceFactoryProvider(); + if (factory_provider_bindings_.empty()) + device_factory_provider_->SetServiceRef(ref_factory_->CreateRef()); factory_provider_bindings_.AddBinding(device_factory_provider_.get(), std::move(request)); @@ -132,15 +142,17 @@ void ServiceImpl::LazyInitializeDeviceFactoryProvider() { if (device_factory_provider_) return; - device_factory_provider_ = - std::make_unique<DeviceFactoryProviderImpl>(ref_factory_->CreateRef()); + device_factory_provider_ = std::make_unique<DeviceFactoryProviderImpl>(); } void ServiceImpl::OnProviderClientDisconnected() { - // Reset factory provider if no client is connected. - if (factory_provider_bindings_.empty()) { - device_factory_provider_.reset(); - } + // If last client has disconnected, release service ref so that service + // shutdown timeout starts if no other references are still alive. + // We keep the |device_factory_provider_| instance alive in order to avoid + // losing state that would be expensive to reinitialize, e.g. having + // already enumerated the available devices. + if (factory_provider_bindings_.empty()) + device_factory_provider_->SetServiceRef(nullptr); if (!factory_provider_client_disconnected_cb_.is_null()) { factory_provider_client_disconnected_cb_.Run(); diff --git a/chromium/services/video_capture/service_impl.h b/chromium/services/video_capture/service_impl.h index 5fabb9ad396..e5b93d99d47 100644 --- a/chromium/services/video_capture/service_impl.h +++ b/chromium/services/video_capture/service_impl.h @@ -25,7 +25,9 @@ namespace video_capture { class ServiceImpl : public service_manager::Service, public service_manager::ServiceKeepalive::TimeoutObserver { public: - ServiceImpl(float shutdown_delay_in_seconds = 5.0f); + // If |shutdown_delay| is provided, the service will shut itself down as soon + // as no client was connect for the corresponding duration. + explicit ServiceImpl(base::Optional<base::TimeDelta> shutdown_delay); ~ServiceImpl() override; static std::unique_ptr<service_manager::Service> Create(); @@ -58,7 +60,7 @@ class ServiceImpl : public service_manager::Service, void LazyInitializeDeviceFactoryProvider(); void OnProviderClientDisconnected(); - const float shutdown_delay_in_seconds_; + const base::Optional<base::TimeDelta> shutdown_delay_; #if defined(OS_WIN) // COM must be initialized in order to access the video capture devices. base::win::ScopedCOMInitializer com_initializer_; diff --git a/chromium/services/video_capture/service_main.cc b/chromium/services/video_capture/service_main.cc index 52ae77d057c..520937296b3 100644 --- a/chromium/services/video_capture/service_main.cc +++ b/chromium/services/video_capture/service_main.cc @@ -7,6 +7,7 @@ #include "services/video_capture/service_impl.h" MojoResult ServiceMain(MojoHandle service_request_handle) { - return service_manager::ServiceRunner(new video_capture::ServiceImpl()) + return service_manager::ServiceRunner( + video_capture::ServiceImpl::Create().release()) .Run(service_request_handle); } diff --git a/chromium/services/video_capture/shared_memory_virtual_device_mojo_adapter.cc b/chromium/services/video_capture/shared_memory_virtual_device_mojo_adapter.cc index f14b8488ade..c85448c6d68 100644 --- a/chromium/services/video_capture/shared_memory_virtual_device_mojo_adapter.cc +++ b/chromium/services/video_capture/shared_memory_virtual_device_mojo_adapter.cc @@ -55,12 +55,15 @@ int SharedMemoryVirtualDeviceMojoAdapter::max_buffer_pool_buffer_count() { void SharedMemoryVirtualDeviceMojoAdapter::RequestFrameBuffer( const gfx::Size& dimension, media::VideoPixelFormat pixel_format, + media::mojom::PlaneStridesPtr strides, RequestFrameBufferCallback callback) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); int buffer_id_to_drop = media::VideoCaptureBufferPool::kInvalidId; - const int buffer_id = buffer_pool_->ReserveForProducer( - dimension, pixel_format, 0 /* frame_feedback_id */, &buffer_id_to_drop); + int buffer_id = media::VideoCaptureBufferPool::kInvalidId; + const auto reserve_result = buffer_pool_->ReserveForProducer( + dimension, pixel_format, strides, 0 /* frame_feedback_id */, &buffer_id, + &buffer_id_to_drop); // Remove dropped buffer if there is one. if (buffer_id_to_drop != media::VideoCaptureBufferPool::kInvalidId) { @@ -75,8 +78,8 @@ void SharedMemoryVirtualDeviceMojoAdapter::RequestFrameBuffer( } } - // No buffer available. - if (buffer_id == media::VideoCaptureBufferPool::kInvalidId) { + if (reserve_result != + media::VideoCaptureDevice::Client::ReserveResult::kSucceeded) { std::move(callback).Run(mojom::kInvalidBufferId); return; } diff --git a/chromium/services/video_capture/shared_memory_virtual_device_mojo_adapter.h b/chromium/services/video_capture/shared_memory_virtual_device_mojo_adapter.h index 78e59af9702..013a31a01cb 100644 --- a/chromium/services/video_capture/shared_memory_virtual_device_mojo_adapter.h +++ b/chromium/services/video_capture/shared_memory_virtual_device_mojo_adapter.h @@ -28,6 +28,7 @@ class SharedMemoryVirtualDeviceMojoAdapter // mojom::SharedMemoryVirtualDevice implementation. void RequestFrameBuffer(const gfx::Size& dimension, media::VideoPixelFormat pixel_format, + media::mojom::PlaneStridesPtr strides, RequestFrameBufferCallback callback) override; void OnFrameReadyInBuffer( int32_t buffer_id, diff --git a/chromium/services/video_capture/texture_virtual_device_mojo_adapter_unittest.cc b/chromium/services/video_capture/texture_virtual_device_mojo_adapter_unittest.cc index 587223cccab..24d3b378892 100644 --- a/chromium/services/video_capture/texture_virtual_device_mojo_adapter_unittest.cc +++ b/chromium/services/video_capture/texture_virtual_device_mojo_adapter_unittest.cc @@ -6,7 +6,7 @@ #include "base/run_loop.h" #include "base/test/scoped_task_environment.h" -#include "services/video_capture/test/mock_receiver.h" +#include "services/video_capture/public/cpp/mock_receiver.h" #include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h" diff --git a/chromium/services/video_capture/virtual_device_enabled_device_factory.cc b/chromium/services/video_capture/virtual_device_enabled_device_factory.cc index 03cb1b71642..4675957fb0a 100644 --- a/chromium/services/video_capture/virtual_device_enabled_device_factory.cc +++ b/chromium/services/video_capture/virtual_device_enabled_device_factory.cc @@ -6,6 +6,7 @@ #include "base/logging.h" #include "media/capture/video/video_capture_device_info.h" +#include "services/video_capture/device_factory_media_to_mojo_adapter.h" #include "services/video_capture/shared_memory_virtual_device_mojo_adapter.h" #include "services/video_capture/texture_virtual_device_mojo_adapter.h" @@ -84,15 +85,21 @@ class VirtualDeviceEnabledDeviceFactory::VirtualDeviceEntry { }; VirtualDeviceEnabledDeviceFactory::VirtualDeviceEnabledDeviceFactory( - std::unique_ptr<service_manager::ServiceContextRef> service_ref, - std::unique_ptr<mojom::DeviceFactory> device_factory) - : service_ref_(std::move(service_ref)), - device_factory_(std::move(device_factory)), - weak_factory_(this) {} + std::unique_ptr<DeviceFactoryMediaToMojoAdapter> device_factory) + : device_factory_(std::move(device_factory)), weak_factory_(this) {} VirtualDeviceEnabledDeviceFactory::~VirtualDeviceEnabledDeviceFactory() = default; +void VirtualDeviceEnabledDeviceFactory::SetServiceRef( + std::unique_ptr<service_manager::ServiceContextRef> service_ref) { + if (service_ref) + device_factory_->SetServiceRef(service_ref->Clone()); + else + device_factory_->SetServiceRef(nullptr); + service_ref_ = std::move(service_ref); +} + void VirtualDeviceEnabledDeviceFactory::GetDeviceInfos( GetDeviceInfosCallback callback) { device_factory_->GetDeviceInfos( @@ -157,6 +164,7 @@ void VirtualDeviceEnabledDeviceFactory::AddSharedMemoryVirtualDevice( std::move(producer_binding)); virtual_devices_by_id_.insert( std::make_pair(device_id, std::move(device_entry))); + EmitDevicesChangedEvent(); } void VirtualDeviceEnabledDeviceFactory::AddTextureVirtualDevice( @@ -183,6 +191,15 @@ void VirtualDeviceEnabledDeviceFactory::AddTextureVirtualDevice( std::move(producer_binding)); virtual_devices_by_id_.insert( std::make_pair(device_id, std::move(device_entry))); + EmitDevicesChangedEvent(); +} + +void VirtualDeviceEnabledDeviceFactory::RegisterVirtualDevicesChangedObserver( + mojom::DevicesChangedObserverPtr observer) { + observer.set_connection_error_handler(base::BindOnce( + &VirtualDeviceEnabledDeviceFactory::OnDevicesChangedObserverDisconnected, + weak_factory_.GetWeakPtr(), &observer)); + devices_changed_observers_.push_back(std::move(observer)); } void VirtualDeviceEnabledDeviceFactory::OnGetDeviceInfos( @@ -202,6 +219,7 @@ void VirtualDeviceEnabledDeviceFactory:: const std::string& device_id) { virtual_devices_by_id_.at(device_id).StopDevice(); virtual_devices_by_id_.erase(device_id); + EmitDevicesChangedEvent(); } void VirtualDeviceEnabledDeviceFactory:: @@ -210,4 +228,23 @@ void VirtualDeviceEnabledDeviceFactory:: virtual_devices_by_id_.at(device_id).StopDevice(); } +void VirtualDeviceEnabledDeviceFactory::EmitDevicesChangedEvent() { + for (auto& observer : devices_changed_observers_) + observer->OnDevicesChanged(); +} + +void VirtualDeviceEnabledDeviceFactory::OnDevicesChangedObserverDisconnected( + mojom::DevicesChangedObserverPtr* observer) { + auto iter = std::find_if( + devices_changed_observers_.begin(), devices_changed_observers_.end(), + [observer](const mojom::DevicesChangedObserverPtr& entry) { + return &entry == observer; + }); + if (iter == devices_changed_observers_.end()) { + DCHECK(false); + return; + } + devices_changed_observers_.erase(iter); +} + } // namespace video_capture diff --git a/chromium/services/video_capture/virtual_device_enabled_device_factory.h b/chromium/services/video_capture/virtual_device_enabled_device_factory.h index 8761ec068dd..b40f3a947f1 100644 --- a/chromium/services/video_capture/virtual_device_enabled_device_factory.h +++ b/chromium/services/video_capture/virtual_device_enabled_device_factory.h @@ -15,15 +15,19 @@ namespace video_capture { +class DeviceFactoryMediaToMojoAdapter; + // Decorator that adds support for virtual devices to a given // mojom::DeviceFactory. class VirtualDeviceEnabledDeviceFactory : public mojom::DeviceFactory { public: - VirtualDeviceEnabledDeviceFactory( - std::unique_ptr<service_manager::ServiceContextRef> service_ref, - std::unique_ptr<mojom::DeviceFactory> factory); + explicit VirtualDeviceEnabledDeviceFactory( + std::unique_ptr<DeviceFactoryMediaToMojoAdapter> factory); ~VirtualDeviceEnabledDeviceFactory() override; + void SetServiceRef( + std::unique_ptr<service_manager::ServiceContextRef> service_ref); + // mojom::DeviceFactory implementation. void GetDeviceInfos(GetDeviceInfosCallback callback) override; void CreateDevice(const std::string& device_id, @@ -37,6 +41,8 @@ class VirtualDeviceEnabledDeviceFactory : public mojom::DeviceFactory { void AddTextureVirtualDevice( const media::VideoCaptureDeviceInfo& device_info, mojom::TextureVirtualDeviceRequest virtual_device) override; + void RegisterVirtualDevicesChangedObserver( + mojom::DevicesChangedObserverPtr observer) override; private: class VirtualDeviceEntry; @@ -49,10 +55,14 @@ class VirtualDeviceEnabledDeviceFactory : public mojom::DeviceFactory { const std::string& device_id); void OnVirtualDeviceConsumerConnectionErrorOrClose( const std::string& device_id); + void EmitDevicesChangedEvent(); + void OnDevicesChangedObserverDisconnected( + mojom::DevicesChangedObserverPtr* observer); std::map<std::string, VirtualDeviceEntry> virtual_devices_by_id_; - const std::unique_ptr<service_manager::ServiceContextRef> service_ref_; - const std::unique_ptr<mojom::DeviceFactory> device_factory_; + const std::unique_ptr<DeviceFactoryMediaToMojoAdapter> device_factory_; + std::unique_ptr<service_manager::ServiceContextRef> service_ref_; + std::vector<mojom::DevicesChangedObserverPtr> devices_changed_observers_; base::WeakPtrFactory<VirtualDeviceEnabledDeviceFactory> weak_factory_; DISALLOW_COPY_AND_ASSIGN(VirtualDeviceEnabledDeviceFactory); diff --git a/chromium/services/viz/BUILD.gn b/chromium/services/viz/BUILD.gn index 95a991352fc..1bbac7e6ec8 100644 --- a/chromium/services/viz/BUILD.gn +++ b/chromium/services/viz/BUILD.gn @@ -10,6 +10,7 @@ service("viz") { sources = [ "main.cc", ] + configs = [ "//build/config/compiler:wexit_time_destructors" ] deps = [ ":lib", "//services/service_manager/public/cpp", @@ -23,6 +24,8 @@ source_set("lib") { "service.h", ] + configs += [ "//build/config/compiler:wexit_time_destructors" ] + deps = [ "//base", "//components/viz/service", diff --git a/chromium/services/viz/privileged/interfaces/compositing/DEPS b/chromium/services/viz/privileged/interfaces/compositing/DEPS new file mode 100644 index 00000000000..dbd2a0e036e --- /dev/null +++ b/chromium/services/viz/privileged/interfaces/compositing/DEPS @@ -0,0 +1,3 @@ +include_rules = [ + "+ui/gfx/geometry/mojo/geometry_struct_traits.h", +] diff --git a/chromium/services/viz/privileged/interfaces/compositing/display_private.mojom b/chromium/services/viz/privileged/interfaces/compositing/display_private.mojom index 888386d11c6..b7c138978ae 100644 --- a/chromium/services/viz/privileged/interfaces/compositing/display_private.mojom +++ b/chromium/services/viz/privileged/interfaces/compositing/display_private.mojom @@ -4,6 +4,7 @@ module viz.mojom; +import "gpu/ipc/common/context_result.mojom"; import "mojo/public/mojom/base/time.mojom"; import "ui/gfx/mojo/ca_layer_params.mojom"; import "ui/gfx/mojo/color_space.mojom"; @@ -33,11 +34,6 @@ interface DisplayPrivate { gfx.mojom.ColorSpace device_color_space); SetOutputIsSecure(bool secure); - // Locks the vsync interval used to generate BeginFrames for this display to - // |interval|. Changes to vsync interval from other sources will be ignored. - // This will do nothing if the display is using an external BeginFrame source. - SetAuthoritativeVSyncInterval(mojo_base.mojom.TimeDelta interval); - // Updates vsync parameters used to generate BeginFrames for this display. // This will do nothing if the display is using an external BeginFrame source. SetDisplayVSyncParameters( @@ -74,4 +70,9 @@ interface DisplayClient { // size of the swapped frame. [EnableIf=is_android] DidCompleteSwapWithSize(gfx.mojom.Size size); + + // Notifies that context creation failed. On Android we can't fall back to + // SW in these cases, so we need to handle this specifically. + [EnableIf=is_android] + OnFatalOrSurfaceContextCreationFailure(gpu.mojom.ContextResult result); }; diff --git a/chromium/services/viz/privileged/interfaces/compositing/frame_sink_manager.mojom b/chromium/services/viz/privileged/interfaces/compositing/frame_sink_manager.mojom index fa5611797b1..5d541e6f7ae 100644 --- a/chromium/services/viz/privileged/interfaces/compositing/frame_sink_manager.mojom +++ b/chromium/services/viz/privileged/interfaces/compositing/frame_sink_manager.mojom @@ -27,6 +27,8 @@ struct RootCompositorFrameSinkParams { bool gpu_compositing = true; RendererSettings renderer_settings; bool send_swap_size_notifications = false; + // Disables begin frame rate limiting for the display compositor. + bool disable_frame_rate_limit = false; associated CompositorFrameSink& compositor_frame_sink; CompositorFrameSinkClient compositor_frame_sink_client; @@ -111,16 +113,6 @@ interface FrameSinkManager { UnregisterFrameSinkHierarchy(FrameSinkId parent_frame_sink_id, FrameSinkId child_frame_sink_id); - // Assigns the temporary reference for |surface_id| to FrameSinkId |owner|. - // If |owner| is invalidated before it converts the temporary reference to a - // surface reference then the temporary reference will be dropped. - AssignTemporaryReference(SurfaceId surface_id, - FrameSinkId owner); - - // Drops the temporary reference for |surface_id|. This will get called when - // the FrameSinkManagerClient doesn't think |surface_id| will be embedded. - DropTemporaryReference(SurfaceId surface_id); - // Requests viz to notify |observer| whenever video activity is detected in // one of the clients. See viz::VideoDetector. AddVideoDetectorObserver(VideoDetectorObserver observer); @@ -142,9 +134,6 @@ interface FrameSinkManager { // compositor. The frame sink manager host is either the browser process in // Chrome or the window server process. interface FrameSinkManagerClient { - // Called by the frame sink manager when a new Surface is created. - OnSurfaceCreated(SurfaceId surface_id); - // Called by the frame sink manager when a CompositorFrame with a new // SurfaceId activates for the first time. OnFirstSurfaceActivation(SurfaceInfo surface_info); diff --git a/chromium/services/viz/privileged/interfaces/compositing/frame_sink_video_capture.mojom b/chromium/services/viz/privileged/interfaces/compositing/frame_sink_video_capture.mojom index 240ceefcbac..f65d3023576 100644 --- a/chromium/services/viz/privileged/interfaces/compositing/frame_sink_video_capture.mojom +++ b/chromium/services/viz/privileged/interfaces/compositing/frame_sink_video_capture.mojom @@ -11,6 +11,7 @@ import "mojo/public/mojom/base/shared_memory.mojom"; import "services/viz/public/interfaces/compositing/frame_sink_id.mojom"; import "skia/public/interfaces/bitmap.mojom"; import "ui/gfx/geometry/mojo/geometry.mojom"; +import "ui/gfx/mojo/color_space.mojom"; // Provided with each call to FrameSinkVideoConsumer::OnFrameCaptured() so that // the consumer can notify the capturer the instant it is done consuming the @@ -73,7 +74,7 @@ interface FrameSinkVideoCapturer { // // Default, if never called: PIXEL_FORMAT_I420, COLOR_SPACE_HD_REC709 SetFormat(media.mojom.VideoPixelFormat format, - media.mojom.ColorSpace color_space); + gfx.mojom.ColorSpace color_space); // Specifies the maximum rate of capture in terms of a minimum time period // (min_period = 1/max_frame_rate). diff --git a/chromium/services/viz/privileged/interfaces/compositing/renderer_settings.mojom b/chromium/services/viz/privileged/interfaces/compositing/renderer_settings.mojom index d685c6295e6..d7211df5fb3 100644 --- a/chromium/services/viz/privileged/interfaces/compositing/renderer_settings.mojom +++ b/chromium/services/viz/privileged/interfaces/compositing/renderer_settings.mojom @@ -4,6 +4,8 @@ module viz.mojom; +import "ui/gfx/geometry/mojo/geometry.mojom"; + struct RendererSettings { bool allow_antialiasing; bool finish_rendering_on_resize; @@ -21,4 +23,7 @@ struct RendererSettings { bool use_skia_deferred_display_list; bool allow_overlays; bool requires_alpha_channel; + + [EnableIf=is_android] + gfx.mojom.Size initial_screen_size; }; diff --git a/chromium/services/viz/privileged/interfaces/compositing/renderer_settings.typemap b/chromium/services/viz/privileged/interfaces/compositing/renderer_settings.typemap index 3614f119fc7..a906715e3ae 100644 --- a/chromium/services/viz/privileged/interfaces/compositing/renderer_settings.typemap +++ b/chromium/services/viz/privileged/interfaces/compositing/renderer_settings.typemap @@ -9,6 +9,9 @@ traits_headers = [ "//services/viz/privileged/interfaces/compositing/renderer_se deps = [ "//cc", ] +public_deps = [ + "//ui/gfx/geometry/mojo", +] sources = [ "renderer_settings_struct_traits.cc", ] diff --git a/chromium/services/viz/privileged/interfaces/compositing/renderer_settings_struct_traits.cc b/chromium/services/viz/privileged/interfaces/compositing/renderer_settings_struct_traits.cc index 880a8369c9b..0e3bcb2c2fb 100644 --- a/chromium/services/viz/privileged/interfaces/compositing/renderer_settings_struct_traits.cc +++ b/chromium/services/viz/privileged/interfaces/compositing/renderer_settings_struct_traits.cc @@ -11,6 +11,8 @@ namespace mojo { bool StructTraits<viz::mojom::RendererSettingsDataView, viz::RendererSettings>:: Read(viz::mojom::RendererSettingsDataView data, viz::RendererSettings* out) { + bool success = true; + out->allow_antialiasing = data.allow_antialiasing(); out->force_antialiasing = data.force_antialiasing(); out->force_blending_with_shaders = data.force_blending_with_shaders(); @@ -29,7 +31,12 @@ bool StructTraits<viz::mojom::RendererSettingsDataView, viz::RendererSettings>:: out->use_skia_deferred_display_list = data.use_skia_deferred_display_list(); out->allow_overlays = data.allow_overlays(); out->requires_alpha_channel = data.requires_alpha_channel(); - return true; + +#if defined(OS_ANDROID) + success = data.ReadInitialScreenSize(&out->initial_screen_size); +#endif + + return success; } } // namespace mojo diff --git a/chromium/services/viz/privileged/interfaces/compositing/renderer_settings_struct_traits.h b/chromium/services/viz/privileged/interfaces/compositing/renderer_settings_struct_traits.h index 88c763c9ab6..d363e9bd8d7 100644 --- a/chromium/services/viz/privileged/interfaces/compositing/renderer_settings_struct_traits.h +++ b/chromium/services/viz/privileged/interfaces/compositing/renderer_settings_struct_traits.h @@ -4,9 +4,11 @@ #ifndef SERVICES_VIZ_PRIVILEGED_INTERFACES_COMPOSITING_RENDERER_SETTINGS_STRUCT_TRAITS_H_ #define SERVICES_VIZ_PRIVILEGED_INTERFACES_COMPOSITING_RENDERER_SETTINGS_STRUCT_TRAITS_H_ +#include "build/build_config.h" #include "components/viz/common/display/renderer_settings.h" #include "services/viz/privileged/interfaces/compositing/renderer_settings.mojom.h" #include "services/viz/privileged/interfaces/compositing/renderer_settings_struct_traits.h" +#include "ui/gfx/geometry/mojo/geometry_struct_traits.h" namespace mojo { template <> @@ -80,6 +82,12 @@ struct StructTraits<viz::mojom::RendererSettingsDataView, return input.requires_alpha_channel; } +#if defined(OS_ANDROID) + static gfx::Size initial_screen_size(const viz::RendererSettings& input) { + return input.initial_screen_size; + } +#endif + static bool Read(viz::mojom::RendererSettingsDataView data, viz::RendererSettings* out); }; diff --git a/chromium/services/viz/privileged/interfaces/gl/BUILD.gn b/chromium/services/viz/privileged/interfaces/gl/BUILD.gn index 11b4d416feb..2e03b5d442e 100644 --- a/chromium/services/viz/privileged/interfaces/gl/BUILD.gn +++ b/chromium/services/viz/privileged/interfaces/gl/BUILD.gn @@ -12,11 +12,13 @@ mojom("gl") { ] public_deps = [ - "//components/arc/common:media", "//gpu/ipc/common:interfaces", "//media/mojo/interfaces", "//ui/gfx/geometry/mojo", "//ui/gfx/mojo", "//url/mojom:url_mojom_gurl", ] + if (is_chromeos) { + public_deps += [ "//components/arc/common:media" ] + } } diff --git a/chromium/services/viz/privileged/interfaces/gl/gpu_host.mojom b/chromium/services/viz/privileged/interfaces/gl/gpu_host.mojom index 11849c9ad90..2e081bb5cdf 100644 --- a/chromium/services/viz/privileged/interfaces/gl/gpu_host.mojom +++ b/chromium/services/viz/privileged/interfaces/gl/gpu_host.mojom @@ -35,8 +35,10 @@ interface GpuHost { // track of this decision in case the GPU process crashes. DisableGpuCompositing(); + [EnableIf=is_win] SetChildSurface(gpu.mojom.SurfaceHandle parent, gpu.mojom.SurfaceHandle child); + StoreShaderToDisk(int32 client_id, string key, string shader); RecordLogMessage(int32 severity, string header, string message); diff --git a/chromium/services/viz/privileged/interfaces/gl/gpu_service.mojom b/chromium/services/viz/privileged/interfaces/gl/gpu_service.mojom index 2541925da94..ad15dbc27b9 100644 --- a/chromium/services/viz/privileged/interfaces/gl/gpu_service.mojom +++ b/chromium/services/viz/privileged/interfaces/gl/gpu_service.mojom @@ -4,10 +4,15 @@ module viz.mojom; +[EnableIf=is_chromeos] import "components/arc/common/protected_buffer_manager.mojom"; +[EnableIf=is_chromeos] import "components/arc/common/video_decode_accelerator.mojom"; +[EnableIf=is_chromeos] import "components/arc/common/video_encode_accelerator.mojom"; +[EnableIf=is_chromeos] import "components/arc/common/video_protected_buffer_allocator.mojom"; +import "gpu/ipc/common/dx_diag_node.mojom"; import "gpu/ipc/common/gpu_info.mojom"; import "gpu/ipc/common/memory_stats.mojom"; import "gpu/ipc/common/surface_handle.mojom"; @@ -72,8 +77,11 @@ interface GpuService { GetVideoMemoryUsageStats() => (gpu.mojom.VideoMemoryUsageStats stats); - RequestCompleteGpuInfo() => (gpu.mojom.GpuInfo gpu_info); - GetGpuSupportedRuntimeVersion() => (gpu.mojom.GpuInfo gpu_info); + [EnableIf=is_win] + RequestCompleteGpuInfo() => (gpu.mojom.DxDiagNode dx_diagnostics); + [EnableIf=is_win] + GetGpuSupportedRuntimeVersion() + => (gpu.mojom.Dx12VulkanVersionInfo dx12_vulkan_version_info); // Requests that the GPU process query system availability of HDR output and // return it. diff --git a/chromium/services/viz/public/cpp/compositing/BUILD.gn b/chromium/services/viz/public/cpp/compositing/BUILD.gn index 4f61f0ad065..4b9c524f728 100644 --- a/chromium/services/viz/public/cpp/compositing/BUILD.gn +++ b/chromium/services/viz/public/cpp/compositing/BUILD.gn @@ -47,5 +47,6 @@ source_set("perftests") { "//ui/gfx", "//ui/gfx:test_support", "//ui/gfx/geometry", + "//ui/gfx/geometry/mojo:struct_traits", ] } diff --git a/chromium/services/viz/public/cpp/compositing/compositor_frame_metadata_struct_traits.cc b/chromium/services/viz/public/cpp/compositing/compositor_frame_metadata_struct_traits.cc index 87ac642aaae..781e20c3bc1 100644 --- a/chromium/services/viz/public/cpp/compositing/compositor_frame_metadata_struct_traits.cc +++ b/chromium/services/viz/public/cpp/compositing/compositor_frame_metadata_struct_traits.cc @@ -35,12 +35,12 @@ bool StructTraits<viz::mojom::CompositorFrameMetadataDataView, out->request_presentation_feedback = data.request_presentation_feedback(); out->root_background_color = data.root_background_color(); out->min_page_scale_factor = data.min_page_scale_factor(); + out->top_controls_height = data.top_controls_height(); + out->top_controls_shown_ratio = data.top_controls_shown_ratio(); #if defined(OS_ANDROID) out->max_page_scale_factor = data.max_page_scale_factor(); out->root_overflow_y_hidden = data.root_overflow_y_hidden(); - out->top_controls_height = data.top_controls_height(); - out->top_controls_shown_ratio = data.top_controls_shown_ratio(); out->bottom_controls_height = data.bottom_controls_height(); out->bottom_controls_shown_ratio = data.bottom_controls_shown_ratio(); #endif diff --git a/chromium/services/viz/public/cpp/compositing/compositor_frame_metadata_struct_traits.h b/chromium/services/viz/public/cpp/compositing/compositor_frame_metadata_struct_traits.h index 6dc5e4ccca7..5fc6ce96a62 100644 --- a/chromium/services/viz/public/cpp/compositing/compositor_frame_metadata_struct_traits.h +++ b/chromium/services/viz/public/cpp/compositing/compositor_frame_metadata_struct_traits.h @@ -101,6 +101,16 @@ struct StructTraits<viz::mojom::CompositorFrameMetadataDataView, return metadata.min_page_scale_factor; } + static float top_controls_height( + const viz::CompositorFrameMetadata& metadata) { + return metadata.top_controls_height; + } + + static float top_controls_shown_ratio( + const viz::CompositorFrameMetadata& metadata) { + return metadata.top_controls_shown_ratio; + } + #if defined(OS_ANDROID) static float max_page_scale_factor( const viz::CompositorFrameMetadata& metadata) { @@ -117,16 +127,6 @@ struct StructTraits<viz::mojom::CompositorFrameMetadataDataView, return metadata.root_overflow_y_hidden; } - static float top_controls_height( - const viz::CompositorFrameMetadata& metadata) { - return metadata.top_controls_height; - } - - static float top_controls_shown_ratio( - const viz::CompositorFrameMetadata& metadata) { - return metadata.top_controls_shown_ratio; - } - static float bottom_controls_height( const viz::CompositorFrameMetadata& metadata) { return metadata.bottom_controls_height; diff --git a/chromium/services/viz/public/cpp/compositing/quads_struct_traits.cc b/chromium/services/viz/public/cpp/compositing/quads_struct_traits.cc index 47121a2087d..af3742e9297 100644 --- a/chromium/services/viz/public/cpp/compositing/quads_struct_traits.cc +++ b/chromium/services/viz/public/cpp/compositing/quads_struct_traits.cc @@ -79,6 +79,7 @@ bool StructTraits<viz::mojom::RenderPassQuadStateDataView, viz::DrawQuad>::Read( return false; } quad->force_anti_aliasing_off = data.force_anti_aliasing_off(); + quad->backdrop_filter_quality = data.backdrop_filter_quality(); return true; } diff --git a/chromium/services/viz/public/cpp/compositing/quads_struct_traits.h b/chromium/services/viz/public/cpp/compositing/quads_struct_traits.h index ac28fea34eb..129fcd1c514 100644 --- a/chromium/services/viz/public/cpp/compositing/quads_struct_traits.h +++ b/chromium/services/viz/public/cpp/compositing/quads_struct_traits.h @@ -189,6 +189,12 @@ struct StructTraits<viz::mojom::RenderPassQuadStateDataView, viz::DrawQuad> { return quad->force_anti_aliasing_off; } + static float backdrop_filter_quality(const viz::DrawQuad& input) { + const viz::RenderPassDrawQuad* quad = + viz::RenderPassDrawQuad::MaterialCast(&input); + return quad->backdrop_filter_quality; + } + static bool Read(viz::mojom::RenderPassQuadStateDataView data, viz::DrawQuad* out); }; diff --git a/chromium/services/viz/public/cpp/compositing/struct_traits_unittest.cc b/chromium/services/viz/public/cpp/compositing/struct_traits_unittest.cc index 0993a63c0d6..18a5ffd7d37 100644 --- a/chromium/services/viz/public/cpp/compositing/struct_traits_unittest.cc +++ b/chromium/services/viz/public/cpp/compositing/struct_traits_unittest.cc @@ -620,13 +620,13 @@ TEST_F(StructTraitsTest, CompositorFrameMetadata) { uint64_t begin_frame_ack_sequence_number = 0xdeadbeef; FrameDeadline frame_deadline(base::TimeTicks(), 4u, base::TimeDelta(), true); const float min_page_scale_factor = 3.5f; + const float top_bar_height(1234.5f); + const float top_bar_shown_ratio(1.0f); #if defined(OS_ANDROID) const float max_page_scale_factor = 4.6f; const gfx::SizeF root_layer_size(1234.5f, 5432.1f); const bool root_overflow_y_hidden = true; - const float top_bar_height(1234.5f); - const float top_bar_shown_ratio(1.0f); const float bottom_bar_height(1234.5f); const float bottom_bar_shown_ratio(1.0f); Selection<gfx::SelectionBound> selection; @@ -656,13 +656,13 @@ TEST_F(StructTraitsTest, CompositorFrameMetadata) { input.frame_token = frame_token; input.begin_frame_ack.sequence_number = begin_frame_ack_sequence_number; input.min_page_scale_factor = min_page_scale_factor; + input.top_controls_height = top_bar_height; + input.top_controls_shown_ratio = top_bar_shown_ratio; #if defined(OS_ANDROID) input.max_page_scale_factor = max_page_scale_factor; input.root_layer_size = root_layer_size; input.root_overflow_y_hidden = root_overflow_y_hidden; - input.top_controls_height = top_bar_height; - input.top_controls_shown_ratio = top_bar_shown_ratio; input.bottom_controls_height = bottom_bar_height; input.bottom_controls_shown_ratio = bottom_bar_shown_ratio; input.selection = selection; @@ -694,13 +694,13 @@ TEST_F(StructTraitsTest, CompositorFrameMetadata) { EXPECT_EQ(begin_frame_ack_sequence_number, output.begin_frame_ack.sequence_number); EXPECT_EQ(min_page_scale_factor, output.min_page_scale_factor); + EXPECT_EQ(top_bar_height, output.top_controls_height); + EXPECT_EQ(top_bar_shown_ratio, output.top_controls_shown_ratio); #if defined(OS_ANDROID) EXPECT_EQ(max_page_scale_factor, output.max_page_scale_factor); EXPECT_EQ(root_layer_size, output.root_layer_size); EXPECT_EQ(root_overflow_y_hidden, output.root_overflow_y_hidden); - EXPECT_EQ(top_bar_height, output.top_controls_height); - EXPECT_EQ(top_bar_shown_ratio, output.top_controls_shown_ratio); EXPECT_EQ(bottom_bar_height, output.bottom_controls_height); EXPECT_EQ(bottom_bar_shown_ratio, output.bottom_controls_shown_ratio); EXPECT_EQ(selection, output.selection); @@ -903,6 +903,7 @@ TEST_F(StructTraitsTest, QuadListBasic) { const gfx::Rect rect2(2468, 8642, 4321, 1234); const uint32_t color2 = 0xffffffff; const bool force_anti_aliasing_off = true; + const float backdrop_filter_quality = 1.0f; SolidColorDrawQuad* solid_quad = render_pass->CreateAndAppendDrawQuad<SolidColorDrawQuad>(); solid_quad->SetNew(sqs, rect2, rect2, color2, force_anti_aliasing_off); @@ -934,7 +935,7 @@ TEST_F(StructTraitsTest, QuadListBasic) { render_pass_quad->SetNew(sqs, rect4, rect4, render_pass_id, resource_id4, mask_uv_rect, mask_texture_size, filters_scale, filters_origin, tex_coord_rect, - force_anti_aliasing_off); + force_anti_aliasing_off, backdrop_filter_quality); const gfx::Rect rect5(123, 567, 91011, 131415); const ResourceId resource_id5(1337); diff --git a/chromium/services/viz/public/cpp/compositing/transferable_resource.typemap b/chromium/services/viz/public/cpp/compositing/transferable_resource.typemap index d286183c89b..e41b2125ce7 100644 --- a/chromium/services/viz/public/cpp/compositing/transferable_resource.typemap +++ b/chromium/services/viz/public/cpp/compositing/transferable_resource.typemap @@ -14,4 +14,5 @@ sources = [ type_mappings = [ "viz.mojom.TransferableResource=viz::TransferableResource" ] deps = [ "//gpu/ipc/common:struct_traits", + "//ui/gfx/geometry/mojo:struct_traits", ] diff --git a/chromium/services/viz/public/interfaces/compositing/compositor_frame_metadata.mojom b/chromium/services/viz/public/interfaces/compositing/compositor_frame_metadata.mojom index d514e553794..ab78037e598 100644 --- a/chromium/services/viz/public/interfaces/compositing/compositor_frame_metadata.mojom +++ b/chromium/services/viz/public/interfaces/compositing/compositor_frame_metadata.mojom @@ -42,10 +42,8 @@ struct CompositorFrameMetadata { [EnableIf=is_android] bool root_overflow_y_hidden; - [EnableIf=is_android] float top_controls_height; - [EnableIf=is_android] float top_controls_shown_ratio; [EnableIf=is_android] diff --git a/chromium/services/viz/public/interfaces/compositing/quads.mojom b/chromium/services/viz/public/interfaces/compositing/quads.mojom index b1eaad0ab05..3f12a143d90 100644 --- a/chromium/services/viz/public/interfaces/compositing/quads.mojom +++ b/chromium/services/viz/public/interfaces/compositing/quads.mojom @@ -41,6 +41,7 @@ struct RenderPassQuadState { gfx.mojom.RectF tex_coord_rect; bool force_anti_aliasing_off; + float backdrop_filter_quality; }; struct SolidColorQuadState { diff --git a/chromium/services/viz/public/interfaces/hit_test/input_target_client.mojom b/chromium/services/viz/public/interfaces/hit_test/input_target_client.mojom index 1ec376afafc..c97749be0b7 100644 --- a/chromium/services/viz/public/interfaces/hit_test/input_target_client.mojom +++ b/chromium/services/viz/public/interfaces/hit_test/input_target_client.mojom @@ -18,6 +18,7 @@ interface InputTargetClient { // out-of-process iframe). // |local_point| is the point in the coordinate space of the RenderWidget // indicated by the FrameSinkId. - FrameSinkIdAt(gfx.mojom.Point point) => (FrameSinkId id, - gfx.mojom.PointF local_point); + // |trace_id| is used for trace events and does not change the functionality. + FrameSinkIdAt(gfx.mojom.Point point, uint64 trace_id) => + (FrameSinkId id, gfx.mojom.PointF local_point); }; diff --git a/chromium/services/ws/BUILD.gn b/chromium/services/ws/BUILD.gn index e22d19b2e54..ad1d532a0cd 100644 --- a/chromium/services/ws/BUILD.gn +++ b/chromium/services/ws/BUILD.gn @@ -3,6 +3,7 @@ # found in the LICENSE file. import("//build/config/ui.gni") +import("//mojo/public/tools/bindings/mojom.gni") import("//testing/test.gni") import("//services/catalog/public/tools/catalog.gni") import("//services/service_manager/public/cpp/service.gni") @@ -18,9 +19,9 @@ component("lib") { "//ash:ash_unittests", ] public = [ - "gpu_interface_provider.h", "ids.h", "window_delegate_impl.h", + "window_manager_interface.h", "window_properties.h", "window_service.h", "window_service_delegate.h", @@ -69,6 +70,8 @@ component("lib") { "window_tree_factory.cc", ] + configs += [ "//build/config/compiler:wexit_time_destructors" ] + public_deps = [ "//components/discardable_memory/public/interfaces", "//components/viz/host", @@ -83,6 +86,11 @@ component("lib") { "//ui/wm/public", ] + deps = [ + "//services/ws/public/cpp", + "//services/ws/public/cpp/host", + ] + defines = [ "IS_WINDOW_SERVICE_IMPL" ] } @@ -94,6 +102,8 @@ source_set("host") { "host_context_factory.cc", ] + configs += [ "//build/config/compiler:wexit_time_destructors" ] + deps = [ ":lib", "//cc/mojo_embedder", @@ -113,6 +123,8 @@ static_library("test_support") { testonly = true sources = [ + "client_root_test_helper.cc", + "client_root_test_helper.h", "event_test_utils.cc", "event_test_utils.h", "server_window_test_helper.cc", @@ -143,6 +155,7 @@ static_library("test_support") { "//services/service_manager/public/cpp:service_test_support", "//services/ws/common", "//services/ws/public/cpp", + "//services/ws/public/cpp/host", "//services/ws/public/mojom", "//testing/gtest", "//ui/aura", @@ -175,12 +188,14 @@ source_set("tests") { testonly = true sources = [ + "client_root_unittest.cc", "drag_drop_delegate_unittest.cc", "embedding_unittest.cc", "focus_handler_unittest.cc", "injected_event_handler_unittest.cc", "screen_provider_unittest.cc", "server_window_unittest.cc", + "topmost_window_observer_unittest.cc", "user_activity_monitor_unittest.cc", "window_delegate_impl_unittest.cc", "window_service_observer_unittest.cc", @@ -191,10 +206,12 @@ source_set("tests") { deps = [ ":lib", + ":test_mojom", ":test_support", "//base", "//base/test:test_support", "//components/viz/common", + "//components/viz/test:test_support", "//mojo/public/cpp/bindings", "//services/service_manager/public/cpp:service_test_support", "//services/service_manager/public/cpp/test:test_support", @@ -202,6 +219,7 @@ source_set("tests") { "//services/ws/common:task_runner_test_base", "//services/ws/gpu_host:tests", "//services/ws/public/cpp", + "//services/ws/public/cpp/host", "//services/ws/public/mojom", "//testing/gtest", "//third_party/mesa_headers", @@ -224,3 +242,11 @@ service_manifest("manifest") { name = "ui" source = "manifest.json" } + +mojom("test_mojom") { + testonly = true + + sources = [ + "test_wm.mojom", + ] +} diff --git a/chromium/services/ws/DEPS b/chromium/services/ws/DEPS index d28697247be..68e667664fa 100644 --- a/chromium/services/ws/DEPS +++ b/chromium/services/ws/DEPS @@ -4,6 +4,7 @@ include_rules = [ "+components/discardable_memory/public", "+components/viz/common", "+components/viz/host", + "+components/viz/test", "+services/viz/public/interfaces", "+third_party/skia/include", "+ui", diff --git a/chromium/services/ws/OWNERS b/chromium/services/ws/OWNERS index 94028e3c531..48ad9796eac 100644 --- a/chromium/services/ws/OWNERS +++ b/chromium/services/ws/OWNERS @@ -12,3 +12,6 @@ per-file manifest.json=file://ipc/SECURITY_OWNERS per-file test_manifest.json=set noparent per-file test_manifest.json=file://ipc/SECURITY_OWNERS + +per-file *.mojom=set noparent +per-file *.mojom=file://ipc/SECURITY_OWNERS diff --git a/chromium/services/ws/client_change.cc b/chromium/services/ws/client_change.cc index d79a4e23875..3b91cc6a610 100644 --- a/chromium/services/ws/client_change.cc +++ b/chromium/services/ws/client_change.cc @@ -13,8 +13,9 @@ namespace ws { ClientChange::ClientChange(ClientChangeTracker* tracker, aura::Window* window, - ClientChangeType type) - : tracker_(tracker), type_(type) { + ClientChangeType type, + const void* property_key) + : tracker_(tracker), type_(type), property_key_(property_key) { DCHECK(!tracker_->current_change_); tracker_->current_change_ = this; if (window) diff --git a/chromium/services/ws/client_change.h b/chromium/services/ws/client_change.h index a4e0d27ff2b..fbbb07f4a30 100644 --- a/chromium/services/ws/client_change.h +++ b/chromium/services/ws/client_change.h @@ -39,24 +39,33 @@ enum class ClientChangeType { // the window. class COMPONENT_EXPORT(WINDOW_SERVICE) ClientChange { public: + // |property_key| is only used for changes of type kProperty. ClientChange(ClientChangeTracker* tracker, aura::Window* window, - ClientChangeType type); + ClientChangeType type, + const void* property_key = nullptr); ~ClientChange(); // The window the changes associated with. Is null if the window has been // destroyed during processing. aura::Window* window() { + return const_cast<aura::Window*>( + const_cast<const ClientChange*>(this)->window()); + } + + const aura::Window* window() const { return !window_tracker_.windows().empty() ? window_tracker_.windows()[0] : nullptr; } ClientChangeType type() const { return type_; } + const void* property_key() const { return property_key_; } private: ClientChangeTracker* tracker_; aura::WindowTracker window_tracker_; const ClientChangeType type_; + const void* property_key_; DISALLOW_COPY_AND_ASSIGN(ClientChange); }; diff --git a/chromium/services/ws/client_change_tracker.cc b/chromium/services/ws/client_change_tracker.cc index f2ff40f9188..c100a1fa941 100644 --- a/chromium/services/ws/client_change_tracker.cc +++ b/chromium/services/ws/client_change_tracker.cc @@ -12,10 +12,26 @@ ClientChangeTracker::ClientChangeTracker() = default; ClientChangeTracker::~ClientChangeTracker() = default; -bool ClientChangeTracker::IsProcessingChangeForWindow(aura::Window* window, - ClientChangeType type) { +bool ClientChangeTracker::IsProcessingChangeForWindow( + aura::Window* window, + ClientChangeType type) const { + return DoesCurrentChangeEqual(window, type, nullptr); +} + +bool ClientChangeTracker::IsProcessingPropertyChangeForWindow( + aura::Window* window, + const void* property_key) const { + return DoesCurrentChangeEqual(window, ClientChangeType::kProperty, + property_key); +} + +bool ClientChangeTracker::DoesCurrentChangeEqual( + aura::Window* window, + ClientChangeType type, + const void* property_key) const { return current_change_ && current_change_->window() == window && - current_change_->type() == type; + current_change_->type() == type && + current_change_->property_key() == property_key; } } // namespace ws diff --git a/chromium/services/ws/client_change_tracker.h b/chromium/services/ws/client_change_tracker.h index 2768acb28eb..c510ee44a85 100644 --- a/chromium/services/ws/client_change_tracker.h +++ b/chromium/services/ws/client_change_tracker.h @@ -29,11 +29,18 @@ class COMPONENT_EXPORT(WINDOW_SERVICE) ClientChangeTracker { ClientChangeTracker(); ~ClientChangeTracker(); - bool IsProcessingChangeForWindow(aura::Window* window, ClientChangeType type); + bool IsProcessingChangeForWindow(aura::Window* window, + ClientChangeType type) const; + bool IsProcessingPropertyChangeForWindow(aura::Window* window, + const void* property_key) const; private: friend class ClientChange; + bool DoesCurrentChangeEqual(aura::Window* window, + ClientChangeType type, + const void* property_key) const; + // Owned by the caller that created the ClientChange. This is set in // ClientChange's constructor and reset in the destructor. ClientChange* current_change_ = nullptr; diff --git a/chromium/services/ws/client_root.cc b/chromium/services/ws/client_root.cc index 18b8421f3a7..808f1042380 100644 --- a/chromium/services/ws/client_root.cc +++ b/chromium/services/ws/client_root.cc @@ -29,9 +29,6 @@ ClientRoot::ClientRoot(WindowTree* window_tree, window_->AddObserver(this); if (window_->GetHost()) window->GetHost()->AddObserver(this); - // TODO: wire up gfx::Insets() correctly below. See usage in - // aura::ClientSurfaceEmbedder for details. Insets here are used for - // guttering. client_surface_embedder_ = std::make_unique<aura::ClientSurfaceEmbedder>( window_, is_top_level, gfx::Insets()); // Ensure there is a valid LocalSurfaceId (if necessary). @@ -50,6 +47,13 @@ ClientRoot::~ClientRoot() { server_window->frame_sink_id()); } +void ClientRoot::SetClientAreaInsets(const gfx::Insets& client_area_insets) { + if (!is_top_level_) + return; + + client_surface_embedder_->SetClientAreaInsets(client_area_insets); +} + void ClientRoot::RegisterVizEmbeddingSupport() { // This function should only be called once. viz::HostFrameSinkManager* host_frame_sink_manager = @@ -90,6 +94,63 @@ void ClientRoot::UpdateLocalSurfaceIdIfNecessary() { } } +void ClientRoot::OnLocalSurfaceIdChanged() { + if (!ShouldAssignLocalSurfaceId()) + HandleBoundsOrScaleFactorChange(window_->bounds()); +} + +void ClientRoot::AttachChildFrameSinkId(ServerWindow* server_window) { + DCHECK(server_window->attached_frame_sink_id().is_valid()); + DCHECK(ServerWindow::GetMayBeNull(window_)->frame_sink_id().is_valid()); + viz::HostFrameSinkManager* host_frame_sink_manager = + window_->env()->context_factory_private()->GetHostFrameSinkManager(); + const viz::FrameSinkId& frame_sink_id = + server_window->attached_frame_sink_id(); + if (host_frame_sink_manager->IsFrameSinkIdRegistered(frame_sink_id)) { + host_frame_sink_manager->RegisterFrameSinkHierarchy( + ServerWindow::GetMayBeNull(window_)->frame_sink_id(), frame_sink_id); + } +} + +void ClientRoot::UnattachChildFrameSinkId(ServerWindow* server_window) { + DCHECK(server_window->attached_frame_sink_id().is_valid()); + DCHECK(ServerWindow::GetMayBeNull(window_)->frame_sink_id().is_valid()); + viz::HostFrameSinkManager* host_frame_sink_manager = + window_->env()->context_factory_private()->GetHostFrameSinkManager(); + const viz::FrameSinkId& root_frame_sink_id = + ServerWindow::GetMayBeNull(window_)->frame_sink_id(); + const viz::FrameSinkId& window_frame_sink_id = + server_window->attached_frame_sink_id(); + if (host_frame_sink_manager->IsFrameSinkHierarchyRegistered( + root_frame_sink_id, window_frame_sink_id)) { + host_frame_sink_manager->UnregisterFrameSinkHierarchy(root_frame_sink_id, + window_frame_sink_id); + } +} + +void ClientRoot::AttachChildFrameSinkIdRecursive(ServerWindow* server_window) { + if (server_window->attached_frame_sink_id().is_valid()) + AttachChildFrameSinkId(server_window); + + for (aura::Window* child : server_window->window()->children()) { + ServerWindow* child_server_window = ServerWindow::GetMayBeNull(child); + if (child_server_window->owning_window_tree() == window_tree_) + AttachChildFrameSinkIdRecursive(child_server_window); + } +} + +void ClientRoot::UnattachChildFrameSinkIdRecursive( + ServerWindow* server_window) { + if (server_window->attached_frame_sink_id().is_valid()) + UnattachChildFrameSinkId(server_window); + + for (aura::Window* child : server_window->window()->children()) { + ServerWindow* child_server_window = ServerWindow::GetMayBeNull(child); + if (child_server_window->owning_window_tree() == window_tree_) + UnattachChildFrameSinkIdRecursive(child_server_window); + } +} + void ClientRoot::UpdatePrimarySurfaceId() { UpdateLocalSurfaceIdIfNecessary(); ServerWindow* server_window = ServerWindow::GetMayBeNull(window_); @@ -127,8 +188,8 @@ void ClientRoot::HandleBoundsOrScaleFactorChange(const gfx::Rect& old_bounds) { void ClientRoot::OnWindowPropertyChanged(aura::Window* window, const void* key, intptr_t old) { - if (window_tree_->property_change_tracker_->IsProcessingChangeForWindow( - window, ClientChangeType::kProperty)) { + if (window_tree_->property_change_tracker_ + ->IsProcessingPropertyChangeForWindow(window, key)) { // Do not send notifications for changes intiated by the client. return; } @@ -159,6 +220,9 @@ void ClientRoot::OnWindowAddedToRootWindow(aura::Window* window) { DCHECK(window->GetHost()); window->GetHost()->AddObserver(this); CheckForScaleFactorChange(); + window_tree_->window_tree_client_->OnWindowDisplayChanged( + window_tree_->TransportIdForWindow(window), + window->GetHost()->GetDisplayId()); } void ClientRoot::OnWindowRemovingFromRootWindow(aura::Window* window, @@ -178,6 +242,8 @@ void ClientRoot::OnFirstSurfaceActivation( ServerWindow* server_window = ServerWindow::GetMayBeNull(window_); if (server_window->local_surface_id().has_value()) { DCHECK(!fallback_surface_info_); + if (!client_surface_embedder_->HasPrimarySurfaceId()) + UpdatePrimarySurfaceId(); client_surface_embedder_->SetFallbackSurfaceInfo(surface_info); } else { fallback_surface_info_ = std::make_unique<viz::SurfaceInfo>(surface_info); diff --git a/chromium/services/ws/client_root.h b/chromium/services/ws/client_root.h index 476a97b5eb0..44e630dbeeb 100644 --- a/chromium/services/ws/client_root.h +++ b/chromium/services/ws/client_root.h @@ -21,12 +21,17 @@ class ClientSurfaceEmbedder; class Window; } // namespace aura +namespace gfx { +class Insets; +} + namespace viz { class SurfaceInfo; } namespace ws { +class ServerWindow; class WindowTree; // WindowTree creates a ClientRoot for each window the client is embedded in. A @@ -42,6 +47,11 @@ class COMPONENT_EXPORT(WINDOW_SERVICE) ClientRoot ClientRoot(WindowTree* window_tree, aura::Window* window, bool is_top_level); ~ClientRoot() override; + // Called when the client area of the window changes. If the window is a + // top-level window, then this propagates the insets to the + // ClientSurfaceEmbedder. + void SetClientAreaInsets(const gfx::Insets& client_area_insets); + // Registers the necessary state needed for embedding in viz. void RegisterVizEmbeddingSupport(); @@ -49,7 +59,22 @@ class COMPONENT_EXPORT(WINDOW_SERVICE) ClientRoot bool is_top_level() const { return is_top_level_; } + // Called when the LocalSurfaceId of the embedder changes. + void OnLocalSurfaceIdChanged(); + + // Attaches/unattaches server_window->attached_frame_sink_id() to the + // HostFrameSinkManager. + void AttachChildFrameSinkId(ServerWindow* server_window); + void UnattachChildFrameSinkId(ServerWindow* server_window); + + // Recurses through all descendants with the same WindowTree calling + // AttachChildFrameSinkId()/UnattachChildFrameSinkId(). + void AttachChildFrameSinkIdRecursive(ServerWindow* server_window); + void UnattachChildFrameSinkIdRecursive(ServerWindow* server_window); + private: + friend class ClientRootTestHelper; + void UpdatePrimarySurfaceId(); // Returns true if the WindowService should assign the LocalSurfaceId. A value diff --git a/chromium/services/ws/client_root_test_helper.cc b/chromium/services/ws/client_root_test_helper.cc new file mode 100644 index 00000000000..7e660b0760a --- /dev/null +++ b/chromium/services/ws/client_root_test_helper.cc @@ -0,0 +1,20 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "services/ws/client_root_test_helper.h" + +#include "services/ws/client_root.h" + +namespace ws { + +ClientRootTestHelper::ClientRootTestHelper(ClientRoot* client_root) + : client_root_(client_root) {} + +ClientRootTestHelper::~ClientRootTestHelper() = default; + +aura::ClientSurfaceEmbedder* ClientRootTestHelper::GetClientSurfaceEmbedder() { + return client_root_->client_surface_embedder_.get(); +} + +} // namespace ws diff --git a/chromium/services/ws/client_root_test_helper.h b/chromium/services/ws/client_root_test_helper.h new file mode 100644 index 00000000000..7674df3b28b --- /dev/null +++ b/chromium/services/ws/client_root_test_helper.h @@ -0,0 +1,35 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef SERVICES_WS_CLIENT_ROOT_TEST_HELPER_H_ +#define SERVICES_WS_CLIENT_ROOT_TEST_HELPER_H_ + +#include "base/macros.h" +#include "ui/events/event.h" + +namespace aura { +class ClientSurfaceEmbedder; +} + +namespace ws { + +class ClientRoot; + +// Used for accessing private members of ServerWindow in tests. +class ClientRootTestHelper { + public: + explicit ClientRootTestHelper(ClientRoot* client_root); + ~ClientRootTestHelper(); + + aura::ClientSurfaceEmbedder* GetClientSurfaceEmbedder(); + + private: + ClientRoot* client_root_; + + DISALLOW_COPY_AND_ASSIGN(ClientRootTestHelper); +}; + +} // namespace ws + +#endif // SERVICES_WS_CLIENT_ROOT_TEST_HELPER_H_ diff --git a/chromium/services/ws/client_root_unittest.cc b/chromium/services/ws/client_root_unittest.cc new file mode 100644 index 00000000000..f1007498f73 --- /dev/null +++ b/chromium/services/ws/client_root_unittest.cc @@ -0,0 +1,89 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "services/ws/client_root.h" + +#include <string> + +#include "services/ws/public/cpp/property_type_converters.h" +#include "services/ws/public/mojom/window_manager.mojom.h" +#include "services/ws/window_service.h" +#include "services/ws/window_service_test_setup.h" +#include "services/ws/window_tree_test_helper.h" +#include "testing/gtest/include/gtest/gtest.h" +#include "ui/aura/client/aura_constants.h" +#include "ui/aura/mus/property_converter.h" +#include "ui/aura/window.h" +#include "ui/aura/window_observer.h" +#include "ui/aura/window_tracker.h" + +namespace ws { +namespace { + +// WindowObserver that changes a property (|aura::client::kNameKey|) from +// OnWindowPropertyChanged(). This mirrors ash changing a property when applying +// a property change from a client. +class CascadingPropertyTestHelper : public aura::WindowObserver { + public: + explicit CascadingPropertyTestHelper(aura::Window* window) : window_(window) { + window_->AddObserver(this); + } + ~CascadingPropertyTestHelper() override { window_->RemoveObserver(this); } + + bool did_set_property() const { return did_set_property_; } + + // WindowObserver: + void OnWindowPropertyChanged(aura::Window* window, + const void* key, + intptr_t old) override { + if (!did_set_property_) { + did_set_property_ = true; + window->SetProperty(aura::client::kNameKey, new std::string("TEST")); + } + } + + private: + aura::Window* window_; + bool did_set_property_ = false; + + DISALLOW_COPY_AND_ASSIGN(CascadingPropertyTestHelper); +}; + +// Verifies a property change that occurs while servicing a property change from +// the client results in notifying the client of the new property. +TEST(ClientRoot, CascadingPropertyChange) { + WindowServiceTestSetup setup; + aura::Window* top_level = + setup.window_tree_test_helper()->NewTopLevelWindow(); + ASSERT_TRUE(top_level); + setup.changes()->clear(); + CascadingPropertyTestHelper property_helper(top_level); + + // Apply a change from a client. + aura::PropertyConverter::PrimitiveType client_value = true; + std::vector<uint8_t> client_transport_value = + mojo::ConvertTo<std::vector<uint8_t>>(client_value); + setup.window_tree_test_helper()->SetWindowProperty( + top_level, mojom::WindowManager::kAlwaysOnTop_Property, + client_transport_value, 2); + + // CascadingPropertyTestHelper should have gotten the change *and* changed + // another property. + EXPECT_TRUE(property_helper.did_set_property()); + ASSERT_FALSE(setup.changes()->empty()); + + // The client should be notified of the new value. + EXPECT_EQ(CHANGE_TYPE_PROPERTY_CHANGED, (*setup.changes())[0].type); + EXPECT_EQ(mojom::WindowManager::kName_Property, + (*setup.changes())[0].property_key); + setup.changes()->erase(setup.changes()->begin()); + + // And the initial change should be acked with completed. + EXPECT_EQ("ChangeCompleted id=2 success=true", + SingleChangeToDescription(*setup.changes())); + EXPECT_TRUE(top_level->GetProperty(aura::client::kAlwaysOnTopKey)); +} + +} // namespace +} // namespace ws diff --git a/chromium/services/ws/drag_drop_delegate_unittest.cc b/chromium/services/ws/drag_drop_delegate_unittest.cc index f8d108403b2..1c52bf2d0ca 100644 --- a/chromium/services/ws/drag_drop_delegate_unittest.cc +++ b/chromium/services/ws/drag_drop_delegate_unittest.cc @@ -26,15 +26,6 @@ #include "ui/wm/core/default_screen_position_client.h" #include "url/gurl.h" -namespace ui { - -// An equal-to operator to make EXPECT_EQ happy. -bool operator==(const FileInfo& info1, const FileInfo& info2) { - return info1.path == info2.path && info1.display_name == info2.display_name; -} - -} // namespace ui - namespace ws { class DragDropDelegateTest : public testing::Test { diff --git a/chromium/services/ws/event_injector.cc b/chromium/services/ws/event_injector.cc index fcb79d8a672..194c048d185 100644 --- a/chromium/services/ws/event_injector.cc +++ b/chromium/services/ws/event_injector.cc @@ -79,18 +79,7 @@ EventInjector::EventAndHost EventInjector::DetermineEventAndHost( } event_and_host.window_tree_host = window_tree_host; - - // Map PointerEvents to Mouse/Touch event. This should be unnecessary. - // TODO: https://crbug.com/865781 - if (event->IsMousePointerEvent()) { - event_and_host.event = - std::make_unique<ui::MouseEvent>(*event->AsPointerEvent()); - } else if (event->IsTouchPointerEvent()) { - event_and_host.event = - std::make_unique<ui::TouchEvent>(*event->AsPointerEvent()); - } else { - event_and_host.event = std::move(event); - } + event_and_host.event = std::move(event); return event_and_host; } diff --git a/chromium/services/ws/event_test_utils.cc b/chromium/services/ws/event_test_utils.cc index 7d3173569d3..05c1300e689 100644 --- a/chromium/services/ws/event_test_utils.cc +++ b/chromium/services/ws/event_test_utils.cc @@ -6,6 +6,7 @@ #include "ui/events/event.h" #include "ui/events/event_constants.h" +#include "ui/events/event_utils.h" #include "ui/gfx/geometry/point.h" namespace ws { @@ -44,7 +45,7 @@ std::string EventToEventType(const ui::Event* event) { default: break; } - return "<unexpected-type>"; + return std::string(EventTypeName(event->type())); } std::string LocatedEventToEventTypeAndLocation(const ui::Event* event) { diff --git a/chromium/services/ws/focus_handler.cc b/chromium/services/ws/focus_handler.cc index cdebcb411e6..c7a2eda6c1f 100644 --- a/chromium/services/ws/focus_handler.cc +++ b/chromium/services/ws/focus_handler.cc @@ -12,6 +12,7 @@ #include "services/ws/window_service_delegate.h" #include "services/ws/window_tree.h" #include "ui/aura/client/focus_client.h" +#include "ui/wm/public/activation_client.h" namespace ws { @@ -72,9 +73,31 @@ bool FocusHandler::SetFocus(aura::Window* window) { ClientChange change(window_tree_->property_change_tracker_.get(), window, ClientChangeType::kFocus); + + // FocusController has a special API to reset focus inside the active window, + // which happens when a view requests focus (e.g. the find bar). + // https://crbug.com/880533 + wm::ActivationClient* activation_client = + wm::GetActivationClient(window->GetRootWindow()); + if (activation_client) { + aura::Window* active_window = activation_client->GetActiveWindow(); + if (active_window && active_window->Contains(window)) { + focus_client->ResetFocusWithinActiveWindow(window); + if (focus_client->GetFocusedWindow() != window) { + DVLOG(1) << "SetFocus failed (FocusClient::ResetFocusWithinActiveWindow" + << " failed for " << window->GetName() << ")"; + return false; + } + if (server_window) + server_window->set_focus_owner(window_tree_); + return true; + } + } + focus_client->FocusWindow(window); if (focus_client->GetFocusedWindow() != window) { - DVLOG(1) << "SetFocus failed (FocusClient::FocusWindow call failed)"; + DVLOG(1) << "SetFocus failed (FocusClient::FocusWindow call failed for " + << window->GetName() << ")"; return false; } diff --git a/chromium/services/ws/focus_handler_unittest.cc b/chromium/services/ws/focus_handler_unittest.cc index c7f33260962..266873aed16 100644 --- a/chromium/services/ws/focus_handler_unittest.cc +++ b/chromium/services/ws/focus_handler_unittest.cc @@ -105,6 +105,22 @@ TEST(FocusHandlerTest, FocusChild) { EXPECT_TRUE(setup.window_tree_test_helper()->SetFocus(window)); } +// Regression test for https://crbug.com/880533 +TEST(FocusHandlerTest, FocusChildOfActiveWindow) { + WindowServiceTestSetup setup; + aura::Window* top_level = + setup.window_tree_test_helper()->NewTopLevelWindow(); + top_level->Show(); + setup.focus_controller()->ActivateWindow(top_level); + EXPECT_EQ(top_level, setup.focus_controller()->GetActiveWindow()); + + aura::Window* child = setup.window_tree_test_helper()->NewWindow(); + top_level->AddChild(child); + child->Show(); + EXPECT_TRUE(setup.window_tree_test_helper()->SetFocus(child)); + EXPECT_TRUE(child->HasFocus()); +} + TEST(FocusHandlerTest, NotifyOnFocusChange) { WindowServiceTestSetup setup; aura::Window* top_level = diff --git a/chromium/services/ws/gpu_host/BUILD.gn b/chromium/services/ws/gpu_host/BUILD.gn index e55b6a20221..a7cb7eb1543 100644 --- a/chromium/services/ws/gpu_host/BUILD.gn +++ b/chromium/services/ws/gpu_host/BUILD.gn @@ -4,8 +4,6 @@ source_set("gpu_host") { sources = [ - "gpu_client.cc", - "gpu_client.h", "gpu_host.cc", "gpu_host.h", "gpu_host_delegate.h", @@ -14,12 +12,15 @@ source_set("gpu_host") { deps = [ "//base", "//components/discardable_memory/service", + "//components/viz/common", "//components/viz/host", "//components/viz/service/main", # TODO(sad): Temporary until GPU process split. "//gpu/command_buffer/client", "//gpu/command_buffer/client:gles2_interface", + "//gpu/command_buffer/service", "//gpu/ipc/client", "//gpu/ipc/common", + "//gpu/ipc/host", "//mojo/public/cpp/bindings", "//mojo/public/cpp/system", "//services/service_manager/public/cpp", @@ -50,18 +51,20 @@ source_set("gpu_host") { ] } -static_library("test_support") { +source_set("test_support") { testonly = true sources = [ - "test_gpu_host.cc", - "test_gpu_host.h", + "gpu_host_test_api.cc", + "gpu_host_test_api.h", ] deps = [ ":gpu_host", "//base", + "//components/viz/host", "//components/viz/test:test_support", + "//services/viz/privileged/interfaces", ] } @@ -74,10 +77,12 @@ source_set("tests") { deps = [ ":gpu_host", + ":test_support", "//base", "//base/test:test_config", "//base/test:test_support", "//components/discardable_memory/service", + "//components/viz/host", "//components/viz/service", "//components/viz/service/main", "//components/viz/test:test_support", diff --git a/chromium/services/ws/gpu_host/DEPS b/chromium/services/ws/gpu_host/DEPS index 9c86b463b29..988a58334aa 100644 --- a/chromium/services/ws/gpu_host/DEPS +++ b/chromium/services/ws/gpu_host/DEPS @@ -2,10 +2,13 @@ include_rules = [ "+base", "+components/viz/common", "+components/viz/host", + "+components/viz/test", "+gpu/command_buffer/client", + "+gpu/command_buffer/service/gpu_switches.h", "+gpu/config", "+gpu/ipc/client", "+gpu/ipc/common", + "+gpu/ipc/host", "+mojo/public", "+services/viz/privileged/interfaces", "+services/viz/public/interfaces", @@ -22,9 +25,6 @@ specific_include_rules = { "gpu_host_unittest.cc": [ "+components/viz/service/gl/gpu_service_impl.h", ], - "test_gpu_host.h": [ - "+components/viz/test", - ], ".*_(unit|pixel|perf)test.*\.cc": [ "+components/viz/test", ], diff --git a/chromium/services/ws/gpu_host/gpu_client.cc b/chromium/services/ws/gpu_host/gpu_client.cc deleted file mode 100644 index 0c9d9190fea..00000000000 --- a/chromium/services/ws/gpu_host/gpu_client.cc +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright 2017 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "services/ws/gpu_host/gpu_client.h" - -#include "components/viz/host/host_gpu_memory_buffer_manager.h" -#include "services/viz/privileged/interfaces/gl/gpu_service.mojom.h" - -namespace ws { -namespace gpu_host { - -GpuClient::GpuClient(int client_id, - gpu::GPUInfo* gpu_info, - gpu::GpuFeatureInfo* gpu_feature_info, - viz::HostGpuMemoryBufferManager* gpu_memory_buffer_manager, - viz::mojom::GpuService* gpu_service) - : client_id_(client_id), - gpu_info_(gpu_info), - gpu_feature_info_(gpu_feature_info), - gpu_memory_buffer_manager_(gpu_memory_buffer_manager), - gpu_service_(gpu_service), - weak_factory_(this) { - DCHECK(gpu_memory_buffer_manager_); - DCHECK(gpu_service_); -} - -GpuClient::~GpuClient() { - gpu_memory_buffer_manager_->DestroyAllGpuMemoryBufferForClient(client_id_); - if (!establish_callback_.is_null()) { - std::move(establish_callback_) - .Run(client_id_, mojo::ScopedMessagePipeHandle(), gpu::GPUInfo(), - gpu::GpuFeatureInfo()); - } -} - -void GpuClient::OnGpuChannelEstablished( - mojo::ScopedMessagePipeHandle channel_handle) { - base::ResetAndReturn(&establish_callback_) - .Run(client_id_, std::move(channel_handle), *gpu_info_, - *gpu_feature_info_); -} - -// mojom::Gpu overrides: -void GpuClient::EstablishGpuChannel(EstablishGpuChannelCallback callback) { - // TODO(sad): https://crbug.com/617415 figure out how to generate a meaningful - // tracing id. - const uint64_t client_tracing_id = 0; - constexpr bool is_gpu_host = false; - if (!establish_callback_.is_null()) { - std::move(establish_callback_) - .Run(client_id_, mojo::ScopedMessagePipeHandle(), gpu::GPUInfo(), - gpu::GpuFeatureInfo()); - } - establish_callback_ = std::move(callback); - const bool cache_shaders_on_disk = true; - gpu_service_->EstablishGpuChannel( - client_id_, client_tracing_id, is_gpu_host, cache_shaders_on_disk, - base::Bind(&GpuClient::OnGpuChannelEstablished, - weak_factory_.GetWeakPtr())); -} - -void GpuClient::CreateJpegDecodeAccelerator( - media::mojom::JpegDecodeAcceleratorRequest jda_request) { - gpu_service_->CreateJpegDecodeAccelerator(std::move(jda_request)); -} - -void GpuClient::CreateVideoEncodeAcceleratorProvider( - media::mojom::VideoEncodeAcceleratorProviderRequest request) { - gpu_service_->CreateVideoEncodeAcceleratorProvider(std::move(request)); -} - -void GpuClient::CreateGpuMemoryBuffer( - gfx::GpuMemoryBufferId id, - const gfx::Size& size, - gfx::BufferFormat format, - gfx::BufferUsage usage, - mojom::GpuMemoryBufferFactory::CreateGpuMemoryBufferCallback callback) { - gpu_memory_buffer_manager_->AllocateGpuMemoryBuffer( - id, client_id_, size, format, usage, gpu::kNullSurfaceHandle, - std::move(callback)); -} - -void GpuClient::DestroyGpuMemoryBuffer(gfx::GpuMemoryBufferId id, - const gpu::SyncToken& sync_token) { - gpu_memory_buffer_manager_->DestroyGpuMemoryBuffer(id, client_id_, - sync_token); -} - -void GpuClient::CreateGpuMemoryBufferFactory( - mojom::GpuMemoryBufferFactoryRequest request) { - gpu_memory_buffer_factory_bindings_.AddBinding(this, std::move(request)); -} - -} // namespace gpu_host -} // namespace ws diff --git a/chromium/services/ws/gpu_host/gpu_client.h b/chromium/services/ws/gpu_host/gpu_client.h deleted file mode 100644 index a4507d08b87..00000000000 --- a/chromium/services/ws/gpu_host/gpu_client.h +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2017 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef SERVICES_WS_GPU_HOST_GPU_CLIENT_H_ -#define SERVICES_WS_GPU_HOST_GPU_CLIENT_H_ - -#include "base/memory/weak_ptr.h" -#include "gpu/config/gpu_feature_info.h" -#include "gpu/config/gpu_info.h" -#include "mojo/public/cpp/bindings/binding_set.h" -#include "services/ws/public/mojom/gpu.mojom.h" - -namespace viz { -namespace mojom { -class GpuService; -} // namespace mojom - -class HostGpuMemoryBufferManager; -} // namespace viz - -namespace ws { -namespace gpu_host { - -namespace test { -class GpuHostTest; -} // namespace test - -// The implementation that relays requests from clients to the real -// service implementation in the GPU process over mojom.GpuService. -class GpuClient : public mojom::GpuMemoryBufferFactory, public mojom::Gpu { - public: - GpuClient(int client_id, - gpu::GPUInfo* gpu_info, - gpu::GpuFeatureInfo* gpu_feature_info, - viz::HostGpuMemoryBufferManager* gpu_memory_buffer_manager, - viz::mojom::GpuService* gpu_service); - ~GpuClient() override; - - private: - friend class test::GpuHostTest; - - // EstablishGpuChannelCallback: - void OnGpuChannelEstablished(mojo::ScopedMessagePipeHandle channel_handle); - - // mojom::GpuMemoryBufferFactory overrides: - void CreateGpuMemoryBuffer( - gfx::GpuMemoryBufferId id, - const gfx::Size& size, - gfx::BufferFormat format, - gfx::BufferUsage usage, - mojom::GpuMemoryBufferFactory::CreateGpuMemoryBufferCallback callback) - override; - void DestroyGpuMemoryBuffer(gfx::GpuMemoryBufferId id, - const gpu::SyncToken& sync_token) override; - - // mojom::Gpu overrides: - void CreateGpuMemoryBufferFactory( - mojom::GpuMemoryBufferFactoryRequest request) override; - void EstablishGpuChannel(EstablishGpuChannelCallback callback) override; - void CreateJpegDecodeAccelerator( - media::mojom::JpegDecodeAcceleratorRequest jda_request) override; - void CreateVideoEncodeAcceleratorProvider( - media::mojom::VideoEncodeAcceleratorProviderRequest vea_provider_request) - override; - - const int client_id_; - mojo::BindingSet<mojom::GpuMemoryBufferFactory> - gpu_memory_buffer_factory_bindings_; - - // The objects these pointers refer to are owned by the GpuHost object. - const gpu::GPUInfo* gpu_info_; - const gpu::GpuFeatureInfo* gpu_feature_info_; - viz::HostGpuMemoryBufferManager* gpu_memory_buffer_manager_; - viz::mojom::GpuService* gpu_service_; - EstablishGpuChannelCallback establish_callback_; - - base::WeakPtrFactory<GpuClient> weak_factory_; - - DISALLOW_COPY_AND_ASSIGN(GpuClient); -}; - -} // namespace gpu_host -} // namespace ws - -#endif // SERVICES_WS_GPU_HOST_GPU_CLIENT_H_ diff --git a/chromium/services/ws/gpu_host/gpu_host.cc b/chromium/services/ws/gpu_host/gpu_host.cc index 3cf44130f77..a95e96720a9 100644 --- a/chromium/services/ws/gpu_host/gpu_host.cc +++ b/chromium/services/ws/gpu_host/gpu_host.cc @@ -9,16 +9,22 @@ #include "base/run_loop.h" #include "base/threading/thread_task_runner_handle.h" #include "components/discardable_memory/service/discardable_shared_memory_manager.h" +#include "components/viz/common/frame_sinks/begin_frame_source.h" +#include "components/viz/common/switches.h" +#include "components/viz/host/gpu_client.h" +#include "components/viz/host/gpu_client_delegate.h" #include "components/viz/host/host_gpu_memory_buffer_manager.h" +#include "gpu/command_buffer/service/gpu_switches.h" #include "gpu/ipc/client/gpu_channel_host.h" #include "gpu/ipc/common/gpu_memory_buffer_impl_shared_memory.h" #include "gpu/ipc/common/gpu_memory_buffer_support.h" +#include "gpu/ipc/host/shader_disk_cache.h" #include "mojo/public/cpp/bindings/strong_binding.h" #include "mojo/public/cpp/system/buffer.h" #include "mojo/public/cpp/system/platform_handle.h" #include "services/service_manager/public/cpp/connector.h" +#include "services/viz/privileged/interfaces/viz_main.mojom.h" #include "services/viz/public/interfaces/constants.mojom.h" -#include "services/ws/gpu_host/gpu_client.h" #include "services/ws/gpu_host/gpu_host_delegate.h" #include "ui/gfx/buffer_format_util.h" @@ -44,45 +50,91 @@ bool HasSplitVizProcess() { return base::CommandLine::ForCurrentProcess()->HasSwitch(kEnableViz); } +class GpuClientDelegate : public viz::GpuClientDelegate { + public: + GpuClientDelegate(viz::GpuHostImpl* gpu_host_impl, + viz::HostGpuMemoryBufferManager* gpu_memory_buffer_manager); + ~GpuClientDelegate() override; + + // viz::GpuClientDelegate: + viz::GpuHostImpl* EnsureGpuHost() override; + viz::HostGpuMemoryBufferManager* GetGpuMemoryBufferManager() override; + + private: + viz::GpuHostImpl* gpu_host_impl_; + viz::HostGpuMemoryBufferManager* gpu_memory_buffer_manager_; + + DISALLOW_COPY_AND_ASSIGN(GpuClientDelegate); +}; + +GpuClientDelegate::GpuClientDelegate( + viz::GpuHostImpl* gpu_host_impl, + viz::HostGpuMemoryBufferManager* gpu_memory_buffer_manager) + : gpu_host_impl_(gpu_host_impl), + gpu_memory_buffer_manager_(gpu_memory_buffer_manager) {} + +GpuClientDelegate::~GpuClientDelegate() = default; + +viz::GpuHostImpl* GpuClientDelegate::EnsureGpuHost() { + return gpu_host_impl_; +} + +viz::HostGpuMemoryBufferManager* +GpuClientDelegate::GetGpuMemoryBufferManager() { + return gpu_memory_buffer_manager_; +} + } // namespace -DefaultGpuHost::DefaultGpuHost( - GpuHostDelegate* delegate, - service_manager::Connector* connector, - discardable_memory::DiscardableSharedMemoryManager* - discardable_shared_memory_manager) +GpuHost::GpuHost(GpuHostDelegate* delegate, + service_manager::Connector* connector, + discardable_memory::DiscardableSharedMemoryManager* + discardable_shared_memory_manager) : delegate_(delegate), + discardable_shared_memory_manager_(discardable_shared_memory_manager), next_client_id_(kInternalGpuChannelClientId + 1), main_thread_task_runner_(base::ThreadTaskRunnerHandle::Get()), - gpu_host_binding_(this), gpu_thread_("GpuThread") { - DCHECK(discardable_shared_memory_manager); + DCHECK(discardable_shared_memory_manager_); - auto request = MakeRequest(&viz_main_); - if (connector && HasSplitVizProcess()) { - connector->BindInterface(viz::mojom::kVizServiceName, std::move(request)); - } else { + viz::GpuHostImpl::InitFontRenderParams( + gfx::GetFontRenderParams(gfx::FontRenderParamsQuery(), nullptr)); + + bool in_process = !connector || !HasSplitVizProcess(); + + viz::mojom::VizMainPtr viz_main_ptr; + if (in_process) { // TODO(crbug.com/620927): This should be removed once ozone-mojo is done. gpu_thread_.Start(); gpu_thread_.task_runner()->PostTask( - FROM_HERE, base::BindOnce(&DefaultGpuHost::InitializeVizMain, - base::Unretained(this), - base::Passed(MakeRequest(&viz_main_)))); + FROM_HERE, + base::BindOnce(&GpuHost::InitializeVizMain, base::Unretained(this), + base::Passed(MakeRequest(&viz_main_ptr)))); + } else { + // Currently, GPU is only run in process in OOP-Ash. + NOTREACHED(); } - discardable_memory::mojom::DiscardableSharedMemoryManagerPtr - discardable_manager_ptr; - service_manager::BindSourceInfo source_info; - discardable_shared_memory_manager->Bind( - mojo::MakeRequest(&discardable_manager_ptr), source_info); - - viz::mojom::GpuHostPtr gpu_host_proxy; - gpu_host_binding_.Bind(mojo::MakeRequest(&gpu_host_proxy)); - viz_main_->CreateGpuService( - MakeRequest(&gpu_service_), std::move(gpu_host_proxy), - std::move(discardable_manager_ptr), mojo::ScopedSharedBufferHandle(), - gfx::GetFontRenderParams(gfx::FontRenderParamsQuery(), nullptr) - .subpixel_rendering); + viz::GpuHostImpl::InitParams params; + params.restart_id = viz::BeginFrameSource::kNotRestartableId + 1; + params.in_process = in_process; + params.disable_gpu_shader_disk_cache = + base::CommandLine::ForCurrentProcess()->HasSwitch( + switches::kDisableGpuShaderDiskCache); + params.deadline_to_synchronize_surfaces = + switches::GetDeadlineToSynchronizeSurfaces(); + params.main_thread_task_runner = main_thread_task_runner_; + gpu_host_impl_ = std::make_unique<viz::GpuHostImpl>( + this, std::make_unique<viz::VizMainWrapper>(std::move(viz_main_ptr)), + std::move(params)); + +#if defined(OS_WIN) + // For OS_WIN the process id for GPU is needed. Using GetCurrentProcessId() + // only works with in-process GPU, which is fine because GpuHost isn't used + // outside of tests. + gpu_host_impl_->OnProcessLaunched(::GetCurrentProcessId()); +#endif + gpu_memory_buffer_manager_ = std::make_unique<viz::HostGpuMemoryBufferManager>( base::BindRepeating( @@ -90,128 +142,131 @@ DefaultGpuHost::DefaultGpuHost( base::OnceClosure connection_error_handler) { return gpu_service; }, - gpu_service_.get()), + gpu_host_impl_->gpu_service()), next_client_id_++, std::make_unique<gpu::GpuMemoryBufferSupport>(), main_thread_task_runner_); + + shader_cache_factory_ = std::make_unique<gpu::ShaderCacheFactory>(); } -DefaultGpuHost::~DefaultGpuHost() { +GpuHost::~GpuHost() { // TODO(crbug.com/620927): This should be removed once ozone-mojo is done. if (gpu_thread_.IsRunning()) { // Stop() will return after |viz_main_impl_| has been destroyed. gpu_thread_.task_runner()->PostTask( - FROM_HERE, base::BindOnce(&DefaultGpuHost::DestroyVizMain, - base::Unretained(this))); + FROM_HERE, + base::BindOnce(&GpuHost::DestroyVizMain, base::Unretained(this))); gpu_thread_.Stop(); } -} -void DefaultGpuHost::Shutdown() { - gpu_service_.reset(); - gpu_bindings_.CloseAllBindings(); + viz::GpuHostImpl::ResetFontRenderParams(); } -void DefaultGpuHost::Add(mojom::GpuRequest request) { - AddInternal(std::move(request)); +void GpuHost::CreateFrameSinkManager( + viz::mojom::FrameSinkManagerRequest request, + viz::mojom::FrameSinkManagerClientPtrInfo client) { + gpu_host_impl_->ConnectFrameSinkManager(std::move(request), + std::move(client)); } -void DefaultGpuHost::OnAcceleratedWidgetAvailable( - gfx::AcceleratedWidget widget) { -#if defined(OS_WIN) - gfx::RenderingWindowManager::GetInstance()->RegisterParent(widget); -#endif -} +void GpuHost::Shutdown() { + gpu_host_impl_.reset(); -void DefaultGpuHost::OnAcceleratedWidgetDestroyed( - gfx::AcceleratedWidget widget) { -#if defined(OS_WIN) - gfx::RenderingWindowManager::GetInstance()->UnregisterParent(widget); -#endif + gpu_clients_.clear(); } -void DefaultGpuHost::CreateFrameSinkManager( - viz::mojom::FrameSinkManagerParamsPtr params) { - viz_main_->CreateFrameSinkManager(std::move(params)); +void GpuHost::Add(mojom::GpuRequest request) { + const int client_id = next_client_id_++; + const uint64_t client_tracing_id = 0; + auto client = std::make_unique<viz::GpuClient>( + std::make_unique<GpuClientDelegate>(gpu_host_impl_.get(), + gpu_memory_buffer_manager_.get()), + client_id, client_tracing_id, main_thread_task_runner_); + client->Add(std::move(request)); + gpu_clients_.push_back(std::move(client)); } #if defined(OS_CHROMEOS) -void DefaultGpuHost::AddArc(mojom::ArcRequest request) { - arc_bindings_.AddBinding(std::make_unique<ArcClient>(gpu_service_.get()), - std::move(request)); +void GpuHost::AddArc(mojom::ArcRequest request) { + arc_bindings_.AddBinding( + std::make_unique<ArcClient>(gpu_host_impl_->gpu_service()), + std::move(request)); } #endif // defined(OS_CHROMEOS) -GpuClient* DefaultGpuHost::AddInternal(mojom::GpuRequest request) { - auto client(std::make_unique<GpuClient>( - next_client_id_++, &gpu_info_, &gpu_feature_info_, - gpu_memory_buffer_manager_.get(), gpu_service_.get())); - GpuClient* client_ref = client.get(); - gpu_bindings_.AddBinding(std::move(client), std::move(request)); - return client_ref; -} - -void DefaultGpuHost::OnBadMessageFromGpu() { +void GpuHost::OnBadMessageFromGpu() { // TODO(sad): Received some unexpected message from the gpu process. We // should kill the process and restart it. NOTIMPLEMENTED(); } -void DefaultGpuHost::InitializeVizMain(viz::mojom::VizMainRequest request) { +void GpuHost::InitializeVizMain(viz::mojom::VizMainRequest request) { viz::VizMainImpl::ExternalDependencies deps; deps.create_display_compositor = true; viz_main_impl_ = std::make_unique<viz::VizMainImpl>(nullptr, std::move(deps)); viz_main_impl_->Bind(std::move(request)); } -void DefaultGpuHost::DestroyVizMain() { +void GpuHost::DestroyVizMain() { DCHECK(viz_main_impl_); viz_main_impl_.reset(); } -void DefaultGpuHost::DidInitialize(const gpu::GPUInfo& gpu_info, - const gpu::GpuFeatureInfo& gpu_feature_info, - const base::Optional<gpu::GPUInfo>&, - const base::Optional<gpu::GpuFeatureInfo>&) { +gpu::GPUInfo GpuHost::GetGPUInfo() const { + return gpu_info_; +} + +gpu::GpuFeatureInfo GpuHost::GetGpuFeatureInfo() const { + return gpu_feature_info_; +} + +void GpuHost::DidInitialize( + const gpu::GPUInfo& gpu_info, + const gpu::GpuFeatureInfo& gpu_feature_info, + const base::Optional<gpu::GPUInfo>& gpu_info_for_hardware_gpu, + const base::Optional<gpu::GpuFeatureInfo>& + gpu_feature_info_for_hardware_gpu) { gpu_info_ = gpu_info; gpu_feature_info_ = gpu_feature_info; delegate_->OnGpuServiceInitialized(); } -void DefaultGpuHost::DidFailInitialize() {} +void GpuHost::DidFailInitialize() {} -void DefaultGpuHost::DidCreateContextSuccessfully() {} +void GpuHost::DidCreateContextSuccessfully() {} -void DefaultGpuHost::DidCreateOffscreenContext(const GURL& url) {} +void GpuHost::BlockDomainFrom3DAPIs(const GURL& url, gpu::DomainGuilt guilt) {} -void DefaultGpuHost::DidDestroyOffscreenContext(const GURL& url) {} +void GpuHost::DisableGpuCompositing() {} -void DefaultGpuHost::DidDestroyChannel(int32_t client_id) {} +bool GpuHost::GpuAccessAllowed() const { + return true; +} -void DefaultGpuHost::DidLoseContext(bool offscreen, - gpu::error::ContextLostReason reason, - const GURL& active_url) {} +gpu::ShaderCacheFactory* GpuHost::GetShaderCacheFactory() { + return shader_cache_factory_.get(); +} -void DefaultGpuHost::DisableGpuCompositing() {} +void GpuHost::RecordLogMessage(int32_t severity, + const std::string& header, + const std::string& message) {} -void DefaultGpuHost::SetChildSurface(gpu::SurfaceHandle parent, - gpu::SurfaceHandle child) { -#if defined(OS_WIN) - // Using GetCurrentProcessId() only works with in-process GPU, which is fine - // because DefaultGpuHost isn't used outside of tests. - gfx::RenderingWindowManager::GetInstance()->RegisterChild( - parent, child, /*expected_child_process_id=*/::GetCurrentProcessId()); -#else +void GpuHost::BindDiscardableMemoryRequest( + discardable_memory::mojom::DiscardableSharedMemoryManagerRequest request) { + service_manager::BindSourceInfo source_info; + discardable_shared_memory_manager_->Bind(std::move(request), source_info); +} + +void GpuHost::BindInterface(const std::string& interface_name, + mojo::ScopedMessagePipeHandle interface_pipe) { NOTREACHED(); -#endif } -void DefaultGpuHost::StoreShaderToDisk(int32_t client_id, - const std::string& key, - const std::string& shader) {} +#if defined(USE_OZONE) +void GpuHost::TerminateGpuProcess(const std::string& message) {} -void DefaultGpuHost::RecordLogMessage(int32_t severity, - const std::string& header, - const std::string& message) {} +void GpuHost::SendGpuProcessMessage(IPC::Message* message) {} +#endif } // namespace gpu_host } // namespace ws diff --git a/chromium/services/ws/gpu_host/gpu_host.h b/chromium/services/ws/gpu_host/gpu_host.h index 0ff47fe662b..d87738e8da7 100644 --- a/chromium/services/ws/gpu_host/gpu_host.h +++ b/chromium/services/ws/gpu_host/gpu_host.h @@ -5,142 +5,127 @@ #ifndef SERVICES_WS_GPU_HOST_GPU_HOST_H_ #define SERVICES_WS_GPU_HOST_GPU_HOST_H_ -#include "base/single_thread_task_runner.h" #include "base/threading/thread.h" #include "build/build_config.h" +#include "components/viz/host/gpu_host_impl.h" #include "components/viz/service/main/viz_main_impl.h" #include "gpu/config/gpu_feature_info.h" #include "gpu/config/gpu_info.h" -#include "gpu/ipc/client/gpu_channel_host.h" -#include "mojo/public/cpp/bindings/binding_set.h" -#include "mojo/public/cpp/bindings/interface_request.h" #include "mojo/public/cpp/bindings/strong_binding_set.h" -#include "services/viz/privileged/interfaces/gl/gpu_host.mojom.h" -#include "services/viz/privileged/interfaces/gl/gpu_service.mojom.h" #include "services/ws/public/mojom/gpu.mojom.h" #if defined(OS_CHROMEOS) #include "services/ws/public/mojom/arc.mojom.h" #endif // defined(OS_CHROMEOS) +namespace base { +class SingleThreadTaskRunner; +} + namespace discardable_memory { class DiscardableSharedMemoryManager; } +namespace gpu { +class ShaderCacheFactory; +} + namespace service_manager { class Connector; } namespace viz { +class GpuClient; +class GpuHostImpl; class HostGpuMemoryBufferManager; } namespace ws { namespace gpu_host { - -class GpuClient; - -namespace test { -class GpuHostTest; -} // namespace test - class GpuHostDelegate; // GpuHost sets up connection from clients to the real service implementation in // the GPU process. -class GpuHost { +class GpuHost : public viz::GpuHostImpl::Delegate { public: - GpuHost() = default; - virtual ~GpuHost() = default; + GpuHost(GpuHostDelegate* delegate, + service_manager::Connector* connector, + discardable_memory::DiscardableSharedMemoryManager* + discardable_shared_memory_manager); + ~GpuHost() override; - virtual void Add(mojom::GpuRequest request) = 0; - virtual void OnAcceleratedWidgetAvailable(gfx::AcceleratedWidget widget) = 0; - virtual void OnAcceleratedWidgetDestroyed(gfx::AcceleratedWidget widget) = 0; - - // Requests a viz::mojom::FrameSinkManager interface from viz. - virtual void CreateFrameSinkManager( - viz::mojom::FrameSinkManagerParamsPtr params) = 0; - -#if defined(OS_CHROMEOS) - virtual void AddArc(mojom::ArcRequest request) = 0; -#endif // defined(OS_CHROMEOS) -}; - -class DefaultGpuHost : public GpuHost, public viz::mojom::GpuHost { - public: - DefaultGpuHost(GpuHostDelegate* delegate, - service_manager::Connector* connector, - discardable_memory::DiscardableSharedMemoryManager* - discardable_shared_memory_manager); - ~DefaultGpuHost() override; + void CreateFrameSinkManager(viz::mojom::FrameSinkManagerRequest request, + viz::mojom::FrameSinkManagerClientPtrInfo client); void Shutdown(); - // GpuHost: - void Add(mojom::GpuRequest request) override; - void OnAcceleratedWidgetAvailable(gfx::AcceleratedWidget widget) override; - void OnAcceleratedWidgetDestroyed(gfx::AcceleratedWidget widget) override; - void CreateFrameSinkManager( - viz::mojom::FrameSinkManagerParamsPtr params) override; + void Add(mojom::GpuRequest request); + #if defined(OS_CHROMEOS) - void AddArc(mojom::ArcRequest request) override; + void AddArc(mojom::ArcRequest request); #endif // defined(OS_CHROMEOS) private: - friend class test::GpuHostTest; + friend class GpuHostTestApi; - GpuClient* AddInternal(mojom::GpuRequest request); void OnBadMessageFromGpu(); // TODO(crbug.com/611505): this goes away after the gpu process split in mus. void InitializeVizMain(viz::mojom::VizMainRequest request); void DestroyVizMain(); - // viz::mojom::GpuHost: - void DidInitialize(const gpu::GPUInfo& gpu_info, - const gpu::GpuFeatureInfo& gpu_feature_info, - const base::Optional<gpu::GPUInfo>&, - const base::Optional<gpu::GpuFeatureInfo>&) override; + // viz::GpuHostImpl::Delegate: + gpu::GPUInfo GetGPUInfo() const override; + gpu::GpuFeatureInfo GetGpuFeatureInfo() const override; + void DidInitialize( + const gpu::GPUInfo& gpu_info, + const gpu::GpuFeatureInfo& gpu_feature_info, + const base::Optional<gpu::GPUInfo>& gpu_info_for_hardware_gpu, + const base::Optional<gpu::GpuFeatureInfo>& + gpu_feature_info_for_hardware_gpu) override; void DidFailInitialize() override; void DidCreateContextSuccessfully() override; - void DidCreateOffscreenContext(const GURL& url) override; - void DidDestroyOffscreenContext(const GURL& url) override; - void DidDestroyChannel(int32_t client_id) override; - void DidLoseContext(bool offscreen, - gpu::error::ContextLostReason reason, - const GURL& active_url) override; + void BlockDomainFrom3DAPIs(const GURL& url, gpu::DomainGuilt guilt) override; void DisableGpuCompositing() override; - void SetChildSurface(gpu::SurfaceHandle parent, - gpu::SurfaceHandle child) override; - void StoreShaderToDisk(int32_t client_id, - const std::string& key, - const std::string& shader) override; + bool GpuAccessAllowed() const override; + gpu::ShaderCacheFactory* GetShaderCacheFactory() override; void RecordLogMessage(int32_t severity, const std::string& header, const std::string& message) override; + void BindDiscardableMemoryRequest( + discardable_memory::mojom::DiscardableSharedMemoryManagerRequest request) + override; + void BindInterface(const std::string& interface_name, + mojo::ScopedMessagePipeHandle interface_pipe) override; +#if defined(USE_OZONE) + void TerminateGpuProcess(const std::string& message) override; + void SendGpuProcessMessage(IPC::Message* message) override; +#endif GpuHostDelegate* const delegate_; + discardable_memory::DiscardableSharedMemoryManager* + discardable_shared_memory_manager_; int32_t next_client_id_; scoped_refptr<base::SingleThreadTaskRunner> main_thread_task_runner_; - viz::mojom::GpuServicePtr gpu_service_; - mojo::Binding<viz::mojom::GpuHost> gpu_host_binding_; + std::unique_ptr<viz::GpuHostImpl> gpu_host_impl_; gpu::GPUInfo gpu_info_; gpu::GpuFeatureInfo gpu_feature_info_; std::unique_ptr<viz::HostGpuMemoryBufferManager> gpu_memory_buffer_manager_; - viz::mojom::VizMainPtr viz_main_; + std::unique_ptr<gpu::ShaderCacheFactory> shader_cache_factory_; + + std::vector<std::unique_ptr<viz::GpuClient>> gpu_clients_; // TODO(crbug.com/620927): This should be removed once ozone-mojo is done. base::Thread gpu_thread_; std::unique_ptr<viz::VizMainImpl> viz_main_impl_; - mojo::StrongBindingSet<mojom::Gpu> gpu_bindings_; #if defined(OS_CHROMEOS) mojo::StrongBindingSet<mojom::Arc> arc_bindings_; #endif // defined(OS_CHROMEOS) - DISALLOW_COPY_AND_ASSIGN(DefaultGpuHost); + DISALLOW_COPY_AND_ASSIGN(GpuHost); }; } // namespace gpu_host diff --git a/chromium/services/ws/gpu_host/gpu_host_test_api.cc b/chromium/services/ws/gpu_host/gpu_host_test_api.cc new file mode 100644 index 00000000000..a89f9bf026d --- /dev/null +++ b/chromium/services/ws/gpu_host/gpu_host_test_api.cc @@ -0,0 +1,32 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "services/ws/gpu_host/gpu_host_test_api.h" + +#include <algorithm> + +#include "components/viz/host/gpu_client.h" +#include "components/viz/test/gpu_host_impl_test_api.h" +#include "services/ws/gpu_host/gpu_host.h" + +namespace ws { +namespace gpu_host { + +GpuHostTestApi::GpuHostTestApi(GpuHost* gpu_host) : gpu_host_(gpu_host) {} + +GpuHostTestApi::~GpuHostTestApi() = default; + +void GpuHostTestApi::SetGpuService(viz::mojom::GpuServicePtr gpu_service) { + return viz::GpuHostImplTestApi(gpu_host_->gpu_host_impl_.get()) + .SetGpuService(std::move(gpu_service)); +} + +base::WeakPtr<viz::GpuClient> GpuHostTestApi::GetLastGpuClient() { + if (gpu_host_->gpu_clients_.empty()) + return nullptr; + return gpu_host_->gpu_clients_.back()->GetWeakPtr(); +} + +} // namespace gpu_host +} // namespace ws diff --git a/chromium/services/ws/gpu_host/gpu_host_test_api.h b/chromium/services/ws/gpu_host/gpu_host_test_api.h new file mode 100644 index 00000000000..c4a5f609403 --- /dev/null +++ b/chromium/services/ws/gpu_host/gpu_host_test_api.h @@ -0,0 +1,36 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef SERVICES_WS_GPU_HOST_GPU_HOST_TEST_API_H_ +#define SERVICES_WS_GPU_HOST_GPU_HOST_TEST_API_H_ + +#include "base/memory/weak_ptr.h" +#include "services/viz/privileged/interfaces/gl/gpu_service.mojom.h" + +namespace viz { +class GpuClient; +} + +namespace ws { +namespace gpu_host { +class GpuHost; + +class GpuHostTestApi { + public: + GpuHostTestApi(GpuHost* gpu_host); + ~GpuHostTestApi(); + + void SetGpuService(viz::mojom::GpuServicePtr gpu_service); + base::WeakPtr<viz::GpuClient> GetLastGpuClient(); + + private: + GpuHost* gpu_host_; + + DISALLOW_COPY_AND_ASSIGN(GpuHostTestApi); +}; + +} // namespace gpu_host +} // namespace ws + +#endif // SERVICES_WS_GPU_HOST_GPU_HOST_TEST_API_H_ diff --git a/chromium/services/ws/gpu_host/gpu_host_unittest.cc b/chromium/services/ws/gpu_host/gpu_host_unittest.cc index c2849fa8a88..a7583996ab1 100644 --- a/chromium/services/ws/gpu_host/gpu_host_unittest.cc +++ b/chromium/services/ws/gpu_host/gpu_host_unittest.cc @@ -9,10 +9,12 @@ #include "base/message_loop/message_loop.h" #include "base/single_thread_task_runner.h" #include "components/discardable_memory/service/discardable_shared_memory_manager.h" +#include "components/viz/host/gpu_client.h" #include "components/viz/service/gl/gpu_service_impl.h" +#include "components/viz/test/gpu_host_impl_test_api.h" #include "gpu/config/gpu_info.h" -#include "services/ws/gpu_host/gpu_client.h" #include "services/ws/gpu_host/gpu_host_delegate.h" +#include "services/ws/gpu_host/gpu_host_test_api.h" #include "services/ws/public/mojom/gpu.mojom.h" #include "testing/gtest/include/gtest/gtest.h" #include "ui/gl/init/gl_factory.h" @@ -36,7 +38,7 @@ class TestGpuHostDelegate : public GpuHostDelegate { }; // Test implementation of GpuService. For testing behaviour of calls made by -// GpuClient +// viz::GpuClient. class TestGpuService : public viz::GpuServiceImpl { public: explicit TestGpuService( @@ -72,9 +74,7 @@ class GpuHostTest : public testing::Test { io_thread_.Stop(); } - GpuHost* gpu_host() { return gpu_host_.get(); } - - base::WeakPtr<GpuClient> AddGpuClient(); + base::WeakPtr<viz::GpuClient> AddGpuClient(); void DestroyHost(); // testing::Test @@ -84,22 +84,20 @@ class GpuHostTest : public testing::Test { private: base::MessageLoop message_loop_; - base::WeakPtr<GpuClient> client_ref_; - base::Thread io_thread_; TestGpuHostDelegate gpu_host_delegate_; discardable_memory::DiscardableSharedMemoryManager discardable_memory_manager_; std::unique_ptr<TestGpuService> gpu_service_; viz::mojom::GpuServicePtr gpu_service_ptr_; - std::unique_ptr<DefaultGpuHost> gpu_host_; + std::unique_ptr<GpuHost> gpu_host_; DISALLOW_COPY_AND_ASSIGN(GpuHostTest); }; -base::WeakPtr<GpuClient> GpuHostTest::AddGpuClient() { - GpuClient* client = gpu_host_->AddInternal(mojom::GpuRequest()); - return client->weak_factory_.GetWeakPtr(); +base::WeakPtr<viz::GpuClient> GpuHostTest::AddGpuClient() { + gpu_host_->Add(mojom::GpuRequest()); + return GpuHostTestApi(gpu_host_.get()).GetLastGpuClient(); } void GpuHostTest::DestroyHost() { @@ -108,10 +106,10 @@ void GpuHostTest::DestroyHost() { void GpuHostTest::SetUp() { testing::Test::SetUp(); - gpu_host_ = std::make_unique<DefaultGpuHost>(&gpu_host_delegate_, nullptr, - &discardable_memory_manager_); + gpu_host_ = std::make_unique<GpuHost>(&gpu_host_delegate_, nullptr, + &discardable_memory_manager_); gpu_service_->Bind(mojo::MakeRequest(&gpu_service_ptr_)); - gpu_host_->gpu_service_ = std::move(gpu_service_ptr_); + GpuHostTestApi(gpu_host_.get()).SetGpuService(std::move(gpu_service_ptr_)); } void GpuHostTest::TearDown() { @@ -120,18 +118,18 @@ void GpuHostTest::TearDown() { testing::Test::TearDown(); } -// Tests to verify, that if a GpuHost is deleted before GpuClient receives a -// callback, that GpuClient is torn down and does not attempt to use GpuInfo -// after deletion. This should not crash on asan-builds. +// Tests to verify, that if a GpuHost is deleted before viz::GpuClient receives +// a callback, that viz::GpuClient is torn down and does not attempt to use +// GpuInfo after deletion. This should not crash on asan-builds. TEST_F(GpuHostTest, GpuClientDestructionOrder) { - base::WeakPtr<GpuClient> client_ref = AddGpuClient(); + base::WeakPtr<viz::GpuClient> client_ref = AddGpuClient(); EXPECT_NE(nullptr, client_ref); DestroyHost(); EXPECT_EQ(nullptr, client_ref); } TEST_F(GpuHostTest, GpuClientDestroyedWhileChannelRequestInFlight) { - base::WeakPtr<GpuClient> client_ref = AddGpuClient(); + base::WeakPtr<viz::GpuClient> client_ref = AddGpuClient(); mojom::Gpu* gpu = client_ref.get(); bool callback_called = false; gpu->EstablishGpuChannel( diff --git a/chromium/services/ws/gpu_host/test_gpu_host.cc b/chromium/services/ws/gpu_host/test_gpu_host.cc deleted file mode 100644 index c825c3618be..00000000000 --- a/chromium/services/ws/gpu_host/test_gpu_host.cc +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright 2017 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "services/ws/gpu_host/test_gpu_host.h" - -namespace ws { -namespace gpu_host { - -TestGpuHost::TestGpuHost() = default; - -TestGpuHost::~TestGpuHost() = default; - -void TestGpuHost::CreateFrameSinkManager( - viz::mojom::FrameSinkManagerParamsPtr params) { - frame_sink_manager_ = std::make_unique<viz::TestFrameSinkManagerImpl>(); - viz::mojom::FrameSinkManagerClientPtr client( - std::move(params->frame_sink_manager_client)); - frame_sink_manager_->BindRequest(std::move(params->frame_sink_manager), - std::move(client)); -} - -} // namespace gpu_host -} // namespace ws diff --git a/chromium/services/ws/gpu_host/test_gpu_host.h b/chromium/services/ws/gpu_host/test_gpu_host.h deleted file mode 100644 index eb04e1ba186..00000000000 --- a/chromium/services/ws/gpu_host/test_gpu_host.h +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2017 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef SERVICES_WS_GPU_HOST_TEST_GPU_HOST_H_ -#define SERVICES_WS_GPU_HOST_TEST_GPU_HOST_H_ - -#include "base/macros.h" -#include "components/viz/test/test_frame_sink_manager.h" -#include "services/ws/gpu_host/gpu_host.h" - -namespace ws { -namespace gpu_host { - -class TestGpuHost : public gpu_host::GpuHost { - public: - TestGpuHost(); - ~TestGpuHost() override; - - private: - void Add(mojom::GpuRequest request) override {} - void OnAcceleratedWidgetAvailable(gfx::AcceleratedWidget widget) override {} - void OnAcceleratedWidgetDestroyed(gfx::AcceleratedWidget widget) override {} - void CreateFrameSinkManager( - viz::mojom::FrameSinkManagerParamsPtr params) override; -#if defined(OS_CHROMEOS) - void AddArc(mojom::ArcRequest request) override {} -#endif // defined(OS_CHROMEOS) - - std::unique_ptr<viz::TestFrameSinkManagerImpl> frame_sink_manager_; - - DISALLOW_COPY_AND_ASSIGN(TestGpuHost); -}; - -} // namespace gpu_host -} // namespace ws - -#endif // SERVICES_WS_GPU_HOST_TEST_GPU_HOST_H_ diff --git a/chromium/services/ws/host_context_factory.cc b/chromium/services/ws/host_context_factory.cc index d9035483c74..433ad7524e4 100644 --- a/chromium/services/ws/host_context_factory.cc +++ b/chromium/services/ws/host_context_factory.cc @@ -46,7 +46,7 @@ void HostContextFactory::OnEstablishedGpuChannel( return; } context_factory_private_->ConfigureCompositor( - compositor, std::move(context_provider), nullptr); + compositor.get(), std::move(context_provider), nullptr); } void HostContextFactory::CreateLayerTreeFrameSink( diff --git a/chromium/services/ws/injected_event_handler_unittest.cc b/chromium/services/ws/injected_event_handler_unittest.cc index 30b8ca78fe1..9f5f9430577 100644 --- a/chromium/services/ws/injected_event_handler_unittest.cc +++ b/chromium/services/ws/injected_event_handler_unittest.cc @@ -9,7 +9,7 @@ #include "base/bind.h" #include "services/service_manager/public/cpp/connector.h" #include "services/service_manager/public/cpp/test/test_connector_factory.h" -#include "services/ws/gpu_interface_provider.h" +#include "services/ws/public/cpp/host/gpu_interface_provider.h" #include "services/ws/public/mojom/constants.mojom.h" #include "services/ws/public/mojom/window_tree.mojom.h" #include "services/ws/window_service_test_setup.h" diff --git a/chromium/services/ws/input_devices/input_device_unittests.cc b/chromium/services/ws/input_devices/input_device_unittests.cc index ad2c3f5cee5..a45f806ac3d 100644 --- a/chromium/services/ws/input_devices/input_device_unittests.cc +++ b/chromium/services/ws/input_devices/input_device_unittests.cc @@ -128,8 +128,8 @@ TEST_F(InputDeviceTest, AddDevices) { TEST_F(InputDeviceTest, AddDeviceAfterComplete) { const ui::InputDevice keyboard1(100, ui::INPUT_DEVICE_INTERNAL, "Keyboard1"); - const ui::InputDevice keyboard2(200, ui::INPUT_DEVICE_EXTERNAL, "Keyboard2"); - const ui::InputDevice mouse(300, ui::INPUT_DEVICE_EXTERNAL, "Mouse"); + const ui::InputDevice keyboard2(200, ui::INPUT_DEVICE_USB, "Keyboard2"); + const ui::InputDevice mouse(300, ui::INPUT_DEVICE_USB, "Mouse"); TestInputDeviceClient client; AddClientAsObserver(&client); diff --git a/chromium/services/ws/public/cpp/BUILD.gn b/chromium/services/ws/public/cpp/BUILD.gn index 663fc03ff9c..8e96d80e3bb 100644 --- a/chromium/services/ws/public/cpp/BUILD.gn +++ b/chromium/services/ws/public/cpp/BUILD.gn @@ -12,6 +12,8 @@ source_set("cpp") { "raster_thread_helper.h", ] + configs += [ "//build/config/compiler:wexit_time_destructors" ] + public_deps = [ "//base", "//cc", @@ -49,6 +51,8 @@ source_set("internal") { "raster_thread_helper.cc", ] + configs += [ "//build/config/compiler:wexit_time_destructors" ] + deps = [ "//base", "//cc", diff --git a/chromium/services/ws/public/cpp/gpu/OWNERS b/chromium/services/ws/public/cpp/gpu/OWNERS new file mode 100644 index 00000000000..50daca70bdc --- /dev/null +++ b/chromium/services/ws/public/cpp/gpu/OWNERS @@ -0,0 +1,3 @@ +file://gpu/OWNERS + +# COMPONENT: Internals>GPU>Internals diff --git a/chromium/services/ws/public/cpp/gpu/command_buffer_metrics.cc b/chromium/services/ws/public/cpp/gpu/command_buffer_metrics.cc index 7b8f3dde8db..2671ff8cb99 100644 --- a/chromium/services/ws/public/cpp/gpu/command_buffer_metrics.cc +++ b/chromium/services/ws/public/cpp/gpu/command_buffer_metrics.cc @@ -56,6 +56,9 @@ void RecordContextLost(ContextType type, viz::ContextLostReason reason) { case ContextType::FOR_TESTING: // Don't record UMA, this is just for tests. break; + case ContextType::XR_COMPOSITING: + UMA_HISTOGRAM_ENUMERATION("GPU.ContextLost.XRCompositing", reason); + break; } } @@ -91,6 +94,8 @@ std::string ContextTypeToString(ContextType type) { return "Unknown"; case ContextType::FOR_TESTING: return "ForTesting"; + case ContextType::XR_COMPOSITING: + return "XRCompositing"; } } diff --git a/chromium/services/ws/public/cpp/gpu/command_buffer_metrics.h b/chromium/services/ws/public/cpp/gpu/command_buffer_metrics.h index 7c0cc5b0b5b..8bb757cae2a 100644 --- a/chromium/services/ws/public/cpp/gpu/command_buffer_metrics.h +++ b/chromium/services/ws/public/cpp/gpu/command_buffer_metrics.h @@ -31,6 +31,7 @@ enum class ContextType { MUS_CLIENT, UNKNOWN, FOR_TESTING, + XR_COMPOSITING, }; std::string ContextTypeToString(ContextType type); diff --git a/chromium/services/ws/public/cpp/gpu/context_provider_command_buffer.cc b/chromium/services/ws/public/cpp/gpu/context_provider_command_buffer.cc index 353b5502f12..8a382e5d5bb 100644 --- a/chromium/services/ws/public/cpp/gpu/context_provider_command_buffer.cc +++ b/chromium/services/ws/public/cpp/gpu/context_provider_command_buffer.cc @@ -14,6 +14,7 @@ #include "base/callback_helpers.h" #include "base/command_line.h" +#include "base/no_destructor.h" #include "base/optional.h" #include "base/strings/stringprintf.h" #include "base/threading/thread_task_runner_handle.h" @@ -396,6 +397,11 @@ class GrContext* ContextProviderCommandBuffer::GrContext() { return gr_context_->get(); } +gpu::SharedImageInterface* +ContextProviderCommandBuffer::SharedImageInterface() { + return command_buffer_->channel()->shared_image_interface(); +} + viz::ContextCacheController* ContextProviderCommandBuffer::CacheController() { CheckValidThreadOrLockAcquired(); return cache_controller_.get(); @@ -428,8 +434,9 @@ const gpu::GpuFeatureInfo& ContextProviderCommandBuffer::GetGpuFeatureInfo() DCHECK_EQ(bind_result_, gpu::ContextResult::kSuccess); CheckValidThreadOrLockAcquired(); if (!command_buffer_ || !command_buffer_->channel()) { - static const gpu::GpuFeatureInfo default_gpu_feature_info; - return default_gpu_feature_info; + static const base::NoDestructor<gpu::GpuFeatureInfo> + default_gpu_feature_info; + return *default_gpu_feature_info; } return command_buffer_->channel()->gpu_feature_info(); } diff --git a/chromium/services/ws/public/cpp/gpu/context_provider_command_buffer.h b/chromium/services/ws/public/cpp/gpu/context_provider_command_buffer.h index 2f0df9e3a9a..7598338ef92 100644 --- a/chromium/services/ws/public/cpp/gpu/context_provider_command_buffer.h +++ b/chromium/services/ws/public/cpp/gpu/context_provider_command_buffer.h @@ -90,6 +90,7 @@ class ContextProviderCommandBuffer gpu::raster::RasterInterface* RasterInterface() override; gpu::ContextSupport* ContextSupport() override; class GrContext* GrContext() override; + gpu::SharedImageInterface* SharedImageInterface() override; viz::ContextCacheController* CacheController() override; base::Lock* GetLock() override; const gpu::Capabilities& ContextCapabilities() const override; diff --git a/chromium/services/ws/public/cpp/host/BUILD.gn b/chromium/services/ws/public/cpp/host/BUILD.gn new file mode 100644 index 00000000000..113afe0b3c7 --- /dev/null +++ b/chromium/services/ws/public/cpp/host/BUILD.gn @@ -0,0 +1,13 @@ +# Copyright 2018 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +source_set("host") { + sources = [ + "gpu_interface_provider.h", + ] + + deps = [ + "//services/service_manager/public/cpp", + ] +} diff --git a/chromium/services/ws/gpu_interface_provider.h b/chromium/services/ws/public/cpp/host/gpu_interface_provider.h index 51c670f8e56..1f81f1525be 100644 --- a/chromium/services/ws/gpu_interface_provider.h +++ b/chromium/services/ws/public/cpp/host/gpu_interface_provider.h @@ -2,10 +2,9 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#ifndef SERVICES_WS_GPU_INTERFACE_PROVIDER_H_ -#define SERVICES_WS_GPU_INTERFACE_PROVIDER_H_ +#ifndef SERVICES_WS_PUBLIC_CPP_HOST_GPU_INTERFACE_PROVIDER_H_ +#define SERVICES_WS_PUBLIC_CPP_HOST_GPU_INTERFACE_PROVIDER_H_ -#include "base/component_export.h" #include "services/service_manager/public/cpp/binder_registry.h" namespace ws { @@ -13,7 +12,7 @@ namespace ws { // GpuInterfaceProvider is responsible for providing the Gpu related interfaces. // The implementation of these varies depending upon where the WindowService is // hosted. -class COMPONENT_EXPORT(WINDOW_SERVICE) GpuInterfaceProvider { +class GpuInterfaceProvider { public: virtual ~GpuInterfaceProvider() {} @@ -32,4 +31,4 @@ class COMPONENT_EXPORT(WINDOW_SERVICE) GpuInterfaceProvider { } // namespace ws -#endif // SERVICES_WS_GPU_INTERFACE_PROVIDER_H_ +#endif // SERVICES_WS_PUBLIC_CPP_HOST_GPU_INTERFACE_PROVIDER_H_ diff --git a/chromium/services/ws/public/cpp/input_devices/input_device_client.cc b/chromium/services/ws/public/cpp/input_devices/input_device_client.cc index d1fdc5a38dc..1c0b122ad32 100644 --- a/chromium/services/ws/public/cpp/input_devices/input_device_client.cc +++ b/chromium/services/ws/public/cpp/input_devices/input_device_client.cc @@ -158,4 +158,9 @@ void InputDeviceClient::NotifyObserversTouchscreenDeviceConfigurationChanged() { observer.OnTouchscreenDeviceConfigurationChanged(); } +void InputDeviceClient::NotifyObserversTouchpadDeviceConfigurationChanged() { + for (auto& observer : observers_) + observer.OnTouchpadDeviceConfigurationChanged(); +} + } // namespace ws diff --git a/chromium/services/ws/public/cpp/input_devices/input_device_client.h b/chromium/services/ws/public/cpp/input_devices/input_device_client.h index e3c2f447e6d..8b8b8c60bbd 100644 --- a/chromium/services/ws/public/cpp/input_devices/input_device_client.h +++ b/chromium/services/ws/public/cpp/input_devices/input_device_client.h @@ -75,6 +75,7 @@ class InputDeviceClient : public mojom::InputDeviceObserverMojo, void NotifyObserversDeviceListsComplete(); void NotifyObserversKeyboardDeviceConfigurationChanged(); void NotifyObserversTouchscreenDeviceConfigurationChanged(); + void NotifyObserversTouchpadDeviceConfigurationChanged(); mojo::Binding<mojom::InputDeviceObserverMojo> binding_; diff --git a/chromium/services/ws/public/cpp/input_devices/input_device_client_test_api.cc b/chromium/services/ws/public/cpp/input_devices/input_device_client_test_api.cc index 9e60a197d7a..20f630db48f 100644 --- a/chromium/services/ws/public/cpp/input_devices/input_device_client_test_api.cc +++ b/chromium/services/ws/public/cpp/input_devices/input_device_client_test_api.cc @@ -52,6 +52,16 @@ void InputDeviceClientTestApi:: } } +void InputDeviceClientTestApi:: + NotifyObserversTouchpadDeviceConfigurationChanged() { + if (ui::DeviceDataManager::instance_) { + ui::DeviceDataManager::instance_ + ->NotifyObserversTouchpadDeviceConfigurationChanged(); + } else { + GetInputDeviceClient()->NotifyObserversTouchpadDeviceConfigurationChanged(); + } +} + void InputDeviceClientTestApi::OnDeviceListsComplete() { if (ui::DeviceDataManager::instance_) ui::DeviceDataManager::instance_->OnDeviceListsComplete(); @@ -88,6 +98,15 @@ void InputDeviceClientTestApi::SetTouchscreenDevices( } } +void InputDeviceClientTestApi::SetTouchpadDevices( + const std::vector<ui::InputDevice>& devices) { + if (ui::DeviceDataManager::instance_) { + ui::DeviceDataManager::instance_->OnTouchpadDevicesUpdated(devices); + } else { + GetInputDeviceClient()->OnTouchpadDeviceConfigurationChanged(devices); + } +} + InputDeviceClient* InputDeviceClientTestApi::GetInputDeviceClient() { if (ui::DeviceDataManager::instance_ || !ui::InputDeviceManager::HasInstance()) diff --git a/chromium/services/ws/public/cpp/input_devices/input_device_client_test_api.h b/chromium/services/ws/public/cpp/input_devices/input_device_client_test_api.h index cda4f61b08d..9b24c83e582 100644 --- a/chromium/services/ws/public/cpp/input_devices/input_device_client_test_api.h +++ b/chromium/services/ws/public/cpp/input_devices/input_device_client_test_api.h @@ -36,10 +36,12 @@ class InputDeviceClientTestApi { void NotifyObserversKeyboardDeviceConfigurationChanged(); void NotifyObserversStylusStateChanged(ui::StylusState stylus_state); void NotifyObserversTouchscreenDeviceConfigurationChanged(); + void NotifyObserversTouchpadDeviceConfigurationChanged(); void OnDeviceListsComplete(); void SetKeyboardDevices(const std::vector<ui::InputDevice>& devices); void SetMouseDevices(const std::vector<ui::InputDevice>& devices); + void SetTouchpadDevices(const std::vector<ui::InputDevice>& devices); // |are_touchscreen_target_displays_valid| is only applicable to // InputDeviceClient. See diff --git a/chromium/services/ws/public/cpp/property_type_converters.cc b/chromium/services/ws/public/cpp/property_type_converters.cc index 8e0300de4af..aa61e57ce37 100644 --- a/chromium/services/ws/public/cpp/property_type_converters.cc +++ b/chromium/services/ws/public/cpp/property_type_converters.cc @@ -206,4 +206,18 @@ bool TypeConverter<bool, std::vector<uint8_t>>::Convert( return !input.empty() && (input[0] == 1); } +// static +std::vector<uint8_t> TypeConverter<std::vector<uint8_t>, uint64_t>::Convert( + uint64_t input) { + return TypeConverter<std::vector<uint8_t>, int64_t>::Convert( + static_cast<int64_t>(input)); +} + +// static +ws::Id TypeConverter<uint64_t, std::vector<uint8_t>>::Convert( + const std::vector<uint8_t>& input) { + return static_cast<uint64_t>( + TypeConverter<int64_t, std::vector<uint8_t>>::Convert(input)); +} + } // namespace mojo diff --git a/chromium/services/ws/public/cpp/property_type_converters.h b/chromium/services/ws/public/cpp/property_type_converters.h index 2ed9b542a75..ad9f92389e5 100644 --- a/chromium/services/ws/public/cpp/property_type_converters.h +++ b/chromium/services/ws/public/cpp/property_type_converters.h @@ -10,6 +10,7 @@ #include "base/strings/string16.h" #include "mojo/public/cpp/bindings/type_converter.h" +#include "services/ws/common/types.h" class SkBitmap; @@ -109,6 +110,15 @@ struct TypeConverter<base::UnguessableToken, std::vector<uint8_t>> { static base::UnguessableToken Convert(const std::vector<uint8_t>& input); }; +template <> +struct TypeConverter<std::vector<uint8_t>, uint64_t> { + static std::vector<uint8_t> Convert(uint64_t input); +}; +template <> +struct TypeConverter<uint64_t, std::vector<uint8_t>> { + static uint64_t Convert(const std::vector<uint8_t>& input); +}; + } // namespace mojo #endif // SERVICES_WS_PUBLIC_CPP_PROPERTY_TYPE_CONVERTERS_H_ diff --git a/chromium/services/ws/public/mojom/BUILD.gn b/chromium/services/ws/public/mojom/BUILD.gn index 665702abfeb..3ab531b765e 100644 --- a/chromium/services/ws/public/mojom/BUILD.gn +++ b/chromium/services/ws/public/mojom/BUILD.gn @@ -70,6 +70,7 @@ source_set("tests") { "//testing/gtest", "//ui/display/types", "//ui/gfx:test_support", + "//ui/gfx/geometry/mojo:struct_traits", "//ui/gfx/range/mojo:struct_traits", ] } diff --git a/chromium/services/ws/public/mojom/window_manager.mojom b/chromium/services/ws/public/mojom/window_manager.mojom index 91e433d912f..e0b2c8d6e26 100644 --- a/chromium/services/ws/public/mojom/window_manager.mojom +++ b/chromium/services/ws/public/mojom/window_manager.mojom @@ -34,11 +34,6 @@ interface WindowManager { // Type: int32_t. const string kContainerId_InitProperty = "init:container_id"; - // Disables the window manager from handling immersive fullscreen for the - // window. This is typically done if the client wants to handle immersive - // themselves. Type: bool. - const string kDisableImmersive_InitProperty = "init:disable_immersive"; - // The id of the display (display::Display::id()) to create the window on. // Type: int64. const string kDisplayId_InitProperty = "init:display_id"; @@ -78,6 +73,14 @@ interface WindowManager { // "com.google.Photos". Type: mojom::String. const string kArcPackageName_Property = "prop:arc-package-name"; + // The accessibility ui::AXTreeID for a views::Widget. + // TODO(dmazzoni): Convert to base::UnguessableToken. https://crbug.com/881986 + // Type: mojom::String + const string kChildAXTreeID_Property = "prop:child-ax-tree-id"; + + // The modal parent of a child modal window. Type: window Id. + const string kChildModalParent_Property = "prop:child-modal-parent"; + // Whether the window is trying to draw attention to itself (e.g. pulsing its // shelf icon). Type: bool. const string kDrawAttention_Property = "prop:draw-attention"; @@ -158,5 +161,26 @@ interface WindowManager { // A boolean determining whether to show the window's title. const string kWindowTitleShown_Property = "prop:window-title-shown"; + // Duration of an animation, as a TimeDelta. This maps to + // kWindowVisibilityAnimationDuration. + const string kWindowVisibilityAnimationDuration_Property = + "prop:window-visibility-animation-duration"; + + // When the animation should be used. This is an int, and maps to + // WindowVisibilityAnimationTransition (which is a bitmask). + const string kWindowVisibilityAnimationTransition_Property = + "prop:window-visibility-animation-transition"; + + // Type of animation for the window. This is an int, and maps to + // WindowVisibilityAnimationType (which allows for any int value as well). + const string kWindowVisibilityAnimationType_Property = + "prop:window-visibility-animation-type"; + + // Distance to translate a window during an animation of type + // WINDOW_VISIBILITY_ANIMATION_TYPE_VERTICAL. This is a float, and maps to + // kWindowVisibilityAnimationVerticalPositionKey. + const string kWindowVisibilityAnimationVerticalPosition_Property = + "prop:window-visibility-animation-vertical-position"; + // End long lived properties. ------------------------------------------------ }; diff --git a/chromium/services/ws/public/mojom/window_tree.mojom b/chromium/services/ws/public/mojom/window_tree.mojom index 6fc6e0b803a..f7713783b56 100644 --- a/chromium/services/ws/public/mojom/window_tree.mojom +++ b/chromium/services/ws/public/mojom/window_tree.mojom @@ -10,6 +10,7 @@ import "services/viz/public/interfaces/compositing/frame_sink_id.mojom"; import "services/viz/public/interfaces/compositing/local_surface_id.mojom"; import "services/viz/public/interfaces/compositing/surface_info.mojom"; import "services/ws/public/mojom/cursor/cursor.mojom"; +import "services/ws/public/mojom/window_manager.mojom"; import "services/ws/public/mojom/screen_provider_observer.mojom"; import "services/ws/public/mojom/window_tree_constants.mojom"; import "ui/base/mojo/ui_base_types.mojom"; @@ -126,10 +127,12 @@ interface WindowTree { gfx.mojom.Insets insets, array<gfx.mojom.Rect>? additional_client_areas); - // Mouse events outside a hit test mask do not hit the window. The |mask| is - // in window local coordinates. Pass null to clear the mask. - // TODO(jamescook): Convert |mask| to a path. http://crbug.com/613210 - SetHitTestMask(uint64 window_id, gfx.mojom.Rect? mask); + // Insets the hit test of a window by the specified values. The insets must be + // positive (or zero). |mouse| applies to events originating from the mouse, + // and |touch| from a non-mouse pointer device (such as tap). + SetHitTestInsets(uint64 window_id, + gfx.mojom.Insets mouse, + gfx.mojom.Insets touch); // Called by clients that want to accept drag and drops. Windows default to // this being disabled; a window must actively opt-in to receiving OnDrag*() @@ -197,12 +200,6 @@ interface WindowTree { // . Client does not have a valid user id (i.e., it is an embedded app). SetModalType(uint32 change_id, uint64 window_id, ui.mojom.ModalType type); - // Sets the modal parent of a CHILD_MODAL window. This is the modal parent of - // the window, which is not necessarily the same as the parent of the window. - SetChildModalParent(uint32 change_id, - uint64 window_id, - uint64 parent_window_id); - // Reorders a window in its parent, relative to |relative_window_id| according // to |direction|. Only the connection that created the window's parent can // reorder its children. @@ -285,6 +282,12 @@ interface WindowTree { uint32 embed_flags) => (bool success); + // Attaches/unattaches a FrameSinkId to this window. A window can only have + // a single frame-sink-id attached to it. + AttachFrameSinkId(uint64 window_id, + viz.mojom.FrameSinkId frame_sink_id); + UnattachFrameSinkId(uint64 window_id); + // Sets focus to the specified window, use 0 to clear focus. For a window to // get focus the following has to happen: the window is drawn, the window has // been marked as focusable (see SetCanFocus()) and the window is in a @@ -327,8 +330,14 @@ interface WindowTree { // Stacks the window above all sibling windows. StackAtTop(uint32 change_id, uint64 window_id); - // Tells the window manager to perform |string_action| for |window_id|. - PerformWmAction(uint64 window_id, string action); + // Requests a window manager specific interface. |name| is the name of the + // interface. This function is typed to WindowManager, but that's purely + // by necessity. This function is used to request *any* interface known to + // the environment hosting the window service. If |name| is not the name of + // an interface known to the environment hosting the window service, + // |window_manager| is closed. + BindWindowManagerInterface(string name, + associated WindowManager& window_manager); // Returns a shared memory segment that contains two 16-bit ints packed into a // single Atomic32, which represent the current location of the mouse cursor @@ -376,6 +385,22 @@ interface WindowTree { // Called by the client to request stopping the ongoing session of observing // the topmost window under the cursor. StopObservingTopmostWindow(); + + // Called by the client to cancel active touch events. not_cancelled_window_id + // is a window ID, and that window is excluded from cancelling. When + // not_cancelled_window_id is invalid, active touch events should be cancelled + // on all windows. + CancelActiveTouchesExcept(uint64 not_cancelled_window_id); + + // Called by the client to cancel active touches on |window_id|. + CancelActiveTouches(uint64 window_id); + + // Called by the client to transfer the gesture stream from the window of + // |current_id| to the window of |new_id|. If |should_cancel| is set, then + // cancel events are also dispatched to |current_id|. Both |current_id| and + // |new_id| need to be valid window ID created by the client. This operation + // is not allowed for embedded clients. + TransferGestureEventsTo(uint64 current_id, uint64 new_id, bool should_cancel); }; // Changes to windows are not sent to the connection that originated the @@ -480,6 +505,10 @@ interface WindowTreeClient { // Invoked when the opacity of the specified window has changed. OnWindowOpacityChanged(uint64 window, float old_opacity, float new_opacity); + // Invoked when the window moves to a new display. This is only called on + // a top-level window or an embedded root. + OnWindowDisplayChanged(uint64 window, int64 display_id); + // Invoked when the drawn state of |window|'s parent changes. The drawn state // is determined by the visibility of a Window and the Windows ancestors. A // Window is drawn if all ancestors are visible, not drawn if any ancestor is @@ -530,15 +559,6 @@ interface WindowTreeClient { OnWindowCursorChanged(uint64 window_id, CursorData cursor); - // Invoked when a client window submits a new surface ID. The surface ID and - // associated information is propagated to the parent connection. The parent - // compositor can take ownership of this surface ID and embed it along with - // frame_size and device_scale_factor in a layer. - // TODO(fsamuel): Surface IDs should be passed to parents directly instead of - // going through the window server. http://crbug.com/655231 - OnWindowSurfaceChanged(uint64 window_id, - viz.mojom.SurfaceInfo surface_info); - // Called when the mouse cursor enters a window on this connection for the // first time, providing a list of available mime types. We want to send this // set of data only one time, so this isn't part of OnDragEnter(), which diff --git a/chromium/services/ws/remote_view_host/server_remote_view_host.cc b/chromium/services/ws/remote_view_host/server_remote_view_host.cc index b01ab143089..1448b0951e5 100644 --- a/chromium/services/ws/remote_view_host/server_remote_view_host.cc +++ b/chromium/services/ws/remote_view_host/server_remote_view_host.cc @@ -14,7 +14,16 @@ namespace ws { ServerRemoteViewHost::ServerRemoteViewHost(WindowService* window_service) - : window_service_(window_service) {} + : window_service_(window_service), + embedding_root_( + std::make_unique<aura::Window>(nullptr, + aura::client::WINDOW_TYPE_UNKNOWN, + window_service_->env())) { + embedding_root_->set_owned_by_parent(false); + embedding_root_->SetName("ServerRemoteViewHostWindow"); + embedding_root_->SetType(aura::client::WINDOW_TYPE_CONTROL); + embedding_root_->Init(ui::LAYER_NOT_DRAWN); +} ServerRemoteViewHost::~ServerRemoteViewHost() = default; @@ -26,13 +35,6 @@ void ServerRemoteViewHost::EmbedUsingToken( embed_flags_ = embed_flags; embed_callback_ = std::move(callback); - embedding_root_ = std::make_unique<aura::Window>( - nullptr, aura::client::WINDOW_TYPE_UNKNOWN, window_service_->env()); - embedding_root_->set_owned_by_parent(false); - embedding_root_->SetName("ServerRemoteViewHostWindow"); - embedding_root_->SetType(aura::client::WINDOW_TYPE_CONTROL); - embedding_root_->Init(ui::LAYER_NOT_DRAWN); - // TODO(sky): having to wait for being parented is a bit of a hassle. Fix // this. if (GetWidget()) @@ -40,28 +42,17 @@ void ServerRemoteViewHost::EmbedUsingToken( } void ServerRemoteViewHost::EmbedImpl() { - // Should not be attached to anything. - DCHECK(!native_view()); - - // There is a pending embed request. - DCHECK(!embed_token_.is_empty()); - - // TODO(sky): fix this, only necessary because of restrictions in WindowTree. - // Must happen before EmbedUsingToken call for window server to figure out - // the relevant display. - Attach(embedding_root_.get()); - + DCHECK(IsEmbedPending()); const bool result = window_service_->CompleteScheduleEmbedForExistingClient( embedding_root_.get(), embed_token_, embed_flags_); - - if (!result) - embedding_root_.reset(); - std::move(embed_callback_).Run(result); } void ServerRemoteViewHost::AddedToWidget() { - if (!native_view()) + if (native_view()) + return; + Attach(embedding_root_.get()); + if (IsEmbedPending()) EmbedImpl(); } diff --git a/chromium/services/ws/remote_view_host/server_remote_view_host.h b/chromium/services/ws/remote_view_host/server_remote_view_host.h index 0c10e5f6443..f5515add869 100644 --- a/chromium/services/ws/remote_view_host/server_remote_view_host.h +++ b/chromium/services/ws/remote_view_host/server_remote_view_host.h @@ -52,6 +52,8 @@ class ServerRemoteViewHost : public views::NativeViewHost { aura::Window* embedding_root() { return embedding_root_.get(); } private: + bool IsEmbedPending() const { return !embed_token_.is_empty(); } + void EmbedImpl(); // views::NativeViewHost: @@ -61,7 +63,7 @@ class ServerRemoteViewHost : public views::NativeViewHost { base::UnguessableToken embed_token_; int embed_flags_ = 0; EmbedCallback embed_callback_; - std::unique_ptr<aura::Window> embedding_root_; + const std::unique_ptr<aura::Window> embedding_root_; DISALLOW_COPY_AND_ASSIGN(ServerRemoteViewHost); }; diff --git a/chromium/services/ws/screen_provider_unittest.cc b/chromium/services/ws/screen_provider_unittest.cc index 9318283c387..2b2cd2dc4f2 100644 --- a/chromium/services/ws/screen_provider_unittest.cc +++ b/chromium/services/ws/screen_provider_unittest.cc @@ -8,7 +8,7 @@ #include "base/run_loop.h" #include "services/service_manager/public/cpp/test/test_connector_factory.h" -#include "services/ws/gpu_interface_provider.h" +#include "services/ws/public/cpp/host/gpu_interface_provider.h" #include "services/ws/public/mojom/constants.mojom.h" #include "services/ws/public/mojom/screen_provider_observer.mojom.h" #include "services/ws/public/mojom/window_tree.mojom.h" diff --git a/chromium/services/ws/server_window.cc b/chromium/services/ws/server_window.cc index 72f86d977a1..865254b6279 100644 --- a/chromium/services/ws/server_window.cc +++ b/chromium/services/ws/server_window.cc @@ -8,9 +8,12 @@ #include "base/containers/flat_map.h" #include "components/viz/host/host_frame_sink_manager.h" +#include "services/ws/client_root.h" #include "services/ws/drag_drop_delegate.h" #include "services/ws/embedding.h" +#include "services/ws/public/mojom/window_tree_constants.mojom.h" #include "services/ws/window_tree.h" +#include "ui/aura/client/aura_constants.h" #include "ui/aura/client/capture_client_observer.h" #include "ui/aura/env.h" #include "ui/aura/window.h" @@ -36,10 +39,22 @@ bool IsLocationInNonClientArea(const aura::Window* window, if (!server_window || !server_window->IsTopLevel()) return false; - // Locations outside the bounds, assume it's in extended hit test area, which - // is non-client area. - if (!gfx::Rect(window->bounds().size()).Contains(location)) - return true; + // Locations inside bounds but within the resize insets count as non-client + // area. Locations outside the bounds, assume it's in extended hit test area, + // which is non-client area. + ui::WindowShowState window_state = + window->GetProperty(aura::client::kShowStateKey); + if ((window->GetProperty(aura::client::kResizeBehaviorKey) & + ws::mojom::kResizeBehaviorCanResize) && + (window_state != ui::WindowShowState::SHOW_STATE_MAXIMIZED) && + (window_state != ui::WindowShowState::SHOW_STATE_FULLSCREEN)) { + int resize_handle_size = + window->GetProperty(aura::client::kResizeHandleInset); + gfx::Rect non_handle_area(window->bounds().size()); + non_handle_area.Inset(gfx::Insets(resize_handle_size)); + if (!non_handle_area.Contains(location)) + return true; + } gfx::Rect client_area(window->bounds().size()); client_area.Inset(server_window->client_area()); @@ -434,7 +449,11 @@ void PointerPressHandler::OnWindowVisibilityChanged(aura::Window* window, } // namespace -ServerWindow::~ServerWindow() = default; +ServerWindow::~ServerWindow() { + // WindowTree/ClientRoot should have reset |attached_frame_sink_id_| before + // the Window is destroyed. + DCHECK(!attached_frame_sink_id_.is_valid()); +} // static ServerWindow* ServerWindow::Create(aura::Window* window, @@ -480,13 +499,16 @@ void ServerWindow::SetClientArea( additional_client_areas_ = additional_client_areas; client_area_ = insets; + ClientRoot* client_root = + owning_window_tree_ ? owning_window_tree_->GetClientRootForWindow(window_) + : nullptr; + if (client_root) + client_root->SetClientAreaInsets(insets); } -void ServerWindow::SetHitTestMask(const base::Optional<gfx::Rect>& mask) { - gfx::Insets insets; - if (mask) - insets = gfx::Rect(window_->bounds().size()).InsetsFrom(mask.value()); - window_targeter_->SetInsets(insets); +void ServerWindow::SetHitTestInsets(const gfx::Insets& mouse, + const gfx::Insets& touch) { + window_targeter_->SetInsets(mouse, touch); } void ServerWindow::SetCaptureOwner(WindowTree* owner) { diff --git a/chromium/services/ws/server_window.h b/chromium/services/ws/server_window.h index 3d193aef09e..2ba55629424 100644 --- a/chromium/services/ws/server_window.h +++ b/chromium/services/ws/server_window.h @@ -82,7 +82,14 @@ class COMPONENT_EXPORT(WINDOW_SERVICE) ServerWindow { void SetClientArea(const gfx::Insets& insets, const std::vector<gfx::Rect>& additional_client_areas); - void SetHitTestMask(const base::Optional<gfx::Rect>& mask); + void SetHitTestInsets(const gfx::Insets& mouse, const gfx::Insets& touch); + + void set_attached_frame_sink_id(const viz::FrameSinkId& id) { + attached_frame_sink_id_ = id; + } + const viz::FrameSinkId& attached_frame_sink_id() const { + return attached_frame_sink_id_; + } void SetCaptureOwner(WindowTree* owner); WindowTree* capture_owner() const { return capture_owner_; } @@ -199,6 +206,9 @@ class COMPONENT_EXPORT(WINDOW_SERVICE) ServerWindow { // Set to true once AttachCompositorFrameSink() has been called. bool attached_compositor_frame_sink_ = false; + // FrameSinkId set by way of mojom::WindowTree::AttachFrameSinkId(). + viz::FrameSinkId attached_frame_sink_id_; + DISALLOW_COPY_AND_ASSIGN(ServerWindow); }; diff --git a/chromium/services/ws/server_window_unittest.cc b/chromium/services/ws/server_window_unittest.cc index 2522f7a8c37..47bd9cdbbd2 100644 --- a/chromium/services/ws/server_window_unittest.cc +++ b/chromium/services/ws/server_window_unittest.cc @@ -7,10 +7,13 @@ #include <memory> #include "base/run_loop.h" +#include "services/ws/client_root_test_helper.h" #include "services/ws/window_service_test_setup.h" #include "services/ws/window_tree.h" #include "services/ws/window_tree_test_helper.h" #include "testing/gtest/include/gtest/gtest.h" +#include "ui/aura/client/aura_constants.h" +#include "ui/aura/mus/client_surface_embedder.h" #include "ui/events/event.h" #include "ui/events/event_constants.h" #include "ui/wm/core/easy_resize_window_targeter.h" @@ -21,8 +24,7 @@ TEST(ServerWindow, FindTargetForWindowWithEasyResizeTargeter) { WindowServiceTestSetup setup; std::unique_ptr<wm::EasyResizeWindowTargeter> easy_resize_window_targeter = std::make_unique<wm::EasyResizeWindowTargeter>( - setup.root(), gfx::Insets(-10, -10, -10, -10), - gfx::Insets(-10, -10, -10, -10)); + gfx::Insets(-10, -10, -10, -10), gfx::Insets(-10, -10, -10, -10)); setup.root()->SetEventTargeter(std::move(easy_resize_window_targeter)); aura::Window* top_level = setup.window_tree_test_helper()->NewTopLevelWindow(); @@ -49,4 +51,60 @@ TEST(ServerWindow, FindTargetForWindowWithEasyResizeTargeter) { setup.root(), &mouse_event2)); } +TEST(ServerWindow, FindTargetForWindowWithResizeInset) { + WindowServiceTestSetup setup; + + aura::Window* top_level = + setup.window_tree_test_helper()->NewTopLevelWindow(); + ASSERT_TRUE(top_level); + const gfx::Rect top_level_bounds(100, 200, 200, 200); + top_level->SetBounds(top_level_bounds); + top_level->Show(); + + aura::Window* child_window = setup.window_tree_test_helper()->NewWindow(); + ASSERT_TRUE(child_window); + top_level->AddChild(child_window); + child_window->SetBounds(gfx::Rect(top_level_bounds.size())); + child_window->Show(); + + const int kInset = 4; + // Target an event at the resize inset area. + gfx::Point click_point = + top_level_bounds.left_center() + gfx::Vector2d(kInset / 2, 0); + // With no resize inset set yet, the event should go to the child window. + ui::MouseEvent mouse_event(ui::ET_MOUSE_PRESSED, click_point, click_point, + base::TimeTicks::Now(), + /* flags */ 0, + /* changed_button_flags_ */ 0); + EXPECT_EQ(child_window, setup.root()->targeter()->FindTargetForEvent( + setup.root(), &mouse_event)); + + // With the resize inset, the event should go to the toplevel. + top_level->SetProperty(aura::client::kResizeHandleInset, kInset); + ui::MouseEvent mouse_event_2(ui::ET_MOUSE_PRESSED, click_point, click_point, + base::TimeTicks::Now(), + /* flags */ 0, + /* changed_button_flags_ */ 0); + EXPECT_EQ(top_level, setup.root()->targeter()->FindTargetForEvent( + setup.root(), &mouse_event_2)); +} + +TEST(ServerWindow, SetClientAreaPropagatesToClientSurfaceEmbedder) { + WindowServiceTestSetup setup; + + aura::Window* top_level = + setup.window_tree_test_helper()->NewTopLevelWindow(); + ASSERT_TRUE(top_level); + const gfx::Rect top_level_bounds(100, 200, 200, 200); + top_level->SetBounds(top_level_bounds); + const gfx::Insets top_level_insets(1, 2, 11, 12); + setup.window_tree_test_helper()->SetClientArea(top_level, top_level_insets); + aura::ClientSurfaceEmbedder* client_surface_embedder = + ClientRootTestHelper( + setup.window_tree()->GetClientRootForWindow(top_level)) + .GetClientSurfaceEmbedder(); + ASSERT_TRUE(client_surface_embedder); + EXPECT_EQ(top_level_insets, client_surface_embedder->client_area_insets()); +} + } // namespace ws diff --git a/chromium/services/ws/test_change_tracker.cc b/chromium/services/ws/test_change_tracker.cc index fa14b2024fa..24cc3f680b0 100644 --- a/chromium/services/ws/test_change_tracker.cc +++ b/chromium/services/ws/test_change_tracker.cc @@ -153,13 +153,14 @@ std::string ChangeToDescription(const Change& change, change.float_value); case CHANGE_TYPE_REQUEST_CLOSE: return "RequestClose"; - case CHANGE_TYPE_SURFACE_CHANGED: - return base::StringPrintf("SurfaceCreated window_id=%s surface_id=%s", - WindowIdToString(change.window_id).c_str(), - change.surface_id.ToString().c_str()); case CHANGE_TYPE_TRANSFORM_CHANGED: return base::StringPrintf("TransformChanged window_id=%s", WindowIdToString(change.window_id).c_str()); + case CHANGE_TYPE_DISPLAY_CHANGED: + return base::StringPrintf( + "DisplayChanged window_id=%s display_id=%s", + WindowIdToString(change.window_id).c_str(), + base::NumberToString(change.display_id).c_str()); case CHANGE_TYPE_DRAG_DROP_START: return "DragDropStart"; case CHANGE_TYPE_DRAG_ENTER: @@ -427,6 +428,15 @@ void TestChangeTracker::OnWindowOpacityChanged(Id window_id, float opacity) { AddChange(change); } +void TestChangeTracker::OnWindowDisplayChanged(Id window_id, + int64_t display_id) { + Change change; + change.type = CHANGE_TYPE_DISPLAY_CHANGED; + change.window_id = window_id; + change.display_id = display_id; + AddChange(change); +} + void TestChangeTracker::OnWindowParentDrawnStateChanged(Id window_id, bool drawn) { Change change; @@ -512,18 +522,6 @@ void TestChangeTracker::OnTopLevelCreated(uint32_t change_id, AddChange(change); } -void TestChangeTracker::OnWindowSurfaceChanged( - Id window_id, - const viz::SurfaceInfo& surface_info) { - Change change; - change.type = CHANGE_TYPE_SURFACE_CHANGED; - change.window_id = window_id; - change.surface_id = surface_info.id(); - change.frame_size = surface_info.size_in_pixels(); - change.device_scale_factor = surface_info.device_scale_factor(); - AddChange(change); -} - void TestChangeTracker::OnDragDropStart( const base::flat_map<std::string, std::vector<uint8_t>>& drag_data) { Change change; diff --git a/chromium/services/ws/test_change_tracker.h b/chromium/services/ws/test_change_tracker.h index fc17883473b..a8da456efc5 100644 --- a/chromium/services/ws/test_change_tracker.h +++ b/chromium/services/ws/test_change_tracker.h @@ -45,8 +45,8 @@ enum ChangeType { CHANGE_TYPE_ON_TOP_LEVEL_CREATED, CHANGE_TYPE_OPACITY, CHANGE_TYPE_REQUEST_CLOSE, - CHANGE_TYPE_SURFACE_CHANGED, CHANGE_TYPE_TRANSFORM_CHANGED, + CHANGE_TYPE_DISPLAY_CHANGED, CHANGE_TYPE_DRAG_DROP_START, CHANGE_TYPE_DRAG_ENTER, CHANGE_TYPE_DRAG_OVER, @@ -102,9 +102,6 @@ struct Change { std::string property_value; ui::CursorType cursor_type; uint32_t change_id; - viz::SurfaceId surface_id; - gfx::Size frame_size; - float device_scale_factor; gfx::Transform transform; // Set in OnWindowInputEvent() if the event is a KeyEvent. base::flat_map<std::string, std::vector<uint8_t>> key_event_properties; @@ -200,6 +197,7 @@ class TestChangeTracker { void OnWindowDeleted(Id window_id); void OnWindowVisibilityChanged(Id window_id, bool visible); void OnWindowOpacityChanged(Id window_id, float opacity); + void OnWindowDisplayChanged(Id window_id, int64_t display_id); void OnWindowParentDrawnStateChanged(Id window_id, bool drawn); void OnWindowInputEvent(Id window_id, const ui::Event& event, @@ -216,8 +214,6 @@ class TestChangeTracker { void OnTopLevelCreated(uint32_t change_id, mojom::WindowDataPtr window_data, bool drawn); - void OnWindowSurfaceChanged(Id window_id, - const viz::SurfaceInfo& surface_info); void OnDragDropStart( const base::flat_map<std::string, std::vector<uint8_t>>& drag_data); void OnDragEnter(Id window_id); diff --git a/chromium/services/ws/test_window_service_delegate.cc b/chromium/services/ws/test_window_service_delegate.cc index 19601f29f93..699ab5b4e59 100644 --- a/chromium/services/ws/test_window_service_delegate.cc +++ b/chromium/services/ws/test_window_service_delegate.cc @@ -71,4 +71,13 @@ void TestWindowServiceDelegate::CancelDragLoop(aura::Window* window) { cancel_drag_loop_called_ = true; } +aura::Window* TestWindowServiceDelegate::GetTopmostWindowAtPoint( + const gfx::Point& location_in_screen, + const std::set<aura::Window*>& ignore, + aura::Window** real_topmost) { + if (real_topmost) + *real_topmost = real_topmost_; + return topmost_; +} + } // namespace ws diff --git a/chromium/services/ws/test_window_service_delegate.h b/chromium/services/ws/test_window_service_delegate.h index 3a7e41b89fd..4721b59efdf 100644 --- a/chromium/services/ws/test_window_service_delegate.h +++ b/chromium/services/ws/test_window_service_delegate.h @@ -44,6 +44,9 @@ class TestWindowServiceDelegate : public WindowServiceDelegate { bool cancel_drag_loop_called() const { return cancel_drag_loop_called_; } + void set_topmost(aura::Window* window) { topmost_ = window; } + void set_real_topmost(aura::Window* window) { real_topmost_ = window; } + DragDropCompletedCallback TakeDragLoopCallback(); // WindowServiceDelegate: @@ -64,6 +67,9 @@ class TestWindowServiceDelegate : public WindowServiceDelegate { ui::DragDropTypes::DragEventSource source, DragDropCompletedCallback callback) override; void CancelDragLoop(aura::Window* window) override; + aura::Window* GetTopmostWindowAtPoint(const gfx::Point& location_in_screen, + const std::set<aura::Window*>& ignore, + aura::Window** real_topmost) override; private: aura::Window* top_level_parent_; @@ -81,6 +87,9 @@ class TestWindowServiceDelegate : public WindowServiceDelegate { bool cancel_window_move_loop_called_ = false; bool cancel_drag_loop_called_ = false; + aura::Window* topmost_ = nullptr; + aura::Window* real_topmost_ = nullptr; + DISALLOW_COPY_AND_ASSIGN(TestWindowServiceDelegate); }; diff --git a/chromium/services/ws/test_window_tree_client.cc b/chromium/services/ws/test_window_tree_client.cc index 6d697930aa9..2b917e7c9be 100644 --- a/chromium/services/ws/test_window_tree_client.cc +++ b/chromium/services/ws/test_window_tree_client.cc @@ -180,6 +180,11 @@ void TestWindowTreeClient::OnWindowOpacityChanged(Id window, tracker_.OnWindowOpacityChanged(window, new_opacity); } +void TestWindowTreeClient::OnWindowDisplayChanged(Id window_id, + int64_t display_id) { + tracker_.OnWindowDisplayChanged(window_id, display_id); +} + void TestWindowTreeClient::OnWindowParentDrawnStateChanged(Id window, bool drawn) { tracker_.OnWindowParentDrawnStateChanged(window, drawn); @@ -232,12 +237,6 @@ void TestWindowTreeClient::OnWindowCursorChanged(Id window_id, tracker_.OnWindowCursorChanged(window_id, cursor); } -void TestWindowTreeClient::OnWindowSurfaceChanged( - Id window_id, - const viz::SurfaceInfo& surface_info) { - tracker_.OnWindowSurfaceChanged(window_id, surface_info); -} - void TestWindowTreeClient::OnDragDropStart( const base::flat_map<std::string, std::vector<uint8_t>>& drag_data) { tracker_.OnDragDropStart(drag_data); diff --git a/chromium/services/ws/test_window_tree_client.h b/chromium/services/ws/test_window_tree_client.h index a3e2d72dd33..ed67b4ec281 100644 --- a/chromium/services/ws/test_window_tree_client.h +++ b/chromium/services/ws/test_window_tree_client.h @@ -145,6 +145,7 @@ class TestWindowTreeClient : public mojom::WindowTreeClient, void OnWindowOpacityChanged(Id window, float old_opacity, float new_opacity) override; + void OnWindowDisplayChanged(Id window_id, int64_t display_id) override; void OnWindowParentDrawnStateChanged(Id window, bool drawn) override; void OnWindowInputEvent(uint32_t event_id, Id window_id, @@ -160,8 +161,6 @@ class TestWindowTreeClient : public mojom::WindowTreeClient, const base::Optional<std::vector<uint8_t>>& new_data) override; void OnWindowFocused(Id focused_window_id) override; void OnWindowCursorChanged(Id window_id, ui::CursorData cursor) override; - void OnWindowSurfaceChanged(Id window_id, - const viz::SurfaceInfo& surface_info) override; void OnDragDropStart(const base::flat_map<std::string, std::vector<uint8_t>>& drag_data) override; void OnDragEnter(Id window, diff --git a/chromium/services/ws/test_wm.mojom b/chromium/services/ws/test_wm.mojom new file mode 100644 index 00000000000..8c6868a6b61 --- /dev/null +++ b/chromium/services/ws/test_wm.mojom @@ -0,0 +1,9 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +module ws.test.mojom; + +interface TestWm { + DoIt(); +}; diff --git a/chromium/services/ws/test_ws/BUILD.gn b/chromium/services/ws/test_ws/BUILD.gn index 2555796f869..dd32b550a96 100644 --- a/chromium/services/ws/test_ws/BUILD.gn +++ b/chromium/services/ws/test_ws/BUILD.gn @@ -2,6 +2,7 @@ # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. +import("//mojo/public/tools/bindings/mojom.gni") import("//services/service_manager/public/cpp/service.gni") import("//services/service_manager/public/service_manifest.gni") @@ -9,21 +10,42 @@ service("test_ws") { testonly = true sources = [ + "test_ws.cc", + ] + + deps = [ + ":lib", + "//base", + "//services/service_manager/public/cpp", + "//ui/base", + ] +} + +source_set("lib") { + testonly = true + + sources = [ "test_drag_drop_client.cc", "test_drag_drop_client.h", "test_gpu_interface_provider.cc", "test_gpu_interface_provider.h", - "test_ws.cc", + "test_window_service.cc", + "test_window_service.h", + "test_window_service_factory.cc", + "test_window_service_factory.h", ] deps = [ + ":mojom", "//base", "//components/discardable_memory/service", + "//mojo/public/cpp/bindings", "//services/service_manager/public/cpp", "//services/service_manager/public/mojom", "//services/ws:lib", "//services/ws/gpu_host", "//services/ws/public/cpp", + "//services/ws/public/cpp/host", "//services/ws/public/mojom", "//ui/aura", "//ui/aura:test_support", @@ -33,7 +55,17 @@ service("test_ws") { } service_manifest("manifest") { + testonly = true + name = "test_ws" source = "manifest.json" packaged_services = [ "//services/ws:manifest" ] } + +mojom("mojom") { + testonly = true + + sources = [ + "test_ws.mojom", + ] +} diff --git a/chromium/services/ws/test_ws/OWNERS b/chromium/services/ws/test_ws/OWNERS index 59dfd4b3677..045c0ba06e0 100644 --- a/chromium/services/ws/test_ws/OWNERS +++ b/chromium/services/ws/test_ws/OWNERS @@ -1,2 +1,5 @@ per-file manifest.json=set noparent per-file manifest.json=file://ipc/SECURITY_OWNERS + +per-file *.mojom=set noparent +per-file *.mojom=file://ipc/SECURITY_OWNERS diff --git a/chromium/services/ws/test_ws/manifest.json b/chromium/services/ws/test_ws/manifest.json index 2db422b0475..b1ea0bd917e 100644 --- a/chromium/services/ws/test_ws/manifest.json +++ b/chromium/services/ws/test_ws/manifest.json @@ -7,6 +7,9 @@ "provides": { "service_manager:service_factory": [ "service_manager.mojom.ServiceFactory" + ], + "test": [ + "test_ws.mojom.TestWs" ] } } diff --git a/chromium/services/ws/test_ws/test_gpu_interface_provider.h b/chromium/services/ws/test_ws/test_gpu_interface_provider.h index da284f1cbcb..2c01ca1550d 100644 --- a/chromium/services/ws/test_ws/test_gpu_interface_provider.h +++ b/chromium/services/ws/test_ws/test_gpu_interface_provider.h @@ -6,7 +6,7 @@ #define SERVICES_WS_TEST_WS_TEST_GPU_INTERFACE_PROVIDER_H_ #include "components/discardable_memory/public/interfaces/discardable_shared_memory_manager.mojom.h" -#include "services/ws/gpu_interface_provider.h" +#include "services/ws/public/cpp/host/gpu_interface_provider.h" #include "services/ws/public/mojom/gpu.mojom.h" namespace discardable_memory { diff --git a/chromium/services/ws/test_ws/test_window_service.cc b/chromium/services/ws/test_ws/test_window_service.cc new file mode 100644 index 00000000000..651066c0f7f --- /dev/null +++ b/chromium/services/ws/test_ws/test_window_service.cc @@ -0,0 +1,216 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "services/ws/test_ws/test_window_service.h" + +#include <utility> + +#include "base/bind.h" +#include "base/bind_helpers.h" +#include "mojo/public/cpp/bindings/map.h" +#include "services/service_manager/public/cpp/connector.h" +#include "services/ws/public/mojom/constants.mojom.h" +#include "services/ws/test_ws/test_gpu_interface_provider.h" +#include "services/ws/window_service.h" +#include "ui/aura/env.h" +#include "ui/aura/mus/property_utils.h" +#include "ui/compositor/test/context_factories_for_test.h" +#include "ui/gl/test/gl_surface_test_support.h" + +namespace ws { +namespace test { + +TestWindowService::TestWindowService() = default; + +TestWindowService::~TestWindowService() { + Shutdown(base::NullCallback()); +} + +void TestWindowService::InitForInProcess( + ui::ContextFactory* context_factory, + ui::ContextFactoryPrivate* context_factory_private, + std::unique_ptr<GpuInterfaceProvider> gpu_interface_provider) { + is_in_process_ = true; + aura_test_helper_ = std::make_unique<aura::test::AuraTestHelper>( + aura::Env::CreateLocalInstanceForInProcess()); + SetupAuraTestHelper(context_factory, context_factory_private); + + gpu_interface_provider_ = std::move(gpu_interface_provider); +} + +void TestWindowService::InitForOutOfProcess() { +#if defined(OS_CHROMEOS) + // Use gpu service only for ChromeOS to run content_browsertests in mash. + // + // To use this code path for all platforms, we need to fix the following + // flaky failure on Win7 bot: + // gl_surface_egl.cc: + // EGL Driver message (Critical) eglInitialize: No available renderers + // gl_initializer_win.cc: + // GLSurfaceEGL::InitializeOneOff failed. + CreateGpuHost(); +#else + gl::GLSurfaceTestSupport::InitializeOneOff(); + CreateAuraTestHelper(); +#endif // defined(OS_CHROMEOS) +} + +std::unique_ptr<aura::Window> TestWindowService::NewTopLevel( + aura::PropertyConverter* property_converter, + const base::flat_map<std::string, std::vector<uint8_t>>& properties) { + std::unique_ptr<aura::Window> top_level = std::make_unique<aura::Window>( + nullptr, aura::client::WINDOW_TYPE_UNKNOWN, aura_test_helper_->GetEnv()); + aura::SetWindowType(top_level.get(), aura::GetWindowTypeFromProperties( + mojo::FlatMapToMap(properties))); + top_level->Init(ui::LAYER_NOT_DRAWN); + aura_test_helper_->root_window()->AddChild(top_level.get()); + for (auto property : properties) { + property_converter->SetPropertyFromTransportValue( + top_level.get(), property.first, &property.second); + } + return top_level; +} + +void TestWindowService::RunDragLoop(aura::Window* window, + const ui::OSExchangeData& data, + const gfx::Point& screen_location, + uint32_t drag_operation, + ui::DragDropTypes::DragEventSource source, + DragDropCompletedCallback callback) { + std::move(callback).Run(drag_drop_client_.StartDragAndDrop( + data, window->GetRootWindow(), window, screen_location, drag_operation, + source)); +} + +void TestWindowService::CancelDragLoop(aura::Window* window) { + drag_drop_client_.DragCancel(); +} + +aura::WindowTreeHost* TestWindowService::GetWindowTreeHostForDisplayId( + int64_t display_id) { + return aura_test_helper_->host(); +} + +void TestWindowService::OnStart() { + CHECK(!started_); + started_ = true; + + registry_.AddInterface(base::BindRepeating( + &TestWindowService::BindServiceFactory, base::Unretained(this))); + registry_.AddInterface(base::BindRepeating(&TestWindowService::BindTestWs, + base::Unretained(this))); + + if (!is_in_process_) { + DCHECK(!aura_test_helper_); + InitForOutOfProcess(); + } +} + +void TestWindowService::OnBindInterface( + const service_manager::BindSourceInfo& source_info, + const std::string& interface_name, + mojo::ScopedMessagePipeHandle interface_pipe) { + registry_.BindInterface(interface_name, std::move(interface_pipe)); +} + +void TestWindowService::CreateService( + service_manager::mojom::ServiceRequest request, + const std::string& name, + service_manager::mojom::PIDReceiverPtr pid_receiver) { + DCHECK_EQ(name, mojom::kServiceName); + + // Defer CreateService if |aura_test_helper_| is not created. + if (!aura_test_helper_) { + DCHECK(!pending_create_service_); + + pending_create_service_ = base::BindOnce( + &TestWindowService::CreateService, base::Unretained(this), + std::move(request), name, std::move(pid_receiver)); + return; + } + + DCHECK(!ui_service_created_); + ui_service_created_ = true; + + auto window_service = std::make_unique<WindowService>( + this, std::move(gpu_interface_provider_), + aura_test_helper_->focus_client(), /*decrement_client_ids=*/false, + aura_test_helper_->GetEnv()); + service_context_ = std::make_unique<service_manager::ServiceContext>( + std::move(window_service), std::move(request)); + pid_receiver->SetPID(base::GetCurrentProcId()); +} + +void TestWindowService::OnGpuServiceInitialized() { + CreateAuraTestHelper(); + + if (pending_create_service_) + std::move(pending_create_service_).Run(); +} + +void TestWindowService::Shutdown( + test_ws::mojom::TestWs::ShutdownCallback callback) { + // WindowService depends upon Screen, which is owned by AuraTestHelper. + service_context_.reset(); + + // |aura_test_helper_| could be null when exiting before fully initialized. + if (aura_test_helper_) { + aura::client::SetScreenPositionClient(aura_test_helper_->root_window(), + nullptr); + // AuraTestHelper expects TearDown() to be called. + aura_test_helper_->TearDown(); + aura_test_helper_.reset(); + } + + ui::TerminateContextFactoryForTests(); + + if (callback) + std::move(callback).Run(); +} + +void TestWindowService::BindServiceFactory( + service_manager::mojom::ServiceFactoryRequest request) { + service_factory_bindings_.AddBinding(this, std::move(request)); +} + +void TestWindowService::BindTestWs(test_ws::mojom::TestWsRequest request) { + test_ws_bindings_.AddBinding(this, std::move(request)); +} + +void TestWindowService::CreateGpuHost() { + discardable_shared_memory_manager_ = + std::make_unique<discardable_memory::DiscardableSharedMemoryManager>(); + + gpu_host_ = std::make_unique<gpu_host::GpuHost>( + this, context()->connector(), discardable_shared_memory_manager_.get()); + + gpu_interface_provider_ = std::make_unique<TestGpuInterfaceProvider>( + gpu_host_.get(), discardable_shared_memory_manager_.get()); + + // |aura_test_helper_| is created later in OnGpuServiceInitialized. +} + +void TestWindowService::CreateAuraTestHelper() { + DCHECK(!aura_test_helper_); + + ui::ContextFactory* context_factory = nullptr; + ui::ContextFactoryPrivate* context_factory_private = nullptr; + ui::InitializeContextFactoryForTests(false /* enable_pixel_output */, + &context_factory, + &context_factory_private); + aura_test_helper_ = std::make_unique<aura::test::AuraTestHelper>(); + SetupAuraTestHelper(context_factory, context_factory_private); +} + +void TestWindowService::SetupAuraTestHelper( + ui::ContextFactory* context_factory, + ui::ContextFactoryPrivate* context_factory_private) { + aura_test_helper_->SetUp(context_factory, context_factory_private); + + aura::client::SetScreenPositionClient(aura_test_helper_->root_window(), + &screen_position_client_); +} + +} // namespace test +} // namespace ws diff --git a/chromium/services/ws/test_ws/test_window_service.h b/chromium/services/ws/test_ws/test_window_service.h new file mode 100644 index 00000000000..28d8bac866c --- /dev/null +++ b/chromium/services/ws/test_ws/test_window_service.h @@ -0,0 +1,137 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef SERVICES_WS_TEST_WS_TEST_WINDOW_SERVICE_H_ +#define SERVICES_WS_TEST_WS_TEST_WINDOW_SERVICE_H_ + +#include <memory> + +#include "base/callback.h" +#include "base/macros.h" +#include "components/discardable_memory/service/discardable_shared_memory_manager.h" +#include "mojo/public/cpp/bindings/binding_set.h" +#include "services/service_manager/public/cpp/binder_registry.h" +#include "services/service_manager/public/cpp/service.h" +#include "services/service_manager/public/cpp/service_context.h" +#include "services/service_manager/public/mojom/service_factory.mojom.h" +#include "services/ws/gpu_host/gpu_host.h" +#include "services/ws/gpu_host/gpu_host_delegate.h" +#include "services/ws/public/cpp/host/gpu_interface_provider.h" +#include "services/ws/test_ws/test_drag_drop_client.h" +#include "services/ws/test_ws/test_ws.mojom.h" +#include "services/ws/window_service_delegate.h" +#include "ui/aura/test/aura_test_helper.h" +#include "ui/aura/window.h" +#include "ui/aura/window_tree_host.h" +#include "ui/wm/core/default_screen_position_client.h" + +namespace ui { +class ContextFactory; +class ContextFactoryPrivate; +} // namespace ui + +namespace ws { +namespace test { + +// Service implementation that brings up the Window Service on top of aura. +// Uses ws::WindowService to provide the Window Service. +class TestWindowService : public service_manager::Service, + public service_manager::mojom::ServiceFactory, + public gpu_host::GpuHostDelegate, + public WindowServiceDelegate, + public test_ws::mojom::TestWs { + public: + TestWindowService(); + ~TestWindowService() override; + + void InitForInProcess( + ui::ContextFactory* context_factory, + ui::ContextFactoryPrivate* context_factory_private, + std::unique_ptr<GpuInterfaceProvider> gpu_interface_provider); + + private: + void InitForOutOfProcess(); + + // WindowServiceDelegate: + std::unique_ptr<aura::Window> NewTopLevel( + aura::PropertyConverter* property_converter, + const base::flat_map<std::string, std::vector<uint8_t>>& properties) + override; + void RunDragLoop(aura::Window* window, + const ui::OSExchangeData& data, + const gfx::Point& screen_location, + uint32_t drag_operation, + ui::DragDropTypes::DragEventSource source, + DragDropCompletedCallback callback) override; + void CancelDragLoop(aura::Window* window) override; + aura::WindowTreeHost* GetWindowTreeHostForDisplayId( + int64_t display_id) override; + + // service_manager::Service: + void OnStart() override; + void OnBindInterface(const service_manager::BindSourceInfo& source_info, + const std::string& interface_name, + mojo::ScopedMessagePipeHandle interface_pipe) override; + + // service_manager::mojom::ServiceFactory: + void CreateService( + service_manager::mojom::ServiceRequest request, + const std::string& name, + service_manager::mojom::PIDReceiverPtr pid_receiver) override; + + // gpu_host::GpuHostDelegate: + void OnGpuServiceInitialized() override; + + // test_ws::mojom::TestWs: + void Shutdown(test_ws::mojom::TestWs::ShutdownCallback callback) override; + + void BindServiceFactory( + service_manager::mojom::ServiceFactoryRequest request); + void BindTestWs(test_ws::mojom::TestWsRequest request); + + void CreateGpuHost(); + + void CreateAuraTestHelper(); + void SetupAuraTestHelper(ui::ContextFactory* context_factory, + ui::ContextFactoryPrivate* context_factory_private); + + service_manager::BinderRegistry registry_; + + mojo::BindingSet<service_manager::mojom::ServiceFactory> + service_factory_bindings_; + mojo::BindingSet<test_ws::mojom::TestWs> test_ws_bindings_; + + // Handles the ServiceRequest. Owns the WindowService instance. + std::unique_ptr<service_manager::ServiceContext> service_context_; + + std::unique_ptr<aura::test::AuraTestHelper> aura_test_helper_; + + std::unique_ptr<discardable_memory::DiscardableSharedMemoryManager> + discardable_shared_memory_manager_; + std::unique_ptr<gpu_host::GpuHost> gpu_host_; + + // For drag and drop code to convert to/from screen coordinates. + wm::DefaultScreenPositionClient screen_position_client_; + + TestDragDropClient drag_drop_client_; + + bool started_ = false; + bool ui_service_created_ = false; + + base::OnceClosure pending_create_service_; + + // GpuInterfaceProvider that is passed to WindowService when creating it. + std::unique_ptr<GpuInterfaceProvider> gpu_interface_provider_; + + // Whether the service is used in process. Not using features because it + // is used in service_unittests where ui features is not used there. + bool is_in_process_ = false; + + DISALLOW_COPY_AND_ASSIGN(TestWindowService); +}; + +} // namespace test +} // namespace ws + +#endif // SERVICES_WS_TEST_WS_TEST_WINDOW_SERVICE_H_ diff --git a/chromium/services/ws/test_ws/test_window_service_factory.cc b/chromium/services/ws/test_ws/test_window_service_factory.cc new file mode 100644 index 00000000000..456f704de78 --- /dev/null +++ b/chromium/services/ws/test_ws/test_window_service_factory.cc @@ -0,0 +1,32 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "services/ws/test_ws/test_window_service_factory.h" + +#include <utility> + +#include "services/service_manager/public/cpp/service.h" +#include "services/ws/public/cpp/host/gpu_interface_provider.h" +#include "services/ws/test_ws/test_window_service.h" + +namespace ws { +namespace test { + +std::unique_ptr<service_manager::Service> CreateInProcessWindowService( + ui::ContextFactory* context_factory, + ui::ContextFactoryPrivate* context_factory_private, + std::unique_ptr<GpuInterfaceProvider> gpu_interface_provider) { + auto window_service = std::make_unique<TestWindowService>(); + window_service->InitForInProcess(context_factory, context_factory_private, + std::move(gpu_interface_provider)); + return window_service; +} + +std::unique_ptr<service_manager::Service> CreateOutOfProcessWindowService() { + auto window_service = std::make_unique<TestWindowService>(); + return window_service; +} + +} // namespace test +} // namespace ws diff --git a/chromium/services/ws/test_ws/test_window_service_factory.h b/chromium/services/ws/test_ws/test_window_service_factory.h new file mode 100644 index 00000000000..553ce9849e3 --- /dev/null +++ b/chromium/services/ws/test_ws/test_window_service_factory.h @@ -0,0 +1,36 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef SERVICES_WS_TEST_WS_TEST_WINDOW_SERVICE_FACTORY_H_ +#define SERVICES_WS_TEST_WS_TEST_WINDOW_SERVICE_FACTORY_H_ + +#include <memory> + +namespace service_manager { +class Service; +} // namespace service_manager + +namespace ui { +class ContextFactory; +class ContextFactoryPrivate; +} // namespace ui + +namespace ws { +class GpuInterfaceProvider; +} // namespace ws + +namespace ws { +namespace test { + +std::unique_ptr<service_manager::Service> CreateInProcessWindowService( + ui::ContextFactory* context_factory, + ui::ContextFactoryPrivate* context_factory_private, + std::unique_ptr<GpuInterfaceProvider> gpu_interface_provider); + +std::unique_ptr<service_manager::Service> CreateOutOfProcessWindowService(); + +} // namespace test +} // namespace ws + +#endif // SERVICES_WS_TEST_WS_TEST_WINDOW_SERVICE_FACTORY_H_ diff --git a/chromium/services/ws/test_ws/test_ws.cc b/chromium/services/ws/test_ws/test_ws.cc index 1bc67af0e98..286e589358b 100644 --- a/chromium/services/ws/test_ws/test_ws.cc +++ b/chromium/services/ws/test_ws/test_ws.cc @@ -2,233 +2,21 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include <memory> -#include <utility> - -#include "base/bind.h" -#include "base/callback.h" #include "base/message_loop/message_loop.h" -#include "components/discardable_memory/service/discardable_shared_memory_manager.h" -#include "mojo/public/cpp/bindings/binding_set.h" #include "services/service_manager/public/c/main.h" -#include "services/service_manager/public/cpp/binder_registry.h" -#include "services/service_manager/public/cpp/connector.h" #include "services/service_manager/public/cpp/service.h" -#include "services/service_manager/public/cpp/service_context.h" #include "services/service_manager/public/cpp/service_runner.h" -#include "services/service_manager/public/mojom/service_factory.mojom.h" -#include "services/ws/gpu_host/gpu_host.h" -#include "services/ws/gpu_host/gpu_host_delegate.h" -#include "services/ws/public/mojom/constants.mojom.h" -#include "services/ws/test_ws/test_drag_drop_client.h" -#include "services/ws/test_ws/test_gpu_interface_provider.h" -#include "services/ws/window_service.h" -#include "services/ws/window_service_delegate.h" -#include "ui/aura/test/aura_test_helper.h" -#include "ui/aura/window.h" -#include "ui/aura/window_tree_host.h" +#include "services/ws/test_ws/test_window_service_factory.h" #include "ui/base/ui_base_paths.h" -#include "ui/compositor/test/context_factories_for_test.h" #include "ui/gfx/gfx_paths.h" -#include "ui/gl/test/gl_surface_test_support.h" -#include "ui/wm/core/default_screen_position_client.h" - -namespace ws { -namespace test { - -// Service implementation that brings up the Window Service on top of aura. -// Uses ws2::WindowService to provide the Window Service and -// WindowTreeHostFactory to service requests for connections to the Window -// Service. -class TestWindowService : public service_manager::Service, - public service_manager::mojom::ServiceFactory, - public gpu_host::GpuHostDelegate, - public WindowServiceDelegate { - public: - TestWindowService() = default; - - ~TestWindowService() override { - // WindowService depends upon Screen, which is owned by AuraTestHelper. - service_context_.reset(); - - // |aura_test_helper_| could be null when exiting before fully initialized. - if (aura_test_helper_) { - aura::client::SetScreenPositionClient(aura_test_helper_->root_window(), - nullptr); - // AuraTestHelper expects TearDown() to be called. - aura_test_helper_->TearDown(); - aura_test_helper_.reset(); - } - - ui::TerminateContextFactoryForTests(); - } - - private: - // WindowServiceDelegate: - std::unique_ptr<aura::Window> NewTopLevel( - aura::PropertyConverter* property_converter, - const base::flat_map<std::string, std::vector<uint8_t>>& properties) - override { - std::unique_ptr<aura::Window> top_level = - std::make_unique<aura::Window>(nullptr); - top_level->Init(ui::LAYER_NOT_DRAWN); - aura_test_helper_->root_window()->AddChild(top_level.get()); - for (auto property : properties) { - property_converter->SetPropertyFromTransportValue( - top_level.get(), property.first, &property.second); - } - return top_level; - } - void RunDragLoop(aura::Window* window, - const ui::OSExchangeData& data, - const gfx::Point& screen_location, - uint32_t drag_operation, - ui::DragDropTypes::DragEventSource source, - DragDropCompletedCallback callback) override { - std::move(callback).Run(drag_drop_client_.StartDragAndDrop( - data, window->GetRootWindow(), window, screen_location, drag_operation, - source)); - } - void CancelDragLoop(aura::Window* window) override { - drag_drop_client_.DragCancel(); - } - aura::WindowTreeHost* GetWindowTreeHostForDisplayId( - int64_t display_id) override { - return aura_test_helper_->host(); - } - - // service_manager::Service: - void OnStart() override { - CHECK(!started_); - started_ = true; - - gfx::RegisterPathProvider(); - ui::RegisterPathProvider(); - - registry_.AddInterface(base::BindRepeating( - &TestWindowService::BindServiceFactory, base::Unretained(this))); - -#if defined(OS_CHROMEOS) - // Use gpu service only for ChromeOS to run content_browsertests in mash. - // - // To use this code path for all platforms, we need to fix the following - // flaky failure on Win7 bot: - // gl_surface_egl.cc: - // EGL Driver message (Critical) eglInitialize: No available renderers - // gl_initializer_win.cc: - // GLSurfaceEGL::InitializeOneOff failed. - CreateGpuHost(); -#else - gl::GLSurfaceTestSupport::InitializeOneOff(); - CreateAuraTestHelper(); -#endif // defined(OS_CHROMEOS) - } - void OnBindInterface(const service_manager::BindSourceInfo& source_info, - const std::string& interface_name, - mojo::ScopedMessagePipeHandle interface_pipe) override { - registry_.BindInterface(interface_name, std::move(interface_pipe)); - } - - // service_manager::mojom::ServiceFactory: - void CreateService( - service_manager::mojom::ServiceRequest request, - const std::string& name, - service_manager::mojom::PIDReceiverPtr pid_receiver) override { - DCHECK_EQ(name, mojom::kServiceName); - - // Defer CreateService if |aura_test_helper_| is not created. - if (!aura_test_helper_) { - DCHECK(!pending_create_service_); - - pending_create_service_ = base::BindOnce( - &TestWindowService::CreateService, base::Unretained(this), - std::move(request), name, std::move(pid_receiver)); - return; - } - - DCHECK(!ui_service_created_); - ui_service_created_ = true; - - auto window_service = std::make_unique<WindowService>( - this, - std::make_unique<TestGpuInterfaceProvider>( - gpu_host_.get(), discardable_shared_memory_manager_.get()), - aura_test_helper_->focus_client()); - service_context_ = std::make_unique<service_manager::ServiceContext>( - std::move(window_service), std::move(request)); - pid_receiver->SetPID(base::GetCurrentProcId()); - } - - // gpu_host::GpuHostDelegate: - void OnGpuServiceInitialized() override { - CreateAuraTestHelper(); - - if (pending_create_service_) - std::move(pending_create_service_).Run(); - } - - void BindServiceFactory( - service_manager::mojom::ServiceFactoryRequest request) { - service_factory_bindings_.AddBinding(this, std::move(request)); - } - - void CreateGpuHost() { - discardable_shared_memory_manager_ = - std::make_unique<discardable_memory::DiscardableSharedMemoryManager>(); - - gpu_host_ = std::make_unique<gpu_host::DefaultGpuHost>( - this, context()->connector(), discardable_shared_memory_manager_.get()); - - // |aura_test_helper_| is created later in OnGpuServiceInitialized. - } - - void CreateAuraTestHelper() { - DCHECK(!aura_test_helper_); - - ui::ContextFactory* context_factory = nullptr; - ui::ContextFactoryPrivate* context_factory_private = nullptr; - ui::InitializeContextFactoryForTests(false /* enable_pixel_output */, - &context_factory, - &context_factory_private); - aura_test_helper_ = std::make_unique<aura::test::AuraTestHelper>(); - aura_test_helper_->SetUp(context_factory, context_factory_private); - - aura::client::SetScreenPositionClient(aura_test_helper_->root_window(), - &screen_position_client_); - } - - service_manager::BinderRegistry registry_; - - mojo::BindingSet<service_manager::mojom::ServiceFactory> - service_factory_bindings_; - - // Handles the ServiceRequest. Owns the WindowService instance. - std::unique_ptr<service_manager::ServiceContext> service_context_; - - std::unique_ptr<aura::test::AuraTestHelper> aura_test_helper_; - - std::unique_ptr<discardable_memory::DiscardableSharedMemoryManager> - discardable_shared_memory_manager_; - std::unique_ptr<gpu_host::DefaultGpuHost> gpu_host_; - - // For drag and drop code to convert to/from screen coordinates. - wm::DefaultScreenPositionClient screen_position_client_; - - TestDragDropClient drag_drop_client_; - - bool started_ = false; - bool ui_service_created_ = false; - - base::OnceClosure pending_create_service_; - - DISALLOW_COPY_AND_ASSIGN(TestWindowService); -}; - -} // namespace test -} // namespace ws MojoResult ServiceMain(MojoHandle service_request_handle) { - service_manager::ServiceRunner runner(new ws::test::TestWindowService); + gfx::RegisterPathProvider(); + ui::RegisterPathProvider(); + + // |runner| takes ownership of the created test_ws service. + service_manager::ServiceRunner runner( + ws::test::CreateOutOfProcessWindowService().release()); runner.set_message_loop_type(base::MessageLoop::TYPE_UI); return runner.Run(service_request_handle); } diff --git a/chromium/services/ws/test_ws/test_ws.mojom b/chromium/services/ws/test_ws/test_ws.mojom new file mode 100644 index 00000000000..14e97982a4b --- /dev/null +++ b/chromium/services/ws/test_ws/test_ws.mojom @@ -0,0 +1,15 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +module test_ws.mojom; + +const string kServiceName = "test_ws"; + +// Implemented by TestWindowService. +interface TestWs { + // Used when caller needs to explicitly shutdown the window service hosted + // in test_ws. Callback is provided so that caller can resume its shutdown + // sequence. + Shutdown() => (); +}; diff --git a/chromium/services/ws/topmost_window_observer.cc b/chromium/services/ws/topmost_window_observer.cc index ee2babe9b28..d3796052d1a 100644 --- a/chromium/services/ws/topmost_window_observer.cc +++ b/chromium/services/ws/topmost_window_observer.cc @@ -56,7 +56,7 @@ TopmostWindowObserver::~TopmostWindowObserver() { root_->RemovePreTargetHandler(this); if (topmost_) topmost_->RemoveObserver(this); - if (real_topmost_) + if (real_topmost_ && topmost_ != real_topmost_) real_topmost_->RemoveObserver(this); } @@ -118,20 +118,22 @@ void TopmostWindowObserver::UpdateTopmostWindows() { if (topmost == topmost_ && real_topmost == real_topmost_) return; - if (topmost_ != topmost) { - if (topmost_) - topmost_->RemoveObserver(this); - topmost_ = topmost; - if (topmost_) - topmost_->AddObserver(this); - } - if (real_topmost_ != real_topmost) { - if (real_topmost_) - real_topmost_->RemoveObserver(this); - real_topmost_ = real_topmost; - if (real_topmost_) - real_topmost_->RemoveObserver(this); - } + // Since |topmost_| and |real_topmost_| could be same, updating observation + // for those windows is really complicated. To simplify the logic, here always + // removes this from the old windows and then adds to the new windows. This + // means removing and adding can happen on the same window when |topmost_| or + // |real_topmost_| are same. See topmost_window_observer_unittest.cc for the + // corner cases of the updates. + if (topmost_) + topmost_->RemoveObserver(this); + if (real_topmost_ && real_topmost_ != topmost_) + real_topmost_->RemoveObserver(this); + topmost_ = topmost; + real_topmost_ = real_topmost; + if (topmost_) + topmost_->AddObserver(this); + if (real_topmost_ && real_topmost_ != topmost_) + real_topmost_->AddObserver(this); std::vector<aura::Window*> windows; if (real_topmost_) diff --git a/chromium/services/ws/topmost_window_observer.h b/chromium/services/ws/topmost_window_observer.h index 7cd092e913a..7744dbf902b 100644 --- a/chromium/services/ws/topmost_window_observer.h +++ b/chromium/services/ws/topmost_window_observer.h @@ -22,8 +22,9 @@ class WindowTree; // windows under the mouse cursor or touch location. // TODO(mukai): support multiple displays. -class TopmostWindowObserver : public ui::EventHandler, - public aura::WindowObserver { +class COMPONENT_EXPORT(WINDOW_SERVICE) TopmostWindowObserver + : public ui::EventHandler, + public aura::WindowObserver { public: // |source| determines the type of the event, and |initial_target| is the // initial target of the event. This will report the topmost window under the @@ -34,6 +35,8 @@ class TopmostWindowObserver : public ui::EventHandler, ~TopmostWindowObserver() override; private: + friend class TopmostWindowObserverTest; + // ui::EventHandler: void OnMouseEvent(ui::MouseEvent* event) override; void OnTouchEvent(ui::TouchEvent* event) override; diff --git a/chromium/services/ws/topmost_window_observer_unittest.cc b/chromium/services/ws/topmost_window_observer_unittest.cc new file mode 100644 index 00000000000..950970a3539 --- /dev/null +++ b/chromium/services/ws/topmost_window_observer_unittest.cc @@ -0,0 +1,248 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "services/ws/topmost_window_observer.h" + +#include "services/ws/window_service_test_setup.h" +#include "services/ws/window_tree_test_helper.h" +#include "testing/gtest/include/gtest/gtest.h" +#include "ui/aura/client/screen_position_client.h" +#include "ui/aura/window.h" +#include "ui/wm/core/default_screen_position_client.h" + +namespace ws { + +// This class primarily tests observation of TopmostWindowObserver class. +// The actual logic of observing topmosts needs to be tested with Ash, so those +// tests are done in ash/ws/window_service_delegate_impl_unittest.cc. +class TopmostWindowObserverTest : public testing::Test { + public: + TopmostWindowObserverTest() = default; + + void SetUp() override { + aura::client::SetScreenPositionClient(setup_.root(), + &screen_position_client_); + } + void TearDown() override { + aura::client::SetScreenPositionClient(setup_.root(), nullptr); + } + + protected: + aura::Window* NewWindow() { + aura::Window* window = setup_.window_tree_test_helper()->NewWindow(); + setup_.root()->AddChild(window); + return window; + } + void SetupTopmosts(aura::Window* topmost, aura::Window* real_topmost) { + setup_.delegate()->set_topmost(topmost); + setup_.delegate()->set_real_topmost(real_topmost); + } + std::unique_ptr<TopmostWindowObserver> CreateTopmostWindowObserver( + aura::Window* window) { + return std::make_unique<TopmostWindowObserver>( + setup_.window_tree(), mojom::MoveLoopSource::MOUSE, window); + } + void UpdateTopmostWindows(TopmostWindowObserver* observer) { + observer->UpdateTopmostWindows(); + } + void DeleteWindow(aura::Window* window) { + Id id = setup_.window_tree_test_helper()->TransportIdForWindow(window); + static_cast<mojom::WindowTree*>(setup_.window_tree())->DeleteWindow(1, id); + } + + private: + WindowServiceTestSetup setup_; + wm::DefaultScreenPositionClient screen_position_client_; + + DISALLOW_COPY_AND_ASSIGN(TopmostWindowObserverTest); +}; + +TEST_F(TopmostWindowObserverTest, BasicObserving) { + aura::Window* w1 = NewWindow(); + aura::Window* w2 = NewWindow(); + SetupTopmosts(w1, w2); + auto observer = CreateTopmostWindowObserver(w2); + + EXPECT_TRUE(w1->HasObserver(observer.get())); + EXPECT_TRUE(w2->HasObserver(observer.get())); +} + +TEST_F(TopmostWindowObserverTest, RealTopmostIsNull) { + aura::Window* w1 = NewWindow(); + SetupTopmosts(w1, nullptr); + auto observer = CreateTopmostWindowObserver(w1); + + EXPECT_TRUE(w1->HasObserver(observer.get())); +} + +TEST_F(TopmostWindowObserverTest, TopmostIsNull) { + aura::Window* w1 = NewWindow(); + SetupTopmosts(nullptr, w1); + auto observer = CreateTopmostWindowObserver(w1); + + EXPECT_TRUE(w1->HasObserver(observer.get())); +} + +TEST_F(TopmostWindowObserverTest, UpdateTopmost) { + aura::Window* w1 = NewWindow(); + aura::Window* w2 = NewWindow(); + SetupTopmosts(w1, w2); + auto observer = CreateTopmostWindowObserver(w1); + + EXPECT_TRUE(w1->HasObserver(observer.get())); + EXPECT_TRUE(w2->HasObserver(observer.get())); + + aura::Window* w3 = NewWindow(); + SetupTopmosts(w3, w2); + UpdateTopmostWindows(observer.get()); + + EXPECT_FALSE(w1->HasObserver(observer.get())); + EXPECT_TRUE(w2->HasObserver(observer.get())); + EXPECT_TRUE(w3->HasObserver(observer.get())); +} + +TEST_F(TopmostWindowObserverTest, UpdateRealTopmost) { + aura::Window* w1 = NewWindow(); + aura::Window* w2 = NewWindow(); + SetupTopmosts(w1, w2); + auto observer = CreateTopmostWindowObserver(w1); + + EXPECT_TRUE(w1->HasObserver(observer.get())); + EXPECT_TRUE(w2->HasObserver(observer.get())); + + aura::Window* w3 = NewWindow(); + SetupTopmosts(w1, w3); + UpdateTopmostWindows(observer.get()); + + EXPECT_TRUE(w1->HasObserver(observer.get())); + EXPECT_FALSE(w2->HasObserver(observer.get())); + EXPECT_TRUE(w3->HasObserver(observer.get())); +} + +TEST_F(TopmostWindowObserverTest, ToSameTopmost) { + aura::Window* w1 = NewWindow(); + aura::Window* w2 = NewWindow(); + SetupTopmosts(w1, w2); + auto observer = CreateTopmostWindowObserver(w2); + + EXPECT_TRUE(w1->HasObserver(observer.get())); + EXPECT_TRUE(w2->HasObserver(observer.get())); + + SetupTopmosts(w1, w1); + UpdateTopmostWindows(observer.get()); + + EXPECT_TRUE(w1->HasObserver(observer.get())); + EXPECT_FALSE(w2->HasObserver(observer.get())); +} + +TEST_F(TopmostWindowObserverTest, ToSameRealTopmost) { + aura::Window* w1 = NewWindow(); + aura::Window* w2 = NewWindow(); + SetupTopmosts(w1, w2); + auto observer = CreateTopmostWindowObserver(w2); + + EXPECT_TRUE(w1->HasObserver(observer.get())); + EXPECT_TRUE(w2->HasObserver(observer.get())); + + SetupTopmosts(w2, w2); + UpdateTopmostWindows(observer.get()); + + EXPECT_FALSE(w1->HasObserver(observer.get())); + EXPECT_TRUE(w2->HasObserver(observer.get())); +} + +TEST_F(TopmostWindowObserverTest, SameToDifferent) { + aura::Window* w1 = NewWindow(); + aura::Window* w2 = NewWindow(); + SetupTopmosts(w1, w1); + auto observer = CreateTopmostWindowObserver(w1); + + EXPECT_TRUE(w1->HasObserver(observer.get())); + EXPECT_FALSE(w2->HasObserver(observer.get())); + + SetupTopmosts(w1, w2); + UpdateTopmostWindows(observer.get()); + + EXPECT_TRUE(w1->HasObserver(observer.get())); + EXPECT_TRUE(w2->HasObserver(observer.get())); +} + +TEST_F(TopmostWindowObserverTest, SameToDifferent2) { + aura::Window* w1 = NewWindow(); + aura::Window* w2 = NewWindow(); + SetupTopmosts(w1, w1); + auto observer = CreateTopmostWindowObserver(w1); + + EXPECT_TRUE(w1->HasObserver(observer.get())); + EXPECT_FALSE(w2->HasObserver(observer.get())); + + SetupTopmosts(w2, w1); + UpdateTopmostWindows(observer.get()); + + EXPECT_TRUE(w1->HasObserver(observer.get())); + EXPECT_TRUE(w2->HasObserver(observer.get())); +} + +TEST_F(TopmostWindowObserverTest, SameToDifferent3) { + aura::Window* w1 = NewWindow(); + SetupTopmosts(w1, w1); + auto observer = CreateTopmostWindowObserver(w1); + + EXPECT_TRUE(w1->HasObserver(observer.get())); + + aura::Window* w2 = NewWindow(); + aura::Window* w3 = NewWindow(); + + SetupTopmosts(w2, w3); + UpdateTopmostWindows(observer.get()); + + EXPECT_FALSE(w1->HasObserver(observer.get())); + EXPECT_TRUE(w2->HasObserver(observer.get())); + EXPECT_TRUE(w3->HasObserver(observer.get())); +} + +TEST_F(TopmostWindowObserverTest, SameToSame) { + aura::Window* w1 = NewWindow(); + aura::Window* w2 = NewWindow(); + SetupTopmosts(w1, w1); + auto observer = CreateTopmostWindowObserver(w1); + + EXPECT_TRUE(w1->HasObserver(observer.get())); + EXPECT_FALSE(w2->HasObserver(observer.get())); + + SetupTopmosts(w2, w2); + UpdateTopmostWindows(observer.get()); + + EXPECT_FALSE(w1->HasObserver(observer.get())); + EXPECT_TRUE(w2->HasObserver(observer.get())); +} + +TEST_F(TopmostWindowObserverTest, SwapObservingWindows) { + aura::Window* w1 = NewWindow(); + aura::Window* w2 = NewWindow(); + SetupTopmosts(w1, w2); + auto observer = CreateTopmostWindowObserver(w1); + + EXPECT_TRUE(w1->HasObserver(observer.get())); + EXPECT_TRUE(w2->HasObserver(observer.get())); + + SetupTopmosts(w2, w1); + UpdateTopmostWindows(observer.get()); + + EXPECT_TRUE(w1->HasObserver(observer.get())); + EXPECT_TRUE(w2->HasObserver(observer.get())); +} + +TEST_F(TopmostWindowObserverTest, WindowDestroying) { + aura::Window* w1 = NewWindow(); + SetupTopmosts(nullptr, w1); + auto observer = CreateTopmostWindowObserver(w1); + + EXPECT_TRUE(w1->HasObserver(observer.get())); + + SetupTopmosts(nullptr, nullptr); + DeleteWindow(w1); +} + +} // namespace ws diff --git a/chromium/services/ws/window_manager_interface.h b/chromium/services/ws/window_manager_interface.h new file mode 100644 index 00000000000..72b920e2afa --- /dev/null +++ b/chromium/services/ws/window_manager_interface.h @@ -0,0 +1,23 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef SERVICES_WS_WINDOW_MANAGER_INTERFACE_H_ +#define SERVICES_WS_WINDOW_MANAGER_INTERFACE_H_ + +#include "base/component_export.h" +#include "base/macros.h" + +namespace ws { + +// Used for any associated interfaces that the local environment exposes to +// clients. +class COMPONENT_EXPORT(WINDOW_SERVICE) WindowManagerInterface { + public: + WindowManagerInterface() {} + virtual ~WindowManagerInterface() {} +}; + +} // namespace ws + +#endif // SERVICES_WS_WINDOW_MANAGER_INTERFACE_H_ diff --git a/chromium/services/ws/window_service.cc b/chromium/services/ws/window_service.cc index 0c58fd7b485..9f9282930da 100644 --- a/chromium/services/ws/window_service.cc +++ b/chromium/services/ws/window_service.cc @@ -12,7 +12,7 @@ #include "services/ws/common/switches.h" #include "services/ws/embedding.h" #include "services/ws/event_injector.h" -#include "services/ws/gpu_interface_provider.h" +#include "services/ws/public/cpp/host/gpu_interface_provider.h" #include "services/ws/public/mojom/window_manager.mojom.h" #include "services/ws/remoting_event_injector.h" #include "services/ws/screen_provider.h" diff --git a/chromium/services/ws/window_service_delegate.cc b/chromium/services/ws/window_service_delegate.cc index f4c4466c609..dc3b0c87f6a 100644 --- a/chromium/services/ws/window_service_delegate.cc +++ b/chromium/services/ws/window_service_delegate.cc @@ -4,6 +4,8 @@ #include "services/ws/window_service_delegate.h" +#include "services/ws/window_manager_interface.h" + namespace ws { bool WindowServiceDelegate::StoreAndSetCursor(aura::Window* window, @@ -44,4 +46,12 @@ aura::Window* WindowServiceDelegate::GetTopmostWindowAtPoint( return nullptr; } +std::unique_ptr<WindowManagerInterface> +WindowServiceDelegate::CreateWindowManagerInterface( + WindowTree* window_tree, + const std::string& name, + mojo::ScopedInterfaceEndpointHandle handle) { + return nullptr; +} + } // namespace ws diff --git a/chromium/services/ws/window_service_delegate.h b/chromium/services/ws/window_service_delegate.h index b59d0b229e8..d1f2c684342 100644 --- a/chromium/services/ws/window_service_delegate.h +++ b/chromium/services/ws/window_service_delegate.h @@ -30,6 +30,10 @@ namespace gfx { class Point; } +namespace mojo { +class ScopedInterfaceEndpointHandle; +} + namespace ui { class KeyEvent; class OSExchangeData; @@ -38,6 +42,9 @@ class SystemInputInjector; namespace ws { +class WindowManagerInterface; +class WindowTree; + // A delegate used by the WindowService for context-specific operations. class COMPONENT_EXPORT(WINDOW_SERVICE) WindowServiceDelegate { public: @@ -124,6 +131,21 @@ class COMPONENT_EXPORT(WINDOW_SERVICE) WindowServiceDelegate { const std::set<aura::Window*>& ignore, aura::Window** real_topmost); + // Creates and binds a request for an interface provided by the local + // environment. The interface request originated from the client associated + // with |tree|. |name| is the name of the requested interface. The return + // value is owned by |tree|. Return null if |name| is not the name of a known + // interface. + // The following shows how to bind |handle|: + // TestWmInterface* wm_interface_impl = ...; + // mojo::AssociatedBindingTestWmInterface> binding( + // wm_interface_impl, + // mojo::AssociatedInterfaceRequest<TestWmInterface>(std::move(handle))); + virtual std::unique_ptr<WindowManagerInterface> CreateWindowManagerInterface( + WindowTree* tree, + const std::string& name, + mojo::ScopedInterfaceEndpointHandle handle); + protected: virtual ~WindowServiceDelegate() = default; }; diff --git a/chromium/services/ws/window_service_test_setup.cc b/chromium/services/ws/window_service_test_setup.cc index e3f16615eb7..e626b5f6092 100644 --- a/chromium/services/ws/window_service_test_setup.cc +++ b/chromium/services/ws/window_service_test_setup.cc @@ -5,7 +5,7 @@ #include "services/ws/window_service_test_setup.h" #include "services/ws/embedding.h" -#include "services/ws/gpu_interface_provider.h" +#include "services/ws/public/cpp/host/gpu_interface_provider.h" #include "services/ws/window_service.h" #include "services/ws/window_tree.h" #include "services/ws/window_tree_binding.h" diff --git a/chromium/services/ws/window_service_unittest.cc b/chromium/services/ws/window_service_unittest.cc index be552d82801..84dde5b5634 100644 --- a/chromium/services/ws/window_service_unittest.cc +++ b/chromium/services/ws/window_service_unittest.cc @@ -9,9 +9,11 @@ #include "base/run_loop.h" #include "services/service_manager/public/cpp/connector.h" #include "services/service_manager/public/cpp/test/test_connector_factory.h" -#include "services/ws/gpu_interface_provider.h" +#include "services/ws/public/cpp/host/gpu_interface_provider.h" #include "services/ws/public/mojom/constants.mojom.h" #include "services/ws/public/mojom/window_tree.mojom.h" +#include "services/ws/test_wm.mojom.h" +#include "services/ws/window_manager_interface.h" #include "services/ws/window_service_test_setup.h" #include "services/ws/window_tree.h" #include "services/ws/window_tree_test_helper.h" @@ -63,6 +65,87 @@ TEST(WindowServiceTest, DeleteWithClients) { // ensure a DCHECK isn't hit in ~WindowTree. } +// Implementation of mojom::TestWm that sets a boolean when DoIt() is called. +class TestWm : public WindowManagerInterface, public test::mojom::TestWm { + public: + TestWm(mojo::ScopedInterfaceEndpointHandle handle, bool* do_it_called) + : binding_(this, + mojo::AssociatedInterfaceRequest<test::mojom::TestWm>( + std::move(handle))), + do_it_called_(do_it_called) {} + + // test::mojom::TestWm: + void DoIt() override { *do_it_called_ = true; } + + private: + mojo::AssociatedBinding<test::mojom::TestWm> binding_; + bool* do_it_called_; + + DISALLOW_COPY_AND_ASSIGN(TestWm); +}; + +// Subclass os TestWindowServiceDelegate that creates TestWm. +class TestWindowServiceDelegateWithInterface + : public TestWindowServiceDelegate { + public: + TestWindowServiceDelegateWithInterface() = default; + ~TestWindowServiceDelegateWithInterface() override = default; + + bool do_it_called() const { return do_it_called_; } + + // TestWindowServiceDelegate: + std::unique_ptr<WindowManagerInterface> CreateWindowManagerInterface( + WindowTree* window_tree, + const std::string& name, + mojo::ScopedInterfaceEndpointHandle handle) override { + if (name != test::mojom::TestWm::Name_) + return nullptr; + + return std::make_unique<TestWm>(std::move(handle), &do_it_called_); + } + + private: + bool do_it_called_ = false; + + DISALLOW_COPY_AND_ASSIGN(TestWindowServiceDelegateWithInterface); +}; + +TEST(WindowServiceTest, GetWindowManagerInterface) { + // Use |test_setup| to configure aura and other state. + WindowServiceTestSetup test_setup; + + // Create another WindowService. + TestWindowServiceDelegateWithInterface test_window_service_delegate; + std::unique_ptr<WindowService> window_service_ptr = + std::make_unique<WindowService>(&test_window_service_delegate, nullptr, + test_setup.focus_controller()); + std::unique_ptr<service_manager::TestConnectorFactory> factory = + service_manager::TestConnectorFactory::CreateForUniqueService( + std::move(window_service_ptr)); + std::unique_ptr<service_manager::Connector> connector = + factory->CreateConnector(); + + // Connect to |window_service| and ask for a new WindowTree. + mojom::WindowTreeFactoryPtr window_tree_factory; + connector->BindInterface(mojom::kServiceName, &window_tree_factory); + mojom::WindowTreePtr window_tree; + mojom::WindowTreeClientPtr client; + mojom::WindowTreeClientRequest client_request = MakeRequest(&client); + window_tree_factory->CreateWindowTree(MakeRequest(&window_tree), + std::move(client)); + + // Request the TestWm interface and call a function on it. + mojom::WindowManagerAssociatedPtr wm; + window_tree->BindWindowManagerInterface(test::mojom::TestWm::Name_, + MakeRequest(&wm)); + test::mojom::TestWmAssociatedPtr test_wm( + mojo::AssociatedInterfacePtrInfo<test::mojom::TestWm>( + wm.PassInterface().PassHandle(), test::mojom::TestWm::Version_)); + test_wm->DoIt(); + test_wm.FlushForTesting(); + EXPECT_TRUE(test_window_service_delegate.do_it_called()); +} + // Test client ids assigned to window trees that connect to the window service. TEST(WindowServiceTest, ClientIds) { // Use |test_setup| to configure aura and other state. diff --git a/chromium/services/ws/window_tree.cc b/chromium/services/ws/window_tree.cc index 90a19614ee4..18bd02d2305 100644 --- a/chromium/services/ws/window_tree.cc +++ b/chromium/services/ws/window_tree.cc @@ -12,6 +12,7 @@ #include "base/unguessable_token.h" #include "components/viz/common/surfaces/parent_local_surface_id_allocator.h" #include "components/viz/common/surfaces/surface_info.h" +#include "mojo/public/cpp/bindings/associated_binding.h" #include "mojo/public/cpp/bindings/map.h" #include "services/ws/client_change.h" #include "services/ws/client_change_tracker.h" @@ -20,9 +21,11 @@ #include "services/ws/drag_drop_delegate.h" #include "services/ws/embedding.h" #include "services/ws/pointer_watcher.h" +#include "services/ws/public/cpp/property_type_converters.h" #include "services/ws/server_window.h" #include "services/ws/topmost_window_observer.h" #include "services/ws/window_delegate_impl.h" +#include "services/ws/window_manager_interface.h" #include "services/ws/window_service.h" #include "services/ws/window_service_delegate.h" #include "services/ws/window_service_observer.h" @@ -43,6 +46,8 @@ #include "ui/compositor/layer_type.h" #include "ui/display/display.h" #include "ui/display/screen.h" +#include "ui/events/event_utils.h" +#include "ui/events/gestures/gesture_recognizer.h" #include "ui/gfx/image/image_skia.h" #include "ui/wm/core/capture_controller.h" #include "ui/wm/core/window_modality_controller.h" @@ -67,6 +72,11 @@ uint32_t GenerateEventAckId() { return 0x1000000 | (rand() & 0xffffff); } +gfx::Insets MakeInsetsPositive(const gfx::Insets& insets) { + return gfx::Insets(std::max(0, insets.top()), std::max(0, insets.left()), + std::max(0, insets.bottom()), std::max(0, insets.right())); +} + } // namespace // Used to track events sent to the client. @@ -80,6 +90,9 @@ struct WindowTree::InFlightEvent { std::unique_ptr<ui::Event> event; }; +WindowTree::KnownWindow::KnownWindow() = default; +WindowTree::KnownWindow::~KnownWindow() = default; + WindowTree::WindowTree(WindowService* window_service, ClientSpecificId client_id, mojom::WindowTreeClient* client, @@ -102,11 +115,11 @@ WindowTree::~WindowTree() { DeleteClientRootReason::kDestructor); } - while (!client_created_windows_.empty()) { + while (FindFirstClientCreatedWindow()) { // RemoveWindowFromKnownWindows() should make it such that the Window is no // longer recognized as being created (owned) by this client. const bool delete_if_owned = true; - RemoveWindowFromKnownWindows(client_created_windows_.begin()->first, + RemoveWindowFromKnownWindows(FindFirstClientCreatedWindow(), delete_if_owned); } @@ -119,15 +132,14 @@ void WindowTree::InitForEmbed(aura::Window* root, ServerWindow* server_window = window_service_->GetServerWindowForWindowCreateIfNecessary(root); const ClientWindowId client_window_id = server_window->frame_sink_id(); - AddWindowToKnownWindows(root, client_window_id); + AddWindowToKnownWindows(root, client_window_id, nullptr); const bool is_top_level = false; ClientRoot* client_root = CreateClientRoot(root, is_top_level); const int64_t display_id = display::Screen::GetScreen()->GetDisplayNearestWindow(root).id(); const ClientWindowId focused_window_id = - root->HasFocus() ? window_to_client_window_id_map_[root] - : ClientWindowId(); + root->HasFocus() ? ClientWindowIdForWindow(root) : ClientWindowId(); const bool drawn = root->IsVisible() && root->GetHost(); window_tree_client_->OnEmbed(WindowToWindowData(root), std::move(window_tree_ptr), display_id, @@ -180,25 +192,27 @@ void WindowTree::SendEventToClient(aura::Window* window, for (WindowServiceObserver& observer : window_service_->observers()) observer.OnWillSendEventToClient(client_id_, event_id); + std::unique_ptr<ui::Event> event_to_send = ui::Event::Clone(event); // Translate the root location for located events. Event's root location // should be in the coordinate of the root window, however the root for the // target window in the client can be different from the one in the server, // thus the root location needs to be converted from the original coordinate // to the one used in the client. See also 'WindowTreeTest.EventLocation' test // case. - std::unique_ptr<ui::Event> event_to_send = - PointerWatcher::CreateEventForClient(event); if (event.IsLocatedEvent()) { - aura::Window* client_root_window = GetClientRootWindowFor(window); - // The client_root_ may have been removed on shutdown. - if (client_root_window) { + ClientRoot* client_root = FindClientRootContaining(window); + // The |client_root| may have been removed on shutdown. + if (client_root) { gfx::PointF root_location = event_to_send->AsLocatedEvent()->root_location_f(); aura::Window::ConvertPointToTarget(window->GetRootWindow(), - client_root_window, &root_location); + client_root->window(), &root_location); event_to_send->AsLocatedEvent()->set_root_location_f(root_location); } } + DVLOG(4) << "SendEventToClient window=" + << ServerWindow::GetMayBeNull(window)->GetIdForDebugging() + << " event_type=" << ui::EventTypeName(event.type()); window_tree_client_->OnWindowInputEvent( event_id, TransportIdForWindow(window), display_id, std::move(event_to_send), matches_pointer_watcher); @@ -216,6 +230,10 @@ bool WindowTree::IsTopLevel(aura::Window* window) { return iter != client_roots_.end() && (*iter)->is_top_level(); } +aura::Window* WindowTree::GetWindowByTransportId(Id transport_window_id) { + return GetWindowByClientId(MakeClientWindowId(transport_window_id)); +} + void WindowTree::RequestClose(ServerWindow* window) { DCHECK(window->IsTopLevel()); DCHECK_EQ(this, window->owning_window_tree()); @@ -248,7 +266,7 @@ void WindowTree::CompleteScheduleEmbedForExistingClient( aura::Window* window, const ClientWindowId& id, const base::UnguessableToken& token) { - AddWindowToKnownWindows(window, id); + AddWindowToKnownWindows(window, id, nullptr); const bool is_top_level = false; ClientRoot* client_root = CreateClientRoot(window, is_top_level); @@ -279,10 +297,19 @@ bool WindowTree::HasAtLeastOneRootWithCompositorFrameSink() { return false; } -ClientWindowId WindowTree::ClientWindowIdForWindow(aura::Window* window) { - auto iter = window_to_client_window_id_map_.find(window); - return iter == window_to_client_window_id_map_.end() ? ClientWindowId() - : iter->second; +bool WindowTree::IsWindowKnown(aura::Window* window) const { + return window && known_windows_map_.count(window) > 0u; +} + +ClientWindowId WindowTree::ClientWindowIdForWindow(aura::Window* window) const { + auto iter = known_windows_map_.find(window); + return iter == known_windows_map_.end() ? ClientWindowId() + : iter->second.client_window_id; +} + +ClientRoot* WindowTree::GetClientRootForWindow(aura::Window* window) { + auto iter = FindClientRootWithRoot(window); + return iter == client_roots_.end() ? nullptr : iter->get(); } ClientRoot* WindowTree::CreateClientRoot(aura::Window* window, @@ -313,6 +340,7 @@ void WindowTree::DeleteClientRoot(ClientRoot* client_root, aura::Window* window = client_root->window(); ServerWindow* server_window = ServerWindow::GetMayBeNull(window); + client_root->UnattachChildFrameSinkIdRecursive(server_window); if (server_window->capture_owner() == this) { // This client will no longer know about |window|, so it should not receive // any events sent to the client. @@ -392,25 +420,23 @@ aura::Window* WindowTree::GetWindowByClientId(const ClientWindowId& id) { return iter == client_window_id_to_window_map_.end() ? nullptr : iter->second; } -aura::Window* WindowTree::GetWindowByTransportId(Id transport_window_id) { - return GetWindowByClientId(MakeClientWindowId(transport_window_id)); -} - bool WindowTree::IsClientCreatedWindow(aura::Window* window) { - return window && client_created_windows_.count(window) > 0u; + auto iter = known_windows_map_.find(window); + return iter == known_windows_map_.end() ? false + : iter->second.is_client_created; } bool WindowTree::IsClientRootWindow(aura::Window* window) { return window && FindClientRootWithRoot(window) != client_roots_.end(); } -aura::Window* WindowTree::GetClientRootWindowFor(aura::Window* window) { +ClientRoot* WindowTree::FindClientRootContaining(aura::Window* window) { if (!window) return nullptr; auto iter = FindClientRootWithRoot(window); if (iter != client_roots_.end()) - return iter->get()->window(); - return GetClientRootWindowFor(window->parent()); + return iter->get(); + return FindClientRootContaining(window->parent()); } WindowTree::ClientRoots::iterator WindowTree::FindClientRootWithRoot( @@ -424,16 +450,21 @@ WindowTree::ClientRoots::iterator WindowTree::FindClientRootWithRoot( return client_roots_.end(); } -bool WindowTree::IsWindowKnown(aura::Window* window) const { - return window && window_to_client_window_id_map_.count(window) > 0u; -} - bool WindowTree::IsWindowRootOfAnotherClient(aura::Window* window) const { ServerWindow* server_window = ServerWindow::GetMayBeNull(window); return server_window && server_window->embedded_window_tree() != nullptr && server_window->embedded_window_tree() != this; } +bool WindowTree::DoesAnyAncestorInterceptEvents(ServerWindow* window) { + if (window->embedding() && window->embedding()->embedding_tree() != this && + window->embedding()->embedding_tree_intercepts_events()) { + return true; + } + ServerWindow* parent = ServerWindow::GetMayBeNull(window->window()->parent()); + return parent && DoesAnyAncestorInterceptEvents(parent); +} + void WindowTree::OnCaptureLost(aura::Window* lost_capture) { DCHECK(IsWindowKnown(lost_capture)); window_tree_client_->OnCaptureChanged(kInvalidTransportId, @@ -492,21 +523,33 @@ void WindowTree::OnPerformDragDropDone(uint32_t change_id, int drag_result) { change_id, drag_result != ui::DragDropTypes::DRAG_NONE, drag_result); } +aura::Window* WindowTree::FindFirstClientCreatedWindow() { + for (auto& pair : known_windows_map_) { + if (pair.second.is_client_created) + return pair.first; + } + return nullptr; +} + aura::Window* WindowTree::AddClientCreatedWindow( const ClientWindowId& id, bool is_top_level, std::unique_ptr<aura::Window> window_ptr) { aura::Window* window = window_ptr.get(); - client_created_windows_[window] = std::move(window_ptr); ServerWindow::Create(window, this, id, is_top_level); - AddWindowToKnownWindows(window, id); + AddWindowToKnownWindows(window, id, std::move(window_ptr)); return window; } -void WindowTree::AddWindowToKnownWindows(aura::Window* window, - const ClientWindowId& id) { - DCHECK_EQ(0u, window_to_client_window_id_map_.count(window)); - window_to_client_window_id_map_[window] = id; +void WindowTree::AddWindowToKnownWindows( + aura::Window* window, + const ClientWindowId& id, + std::unique_ptr<aura::Window> owned_window) { + DCHECK(!IsWindowKnown(window)); + KnownWindow& known_window = known_windows_map_[window]; + known_window.client_window_id = id; + known_window.is_client_created = owned_window.get() != nullptr; + known_window.owned_window = std::move(owned_window); DCHECK(IsWindowKnown(window)); client_window_id_to_window_map_[id] = window; @@ -517,21 +560,33 @@ void WindowTree::AddWindowToKnownWindows(aura::Window* window, void WindowTree::RemoveWindowFromKnownWindows(aura::Window* window, bool delete_if_owned) { DCHECK(IsWindowKnown(window)); - auto client_iter = client_created_windows_.find(window); - if (client_iter != client_created_windows_.end()) { + + ServerWindow* server_window = ServerWindow::GetMayBeNull(window); + ClientRoot* client_root = FindClientRootContaining(window); + if (client_root) + client_root->UnattachChildFrameSinkIdRecursive(server_window); + + server_window->set_attached_frame_sink_id(viz::FrameSinkId()); + + auto iter = known_windows_map_.find(window); + DCHECK(iter != known_windows_map_.end()); + if (iter->second.owned_window) { window->RemoveObserver(this); if (!delete_if_owned) { // |window| is in the process of being deleted, release() to avoid double // deletion. - client_iter->second.release(); + iter->second.owned_window.release(); } - client_created_windows_.erase(client_iter); + iter->second.owned_window.reset(); } + // Sanity check to make sure deletion didn't result in removal + DCHECK(iter == known_windows_map_.find(window)); + // Remove from these maps after destruction. This is necessary as destruction // may end up expecting to find a ServerWindow. - auto iter = window_to_client_window_id_map_.find(window); - client_window_id_to_window_map_.erase(iter->second); - window_to_client_window_id_map_.erase(iter); + DCHECK(iter != known_windows_map_.end()); + client_window_id_to_window_map_.erase(iter->second.client_window_id); + known_windows_map_.erase(iter); } void WindowTree::RemoveWindowFromKnownWindowsRecursive( @@ -569,9 +624,7 @@ Id WindowTree::ClientWindowIdToTransportId( Id WindowTree::TransportIdForWindow(aura::Window* window) const { DCHECK(IsWindowKnown(window)); - auto iter = window_to_client_window_id_map_.find(window); - DCHECK(iter != window_to_client_window_id_map_.end()); - return ClientWindowIdToTransportId(iter->second); + return ClientWindowIdToTransportId(ClientWindowIdForWindow(window)); } ClientWindowId WindowTree::MakeClientWindowId(Id transport_window_id) const { @@ -733,6 +786,14 @@ bool WindowTree::SetCaptureImpl(const ClientWindowId& window_id) { ServerWindow* server_window = ServerWindow::GetMayBeNull(window); + if (DoesAnyAncestorInterceptEvents(server_window)) { + // If an ancestor is intercepting events, than the descendants are not + // allowed to set capture. This is primarily to prevent renderers from + // setting capture. + DVLOG(1) << "SetCapture failed (ancestor intercepts events)"; + return false; + } + wm::CaptureController* capture_controller = wm::CaptureController::Get(); DCHECK(capture_controller); @@ -935,28 +996,6 @@ bool WindowTree::SetModalTypeImpl(const ClientWindowId& client_window_id, return true; } -bool WindowTree::SetChildModalParentImpl(const ClientWindowId& child_id, - const ClientWindowId& parent_id) { - DVLOG(3) << "setting child window modal parent client=" << client_id_ - << " child_id=" << child_id << " parent_id=" << parent_id; - aura::Window* child = GetWindowByClientId(child_id); - aura::Window* parent = GetWindowByClientId(parent_id); - // A value of null for |parent_id| resets the modal parent. - if (!child) { - DVLOG(1) << "SetChildModalParent failed (invalid id)"; - return false; - } - - if (!IsClientCreatedWindow(child) || - (parent && !IsClientCreatedWindow(parent))) { - DVLOG(1) << "SetChildModalParent failed (access denied)"; - return false; - } - - wm::SetModalParent(child, parent); - return true; -} - bool WindowTree::SetWindowVisibilityImpl(const ClientWindowId& window_id, bool visible) { aura::Window* window = GetWindowByClientId(window_id); @@ -991,21 +1030,39 @@ bool WindowTree::SetWindowPropertyImpl( DVLOG(1) << "SetWindowProperty failed (no window)"; return false; } - DCHECK(window_service_->property_converter()->IsTransportNameRegistered(name)) - << "Attempting to set an unregistered property; this is not implemented. " - << "property name=" << name; + aura::PropertyConverter* property_converter = + window_service_->property_converter(); + if (!property_converter->IsTransportNameRegistered(name)) { + NOTREACHED() << "Attempting to set an unregistered property; this is not " + "implemented. property name=" + << name; + return false; + } if (!IsClientCreatedWindow(window) && !IsClientRootWindow(window)) { DVLOG(1) << "SetWindowProperty failed (access policy denied change)"; return false; } - ClientChange change(property_change_tracker_.get(), window, - ClientChangeType::kProperty); + ClientChange change( + property_change_tracker_.get(), window, ClientChangeType::kProperty, + property_converter->GetPropertyKeyFromTransportName(name)); + + // Special handle the property whose value is a pointer to aura::Window since + // property converter can't convert the transported value. + const aura::WindowProperty<aura::Window*>* property = + property_converter->GetWindowPtrProperty(name); + if (property) { + aura::Window* prop_window = nullptr; + if (value.has_value()) + prop_window = GetWindowByTransportId(mojo::ConvertTo<Id>(value.value())); + window->SetProperty(property, prop_window); + return true; + } + std::unique_ptr<std::vector<uint8_t>> data; if (value.has_value()) data = std::make_unique<std::vector<uint8_t>>(value.value()); - window_service_->property_converter()->SetPropertyFromTransportValue( - window, name, data.get()); + property_converter->SetPropertyFromTransportValue(window, name, data.get()); return true; } @@ -1025,7 +1082,8 @@ bool WindowTree::EmbedImpl(const ClientWindowId& window_id, } const bool owner_intercept_events = - (flags & mojom::kEmbedFlagEmbedderInterceptsEvents) != 0; + (connection_type_ != ConnectionType::kEmbedding && + (flags & mojom::kEmbedFlagEmbedderInterceptsEvents) != 0); std::unique_ptr<Embedding> embedding = std::make_unique<Embedding>(this, window, owner_intercept_events); embedding->Init(window_service_, std::move(window_tree_client_ptr), @@ -1081,11 +1139,11 @@ bool WindowTree::SetWindowBoundsImpl( ServerWindow* server_window = ServerWindow::GetMayBeNull(window); const gfx::Rect original_bounds = IsTopLevel(window) ? window->GetBoundsInScreen() : window->bounds(); + const bool local_surface_id_changed = + server_window->local_surface_id() != local_surface_id; - if (original_bounds == bounds && - server_window->local_surface_id() == local_surface_id) { + if (original_bounds == bounds && !local_surface_id_changed) return true; - } ClientChange change(property_change_tracker_.get(), window, ClientChangeType::kBounds); @@ -1114,8 +1172,21 @@ bool WindowTree::SetWindowBoundsImpl( return false; } - if (window->bounds() == original_bounds) - return false; + if (window->bounds() == original_bounds) { + if (local_surface_id_changed) { + // If the bounds didn't change, but the LocalSurfaceId did, then the + // LocalSurfaceId needs to be propagated to any embeddings. + if (server_window->HasEmbedding() && + server_window->embedding()->embedding_tree() == this) { + WindowTree* embedded_tree = server_window->embedding()->embedded_tree(); + ClientRoot* embedded_client_root = + embedded_tree->GetClientRootForWindow(window); + DCHECK(embedded_client_root); + embedded_client_root->OnLocalSurfaceIdChanged(); + } + } + return (bounds == original_bounds); + } if (window->bounds() == bounds && server_window->local_surface_id() == local_surface_id) { @@ -1243,6 +1314,24 @@ void WindowTree::OnEmbeddedClientConnectionLost(Embedding* embedding) { ServerWindow::GetMayBeNull(embedding->window())->SetEmbedding(nullptr); } +void WindowTree::OnWindowHierarchyChanging( + const HierarchyChangeParams& params) { + if (params.target != params.receiver || !IsClientCreatedWindow(params.target)) + return; + + ServerWindow* server_window = ServerWindow::GetMayBeNull(params.target); + DCHECK(server_window); // non-null because of IsClientCreatedWindow() check. + ClientRoot* old_root = FindClientRootContaining(params.old_parent); + ClientRoot* new_root = FindClientRootContaining(params.new_parent); + if (old_root == new_root) + return; + + if (old_root) + old_root->UnattachChildFrameSinkIdRecursive(server_window); + if (new_root) + new_root->AttachChildFrameSinkIdRecursive(server_window); +} + void WindowTree::OnWindowDestroyed(aura::Window* window) { DCHECK(IsWindowKnown(window)); @@ -1254,11 +1343,8 @@ void WindowTree::OnWindowDestroyed(aura::Window* window) { if (iter != client_roots_.end()) DeleteClientRoot(iter->get(), WindowTree::DeleteClientRootReason::kDeleted); - DCHECK_NE(0u, window_to_client_window_id_map_.count(window)); - const ClientWindowId client_window_id = - window_to_client_window_id_map_[window]; - window_tree_client_->OnWindowDeleted( - ClientWindowIdToTransportId(client_window_id)); + DCHECK(IsWindowKnown(window)); + window_tree_client_->OnWindowDeleted(TransportIdForWindow(window)); const bool delete_if_owned = false; RemoveWindowFromKnownWindows(window, delete_if_owned); @@ -1423,29 +1509,72 @@ void WindowTree::SetClientArea( insets, additional_client_areas.value_or(std::vector<gfx::Rect>())); } -void WindowTree::SetHitTestMask(Id transport_window_id, - const base::Optional<gfx::Rect>& mask) { +void WindowTree::SetHitTestInsets(Id transport_window_id, + const gfx::Insets& mouse, + const gfx::Insets& touch) { const ClientWindowId window_id = MakeClientWindowId(transport_window_id); aura::Window* window = GetWindowByClientId(window_id); - DVLOG(3) << "SetHitTestMask client window_id=" << window_id.ToString() - << " mask=" << (mask ? mask.value().ToString() : "null"); + DVLOG(3) << "SetHitTestInsets client window_id=" << window_id.ToString() + << " mouse=" << mouse.ToString() << " touch=" << touch.ToString(); if (!window) { - DVLOG(1) << "SetHitTestMask failed (invalid window id)"; + DVLOG(1) << "SetHitTestInsets failed (invalid window id)"; return; } if (!IsClientCreatedWindow(window)) { - DVLOG(1) << "SetHitTestMask failed (access denied)"; + DVLOG(1) << "SetHitTestInsets failed (access denied)"; return; } - const gfx::Rect window_local_bounds(window->bounds().size()); - if (mask && !window_local_bounds.Contains(mask.value())) { - DVLOG(1) << "SetHitTestMask failed (mask extends beyond window bounds)"; + + ServerWindow* server_window = ServerWindow::GetMayBeNull(window); + DCHECK(server_window); // Must exist because of preceding conditionals. + server_window->SetHitTestInsets(MakeInsetsPositive(mouse), + MakeInsetsPositive(touch)); +} + +void WindowTree::AttachFrameSinkId(Id transport_window_id, + const viz::FrameSinkId& f) { + if (!f.is_valid()) { + DVLOG(3) << "AttachFrameSinkId failed (invalid frame sink)"; return; } + const ClientWindowId window_id = MakeClientWindowId(transport_window_id); + aura::Window* window = GetWindowByClientId(window_id); + if (!window || !IsClientCreatedWindow(window)) { + DVLOG(3) << "AttachFrameSinkId failed (invalid window id)"; + return; + } + ServerWindow* server_window = ServerWindow::GetMayBeNull(window); + DCHECK(server_window); // Must exist because of preceding conditionals. + if (server_window->attached_frame_sink_id() == f) + return; + if (f.is_valid() && server_window->attached_frame_sink_id().is_valid()) { + DVLOG(3) << "AttachFrameSinkId failed (window already has frame sink)"; + return; + } + server_window->set_attached_frame_sink_id(f); + ClientRoot* client_root = FindClientRootContaining(window); + if (client_root) + client_root->AttachChildFrameSinkId(server_window); +} +void WindowTree::UnattachFrameSinkId(Id transport_window_id) { + const ClientWindowId window_id = MakeClientWindowId(transport_window_id); + aura::Window* window = GetWindowByClientId(window_id); + if (!window || !IsClientCreatedWindow(window)) { + DVLOG(3) << "UnattachFrameSinkId failed (invalid window id)"; + return; + } ServerWindow* server_window = ServerWindow::GetMayBeNull(window); DCHECK(server_window); // Must exist because of preceding conditionals. - server_window->SetHitTestMask(mask); + if (!server_window->attached_frame_sink_id().is_valid()) { + DVLOG(3) << "UnattachFrameSinkId failed (frame sink already cleared)"; + return; + } + + ClientRoot* client_root = FindClientRootContaining(window); + if (client_root) + client_root->UnattachChildFrameSinkId(server_window); + server_window->set_attached_frame_sink_id(viz::FrameSinkId()); } void WindowTree::SetCanAcceptDrops(Id window_id, bool accepts_drops) { @@ -1558,14 +1687,6 @@ void WindowTree::SetModalType(uint32_t change_id, change_id, SetModalTypeImpl(MakeClientWindowId(window_id), type)); } -void WindowTree::SetChildModalParent(uint32_t change_id, - Id window_id, - Id parent_window_id) { - window_tree_client_->OnChangeCompleted( - change_id, SetChildModalParentImpl(MakeClientWindowId(window_id), - MakeClientWindowId(parent_window_id))); -} - void WindowTree::ReorderWindow(uint32_t change_id, Id transport_window_id, Id transport_relative_window_id, @@ -1658,7 +1779,8 @@ void WindowTree::EmbedUsingToken(Id transport_window_id, ServerWindow* server_window = ServerWindow::GetMayBeNull(window); const bool owner_intercept_events = - (embed_flags & mojom::kEmbedFlagEmbedderInterceptsEvents) != 0; + (connection_type_ != ConnectionType::kEmbedding && + (embed_flags & mojom::kEmbedFlagEmbedderInterceptsEvents) != 0); tree_and_id.tree->CompleteScheduleEmbedForExistingClient( window, tree_and_id.id, token); std::unique_ptr<Embedding> embedding = @@ -1796,8 +1918,13 @@ void WindowTree::StackAtTop(uint32_t change_id, Id window_id) { window_tree_client_->OnChangeCompleted(change_id, result); } -void WindowTree::PerformWmAction(Id window_id, const std::string& action) { - NOTIMPLEMENTED_LOG_ONCE(); +void WindowTree::BindWindowManagerInterface( + const std::string& name, + mojom::WindowManagerAssociatedRequest window_manager) { + auto wm_interface = window_service_->delegate()->CreateWindowManagerInterface( + this, name, window_manager.PassHandle()); + if (wm_interface) + window_manager_interfaces_.push_back(std::move(wm_interface)); } void WindowTree::GetCursorLocationMemory( @@ -1925,4 +2052,65 @@ void WindowTree::StopObservingTopmostWindow() { topmost_window_observer_.reset(); } +void WindowTree::CancelActiveTouchesExcept(Id not_cancelled_window_id) { + if (connection_type_ == ConnectionType::kEmbedding) { + DVLOG(1) << "CancelActiveTouchesExcept failed (access denied)"; + return; + } + DVLOG(3) << "CancelActiveTouchesExcept not_cancelled_window_id=" + << MakeClientWindowId(not_cancelled_window_id).ToString(); + aura::Window* not_cancelled_window = nullptr; + if (not_cancelled_window_id != kInvalidTransportId) { + not_cancelled_window = GetWindowByTransportId(not_cancelled_window_id); + if (!not_cancelled_window || !IsClientCreatedWindow(not_cancelled_window)) { + DVLOG(1) << "CancelActiveTouchesExcept failed (invalid window)"; + return; + } + } + window_service_->env()->gesture_recognizer()->CancelActiveTouchesExcept( + not_cancelled_window); +} + +void WindowTree::CancelActiveTouches(Id window_id) { + if (connection_type_ == ConnectionType::kEmbedding) { + DVLOG(1) << "CancelActiveTouches failed (access denied)"; + return; + } + DVLOG(3) << "CancelActiveTouches window_id=" + << MakeClientWindowId(window_id).ToString(); + aura::Window* window = GetWindowByTransportId(window_id); + if (!window || !IsClientCreatedWindow(window)) { + DVLOG(1) << "CancelActiveTouches failed (invalid window)"; + return; + } + window_service_->env()->gesture_recognizer()->CancelActiveTouches(window); +} + +void WindowTree::TransferGestureEventsTo(Id current_id, + Id new_id, + bool should_cancel) { + if (connection_type_ == ConnectionType::kEmbedding) { + DVLOG(1) << "TransferGestureEventsTo failed (access denied)"; + return; + } + DVLOG(3) << "TransferGestureEventsTo current_id=" + << MakeClientWindowId(current_id).ToString() + << " new_id=" << MakeClientWindowId(new_id) + << " should_cancel=" << should_cancel; + aura::Window* current_window = GetWindowByTransportId(current_id); + aura::Window* new_window = GetWindowByTransportId(new_id); + if (!current_window || !IsClientCreatedWindow(current_window)) { + DVLOG(1) << "TransferGestureEventsTo failed (invalid current_window)"; + return; + } + if (!new_window || !IsClientCreatedWindow(new_window)) { + DVLOG(1) << "TransferGestureEventsTo failed (invalid new_window)"; + return; + } + window_service_->env()->gesture_recognizer()->TransferEventsTo( + current_window, new_window, + should_cancel ? ui::TransferTouchesBehavior::kCancel + : ui::TransferTouchesBehavior::kDontCancel); +} + } // namespace ws diff --git a/chromium/services/ws/window_tree.h b/chromium/services/ws/window_tree.h index a979162e624..a38ea1f9f1d 100644 --- a/chromium/services/ws/window_tree.h +++ b/chromium/services/ws/window_tree.h @@ -39,6 +39,7 @@ class FocusHandler; class PointerWatcher; class ServerWindow; class TopmostWindowObserver; +class WindowManagerInterface; class WindowService; // WindowTree manages a client connected to the Window Service. WindowTree @@ -86,6 +87,10 @@ class COMPONENT_EXPORT(WINDOW_SERVICE) WindowTree void SendPointerWatcherEventToClient(int64_t display_id, std::unique_ptr<ui::Event> event); + // Returns the aura::Window associated with the specified transport id; null + // if |transport_window_id| is not a valid id for a window. + aura::Window* GetWindowByTransportId(Id transport_window_id); + // Returns true if |window| was created by the client calling // NewTopLevelWindow(). bool IsTopLevel(aura::Window* window); @@ -123,7 +128,16 @@ class COMPONENT_EXPORT(WINDOW_SERVICE) WindowTree // one of the roots. bool HasAtLeastOneRootWithCompositorFrameSink(); - ClientWindowId ClientWindowIdForWindow(aura::Window* window); + // Returns true if |window| has been exposed to this client. A client + // typically only sees a limited set of windows that may exist. The set of + // windows exposed to the client are referred to as the known windows. + bool IsWindowKnown(aura::Window* window) const; + + ClientWindowId ClientWindowIdForWindow(aura::Window* window) const; + + // If |window| is a client root, the ClientRoot is returned. This does not + // recurse. + ClientRoot* GetClientRootForWindow(aura::Window* window); private: friend class ClientRoot; @@ -161,6 +175,22 @@ class COMPONENT_EXPORT(WINDOW_SERVICE) WindowTree kDestructor, }; + // Used to track every window known to the client. + struct KnownWindow { + KnownWindow(); + ~KnownWindow(); + + // Id for the window. + ClientWindowId client_window_id; + + // If non-null, the client created the window and owns it. During window + // destruction this may be destroyed before the entry is moved. If you need + // to know if the client created the window, use the |is_client_created|. + std::unique_ptr<aura::Window> owned_window; + + bool is_client_created = false; + }; + // Creates a new ClientRoot. The returned ClientRoot is owned by this. // |is_top_level| is true if this is called from // WindowTree::NewTopLevelWindow(). @@ -169,24 +199,22 @@ class COMPONENT_EXPORT(WINDOW_SERVICE) WindowTree void DeleteClientRootWithRoot(aura::Window* window); aura::Window* GetWindowByClientId(const ClientWindowId& id); - aura::Window* GetWindowByTransportId(Id transport_window_id); // Returns true if |this| created |window|. bool IsClientCreatedWindow(aura::Window* window); bool IsClientRootWindow(aura::Window* window); - // Returns the window which is corresponded with the root window for the - // specified |window| in the client. - aura::Window* GetClientRootWindowFor(aura::Window* window); + // Returns the ClientRoot that |window| is parented to, null if |window| is + // not in a ClientRoot. + ClientRoot* FindClientRootContaining(aura::Window* window); ClientRoots::iterator FindClientRootWithRoot(aura::Window* window); - // Returns true if |window| has been exposed to this client. A client - // typically only sees a limited set of windows that may exist. The set of - // windows exposed to the client are referred to as the known windows. - bool IsWindowKnown(aura::Window* window) const; bool IsWindowRootOfAnotherClient(aura::Window* window) const; + // Returns true if |window| has an ancestor that intercepts events. + bool DoesAnyAncestorInterceptEvents(ServerWindow* window); + // Called when one of the windows known to the client loses capture. // |lost_capture| is the window that had capture. void OnCaptureLost(aura::Window* lost_capture); @@ -213,6 +241,10 @@ class COMPONENT_EXPORT(WINDOW_SERVICE) WindowTree // fails or gets canceled). void OnPerformDragDropDone(uint32_t change_id, int drag_result); + // Returns the first window in |known_windows_map_| that was created by + // the client; null if the client did not create an windows. + aura::Window* FindFirstClientCreatedWindow(); + // Called for windows created by the client (including top-levels). aura::Window* AddClientCreatedWindow( const ClientWindowId& id, @@ -221,7 +253,9 @@ class COMPONENT_EXPORT(WINDOW_SERVICE) WindowTree // Adds/removes a Window from the set of windows known to the client. This // also adds or removes any observers that may need to be installed. - void AddWindowToKnownWindows(aura::Window* window, const ClientWindowId& id); + void AddWindowToKnownWindows(aura::Window* window, + const ClientWindowId& id, + std::unique_ptr<aura::Window> owned_window); // |delete_if_owned| indicates if |window| should be deleted if this client // created it. |delete_if_owned| is false only if the window was externally @@ -288,8 +322,6 @@ class COMPONENT_EXPORT(WINDOW_SERVICE) WindowTree bool RemoveTransientWindowFromParentImpl(const ClientWindowId& transient_id); bool SetModalTypeImpl(const ClientWindowId& client_window_id, ui::ModalType type); - bool SetChildModalParentImpl(const ClientWindowId& child_id, - const ClientWindowId& parent_id); bool SetWindowVisibilityImpl(const ClientWindowId& window_id, bool visible); bool SetWindowPropertyImpl(const ClientWindowId& window_id, const std::string& name, @@ -321,6 +353,7 @@ class COMPONENT_EXPORT(WINDOW_SERVICE) WindowTree void OnEmbeddedClientConnectionLost(Embedding* embedding); // aura::WindowObserver: + void OnWindowHierarchyChanging(const HierarchyChangeParams& params) override; void OnWindowDestroyed(aura::Window* window) override; // aura::client::CaptureClientObserver: @@ -355,8 +388,12 @@ class COMPONENT_EXPORT(WINDOW_SERVICE) WindowTree const gfx::Insets& insets, const base::Optional<std::vector<gfx::Rect>>& additional_client_areas) override; - void SetHitTestMask(Id transport_window_id, - const base::Optional<gfx::Rect>& mask) override; + void SetHitTestInsets(Id transport_window_id, + const gfx::Insets& mouse, + const gfx::Insets& touch) override; + void AttachFrameSinkId(Id transport_window_id, + const viz::FrameSinkId& f) override; + void UnattachFrameSinkId(Id transport_window_id) override; void SetCanAcceptDrops(Id window_id, bool accepts_drops) override; void SetWindowVisibility(uint32_t change_id, Id transport_window_id, @@ -383,9 +420,6 @@ class COMPONENT_EXPORT(WINDOW_SERVICE) WindowTree void SetModalType(uint32_t change_id, Id window_id, ui::ModalType type) override; - void SetChildModalParent(uint32_t change_id, - Id window_id, - Id parent_window_id) override; void ReorderWindow(uint32_t change_id, Id transport_window_id, Id transport_relative_window_id, @@ -421,7 +455,9 @@ class COMPONENT_EXPORT(WINDOW_SERVICE) WindowTree void DeactivateWindow(Id transport_window_id) override; void StackAbove(uint32_t change_id, Id above_id, Id below_id) override; void StackAtTop(uint32_t change_id, Id window_id) override; - void PerformWmAction(Id window_id, const std::string& action) override; + void BindWindowManagerInterface( + const std::string& name, + mojom::WindowManagerAssociatedRequest window_manager) override; void GetCursorLocationMemory( GetCursorLocationMemoryCallback callback) override; void PerformWindowMove(uint32_t change_id, @@ -442,6 +478,11 @@ class COMPONENT_EXPORT(WINDOW_SERVICE) WindowTree void ObserveTopmostWindow(mojom::MoveLoopSource source, Id window_id) override; void StopObservingTopmostWindow() override; + void CancelActiveTouchesExcept(Id not_cancelled_window_id) override; + void CancelActiveTouches(Id window_id) override; + void TransferGestureEventsTo(Id current_id, + Id new_id, + bool should_cancel) override; WindowService* window_service_; @@ -461,20 +502,13 @@ class COMPONENT_EXPORT(WINDOW_SERVICE) WindowTree ClientRoots client_roots_; - // Set of windows this client created. The values are the same as key, but - // put inside a unique_ptr to reinforce this class owns these Windows. - // Ideally set would be used, but sets have some painful restrictions - // (c++17's set::extract() may make it possible to use a set again). - std::unordered_map<aura::Window*, std::unique_ptr<aura::Window>> - client_created_windows_; - - // These contain mappings for known windows. At a minimum this contains the - // windows in |client_created_windows_|. It will also contain any windows - // that are exposed (known) to this client for various reasons. For example, - // if this client is the result of an embedding then the window at the embed - // point (the root window of the ClientRoot) was not created by this client, - // but is known and in these mappings. - std::map<aura::Window*, ClientWindowId> window_to_client_window_id_map_; + // These contain mappings for known windows, see KnownWindow for details on + // it. This contains all windows created by the client, as well as windows + // known to the client. For example,if this client is the result of an + // embedding then the window at the embed point (the root window of the + // ClientRoot) was not created by this client, but is known and in these + // mappings. + std::map<aura::Window*, KnownWindow> known_windows_map_; std::unordered_map<ClientWindowId, aura::Window*, ClientWindowIdHash> client_window_id_to_window_map_; @@ -516,6 +550,9 @@ class COMPONENT_EXPORT(WINDOW_SERVICE) WindowTree // Set while a drag loop is in progress. Id pending_drag_source_window_id_ = kInvalidTransportId; + std::vector<std::unique_ptr<WindowManagerInterface>> + window_manager_interfaces_; + base::WeakPtrFactory<WindowTree> weak_factory_{this}; DISALLOW_COPY_AND_ASSIGN(WindowTree); diff --git a/chromium/services/ws/window_tree_client_unittest.cc b/chromium/services/ws/window_tree_client_unittest.cc index 48b03f664b1..6bbe7ae1d7b 100644 --- a/chromium/services/ws/window_tree_client_unittest.cc +++ b/chromium/services/ws/window_tree_client_unittest.cc @@ -394,15 +394,9 @@ class TestWindowTreeClient2 : public TestWindowTreeClient { void OnWindowCursorChanged(Id window_id, ui::CursorData cursor) override { tracker_.OnWindowCursorChanged(window_id, cursor); } - void OnDragDropStart(const base::flat_map<std::string, std::vector<uint8_t>>& drag_data) override {} - void OnWindowSurfaceChanged(Id window_id, - const viz::SurfaceInfo& surface_info) override { - tracker_.OnWindowSurfaceChanged(window_id, surface_info); - } - void OnDragEnter(Id window, uint32_t key_state, const gfx::Point& position, @@ -2174,92 +2168,6 @@ TEST_F(WindowTreeClientTest, DISABLED_ExplicitCapturePropagation) { EXPECT_TRUE(changes1()->empty()); } -TEST_F(WindowTreeClientTest, DISABLED_SurfaceIdPropagation) { - const Id window_1_100 = wt_client1()->NewWindow(100); - ASSERT_TRUE(window_1_100); - ASSERT_TRUE(wt_client1()->AddWindow(root_window_id(), window_1_100)); - - // Establish the second client at client_id_1(),100. - ASSERT_NO_FATAL_FAILURE(EstablishSecondClientWithRoot(window_1_100)); - changes2()->clear(); - - // client_id_1(),100 is the id in the wt_client1's id space. The new client - // should see client_id_2(),1 (the server id). - const Id window_1_100_in_ws2 = BuildWindowId(client_id_1(), 100); - EXPECT_EQ(window_1_100_in_ws2, wt_client2()->root_window_id()); - - // Submit a CompositorFrame to window_1_100_in_ws2 (the embedded window in - // wt2) and make sure the server gets it. - { - viz::mojom::CompositorFrameSinkPtr surface_ptr; - viz::mojom::CompositorFrameSinkClientRequest client_request; - viz::mojom::CompositorFrameSinkClientPtr surface_client_ptr; - client_request = mojo::MakeRequest(&surface_client_ptr); - wt2()->AttachCompositorFrameSink(window_1_100_in_ws2, - mojo::MakeRequest(&surface_ptr), - std::move(surface_client_ptr)); - viz::CompositorFrame compositor_frame; - std::unique_ptr<viz::RenderPass> render_pass = viz::RenderPass::Create(); - gfx::Rect frame_rect(0, 0, 100, 100); - render_pass->SetNew(1, frame_rect, frame_rect, gfx::Transform()); - compositor_frame.render_pass_list.push_back(std::move(render_pass)); - compositor_frame.metadata.device_scale_factor = 1.f; - compositor_frame.metadata.begin_frame_ack = viz::BeginFrameAck(0, 1, true); - viz::LocalSurfaceId local_surface_id(1, base::UnguessableToken::Create()); - surface_ptr->SubmitCompositorFrame( - local_surface_id, std::move(compositor_frame), base::nullopt, 0); - } - // Make sure the parent connection gets the surface ID. - wt_client1()->WaitForChangeCount(1); - // Verify that the submitted frame is for |window_2_101|. - viz::FrameSinkId frame_sink_id = - changes1()->back().surface_id.frame_sink_id(); - // FrameSinkId is based on window's ClientWindowId. - EXPECT_EQ(static_cast<size_t>(client_id_2()), frame_sink_id.client_id()); - EXPECT_EQ(0u, frame_sink_id.sink_id()); - changes1()->clear(); - - // The first window created in the second client gets a server id of - // client_id_2(),1 regardless of the id the client uses. - const Id window_2_101 = wt_client2()->NewWindow(101); - ASSERT_TRUE(wt_client2()->AddWindow(window_1_100_in_ws2, window_2_101)); - const Id window_2_101_in_ws2 = BuildWindowId(client_id_2(), 101); - wt_client1()->WaitForChangeCount(1); - EXPECT_EQ("HierarchyChanged window=" + IdToString(window_2_101_in_ws2) + - " old_parent=null new_parent=" + IdToString(window_1_100), - SingleChangeToDescription(*changes1())); - // Submit a CompositorFrame to window_2_101_in_ws2 (a regular window in - // wt2) and make sure client gets it. - { - viz::mojom::CompositorFrameSinkPtr surface_ptr; - viz::mojom::CompositorFrameSinkClientRequest client_request; - viz::mojom::CompositorFrameSinkClientPtr surface_client_ptr; - client_request = mojo::MakeRequest(&surface_client_ptr); - wt2()->AttachCompositorFrameSink(window_2_101, - mojo::MakeRequest(&surface_ptr), - std::move(surface_client_ptr)); - viz::CompositorFrame compositor_frame; - std::unique_ptr<viz::RenderPass> render_pass = viz::RenderPass::Create(); - gfx::Rect frame_rect(0, 0, 100, 100); - render_pass->SetNew(1, frame_rect, frame_rect, gfx::Transform()); - compositor_frame.render_pass_list.push_back(std::move(render_pass)); - compositor_frame.metadata.device_scale_factor = 1.f; - compositor_frame.metadata.begin_frame_ack = viz::BeginFrameAck(0, 1, true); - viz::LocalSurfaceId local_surface_id(2, base::UnguessableToken::Create()); - surface_ptr->SubmitCompositorFrame( - local_surface_id, std::move(compositor_frame), base::nullopt, 0); - } - // Make sure the parent connection gets the surface ID. - wt_client2()->WaitForChangeCount(1); - // Verify that the submitted frame is for |window_2_101|. - viz::FrameSinkId frame_sink_id2 = - changes2()->back().surface_id.frame_sink_id(); - // FrameSinkId is based on window's ClientWindowId. - EXPECT_NE(0u, frame_sink_id2.client_id()); - EXPECT_EQ(ClientWindowIdFromTransportId(window_2_101), - frame_sink_id2.sink_id()); -} - // Verifies when an unknown window with a known child is added to a hierarchy // the known child is identified in the WindowData. TEST_F(WindowTreeClientTest, DISABLED_AddUnknownWindowKnownParent) { diff --git a/chromium/services/ws/window_tree_test_helper.cc b/chromium/services/ws/window_tree_test_helper.cc index c03e875483e..e74dd6bced5 100644 --- a/chromium/services/ws/window_tree_test_helper.cc +++ b/chromium/services/ws/window_tree_test_helper.cc @@ -95,9 +95,10 @@ void WindowTreeTestHelper::SetClientArea( additional_client_areas); } -void WindowTreeTestHelper::SetHitTestMask(aura::Window* window, - base::Optional<gfx::Rect> mask) { - window_tree_->SetHitTestMask(TransportIdForWindow(window), mask); +void WindowTreeTestHelper::SetHitTestInsets(aura::Window* window, + const gfx::Insets& mouse, + const gfx::Insets& touch) { + window_tree_->SetHitTestInsets(TransportIdForWindow(window), mouse, touch); } void WindowTreeTestHelper::SetWindowProperty(aura::Window* window, diff --git a/chromium/services/ws/window_tree_test_helper.h b/chromium/services/ws/window_tree_test_helper.h index 0a1fb6ada0e..df326cd0633 100644 --- a/chromium/services/ws/window_tree_test_helper.h +++ b/chromium/services/ws/window_tree_test_helper.h @@ -80,7 +80,9 @@ class WindowTreeTestHelper { const gfx::Insets& insets, base::Optional<std::vector<gfx::Rect>> additional_client_areas = base::Optional<std::vector<gfx::Rect>>()); - void SetHitTestMask(aura::Window* window, base::Optional<gfx::Rect> mask); + void SetHitTestInsets(aura::Window* window, + const gfx::Insets& mouse, + const gfx::Insets& touch); void SetWindowProperty(aura::Window* window, const std::string& name, const std::vector<uint8_t>& value, diff --git a/chromium/services/ws/window_tree_unittest.cc b/chromium/services/ws/window_tree_unittest.cc index 414a9d3021f..edaa0fcd9d4 100644 --- a/chromium/services/ws/window_tree_unittest.cc +++ b/chromium/services/ws/window_tree_unittest.cc @@ -12,6 +12,8 @@ #include "base/run_loop.h" #include "base/unguessable_token.h" +#include "components/viz/host/host_frame_sink_manager.h" +#include "components/viz/test/fake_host_frame_sink_client.h" #include "services/ws/event_test_utils.h" #include "services/ws/public/cpp/property_type_converters.h" #include "services/ws/public/mojom/window_manager.mojom.h" @@ -37,6 +39,9 @@ namespace ws { namespace { +DEFINE_UI_CLASS_PROPERTY_KEY(aura::Window*, kTestPropertyKey, nullptr); +const char kTestPropertyServerKey[] = "test-property-server"; + // Passed to Embed() to give the default behavior (see kEmbedFlag* in mojom for // details). constexpr uint32_t kDefaultEmbedFlags = 0; @@ -304,6 +309,30 @@ TEST(WindowTreeTest, SetBoundsAtEmbedWindow) { EXPECT_EQ(CHANGE_TYPE_NODE_BOUNDS_CHANGED, bounds_change.type); EXPECT_EQ(bounds2, bounds_change.bounds2); EXPECT_EQ(local_surface_id, bounds_change.local_surface_id); + embedding_helper->window_tree_client.tracker()->changes()->clear(); + + // Set the bounds from the parent, only updating the LocalSurfaceId (bounds + // remains the same). The client should be notified. + base::Optional<viz::LocalSurfaceId> local_surface_id2( + viz::LocalSurfaceId(1, 3, base::UnguessableToken::Create())); + EXPECT_TRUE(setup.window_tree_test_helper()->SetWindowBounds( + window, bounds2, local_surface_id2)); + EXPECT_EQ(bounds2, window->bounds()); + ASSERT_EQ(1u, + embedding_helper->window_tree_client.tracker()->changes()->size()); + const Change bounds_change2 = + (*(embedding_helper->window_tree_client.tracker()->changes()))[0]; + EXPECT_EQ(CHANGE_TYPE_NODE_BOUNDS_CHANGED, bounds_change2.type); + EXPECT_EQ(bounds2, bounds_change2.bounds2); + EXPECT_EQ(local_surface_id2, bounds_change2.local_surface_id); + embedding_helper->window_tree_client.tracker()->changes()->clear(); + + // Try again with the same values. This should succeed, but not notify the + // client. + EXPECT_TRUE(setup.window_tree_test_helper()->SetWindowBounds( + window, bounds2, local_surface_id2)); + EXPECT_TRUE( + embedding_helper->window_tree_client.tracker()->changes()->empty()); } // Tests the ability of the client to change properties on the server. @@ -351,6 +380,55 @@ TEST(WindowTreeTest, WindowToWindowData) { data->properties[mojom::WindowManager::kAlwaysOnTop_Property])); } +TEST(WindowTreeTest, SetWindowPointerProperty) { + WindowServiceTestSetup setup; + setup.service()->property_converter()->RegisterWindowPtrProperty( + kTestPropertyKey, kTestPropertyServerKey); + + WindowTreeTestHelper* helper = setup.window_tree_test_helper(); + aura::Window* top_level1 = helper->NewTopLevelWindow(); + aura::Window* top_level2 = helper->NewTopLevelWindow(); + Id id1 = helper->TransportIdForWindow(top_level1); + Id id2 = helper->TransportIdForWindow(top_level2); + + base::Optional<std::vector<uint8_t>> value = + mojo::ConvertTo<std::vector<uint8_t>>(id2); + setup.window_tree_test_helper()->window_tree()->SetWindowProperty( + 1, id1, kTestPropertyServerKey, value); + EXPECT_EQ(top_level2, top_level1->GetProperty(kTestPropertyKey)); + + value.reset(); + setup.window_tree_test_helper()->window_tree()->SetWindowProperty( + 1, id1, kTestPropertyServerKey, value); + EXPECT_FALSE(top_level1->GetProperty(kTestPropertyKey)); +} + +TEST(WindowTreeTest, SetWindowPointerPropertyWithInvalidValues) { + WindowServiceTestSetup setup; + setup.service()->property_converter()->RegisterWindowPtrProperty( + kTestPropertyKey, kTestPropertyServerKey); + + WindowTreeTestHelper* helper = setup.window_tree_test_helper(); + aura::Window* top_level = helper->NewTopLevelWindow(); + Id id = helper->TransportIdForWindow(top_level); + base::Optional<std::vector<uint8_t>> value = + mojo::ConvertTo<std::vector<uint8_t>>(kInvalidTransportId); + setup.window_tree_test_helper()->window_tree()->SetWindowProperty( + 1, id, kTestPropertyServerKey, value); + EXPECT_FALSE(top_level->GetProperty(kTestPropertyKey)); + + value = mojo::ConvertTo<std::vector<uint8_t>>(10); + setup.window_tree_test_helper()->window_tree()->SetWindowProperty( + 1, id, kTestPropertyServerKey, value); + EXPECT_FALSE(top_level->GetProperty(kTestPropertyKey)); + + value->clear(); + value->push_back(1); + setup.window_tree_test_helper()->window_tree()->SetWindowProperty( + 1, id, kTestPropertyServerKey, value); + EXPECT_FALSE(top_level->GetProperty(kTestPropertyKey)); +} + TEST(WindowTreeTest, OnWindowInputEventAck) { WindowServiceTestSetup setup; TestWindowTreeClient* window_tree_client = setup.window_tree_client(); @@ -525,22 +603,22 @@ TEST(WindowTreeTest, MovePressDragRelease) { ui::test::EventGenerator event_generator(setup.root()); event_generator.MoveMouseTo(50, 50); - EXPECT_EQ("POINTER_MOVED 40,40", + EXPECT_EQ("MOUSE_MOVED 40,40", LocatedEventToEventTypeAndLocation( window_tree_client->PopInputEvent().event.get())); event_generator.PressLeftButton(); - EXPECT_EQ("POINTER_DOWN 40,40", + EXPECT_EQ("MOUSE_PRESSED 40,40", LocatedEventToEventTypeAndLocation( window_tree_client->PopInputEvent().event.get())); event_generator.MoveMouseTo(0, 0); - EXPECT_EQ("POINTER_MOVED -10,-10", + EXPECT_EQ("MOUSE_DRAGGED -10,-10", LocatedEventToEventTypeAndLocation( window_tree_client->PopInputEvent().event.get())); event_generator.ReleaseLeftButton(); - EXPECT_EQ("POINTER_UP -10,-10", + EXPECT_EQ("MOUSE_RELEASED -10,-10", LocatedEventToEventTypeAndLocation( window_tree_client->PopInputEvent().event.get())); } @@ -571,17 +649,17 @@ TEST(WindowTreeTest, TouchPressDragRelease) { ui::test::EventGenerator event_generator(setup.root()); event_generator.set_current_location(gfx::Point(50, 51)); event_generator.PressTouch(); - EXPECT_EQ("POINTER_DOWN 40,40", + EXPECT_EQ("ET_TOUCH_PRESSED 40,40", LocatedEventToEventTypeAndLocation( window_tree_client->PopInputEvent().event.get())); event_generator.MoveTouch(gfx::Point(5, 6)); - EXPECT_EQ("POINTER_MOVED -5,-5", + EXPECT_EQ("ET_TOUCH_MOVED -5,-5", LocatedEventToEventTypeAndLocation( window_tree_client->PopInputEvent().event.get())); event_generator.ReleaseTouch(); - EXPECT_EQ("POINTER_UP -5,-5", + EXPECT_EQ("ET_TOUCH_RELEASED -5,-5", LocatedEventToEventTypeAndLocation( window_tree_client->PopInputEvent().event.get())); } @@ -635,7 +713,7 @@ TEST(WindowTreeTest, MoveFromClientToNonClient) { ui::test::EventGenerator event_generator(setup.root()); event_generator.MoveMouseTo(50, 50); - EXPECT_EQ("POINTER_MOVED 40,40", + EXPECT_EQ("MOUSE_MOVED 40,40", LocatedEventToEventTypeAndLocation( window_tree_client->PopInputEvent().event.get())); @@ -648,7 +726,7 @@ TEST(WindowTreeTest, MoveFromClientToNonClient) { // Move the mouse over the non-client area. // The event is still sent to the client, and the delegate. event_generator.MoveMouseTo(15, 16); - EXPECT_EQ("POINTER_MOVED 5,6", + EXPECT_EQ("MOUSE_MOVED 5,6", LocatedEventToEventTypeAndLocation( window_tree_client->PopInputEvent().event.get())); @@ -678,7 +756,7 @@ TEST(WindowTreeTest, MoveFromClientToNonClient) { EventToEventType(window_delegate.PopEvent().get())); event_generator.MoveMouseTo(26, 50); - EXPECT_EQ("POINTER_MOVED 16,40", + EXPECT_EQ("MOUSE_MOVED 16,40", LocatedEventToEventTypeAndLocation( window_tree_client->PopInputEvent().event.get())); @@ -688,7 +766,7 @@ TEST(WindowTreeTest, MoveFromClientToNonClient) { // Press in client area. Only the client should get the event. event_generator.PressLeftButton(); - EXPECT_EQ("POINTER_DOWN 16,40", + EXPECT_EQ("MOUSE_PRESSED 16,40", LocatedEventToEventTypeAndLocation( window_tree_client->PopInputEvent().event.get())); @@ -721,7 +799,7 @@ TEST(WindowTreeTest, MouseDownInNonClientWithChildWindow) { // should get the event. ui::test::EventGenerator event_generator(setup.root()); event_generator.MoveMouseTo(15, 16); - EXPECT_EQ("POINTER_MOVED 5,6", + EXPECT_EQ("MOUSE_MOVED 5,6", LocatedEventToEventTypeAndLocation( window_tree_client->PopInputEvent().event.get())); EXPECT_TRUE(window_tree_client->input_events().empty()); @@ -773,7 +851,7 @@ TEST(WindowTreeTest, MouseDownInNonClientDragToClientWithChildWindow) { EXPECT_TRUE(window_tree_client->input_events().empty()); } -TEST(WindowTreeTest, SetHitTestMask) { +TEST(WindowTreeTest, SetHitTestInsets) { EventRecordingWindowDelegate window_delegate; WindowServiceTestSetup setup; setup.delegate()->set_delegate_for_next_top_level(&window_delegate); @@ -787,19 +865,19 @@ TEST(WindowTreeTest, SetHitTestMask) { window_tree_client->ClearInputEvents(); window_delegate.ClearEvents(); - // Set a hit test mask in the window's bounds that excludes the top half. - setup.window_tree_test_helper()->SetHitTestMask(top_level, - gfx::Rect(0, 50, 100, 50)); + // Set the hit test insets in the window's bounds that excludes the top half. + setup.window_tree_test_helper()->SetHitTestInsets( + top_level, gfx::Insets(50, 0, 0, 0), gfx::Insets(50, 0, 0, 0)); - // Events outside the hit test mask are not seen by the delegate or client. + // Events outside the hit test insets are not seen by the delegate or client. ui::test::EventGenerator event_generator(setup.root()); event_generator.MoveMouseTo(50, 30); EXPECT_TRUE(window_tree_client->input_events().empty()); EXPECT_TRUE(window_delegate.events().empty()); - // Events in the hit test mask are seen by the delegate and client. + // Events in the hit test insets are seen by the delegate and client. event_generator.MoveMouseTo(50, 80); - EXPECT_EQ("POINTER_MOVED 40,70", + EXPECT_EQ("MOUSE_MOVED 40,70", LocatedEventToEventTypeAndLocation( window_tree_client->PopInputEvent().event.get())); EXPECT_EQ("MOUSE_ENTERED 40,70", LocatedEventToEventTypeAndLocation( @@ -814,10 +892,6 @@ TEST(WindowTreeTest, PointerWatcher) { aura::Window* top_level = setup.window_tree_test_helper()->NewTopLevelWindow(); ASSERT_TRUE(top_level); - setup.window_tree_test_helper()->SetEventTargetingPolicy( - top_level, mojom::EventTargetingPolicy::NONE); - EXPECT_EQ(mojom::EventTargetingPolicy::NONE, - top_level->event_targeting_policy()); // Start the pointer watcher only for pointer down/up. setup.window_tree_test_helper()->window_tree()->StartPointerWatcher(false); @@ -877,7 +951,7 @@ TEST(WindowTreeTest, MatchesPointerWatcherSet) { TestWindowTreeClient::InputEvent press_input = window_tree_client->PopInputEvent(); ASSERT_TRUE(press_input.event); - EXPECT_EQ("POINTER_DOWN 40,40", + EXPECT_EQ("MOUSE_PRESSED 40,40", LocatedEventToEventTypeAndLocation(press_input.event.get())); EXPECT_TRUE(press_input.matches_pointer_watcher); // Because the event matches a pointer event there should be no observed @@ -908,6 +982,21 @@ TEST(WindowTreeTest, Capture) { EXPECT_TRUE(setup.window_tree_test_helper()->ReleaseCapture(window)); } +TEST(WindowTreeTest, CaptureDisallowedWhenEmbedderInterceptsEvents) { + WindowServiceTestSetup setup; + aura::Window* top_level = + setup.window_tree_test_helper()->NewTopLevelWindow(); + ASSERT_TRUE(top_level); + top_level->Show(); + aura::Window* window = setup.window_tree_test_helper()->NewWindow(); + top_level->AddChild(window); + window->Show(); + std::unique_ptr<EmbeddingHelper> embedding_helper = + setup.CreateEmbedding(window, mojom::kEmbedFlagEmbedderInterceptsEvents); + ASSERT_TRUE(embedding_helper); + EXPECT_FALSE(embedding_helper->window_tree_test_helper->SetCapture(window)); +} + TEST(WindowTreeTest, TransferCaptureToClient) { EventRecordingWindowDelegate window_delegate; WindowServiceTestSetup setup; @@ -938,7 +1027,7 @@ TEST(WindowTreeTest, TransferCaptureToClient) { event_generator.MoveMouseTo(8, 8); // Now the event should go to the client and not local. EXPECT_TRUE(window_delegate.events().empty()); - EXPECT_EQ("POINTER_MOVED", + EXPECT_EQ("MOUSE_MOVED", EventToEventType( setup.window_tree_client()->PopInputEvent().event.get())); EXPECT_TRUE(setup.window_tree_client()->input_events().empty()); @@ -974,7 +1063,7 @@ TEST(WindowTreeTest, TransferCaptureBetweenParentAndChild) { EXPECT_TRUE(setup.window_tree_client()->input_events().empty()); EXPECT_TRUE(window_delegate.events().empty()); EXPECT_EQ( - "POINTER_MOVED", + "MOUSE_MOVED", EventToEventType( embedding_helper->window_tree_client.PopInputEvent().event.get())); EXPECT_TRUE(embedding_helper->window_tree_client.input_events().empty()); @@ -982,7 +1071,7 @@ TEST(WindowTreeTest, TransferCaptureBetweenParentAndChild) { // Set capture from the parent, only the parent should get the event now. EXPECT_TRUE(setup.window_tree_test_helper()->SetCapture(top_level)); event_generator.MoveMouseTo(8, 8); - EXPECT_EQ("POINTER_MOVED", + EXPECT_EQ("MOUSE_MOVED", EventToEventType( setup.window_tree_client()->PopInputEvent().event.get())); EXPECT_TRUE(setup.window_tree_client()->input_events().empty()); @@ -1108,56 +1197,10 @@ TEST(WindowTreeTest, EventsGoToCaptureWindow) { auto drag_event = setup.window_tree_client()->PopInputEvent(); EXPECT_EQ(setup.window_tree_test_helper()->TransportIdForWindow(window), drag_event.window_id); - EXPECT_EQ("POINTER_MOVED -4,-4", + EXPECT_EQ("MOUSE_DRAGGED -4,-4", LocatedEventToEventTypeAndLocation(drag_event.event.get())); } -TEST(WindowTreeTest, InterceptEventsOnEmbeddedWindowWithCapture) { - EventRecordingWindowDelegate window_delegate; - WindowServiceTestSetup setup; - aura::Window* window = setup.window_tree_test_helper()->NewWindow(); - ASSERT_TRUE(window); - setup.delegate()->set_delegate_for_next_top_level(&window_delegate); - aura::Window* top_level = - setup.window_tree_test_helper()->NewTopLevelWindow(); - ASSERT_TRUE(top_level); - top_level->AddChild(window); - top_level->Show(); - window->Show(); - - // Create an embedding, and a new window in the embedding. - std::unique_ptr<EmbeddingHelper> embedding_helper = - setup.CreateEmbedding(window, mojom::kEmbedFlagEmbedderInterceptsEvents); - ASSERT_TRUE(embedding_helper); - aura::Window* window_in_child = - embedding_helper->window_tree_test_helper->NewWindow(); - ASSERT_TRUE(window_in_child); - window_in_child->Show(); - window->AddChild(window_in_child); - EXPECT_TRUE( - embedding_helper->window_tree_test_helper->SetCapture(window_in_child)); - - // Do an initial move (which generates some additional events) and clear - // everything out. - ui::test::EventGenerator event_generator(setup.root()); - event_generator.MoveMouseTo(5, 5); - setup.window_tree_client()->ClearInputEvents(); - window_delegate.ClearEvents(); - embedding_helper->window_tree_client.ClearInputEvents(); - - // Move the mouse. Even though the window in the embedding has capture, the - // event should go to the parent client (setup.window_tree_client()), because - // the embedding was created such that the embedder (parent) intercepts the - // events. - event_generator.MoveMouseTo(6, 6); - EXPECT_TRUE(window_delegate.events().empty()); - EXPECT_EQ("POINTER_MOVED", - EventToEventType( - setup.window_tree_client()->PopInputEvent().event.get())); - EXPECT_TRUE(setup.window_tree_client()->input_events().empty()); - EXPECT_TRUE(embedding_helper->window_tree_client.input_events().empty()); -} - TEST(WindowTreeTest, PointerDownResetOnCaptureChange) { WindowServiceTestSetup setup; aura::Window* window = setup.window_tree_test_helper()->NewWindow(); @@ -2019,10 +2062,10 @@ TEST(WindowTreeTest, DontSendGestures) { // never be forwarded to the client, as it's assumed the client runs its own // gesture recognizer. event_generator.GestureTapAt(gfx::Point(10, 10)); - EXPECT_EQ("POINTER_DOWN", + EXPECT_EQ("ET_TOUCH_PRESSED", EventToEventType( setup.window_tree_client()->PopInputEvent().event.get())); - EXPECT_EQ("POINTER_UP", + EXPECT_EQ("ET_TOUCH_RELEASED", EventToEventType( setup.window_tree_client()->PopInputEvent().event.get())); EXPECT_TRUE(setup.window_tree_client()->input_events().empty()); @@ -2058,5 +2101,48 @@ TEST(WindowTreeTest, DeactivateWindow) { EXPECT_TRUE(wm::IsActiveWindow(top_level1)); } +TEST(WindowTreeTest, AttachFrameSinkId) { + // Create two top-levels and focuses (activates) the second. + WindowServiceTestSetup setup; + aura::Window* top_level = + setup.window_tree_test_helper()->NewTopLevelWindow(); + ASSERT_TRUE(top_level); + top_level->Show(); + + aura::Window* child_window = setup.window_tree_test_helper()->NewWindow(); + ASSERT_TRUE(child_window); + viz::FrameSinkId test_frame_sink_id(101, 102); + viz::HostFrameSinkManager* host_frame_sink_manager = + child_window->env()->context_factory_private()->GetHostFrameSinkManager(); + + // Attach a frame sink to |child_window|. This shouldn't immediately register. + setup.window_tree_test_helper()->window_tree()->AttachFrameSinkId( + setup.window_tree_test_helper()->TransportIdForWindow(child_window), + test_frame_sink_id); + EXPECT_FALSE( + host_frame_sink_manager->IsFrameSinkIdRegistered(test_frame_sink_id)); + + // Add the window to a parent, which should trigger registering the hierarchy. + viz::FakeHostFrameSinkClient test_host_frame_sink_client; + host_frame_sink_manager->RegisterFrameSinkId(test_frame_sink_id, + &test_host_frame_sink_client); + EXPECT_EQ(test_frame_sink_id, + ServerWindow::GetMayBeNull(child_window)->attached_frame_sink_id()); + top_level->AddChild(child_window); + EXPECT_TRUE(host_frame_sink_manager->IsFrameSinkHierarchyRegistered( + ServerWindow::GetMayBeNull(top_level)->frame_sink_id(), + test_frame_sink_id)); + + // Removing the window should remove the association. + top_level->RemoveChild(child_window); + EXPECT_FALSE(host_frame_sink_manager->IsFrameSinkHierarchyRegistered( + ServerWindow::GetMayBeNull(top_level)->frame_sink_id(), + test_frame_sink_id)); + + setup.window_tree_test_helper()->DeleteWindow(child_window); + + host_frame_sink_manager->InvalidateFrameSinkId(test_frame_sink_id); +} + } // namespace } // namespace ws |