diff options
Diffstat (limited to 'chromium/media/learning')
25 files changed, 263 insertions, 66 deletions
diff --git a/chromium/media/learning/common/BUILD.gn b/chromium/media/learning/common/BUILD.gn index 4e208d3a9b7..b86b09c0355 100644 --- a/chromium/media/learning/common/BUILD.gn +++ b/chromium/media/learning/common/BUILD.gn @@ -23,6 +23,8 @@ component("common") { defines = [ "IS_LEARNING_COMMON_IMPL" ] sources = [ + "feature_dictionary.cc", + "feature_dictionary.h", "feature_library.cc", "feature_library.h", "labelled_example.cc", @@ -45,6 +47,7 @@ component("common") { source_set("unit_tests") { testonly = true sources = [ + "feature_dictionary_unittest.cc", "labelled_example_unittest.cc", "value_unittest.cc", ] diff --git a/chromium/media/learning/common/feature_dictionary.cc b/chromium/media/learning/common/feature_dictionary.cc new file mode 100644 index 00000000000..66cf6081f6d --- /dev/null +++ b/chromium/media/learning/common/feature_dictionary.cc @@ -0,0 +1,38 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "media/learning/common/feature_dictionary.h" + +namespace media { +namespace learning { + +FeatureDictionary::FeatureDictionary() = default; + +FeatureDictionary::~FeatureDictionary() = default; + +void FeatureDictionary::Lookup(const LearningTask& task, + FeatureVector* features) { + const size_t num_features = task.feature_descriptions.size(); + + if (features->size() < num_features) + features->resize(num_features); + + for (size_t i = 0; i < num_features; i++) { + const auto& name = task.feature_descriptions[i].name; + auto entry = dictionary_.find(name); + if (entry == dictionary_.end()) + continue; + + // |name| appears in the dictionary, so add its value to |features|. + (*features)[i] = entry->second; + } +} + +void FeatureDictionary::Add(const std::string& name, + const FeatureValue& value) { + dictionary_[name] = value; +} + +} // namespace learning +} // namespace media diff --git a/chromium/media/learning/common/feature_dictionary.h b/chromium/media/learning/common/feature_dictionary.h new file mode 100644 index 00000000000..149804501a4 --- /dev/null +++ b/chromium/media/learning/common/feature_dictionary.h @@ -0,0 +1,52 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MEDIA_LEARNING_COMMON_FEATURE_DICTIONARY_H_ +#define MEDIA_LEARNING_COMMON_FEATURE_DICTIONARY_H_ + +#include <map> +#include <string> + +#include "base/component_export.h" +#include "base/macros.h" +#include "media/learning/common/labelled_example.h" +#include "media/learning/common/learning_task.h" + +namespace media { +namespace learning { + +// Dictionary of feature name => value pairs. +// +// This is useful if one simply wants to snapshot some features, and apply them +// to more than one task without recomputing anything. +// +// While it's not required, FeatureLibrary is useful to provide the descriptions +// that a FeatureDictionary will provide, so that the LearningTask and the +// dictionary agree on names. +class COMPONENT_EXPORT(LEARNING_COMMON) FeatureDictionary { + public: + // [feature name] => snapshotted value. + using Dictionary = std::map<std::string, FeatureValue>; + + FeatureDictionary(); + ~FeatureDictionary(); + + // Add features for |task| to |features| from our dictionary. Features that + // aren't present in the dictionary will be ignored. |features| will be + // expanded if needed to match |task|. + void Lookup(const LearningTask& task, FeatureVector* features); + + // Add |name| to the dictionary with value |value|. + void Add(const std::string& name, const FeatureValue& value); + + private: + Dictionary dictionary_; + + DISALLOW_COPY_AND_ASSIGN(FeatureDictionary); +}; + +} // namespace learning +} // namespace media + +#endif // MEDIA_LEARNING_COMMON_FEATURE_DICTIONARY_H_ diff --git a/chromium/media/learning/common/feature_dictionary_unittest.cc b/chromium/media/learning/common/feature_dictionary_unittest.cc new file mode 100644 index 00000000000..9d4a00fab6e --- /dev/null +++ b/chromium/media/learning/common/feature_dictionary_unittest.cc @@ -0,0 +1,45 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "media/learning/common/feature_dictionary.h" + +#include "testing/gtest/include/gtest/gtest.h" + +namespace media { +namespace learning { + +class FeatureDictionaryTest : public testing::Test {}; + +TEST_F(FeatureDictionaryTest, FillsInFeatures) { + FeatureDictionary dict; + const std::string feature_name_1("feature 1"); + const FeatureValue feature_value_1("feature value 1"); + + const std::string feature_name_2("feature 2"); + const FeatureValue feature_value_2("feature value 2"); + + const std::string feature_name_3("feature 3"); + const FeatureValue feature_value_3("feature value 3"); + + dict.Add(feature_name_1, feature_value_1); + dict.Add(feature_name_2, feature_value_2); + dict.Add(feature_name_3, feature_value_3); + + LearningTask task; + task.feature_descriptions.push_back({"some other feature"}); + task.feature_descriptions.push_back({feature_name_3}); + task.feature_descriptions.push_back({feature_name_1}); + + FeatureVector features; + features.push_back(FeatureValue(0)); // some other feature + + dict.Lookup(task, &features); + EXPECT_EQ(features.size(), 3u); + EXPECT_EQ(features[0], FeatureValue(0)); + EXPECT_EQ(features[1], feature_value_3); + EXPECT_EQ(features[2], feature_value_1); +} + +} // namespace learning +} // namespace media diff --git a/chromium/media/learning/common/learning_session.h b/chromium/media/learning/common/learning_session.h index f0fb9ba911e..6468871eab4 100644 --- a/chromium/media/learning/common/learning_session.h +++ b/chromium/media/learning/common/learning_session.h @@ -10,6 +10,7 @@ #include "base/component_export.h" #include "base/macros.h" +#include "base/supports_user_data.h" #include "media/learning/common/labelled_example.h" #include "media/learning/common/learning_task.h" @@ -19,10 +20,11 @@ namespace learning { class LearningTaskController; // Interface to provide a Learner given the task name. -class COMPONENT_EXPORT(LEARNING_COMMON) LearningSession { +class COMPONENT_EXPORT(LEARNING_COMMON) LearningSession + : public base::SupportsUserData::Data { public: LearningSession(); - virtual ~LearningSession(); + ~LearningSession() override; // Return a LearningTaskController for the given task. virtual std::unique_ptr<LearningTaskController> GetController( diff --git a/chromium/media/learning/common/learning_task.cc b/chromium/media/learning/common/learning_task.cc index fa5c088c8e2..75c0ae59eb8 100644 --- a/chromium/media/learning/common/learning_task.cc +++ b/chromium/media/learning/common/learning_task.cc @@ -5,6 +5,7 @@ #include "media/learning/common/learning_task.h" #include "base/hash/hash.h" +#include "base/no_destructor.h" namespace media { namespace learning { @@ -29,5 +30,11 @@ LearningTask::Id LearningTask::GetId() const { return base::PersistentHash(name); } +// static +const LearningTask& LearningTask::Empty() { + static const base::NoDestructor<LearningTask> empty_task; + return *empty_task; +} + } // namespace learning } // namespace media diff --git a/chromium/media/learning/common/learning_task.h b/chromium/media/learning/common/learning_task.h index 258dd39016f..9d70b93e1e1 100644 --- a/chromium/media/learning/common/learning_task.h +++ b/chromium/media/learning/common/learning_task.h @@ -91,6 +91,9 @@ struct COMPONENT_EXPORT(LEARNING_COMMON) LearningTask { // unique |name| for the task. This is used to identify this task in UKM. Id GetId() const; + // Returns a reference to an empty learning task. + static const LearningTask& Empty(); + // Unique name for this task. std::string name; diff --git a/chromium/media/learning/common/learning_task_controller.h b/chromium/media/learning/common/learning_task_controller.h index ef098330458..d4fd6ae7104 100644 --- a/chromium/media/learning/common/learning_task_controller.h +++ b/chromium/media/learning/common/learning_task_controller.h @@ -67,6 +67,9 @@ class COMPONENT_EXPORT(LEARNING_COMMON) LearningTaskController { // Notify the LearningTaskController that no completion will be sent. virtual void CancelObservation(base::UnguessableToken id) = 0; + // Returns the LearningTask associated with |this|. + virtual const LearningTask& GetLearningTask() = 0; + private: DISALLOW_COPY_AND_ASSIGN(LearningTaskController); }; diff --git a/chromium/media/learning/impl/distribution_reporter_unittest.cc b/chromium/media/learning/impl/distribution_reporter_unittest.cc index 55f1f4a2f47..3e968c5061e 100644 --- a/chromium/media/learning/impl/distribution_reporter_unittest.cc +++ b/chromium/media/learning/impl/distribution_reporter_unittest.cc @@ -6,7 +6,7 @@ #include <vector> #include "base/bind.h" -#include "base/test/scoped_task_environment.h" +#include "base/test/task_environment.h" #include "components/ukm/test_ukm_recorder.h" #include "media/learning/common/learning_task.h" #include "media/learning/impl/distribution_reporter.h" @@ -25,7 +25,7 @@ class DistributionReporterTest : public testing::Test { task_.target_description.ordering = LearningTask::Ordering::kNumeric; } - base::test::ScopedTaskEnvironment scoped_task_environment_; + base::test::TaskEnvironment task_environment_; std::unique_ptr<ukm::TestAutoSetUkmRecorder> ukm_recorder_; diff --git a/chromium/media/learning/impl/extra_trees_trainer_unittest.cc b/chromium/media/learning/impl/extra_trees_trainer_unittest.cc index d9e18970ced..ed8d1679e3c 100644 --- a/chromium/media/learning/impl/extra_trees_trainer_unittest.cc +++ b/chromium/media/learning/impl/extra_trees_trainer_unittest.cc @@ -6,7 +6,7 @@ #include "base/bind.h" #include "base/memory/ref_counted.h" -#include "base/test/scoped_task_environment.h" +#include "base/test/task_environment.h" #include "media/learning/impl/fisher_iris_dataset.h" #include "media/learning/impl/test_random_number_generator.h" #include "testing/gtest/include/gtest/gtest.h" @@ -38,11 +38,11 @@ class ExtraTreesTest : public testing::TestWithParam<LearningTask::Ordering> { [](std::unique_ptr<Model>* model_out, std::unique_ptr<Model> model) { *model_out = std::move(model); }, &model)); - scoped_task_environment_.RunUntilIdle(); + task_environment_.RunUntilIdle(); return model; } - base::test::ScopedTaskEnvironment scoped_task_environment_; + base::test::TaskEnvironment task_environment_; TestRandomNumberGenerator rng_; ExtraTreesTrainer trainer_; diff --git a/chromium/media/learning/impl/learning_fuzzertest.cc b/chromium/media/learning/impl/learning_fuzzertest.cc index cfff3d410e7..385cf8748bf 100644 --- a/chromium/media/learning/impl/learning_fuzzertest.cc +++ b/chromium/media/learning/impl/learning_fuzzertest.cc @@ -2,9 +2,10 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "base/test/scoped_task_environment.h" +#include <fuzzer/FuzzedDataProvider.h> + +#include "base/test/task_environment.h" #include "media/learning/impl/learning_task_controller_impl.h" -#include "third_party/libFuzzer/src/utils/FuzzedDataProvider.h" using media::learning::FeatureValue; using media::learning::FeatureVector; @@ -40,7 +41,7 @@ FeatureVector ConsumeFeatureVector(FuzzedDataProvider* provider) { } extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { - base::test::ScopedTaskEnvironment scoped_task_environment; + base::test::TaskEnvironment task_environment; FuzzedDataProvider provider(data, size); LearningTask task; @@ -67,7 +68,7 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { controller.CompleteObservation( id, ObservationCompletion(TargetValue(ConsumeDouble(&provider)), ConsumeDouble(&provider))); - scoped_task_environment.RunUntilIdle(); + task_environment.RunUntilIdle(); } return 0; diff --git a/chromium/media/learning/impl/learning_session_impl.cc b/chromium/media/learning/impl/learning_session_impl.cc index 7f4e94f86cf..cbe0f5dd30c 100644 --- a/chromium/media/learning/impl/learning_session_impl.cc +++ b/chromium/media/learning/impl/learning_session_impl.cc @@ -22,8 +22,11 @@ class WeakLearningTaskController : public LearningTaskController { public: WeakLearningTaskController( base::WeakPtr<LearningSessionImpl> weak_session, - base::SequenceBound<LearningTaskController>* controller) - : weak_session_(std::move(weak_session)), controller_(controller) {} + base::SequenceBound<LearningTaskController>* controller, + const LearningTask& task) + : weak_session_(std::move(weak_session)), + controller_(controller), + task_(task) {} ~WeakLearningTaskController() override { if (!weak_session_) @@ -63,8 +66,11 @@ class WeakLearningTaskController : public LearningTaskController { id); } + const LearningTask& GetLearningTask() override { return task_; } + base::WeakPtr<LearningSessionImpl> weak_session_; base::SequenceBound<LearningTaskController>* controller_; + LearningTask task_; // Set of ids that have been started but not completed / cancelled yet. std::set<base::UnguessableToken> outstanding_ids_; @@ -92,23 +98,25 @@ void LearningSessionImpl::SetTaskControllerFactoryCBForTesting( std::unique_ptr<LearningTaskController> LearningSessionImpl::GetController( const std::string& task_name) { - auto iter = task_map_.find(task_name); - if (iter == task_map_.end()) + auto iter = controller_map_.find(task_name); + if (iter == controller_map_.end()) return nullptr; // If there were any way to replace / destroy a controller other than when we // destroy |this|, then this wouldn't be such a good idea. return std::make_unique<WeakLearningTaskController>( - weak_factory_.GetWeakPtr(), &iter->second); + weak_factory_.GetWeakPtr(), &iter->second, task_map_[task_name]); } void LearningSessionImpl::RegisterTask( const LearningTask& task, SequenceBoundFeatureProvider feature_provider) { - DCHECK(task_map_.count(task.name) == 0); - task_map_.emplace( + DCHECK(controller_map_.count(task.name) == 0); + controller_map_.emplace( task.name, controller_factory_.Run(task_runner_, task, std::move(feature_provider))); + + task_map_.emplace(task.name, task); } } // namespace learning diff --git a/chromium/media/learning/impl/learning_session_impl.h b/chromium/media/learning/impl/learning_session_impl.h index 06c3eedb513..dd43123d53a 100644 --- a/chromium/media/learning/impl/learning_session_impl.h +++ b/chromium/media/learning/impl/learning_session_impl.h @@ -52,9 +52,12 @@ class COMPONENT_EXPORT(LEARNING_IMPL) LearningSessionImpl scoped_refptr<base::SequencedTaskRunner> task_runner_; // [task_name] = task controller. - using LearningTaskMap = + using LearningTaskControllerMap = std::map<std::string, base::SequenceBound<LearningTaskController>>; - LearningTaskMap task_map_; + LearningTaskControllerMap controller_map_; + + // Used to fetch registered LearningTasks from their name. + std::map<std::string, LearningTask> task_map_; CreateTaskControllerCB controller_factory_; diff --git a/chromium/media/learning/impl/learning_session_impl_unittest.cc b/chromium/media/learning/impl/learning_session_impl_unittest.cc index d69ec98280d..f8f23018fa7 100644 --- a/chromium/media/learning/impl/learning_session_impl_unittest.cc +++ b/chromium/media/learning/impl/learning_session_impl_unittest.cc @@ -7,7 +7,7 @@ #include <vector> #include "base/bind.h" -#include "base/test/scoped_task_environment.h" +#include "base/test/task_environment.h" #include "base/threading/sequenced_task_runner_handle.h" #include "media/learning/common/learning_task_controller.h" #include "media/learning/impl/learning_session_impl.h" @@ -58,6 +58,11 @@ class LearningSessionImplTest : public testing::Test { cancelled_id_ = id; } + const LearningTask& GetLearningTask() override { + NOTREACHED(); + return LearningTask::Empty(); + } + SequenceBoundFeatureProvider feature_provider_; base::UnguessableToken id_; FeatureVector features_; @@ -104,10 +109,10 @@ class LearningSessionImplTest : public testing::Test { // To prevent a memory leak, reset the session. This will post destruction // of other objects, so RunUntilIdle(). session_.reset(); - scoped_task_environment_.RunUntilIdle(); + task_environment_.RunUntilIdle(); } - base::test::ScopedTaskEnvironment scoped_task_environment_; + base::test::TaskEnvironment task_environment_; scoped_refptr<base::SequencedTaskRunner> task_runner_; @@ -125,16 +130,29 @@ TEST_F(LearningSessionImplTest, RegisteringTasksCreatesControllers) { EXPECT_EQ(task_runners_.size(), 0u); session_->RegisterTask(task_0_); - scoped_task_environment_.RunUntilIdle(); + task_environment_.RunUntilIdle(); EXPECT_EQ(task_controllers_.size(), 1u); EXPECT_EQ(task_runners_.size(), 1u); EXPECT_EQ(task_runners_[0], task_runner_.get()); session_->RegisterTask(task_1_); - scoped_task_environment_.RunUntilIdle(); + task_environment_.RunUntilIdle(); EXPECT_EQ(task_controllers_.size(), 2u); EXPECT_EQ(task_runners_.size(), 2u); EXPECT_EQ(task_runners_[1], task_runner_.get()); + + // Make sure controllers are being returned for the right tasks. + // Note: this test passes because LearningSessionController::GetController() + // returns a wrapper around a FakeLTC, instead of the FakeLTC itself. The + // wrapper internally built by LearningSessionImpl has a proper implementation + // of GetLearningTask(), whereas the FakeLTC does not. + std::unique_ptr<LearningTaskController> ltc_0 = + session_->GetController(task_0_.name); + EXPECT_EQ(ltc_0->GetLearningTask().name, task_0_.name); + + std::unique_ptr<LearningTaskController> ltc_1 = + session_->GetController(task_1_.name); + EXPECT_EQ(ltc_1->GetLearningTask().name, task_1_.name); } TEST_F(LearningSessionImplTest, ExamplesAreForwardedToCorrectTask) { @@ -160,7 +178,7 @@ TEST_F(LearningSessionImplTest, ExamplesAreForwardedToCorrectTask) { ltc_1->CompleteObservation( id, ObservationCompletion(example_1.target_value, example_1.weight)); - scoped_task_environment_.RunUntilIdle(); + task_environment_.RunUntilIdle(); EXPECT_EQ(task_controllers_[0]->example_, example_0); EXPECT_EQ(task_controllers_[1]->example_, example_1); } @@ -174,7 +192,7 @@ TEST_F(LearningSessionImplTest, ControllerLifetimeScopedToSession) { // Destroy the session. |controller| should still be usable, though it won't // forward requests anymore. session_.reset(); - scoped_task_environment_.RunUntilIdle(); + task_environment_.RunUntilIdle(); // Should not crash. controller->BeginObservation(base::UnguessableToken::Create(), @@ -186,7 +204,7 @@ TEST_F(LearningSessionImplTest, FeatureProviderIsForwarded) { bool flag = false; session_->RegisterTask( task_0_, base::SequenceBound<FakeFeatureProvider>(task_runner_, &flag)); - scoped_task_environment_.RunUntilIdle(); + task_environment_.RunUntilIdle(); // Registering the task should create a FakeLearningTaskController, which will // call AddFeatures on the fake FeatureProvider. EXPECT_TRUE(flag); @@ -197,18 +215,18 @@ TEST_F(LearningSessionImplTest, DestroyingControllerCancelsObservations) { std::unique_ptr<LearningTaskController> controller = session_->GetController(task_0_.name); - scoped_task_environment_.RunUntilIdle(); + task_environment_.RunUntilIdle(); // Start an observation and verify that it starts. base::UnguessableToken id = base::UnguessableToken::Create(); controller->BeginObservation(id, FeatureVector()); - scoped_task_environment_.RunUntilIdle(); + task_environment_.RunUntilIdle(); EXPECT_EQ(task_controllers_[0]->id_, id); EXPECT_NE(task_controllers_[0]->cancelled_id_, id); // Should result in cancelling the observation. controller.reset(); - scoped_task_environment_.RunUntilIdle(); + task_environment_.RunUntilIdle(); EXPECT_EQ(task_controllers_[0]->cancelled_id_, id); } diff --git a/chromium/media/learning/impl/learning_task_controller_helper_unittest.cc b/chromium/media/learning/impl/learning_task_controller_helper_unittest.cc index 606c7a7ece2..003cc0c95f6 100644 --- a/chromium/media/learning/impl/learning_task_controller_helper_unittest.cc +++ b/chromium/media/learning/impl/learning_task_controller_helper_unittest.cc @@ -7,7 +7,7 @@ #include <vector> #include "base/bind.h" -#include "base/test/scoped_task_environment.h" +#include "base/test/task_environment.h" #include "base/threading/sequenced_task_runner_handle.h" #include "media/learning/impl/learning_task_controller_helper.h" #include "testing/gtest/include/gtest/gtest.h" @@ -51,7 +51,7 @@ class LearningTaskControllerHelperTest : public testing::Test { // To prevent a memory leak, reset the helper. This will post destruction // of other objects, so RunUntilIdle(). helper_.reset(); - scoped_task_environment_.RunUntilIdle(); + task_environment_.RunUntilIdle(); } void CreateClient(bool include_fp) { @@ -60,7 +60,7 @@ class LearningTaskControllerHelperTest : public testing::Test { if (include_fp) { sb_fp = base::SequenceBound<FakeFeatureProvider>(task_runner_, &fp_features_, &fp_cb_); - scoped_task_environment_.RunUntilIdle(); + task_environment_.RunUntilIdle(); } // TODO(liberato): make sure this works without a fp. @@ -82,7 +82,7 @@ class LearningTaskControllerHelperTest : public testing::Test { return helper_->pending_example_count_for_testing(); } - base::test::ScopedTaskEnvironment scoped_task_environment_; + base::test::TaskEnvironment task_environment_; scoped_refptr<base::SequencedTaskRunner> task_runner_; @@ -126,7 +126,7 @@ TEST_F(LearningTaskControllerHelperTest, DropTargetValueWithoutFPWorks) { helper_->BeginObservation(id_, example_.features); EXPECT_EQ(pending_example_count(), 1u); helper_->CancelObservation(id_); - scoped_task_environment_.RunUntilIdle(); + task_environment_.RunUntilIdle(); EXPECT_FALSE(most_recent_example_); EXPECT_EQ(pending_example_count(), 0u); } @@ -136,7 +136,7 @@ TEST_F(LearningTaskControllerHelperTest, AddTargetValueBeforeFP) { CreateClient(true); helper_->BeginObservation(id_, example_.features); EXPECT_EQ(pending_example_count(), 1u); - scoped_task_environment_.RunUntilIdle(); + task_environment_.RunUntilIdle(); // The feature provider should know about the example. EXPECT_EQ(fp_features_, example_.features); @@ -149,7 +149,7 @@ TEST_F(LearningTaskControllerHelperTest, AddTargetValueBeforeFP) { // Add the features, and verify that they arrive at the AddExampleCB. example_.features[0] = FeatureValue(456); std::move(fp_cb_).Run(example_.features); - scoped_task_environment_.RunUntilIdle(); + task_environment_.RunUntilIdle(); EXPECT_EQ(pending_example_count(), 0u); EXPECT_TRUE(most_recent_example_); EXPECT_EQ(*most_recent_example_, example_); @@ -161,7 +161,7 @@ TEST_F(LearningTaskControllerHelperTest, DropTargetValueBeforeFP) { CreateClient(true); helper_->BeginObservation(id_, example_.features); EXPECT_EQ(pending_example_count(), 1u); - scoped_task_environment_.RunUntilIdle(); + task_environment_.RunUntilIdle(); // The feature provider should know about the example. EXPECT_EQ(fp_features_, example_.features); @@ -174,7 +174,7 @@ TEST_F(LearningTaskControllerHelperTest, DropTargetValueBeforeFP) { // example was sent to us. example_.features[0] = FeatureValue(456); std::move(fp_cb_).Run(example_.features); - scoped_task_environment_.RunUntilIdle(); + task_environment_.RunUntilIdle(); EXPECT_EQ(pending_example_count(), 0u); EXPECT_FALSE(most_recent_example_); } @@ -184,7 +184,7 @@ TEST_F(LearningTaskControllerHelperTest, AddTargetValueAfterFP) { CreateClient(true); helper_->BeginObservation(id_, example_.features); EXPECT_EQ(pending_example_count(), 1u); - scoped_task_environment_.RunUntilIdle(); + task_environment_.RunUntilIdle(); // The feature provider should know about the example. EXPECT_EQ(fp_features_, example_.features); EXPECT_EQ(pending_example_count(), 1u); @@ -192,7 +192,7 @@ TEST_F(LearningTaskControllerHelperTest, AddTargetValueAfterFP) { // Add the features, and verify that the example isn't sent yet. example_.features[0] = FeatureValue(456); std::move(fp_cb_).Run(example_.features); - scoped_task_environment_.RunUntilIdle(); + task_environment_.RunUntilIdle(); EXPECT_FALSE(most_recent_example_); EXPECT_EQ(pending_example_count(), 1u); @@ -210,7 +210,7 @@ TEST_F(LearningTaskControllerHelperTest, DropTargetValueAfterFP) { CreateClient(true); helper_->BeginObservation(id_, example_.features); EXPECT_EQ(pending_example_count(), 1u); - scoped_task_environment_.RunUntilIdle(); + task_environment_.RunUntilIdle(); // The feature provider should know about the example. EXPECT_EQ(fp_features_, example_.features); EXPECT_EQ(pending_example_count(), 1u); @@ -220,14 +220,14 @@ TEST_F(LearningTaskControllerHelperTest, DropTargetValueAfterFP) { // callback yet; we might send a TargetValue. example_.features[0] = FeatureValue(456); std::move(fp_cb_).Run(example_.features); - scoped_task_environment_.RunUntilIdle(); + task_environment_.RunUntilIdle(); EXPECT_FALSE(most_recent_example_); EXPECT_EQ(pending_example_count(), 1u); // Cancel the observation, and verify that the pending example has been // removed, and no example was sent to us. helper_->CancelObservation(id_); - scoped_task_environment_.RunUntilIdle(); + task_environment_.RunUntilIdle(); EXPECT_FALSE(most_recent_example_); EXPECT_EQ(pending_example_count(), 0u); } diff --git a/chromium/media/learning/impl/learning_task_controller_impl.cc b/chromium/media/learning/impl/learning_task_controller_impl.cc index 50a89482cdb..c9bb0f35365 100644 --- a/chromium/media/learning/impl/learning_task_controller_impl.cc +++ b/chromium/media/learning/impl/learning_task_controller_impl.cc @@ -75,6 +75,10 @@ void LearningTaskControllerImpl::CancelObservation(base::UnguessableToken id) { helper_->CancelObservation(id); } +const LearningTask& LearningTaskControllerImpl::GetLearningTask() { + return task_; +} + void LearningTaskControllerImpl::AddFinishedExample(LabelledExample example, ukm::SourceId source_id) { // Verify that we have a trainer and that we got the right number of features. diff --git a/chromium/media/learning/impl/learning_task_controller_impl.h b/chromium/media/learning/impl/learning_task_controller_impl.h index 06df120b045..ae011fde55b 100644 --- a/chromium/media/learning/impl/learning_task_controller_impl.h +++ b/chromium/media/learning/impl/learning_task_controller_impl.h @@ -51,6 +51,7 @@ class COMPONENT_EXPORT(LEARNING_IMPL) LearningTaskControllerImpl void CompleteObservation(base::UnguessableToken id, const ObservationCompletion& completion) override; void CancelObservation(base::UnguessableToken id) override; + const LearningTask& GetLearningTask() override; private: // Add |example| to the training data, and process it. diff --git a/chromium/media/learning/impl/learning_task_controller_impl_unittest.cc b/chromium/media/learning/impl/learning_task_controller_impl_unittest.cc index 9daec3aeaf1..0faa166be69 100644 --- a/chromium/media/learning/impl/learning_task_controller_impl_unittest.cc +++ b/chromium/media/learning/impl/learning_task_controller_impl_unittest.cc @@ -7,7 +7,7 @@ #include <utility> #include "base/bind.h" -#include "base/test/scoped_task_environment.h" +#include "base/test/task_environment.h" #include "base/threading/sequenced_task_runner_handle.h" #include "media/learning/impl/distribution_reporter.h" #include "testing/gtest/include/gtest/gtest.h" @@ -112,7 +112,7 @@ class LearningTaskControllerImplTest : public testing::Test { // To prevent a memory leak, reset the controller. This may post // destruction of other objects, so RunUntilIdle(). controller_.reset(); - scoped_task_environment_.RunUntilIdle(); + task_environment_.RunUntilIdle(); } void CreateController(SequenceBoundFeatureProvider feature_provider = @@ -137,7 +137,7 @@ class LearningTaskControllerImplTest : public testing::Test { id, ObservationCompletion(example.target_value, example.weight)); } - base::test::ScopedTaskEnvironment scoped_task_environment_; + base::test::TaskEnvironment task_environment_; // Number of models that we trained. int num_models_ = 0; @@ -208,7 +208,7 @@ TEST_F(LearningTaskControllerImplTest, FeatureProviderIsUsed) { example.features.push_back(FeatureValue(123)); example.weight = 321u; AddExample(example); - scoped_task_environment_.RunUntilIdle(); + task_environment_.RunUntilIdle(); EXPECT_EQ(trainer_raw_->training_data()[0].features[0], FeatureValue(124)); EXPECT_EQ(trainer_raw_->training_data()[0].weight, example.weight); } diff --git a/chromium/media/learning/impl/lookup_table_trainer_unittest.cc b/chromium/media/learning/impl/lookup_table_trainer_unittest.cc index 47618746617..fa0fc88f65c 100644 --- a/chromium/media/learning/impl/lookup_table_trainer_unittest.cc +++ b/chromium/media/learning/impl/lookup_table_trainer_unittest.cc @@ -6,7 +6,7 @@ #include "base/bind.h" #include "base/run_loop.h" -#include "base/test/scoped_task_environment.h" +#include "base/test/task_environment.h" #include "testing/gtest/include/gtest/gtest.h" namespace media { @@ -23,11 +23,11 @@ class LookupTableTrainerTest : public testing::Test { [](std::unique_ptr<Model>* model_out, std::unique_ptr<Model> model) { *model_out = std::move(model); }, &model)); - scoped_task_environment_.RunUntilIdle(); + task_environment_.RunUntilIdle(); return model; } - base::test::ScopedTaskEnvironment scoped_task_environment_; + base::test::TaskEnvironment task_environment_; LookupTableTrainer trainer_; LearningTask task_; diff --git a/chromium/media/learning/impl/random_tree_trainer_unittest.cc b/chromium/media/learning/impl/random_tree_trainer_unittest.cc index f9face03115..e289f06e07c 100644 --- a/chromium/media/learning/impl/random_tree_trainer_unittest.cc +++ b/chromium/media/learning/impl/random_tree_trainer_unittest.cc @@ -6,7 +6,7 @@ #include "base/bind.h" #include "base/run_loop.h" -#include "base/test/scoped_task_environment.h" +#include "base/test/task_environment.h" #include "media/learning/impl/test_random_number_generator.h" #include "testing/gtest/include/gtest/gtest.h" @@ -38,11 +38,11 @@ class RandomTreeTest : public testing::TestWithParam<LearningTask::Ordering> { [](std::unique_ptr<Model>* model_out, std::unique_ptr<Model> model) { *model_out = std::move(model); }, &model)); - scoped_task_environment_.RunUntilIdle(); + task_environment_.RunUntilIdle(); return model; } - base::test::ScopedTaskEnvironment scoped_task_environment_; + base::test::TaskEnvironment task_environment_; TestRandomNumberGenerator rng_; RandomTreeTrainer trainer_; diff --git a/chromium/media/learning/mojo/mojo_learning_task_controller_service_unittest.cc b/chromium/media/learning/mojo/mojo_learning_task_controller_service_unittest.cc index 99e8a391269..b2a38ad25d5 100644 --- a/chromium/media/learning/mojo/mojo_learning_task_controller_service_unittest.cc +++ b/chromium/media/learning/mojo/mojo_learning_task_controller_service_unittest.cc @@ -8,7 +8,7 @@ #include "base/bind.h" #include "base/macros.h" #include "base/memory/ptr_util.h" -#include "base/test/scoped_task_environment.h" +#include "base/test/task_environment.h" #include "base/threading/thread.h" #include "media/learning/mojo/mojo_learning_task_controller_service.h" #include "testing/gtest/include/gtest/gtest.h" @@ -36,6 +36,10 @@ class MojoLearningTaskControllerServiceTest : public ::testing::Test { cancel_args_.id_ = id; } + const LearningTask& GetLearningTask() override { + return LearningTask::Empty(); + } + struct { base::UnguessableToken id_; FeatureVector features_; @@ -72,7 +76,7 @@ class MojoLearningTaskControllerServiceTest : public ::testing::Test { LearningTask task_; // Mojo stuff. - base::test::ScopedTaskEnvironment scoped_task_environment_; + base::test::TaskEnvironment task_environment_; FakeLearningTaskController* controller_raw_ = nullptr; diff --git a/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller.cc b/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller.cc index f4648b38bcb..7cd0cb19589 100644 --- a/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller.cc +++ b/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller.cc @@ -35,5 +35,9 @@ void MojoLearningTaskController::CancelObservation(base::UnguessableToken id) { controller_ptr_->CancelObservation(id); } +const LearningTask& MojoLearningTaskController::GetLearningTask() { + return LearningTask::Empty(); +} + } // namespace learning } // namespace media diff --git a/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller.h b/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller.h index 893d1f0890e..d6c9bed38f1 100644 --- a/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller.h +++ b/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller.h @@ -29,6 +29,7 @@ class COMPONENT_EXPORT(MEDIA_LEARNING_MOJO) MojoLearningTaskController void CompleteObservation(base::UnguessableToken id, const ObservationCompletion& completion) override; void CancelObservation(base::UnguessableToken id) override; + const LearningTask& GetLearningTask() override; private: mojom::LearningTaskControllerPtr controller_ptr_; diff --git a/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller_unittest.cc b/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller_unittest.cc index 546aae171b4..b7af774dfbc 100644 --- a/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller_unittest.cc +++ b/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller_unittest.cc @@ -8,7 +8,7 @@ #include "base/bind.h" #include "base/macros.h" #include "base/memory/ptr_util.h" -#include "base/test/scoped_task_environment.h" +#include "base/test/task_environment.h" #include "base/threading/thread.h" #include "media/learning/mojo/public/cpp/mojo_learning_task_controller.h" #include "mojo/public/cpp/bindings/binding.h" @@ -70,7 +70,7 @@ class MojoLearningTaskControllerTest : public ::testing::Test { } // Mojo stuff. - base::test::ScopedTaskEnvironment scoped_task_environment_; + base::test::TaskEnvironment task_environment_; FakeMojoLearningTaskController fake_learning_controller_; mojo::Binding<mojom::LearningTaskController> learning_controller_binding_; @@ -83,7 +83,7 @@ TEST_F(MojoLearningTaskControllerTest, Begin) { base::UnguessableToken id = base::UnguessableToken::Create(); FeatureVector features = {FeatureValue(123), FeatureValue(456)}; learning_controller_->BeginObservation(id, features); - scoped_task_environment_.RunUntilIdle(); + task_environment_.RunUntilIdle(); EXPECT_EQ(id, fake_learning_controller_.begin_args_.id_); EXPECT_EQ(features, fake_learning_controller_.begin_args_.features_); } @@ -92,7 +92,7 @@ TEST_F(MojoLearningTaskControllerTest, Complete) { base::UnguessableToken id = base::UnguessableToken::Create(); ObservationCompletion completion(TargetValue(1234)); learning_controller_->CompleteObservation(id, completion); - scoped_task_environment_.RunUntilIdle(); + task_environment_.RunUntilIdle(); EXPECT_EQ(id, fake_learning_controller_.complete_args_.id_); EXPECT_EQ(completion.target_value, fake_learning_controller_.complete_args_.completion_.target_value); @@ -101,7 +101,7 @@ TEST_F(MojoLearningTaskControllerTest, Complete) { TEST_F(MojoLearningTaskControllerTest, Cancel) { base::UnguessableToken id = base::UnguessableToken::Create(); learning_controller_->CancelObservation(id); - scoped_task_environment_.RunUntilIdle(); + task_environment_.RunUntilIdle(); EXPECT_EQ(id, fake_learning_controller_.cancel_args_.id_); } diff --git a/chromium/media/learning/mojo/public/mojom/learning_types.typemap b/chromium/media/learning/mojo/public/mojom/learning_types.typemap index beaf1467335..cb7d19c0c5c 100644 --- a/chromium/media/learning/mojo/public/mojom/learning_types.typemap +++ b/chromium/media/learning/mojo/public/mojom/learning_types.typemap @@ -13,8 +13,8 @@ public_deps = [ "//media/learning/common", ] type_mappings = [ - "media.learning.mojom.LabelledExample=media::learning::LabelledExample", - "media.learning.mojom.FeatureValue=media::learning::FeatureValue", - "media.learning.mojom.TargetValue=media::learning::TargetValue", - "media.learning.mojom.ObservationCompletion=media::learning::ObservationCompletion", + "media.learning.mojom.LabelledExample=::media::learning::LabelledExample", + "media.learning.mojom.FeatureValue=::media::learning::FeatureValue", + "media.learning.mojom.TargetValue=::media::learning::TargetValue", + "media.learning.mojom.ObservationCompletion=::media::learning::ObservationCompletion", ] |