summaryrefslogtreecommitdiff
path: root/chromium/media/learning
diff options
context:
space:
mode:
Diffstat (limited to 'chromium/media/learning')
-rw-r--r--chromium/media/learning/common/BUILD.gn3
-rw-r--r--chromium/media/learning/common/feature_dictionary.cc38
-rw-r--r--chromium/media/learning/common/feature_dictionary.h52
-rw-r--r--chromium/media/learning/common/feature_dictionary_unittest.cc45
-rw-r--r--chromium/media/learning/common/learning_session.h6
-rw-r--r--chromium/media/learning/common/learning_task.cc7
-rw-r--r--chromium/media/learning/common/learning_task.h3
-rw-r--r--chromium/media/learning/common/learning_task_controller.h3
-rw-r--r--chromium/media/learning/impl/distribution_reporter_unittest.cc4
-rw-r--r--chromium/media/learning/impl/extra_trees_trainer_unittest.cc6
-rw-r--r--chromium/media/learning/impl/learning_fuzzertest.cc9
-rw-r--r--chromium/media/learning/impl/learning_session_impl.cc22
-rw-r--r--chromium/media/learning/impl/learning_session_impl.h7
-rw-r--r--chromium/media/learning/impl/learning_session_impl_unittest.cc40
-rw-r--r--chromium/media/learning/impl/learning_task_controller_helper_unittest.cc28
-rw-r--r--chromium/media/learning/impl/learning_task_controller_impl.cc4
-rw-r--r--chromium/media/learning/impl/learning_task_controller_impl.h1
-rw-r--r--chromium/media/learning/impl/learning_task_controller_impl_unittest.cc8
-rw-r--r--chromium/media/learning/impl/lookup_table_trainer_unittest.cc6
-rw-r--r--chromium/media/learning/impl/random_tree_trainer_unittest.cc6
-rw-r--r--chromium/media/learning/mojo/mojo_learning_task_controller_service_unittest.cc8
-rw-r--r--chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller.cc4
-rw-r--r--chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller.h1
-rw-r--r--chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller_unittest.cc10
-rw-r--r--chromium/media/learning/mojo/public/mojom/learning_types.typemap8
25 files changed, 263 insertions, 66 deletions
diff --git a/chromium/media/learning/common/BUILD.gn b/chromium/media/learning/common/BUILD.gn
index 4e208d3a9b7..b86b09c0355 100644
--- a/chromium/media/learning/common/BUILD.gn
+++ b/chromium/media/learning/common/BUILD.gn
@@ -23,6 +23,8 @@ component("common") {
defines = [ "IS_LEARNING_COMMON_IMPL" ]
sources = [
+ "feature_dictionary.cc",
+ "feature_dictionary.h",
"feature_library.cc",
"feature_library.h",
"labelled_example.cc",
@@ -45,6 +47,7 @@ component("common") {
source_set("unit_tests") {
testonly = true
sources = [
+ "feature_dictionary_unittest.cc",
"labelled_example_unittest.cc",
"value_unittest.cc",
]
diff --git a/chromium/media/learning/common/feature_dictionary.cc b/chromium/media/learning/common/feature_dictionary.cc
new file mode 100644
index 00000000000..66cf6081f6d
--- /dev/null
+++ b/chromium/media/learning/common/feature_dictionary.cc
@@ -0,0 +1,38 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "media/learning/common/feature_dictionary.h"
+
+namespace media {
+namespace learning {
+
+FeatureDictionary::FeatureDictionary() = default;
+
+FeatureDictionary::~FeatureDictionary() = default;
+
+void FeatureDictionary::Lookup(const LearningTask& task,
+ FeatureVector* features) {
+ const size_t num_features = task.feature_descriptions.size();
+
+ if (features->size() < num_features)
+ features->resize(num_features);
+
+ for (size_t i = 0; i < num_features; i++) {
+ const auto& name = task.feature_descriptions[i].name;
+ auto entry = dictionary_.find(name);
+ if (entry == dictionary_.end())
+ continue;
+
+ // |name| appears in the dictionary, so add its value to |features|.
+ (*features)[i] = entry->second;
+ }
+}
+
+void FeatureDictionary::Add(const std::string& name,
+ const FeatureValue& value) {
+ dictionary_[name] = value;
+}
+
+} // namespace learning
+} // namespace media
diff --git a/chromium/media/learning/common/feature_dictionary.h b/chromium/media/learning/common/feature_dictionary.h
new file mode 100644
index 00000000000..149804501a4
--- /dev/null
+++ b/chromium/media/learning/common/feature_dictionary.h
@@ -0,0 +1,52 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef MEDIA_LEARNING_COMMON_FEATURE_DICTIONARY_H_
+#define MEDIA_LEARNING_COMMON_FEATURE_DICTIONARY_H_
+
+#include <map>
+#include <string>
+
+#include "base/component_export.h"
+#include "base/macros.h"
+#include "media/learning/common/labelled_example.h"
+#include "media/learning/common/learning_task.h"
+
+namespace media {
+namespace learning {
+
+// Dictionary of feature name => value pairs.
+//
+// This is useful if one simply wants to snapshot some features, and apply them
+// to more than one task without recomputing anything.
+//
+// While it's not required, FeatureLibrary is useful to provide the descriptions
+// that a FeatureDictionary will provide, so that the LearningTask and the
+// dictionary agree on names.
+class COMPONENT_EXPORT(LEARNING_COMMON) FeatureDictionary {
+ public:
+ // [feature name] => snapshotted value.
+ using Dictionary = std::map<std::string, FeatureValue>;
+
+ FeatureDictionary();
+ ~FeatureDictionary();
+
+ // Add features for |task| to |features| from our dictionary. Features that
+ // aren't present in the dictionary will be ignored. |features| will be
+ // expanded if needed to match |task|.
+ void Lookup(const LearningTask& task, FeatureVector* features);
+
+ // Add |name| to the dictionary with value |value|.
+ void Add(const std::string& name, const FeatureValue& value);
+
+ private:
+ Dictionary dictionary_;
+
+ DISALLOW_COPY_AND_ASSIGN(FeatureDictionary);
+};
+
+} // namespace learning
+} // namespace media
+
+#endif // MEDIA_LEARNING_COMMON_FEATURE_DICTIONARY_H_
diff --git a/chromium/media/learning/common/feature_dictionary_unittest.cc b/chromium/media/learning/common/feature_dictionary_unittest.cc
new file mode 100644
index 00000000000..9d4a00fab6e
--- /dev/null
+++ b/chromium/media/learning/common/feature_dictionary_unittest.cc
@@ -0,0 +1,45 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "media/learning/common/feature_dictionary.h"
+
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace media {
+namespace learning {
+
+class FeatureDictionaryTest : public testing::Test {};
+
+TEST_F(FeatureDictionaryTest, FillsInFeatures) {
+ FeatureDictionary dict;
+ const std::string feature_name_1("feature 1");
+ const FeatureValue feature_value_1("feature value 1");
+
+ const std::string feature_name_2("feature 2");
+ const FeatureValue feature_value_2("feature value 2");
+
+ const std::string feature_name_3("feature 3");
+ const FeatureValue feature_value_3("feature value 3");
+
+ dict.Add(feature_name_1, feature_value_1);
+ dict.Add(feature_name_2, feature_value_2);
+ dict.Add(feature_name_3, feature_value_3);
+
+ LearningTask task;
+ task.feature_descriptions.push_back({"some other feature"});
+ task.feature_descriptions.push_back({feature_name_3});
+ task.feature_descriptions.push_back({feature_name_1});
+
+ FeatureVector features;
+ features.push_back(FeatureValue(0)); // some other feature
+
+ dict.Lookup(task, &features);
+ EXPECT_EQ(features.size(), 3u);
+ EXPECT_EQ(features[0], FeatureValue(0));
+ EXPECT_EQ(features[1], feature_value_3);
+ EXPECT_EQ(features[2], feature_value_1);
+}
+
+} // namespace learning
+} // namespace media
diff --git a/chromium/media/learning/common/learning_session.h b/chromium/media/learning/common/learning_session.h
index f0fb9ba911e..6468871eab4 100644
--- a/chromium/media/learning/common/learning_session.h
+++ b/chromium/media/learning/common/learning_session.h
@@ -10,6 +10,7 @@
#include "base/component_export.h"
#include "base/macros.h"
+#include "base/supports_user_data.h"
#include "media/learning/common/labelled_example.h"
#include "media/learning/common/learning_task.h"
@@ -19,10 +20,11 @@ namespace learning {
class LearningTaskController;
// Interface to provide a Learner given the task name.
-class COMPONENT_EXPORT(LEARNING_COMMON) LearningSession {
+class COMPONENT_EXPORT(LEARNING_COMMON) LearningSession
+ : public base::SupportsUserData::Data {
public:
LearningSession();
- virtual ~LearningSession();
+ ~LearningSession() override;
// Return a LearningTaskController for the given task.
virtual std::unique_ptr<LearningTaskController> GetController(
diff --git a/chromium/media/learning/common/learning_task.cc b/chromium/media/learning/common/learning_task.cc
index fa5c088c8e2..75c0ae59eb8 100644
--- a/chromium/media/learning/common/learning_task.cc
+++ b/chromium/media/learning/common/learning_task.cc
@@ -5,6 +5,7 @@
#include "media/learning/common/learning_task.h"
#include "base/hash/hash.h"
+#include "base/no_destructor.h"
namespace media {
namespace learning {
@@ -29,5 +30,11 @@ LearningTask::Id LearningTask::GetId() const {
return base::PersistentHash(name);
}
+// static
+const LearningTask& LearningTask::Empty() {
+ static const base::NoDestructor<LearningTask> empty_task;
+ return *empty_task;
+}
+
} // namespace learning
} // namespace media
diff --git a/chromium/media/learning/common/learning_task.h b/chromium/media/learning/common/learning_task.h
index 258dd39016f..9d70b93e1e1 100644
--- a/chromium/media/learning/common/learning_task.h
+++ b/chromium/media/learning/common/learning_task.h
@@ -91,6 +91,9 @@ struct COMPONENT_EXPORT(LEARNING_COMMON) LearningTask {
// unique |name| for the task. This is used to identify this task in UKM.
Id GetId() const;
+ // Returns a reference to an empty learning task.
+ static const LearningTask& Empty();
+
// Unique name for this task.
std::string name;
diff --git a/chromium/media/learning/common/learning_task_controller.h b/chromium/media/learning/common/learning_task_controller.h
index ef098330458..d4fd6ae7104 100644
--- a/chromium/media/learning/common/learning_task_controller.h
+++ b/chromium/media/learning/common/learning_task_controller.h
@@ -67,6 +67,9 @@ class COMPONENT_EXPORT(LEARNING_COMMON) LearningTaskController {
// Notify the LearningTaskController that no completion will be sent.
virtual void CancelObservation(base::UnguessableToken id) = 0;
+ // Returns the LearningTask associated with |this|.
+ virtual const LearningTask& GetLearningTask() = 0;
+
private:
DISALLOW_COPY_AND_ASSIGN(LearningTaskController);
};
diff --git a/chromium/media/learning/impl/distribution_reporter_unittest.cc b/chromium/media/learning/impl/distribution_reporter_unittest.cc
index 55f1f4a2f47..3e968c5061e 100644
--- a/chromium/media/learning/impl/distribution_reporter_unittest.cc
+++ b/chromium/media/learning/impl/distribution_reporter_unittest.cc
@@ -6,7 +6,7 @@
#include <vector>
#include "base/bind.h"
-#include "base/test/scoped_task_environment.h"
+#include "base/test/task_environment.h"
#include "components/ukm/test_ukm_recorder.h"
#include "media/learning/common/learning_task.h"
#include "media/learning/impl/distribution_reporter.h"
@@ -25,7 +25,7 @@ class DistributionReporterTest : public testing::Test {
task_.target_description.ordering = LearningTask::Ordering::kNumeric;
}
- base::test::ScopedTaskEnvironment scoped_task_environment_;
+ base::test::TaskEnvironment task_environment_;
std::unique_ptr<ukm::TestAutoSetUkmRecorder> ukm_recorder_;
diff --git a/chromium/media/learning/impl/extra_trees_trainer_unittest.cc b/chromium/media/learning/impl/extra_trees_trainer_unittest.cc
index d9e18970ced..ed8d1679e3c 100644
--- a/chromium/media/learning/impl/extra_trees_trainer_unittest.cc
+++ b/chromium/media/learning/impl/extra_trees_trainer_unittest.cc
@@ -6,7 +6,7 @@
#include "base/bind.h"
#include "base/memory/ref_counted.h"
-#include "base/test/scoped_task_environment.h"
+#include "base/test/task_environment.h"
#include "media/learning/impl/fisher_iris_dataset.h"
#include "media/learning/impl/test_random_number_generator.h"
#include "testing/gtest/include/gtest/gtest.h"
@@ -38,11 +38,11 @@ class ExtraTreesTest : public testing::TestWithParam<LearningTask::Ordering> {
[](std::unique_ptr<Model>* model_out,
std::unique_ptr<Model> model) { *model_out = std::move(model); },
&model));
- scoped_task_environment_.RunUntilIdle();
+ task_environment_.RunUntilIdle();
return model;
}
- base::test::ScopedTaskEnvironment scoped_task_environment_;
+ base::test::TaskEnvironment task_environment_;
TestRandomNumberGenerator rng_;
ExtraTreesTrainer trainer_;
diff --git a/chromium/media/learning/impl/learning_fuzzertest.cc b/chromium/media/learning/impl/learning_fuzzertest.cc
index cfff3d410e7..385cf8748bf 100644
--- a/chromium/media/learning/impl/learning_fuzzertest.cc
+++ b/chromium/media/learning/impl/learning_fuzzertest.cc
@@ -2,9 +2,10 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
-#include "base/test/scoped_task_environment.h"
+#include <fuzzer/FuzzedDataProvider.h>
+
+#include "base/test/task_environment.h"
#include "media/learning/impl/learning_task_controller_impl.h"
-#include "third_party/libFuzzer/src/utils/FuzzedDataProvider.h"
using media::learning::FeatureValue;
using media::learning::FeatureVector;
@@ -40,7 +41,7 @@ FeatureVector ConsumeFeatureVector(FuzzedDataProvider* provider) {
}
extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
- base::test::ScopedTaskEnvironment scoped_task_environment;
+ base::test::TaskEnvironment task_environment;
FuzzedDataProvider provider(data, size);
LearningTask task;
@@ -67,7 +68,7 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
controller.CompleteObservation(
id, ObservationCompletion(TargetValue(ConsumeDouble(&provider)),
ConsumeDouble(&provider)));
- scoped_task_environment.RunUntilIdle();
+ task_environment.RunUntilIdle();
}
return 0;
diff --git a/chromium/media/learning/impl/learning_session_impl.cc b/chromium/media/learning/impl/learning_session_impl.cc
index 7f4e94f86cf..cbe0f5dd30c 100644
--- a/chromium/media/learning/impl/learning_session_impl.cc
+++ b/chromium/media/learning/impl/learning_session_impl.cc
@@ -22,8 +22,11 @@ class WeakLearningTaskController : public LearningTaskController {
public:
WeakLearningTaskController(
base::WeakPtr<LearningSessionImpl> weak_session,
- base::SequenceBound<LearningTaskController>* controller)
- : weak_session_(std::move(weak_session)), controller_(controller) {}
+ base::SequenceBound<LearningTaskController>* controller,
+ const LearningTask& task)
+ : weak_session_(std::move(weak_session)),
+ controller_(controller),
+ task_(task) {}
~WeakLearningTaskController() override {
if (!weak_session_)
@@ -63,8 +66,11 @@ class WeakLearningTaskController : public LearningTaskController {
id);
}
+ const LearningTask& GetLearningTask() override { return task_; }
+
base::WeakPtr<LearningSessionImpl> weak_session_;
base::SequenceBound<LearningTaskController>* controller_;
+ LearningTask task_;
// Set of ids that have been started but not completed / cancelled yet.
std::set<base::UnguessableToken> outstanding_ids_;
@@ -92,23 +98,25 @@ void LearningSessionImpl::SetTaskControllerFactoryCBForTesting(
std::unique_ptr<LearningTaskController> LearningSessionImpl::GetController(
const std::string& task_name) {
- auto iter = task_map_.find(task_name);
- if (iter == task_map_.end())
+ auto iter = controller_map_.find(task_name);
+ if (iter == controller_map_.end())
return nullptr;
// If there were any way to replace / destroy a controller other than when we
// destroy |this|, then this wouldn't be such a good idea.
return std::make_unique<WeakLearningTaskController>(
- weak_factory_.GetWeakPtr(), &iter->second);
+ weak_factory_.GetWeakPtr(), &iter->second, task_map_[task_name]);
}
void LearningSessionImpl::RegisterTask(
const LearningTask& task,
SequenceBoundFeatureProvider feature_provider) {
- DCHECK(task_map_.count(task.name) == 0);
- task_map_.emplace(
+ DCHECK(controller_map_.count(task.name) == 0);
+ controller_map_.emplace(
task.name,
controller_factory_.Run(task_runner_, task, std::move(feature_provider)));
+
+ task_map_.emplace(task.name, task);
}
} // namespace learning
diff --git a/chromium/media/learning/impl/learning_session_impl.h b/chromium/media/learning/impl/learning_session_impl.h
index 06c3eedb513..dd43123d53a 100644
--- a/chromium/media/learning/impl/learning_session_impl.h
+++ b/chromium/media/learning/impl/learning_session_impl.h
@@ -52,9 +52,12 @@ class COMPONENT_EXPORT(LEARNING_IMPL) LearningSessionImpl
scoped_refptr<base::SequencedTaskRunner> task_runner_;
// [task_name] = task controller.
- using LearningTaskMap =
+ using LearningTaskControllerMap =
std::map<std::string, base::SequenceBound<LearningTaskController>>;
- LearningTaskMap task_map_;
+ LearningTaskControllerMap controller_map_;
+
+ // Used to fetch registered LearningTasks from their name.
+ std::map<std::string, LearningTask> task_map_;
CreateTaskControllerCB controller_factory_;
diff --git a/chromium/media/learning/impl/learning_session_impl_unittest.cc b/chromium/media/learning/impl/learning_session_impl_unittest.cc
index d69ec98280d..f8f23018fa7 100644
--- a/chromium/media/learning/impl/learning_session_impl_unittest.cc
+++ b/chromium/media/learning/impl/learning_session_impl_unittest.cc
@@ -7,7 +7,7 @@
#include <vector>
#include "base/bind.h"
-#include "base/test/scoped_task_environment.h"
+#include "base/test/task_environment.h"
#include "base/threading/sequenced_task_runner_handle.h"
#include "media/learning/common/learning_task_controller.h"
#include "media/learning/impl/learning_session_impl.h"
@@ -58,6 +58,11 @@ class LearningSessionImplTest : public testing::Test {
cancelled_id_ = id;
}
+ const LearningTask& GetLearningTask() override {
+ NOTREACHED();
+ return LearningTask::Empty();
+ }
+
SequenceBoundFeatureProvider feature_provider_;
base::UnguessableToken id_;
FeatureVector features_;
@@ -104,10 +109,10 @@ class LearningSessionImplTest : public testing::Test {
// To prevent a memory leak, reset the session. This will post destruction
// of other objects, so RunUntilIdle().
session_.reset();
- scoped_task_environment_.RunUntilIdle();
+ task_environment_.RunUntilIdle();
}
- base::test::ScopedTaskEnvironment scoped_task_environment_;
+ base::test::TaskEnvironment task_environment_;
scoped_refptr<base::SequencedTaskRunner> task_runner_;
@@ -125,16 +130,29 @@ TEST_F(LearningSessionImplTest, RegisteringTasksCreatesControllers) {
EXPECT_EQ(task_runners_.size(), 0u);
session_->RegisterTask(task_0_);
- scoped_task_environment_.RunUntilIdle();
+ task_environment_.RunUntilIdle();
EXPECT_EQ(task_controllers_.size(), 1u);
EXPECT_EQ(task_runners_.size(), 1u);
EXPECT_EQ(task_runners_[0], task_runner_.get());
session_->RegisterTask(task_1_);
- scoped_task_environment_.RunUntilIdle();
+ task_environment_.RunUntilIdle();
EXPECT_EQ(task_controllers_.size(), 2u);
EXPECT_EQ(task_runners_.size(), 2u);
EXPECT_EQ(task_runners_[1], task_runner_.get());
+
+ // Make sure controllers are being returned for the right tasks.
+ // Note: this test passes because LearningSessionController::GetController()
+ // returns a wrapper around a FakeLTC, instead of the FakeLTC itself. The
+ // wrapper internally built by LearningSessionImpl has a proper implementation
+ // of GetLearningTask(), whereas the FakeLTC does not.
+ std::unique_ptr<LearningTaskController> ltc_0 =
+ session_->GetController(task_0_.name);
+ EXPECT_EQ(ltc_0->GetLearningTask().name, task_0_.name);
+
+ std::unique_ptr<LearningTaskController> ltc_1 =
+ session_->GetController(task_1_.name);
+ EXPECT_EQ(ltc_1->GetLearningTask().name, task_1_.name);
}
TEST_F(LearningSessionImplTest, ExamplesAreForwardedToCorrectTask) {
@@ -160,7 +178,7 @@ TEST_F(LearningSessionImplTest, ExamplesAreForwardedToCorrectTask) {
ltc_1->CompleteObservation(
id, ObservationCompletion(example_1.target_value, example_1.weight));
- scoped_task_environment_.RunUntilIdle();
+ task_environment_.RunUntilIdle();
EXPECT_EQ(task_controllers_[0]->example_, example_0);
EXPECT_EQ(task_controllers_[1]->example_, example_1);
}
@@ -174,7 +192,7 @@ TEST_F(LearningSessionImplTest, ControllerLifetimeScopedToSession) {
// Destroy the session. |controller| should still be usable, though it won't
// forward requests anymore.
session_.reset();
- scoped_task_environment_.RunUntilIdle();
+ task_environment_.RunUntilIdle();
// Should not crash.
controller->BeginObservation(base::UnguessableToken::Create(),
@@ -186,7 +204,7 @@ TEST_F(LearningSessionImplTest, FeatureProviderIsForwarded) {
bool flag = false;
session_->RegisterTask(
task_0_, base::SequenceBound<FakeFeatureProvider>(task_runner_, &flag));
- scoped_task_environment_.RunUntilIdle();
+ task_environment_.RunUntilIdle();
// Registering the task should create a FakeLearningTaskController, which will
// call AddFeatures on the fake FeatureProvider.
EXPECT_TRUE(flag);
@@ -197,18 +215,18 @@ TEST_F(LearningSessionImplTest, DestroyingControllerCancelsObservations) {
std::unique_ptr<LearningTaskController> controller =
session_->GetController(task_0_.name);
- scoped_task_environment_.RunUntilIdle();
+ task_environment_.RunUntilIdle();
// Start an observation and verify that it starts.
base::UnguessableToken id = base::UnguessableToken::Create();
controller->BeginObservation(id, FeatureVector());
- scoped_task_environment_.RunUntilIdle();
+ task_environment_.RunUntilIdle();
EXPECT_EQ(task_controllers_[0]->id_, id);
EXPECT_NE(task_controllers_[0]->cancelled_id_, id);
// Should result in cancelling the observation.
controller.reset();
- scoped_task_environment_.RunUntilIdle();
+ task_environment_.RunUntilIdle();
EXPECT_EQ(task_controllers_[0]->cancelled_id_, id);
}
diff --git a/chromium/media/learning/impl/learning_task_controller_helper_unittest.cc b/chromium/media/learning/impl/learning_task_controller_helper_unittest.cc
index 606c7a7ece2..003cc0c95f6 100644
--- a/chromium/media/learning/impl/learning_task_controller_helper_unittest.cc
+++ b/chromium/media/learning/impl/learning_task_controller_helper_unittest.cc
@@ -7,7 +7,7 @@
#include <vector>
#include "base/bind.h"
-#include "base/test/scoped_task_environment.h"
+#include "base/test/task_environment.h"
#include "base/threading/sequenced_task_runner_handle.h"
#include "media/learning/impl/learning_task_controller_helper.h"
#include "testing/gtest/include/gtest/gtest.h"
@@ -51,7 +51,7 @@ class LearningTaskControllerHelperTest : public testing::Test {
// To prevent a memory leak, reset the helper. This will post destruction
// of other objects, so RunUntilIdle().
helper_.reset();
- scoped_task_environment_.RunUntilIdle();
+ task_environment_.RunUntilIdle();
}
void CreateClient(bool include_fp) {
@@ -60,7 +60,7 @@ class LearningTaskControllerHelperTest : public testing::Test {
if (include_fp) {
sb_fp = base::SequenceBound<FakeFeatureProvider>(task_runner_,
&fp_features_, &fp_cb_);
- scoped_task_environment_.RunUntilIdle();
+ task_environment_.RunUntilIdle();
}
// TODO(liberato): make sure this works without a fp.
@@ -82,7 +82,7 @@ class LearningTaskControllerHelperTest : public testing::Test {
return helper_->pending_example_count_for_testing();
}
- base::test::ScopedTaskEnvironment scoped_task_environment_;
+ base::test::TaskEnvironment task_environment_;
scoped_refptr<base::SequencedTaskRunner> task_runner_;
@@ -126,7 +126,7 @@ TEST_F(LearningTaskControllerHelperTest, DropTargetValueWithoutFPWorks) {
helper_->BeginObservation(id_, example_.features);
EXPECT_EQ(pending_example_count(), 1u);
helper_->CancelObservation(id_);
- scoped_task_environment_.RunUntilIdle();
+ task_environment_.RunUntilIdle();
EXPECT_FALSE(most_recent_example_);
EXPECT_EQ(pending_example_count(), 0u);
}
@@ -136,7 +136,7 @@ TEST_F(LearningTaskControllerHelperTest, AddTargetValueBeforeFP) {
CreateClient(true);
helper_->BeginObservation(id_, example_.features);
EXPECT_EQ(pending_example_count(), 1u);
- scoped_task_environment_.RunUntilIdle();
+ task_environment_.RunUntilIdle();
// The feature provider should know about the example.
EXPECT_EQ(fp_features_, example_.features);
@@ -149,7 +149,7 @@ TEST_F(LearningTaskControllerHelperTest, AddTargetValueBeforeFP) {
// Add the features, and verify that they arrive at the AddExampleCB.
example_.features[0] = FeatureValue(456);
std::move(fp_cb_).Run(example_.features);
- scoped_task_environment_.RunUntilIdle();
+ task_environment_.RunUntilIdle();
EXPECT_EQ(pending_example_count(), 0u);
EXPECT_TRUE(most_recent_example_);
EXPECT_EQ(*most_recent_example_, example_);
@@ -161,7 +161,7 @@ TEST_F(LearningTaskControllerHelperTest, DropTargetValueBeforeFP) {
CreateClient(true);
helper_->BeginObservation(id_, example_.features);
EXPECT_EQ(pending_example_count(), 1u);
- scoped_task_environment_.RunUntilIdle();
+ task_environment_.RunUntilIdle();
// The feature provider should know about the example.
EXPECT_EQ(fp_features_, example_.features);
@@ -174,7 +174,7 @@ TEST_F(LearningTaskControllerHelperTest, DropTargetValueBeforeFP) {
// example was sent to us.
example_.features[0] = FeatureValue(456);
std::move(fp_cb_).Run(example_.features);
- scoped_task_environment_.RunUntilIdle();
+ task_environment_.RunUntilIdle();
EXPECT_EQ(pending_example_count(), 0u);
EXPECT_FALSE(most_recent_example_);
}
@@ -184,7 +184,7 @@ TEST_F(LearningTaskControllerHelperTest, AddTargetValueAfterFP) {
CreateClient(true);
helper_->BeginObservation(id_, example_.features);
EXPECT_EQ(pending_example_count(), 1u);
- scoped_task_environment_.RunUntilIdle();
+ task_environment_.RunUntilIdle();
// The feature provider should know about the example.
EXPECT_EQ(fp_features_, example_.features);
EXPECT_EQ(pending_example_count(), 1u);
@@ -192,7 +192,7 @@ TEST_F(LearningTaskControllerHelperTest, AddTargetValueAfterFP) {
// Add the features, and verify that the example isn't sent yet.
example_.features[0] = FeatureValue(456);
std::move(fp_cb_).Run(example_.features);
- scoped_task_environment_.RunUntilIdle();
+ task_environment_.RunUntilIdle();
EXPECT_FALSE(most_recent_example_);
EXPECT_EQ(pending_example_count(), 1u);
@@ -210,7 +210,7 @@ TEST_F(LearningTaskControllerHelperTest, DropTargetValueAfterFP) {
CreateClient(true);
helper_->BeginObservation(id_, example_.features);
EXPECT_EQ(pending_example_count(), 1u);
- scoped_task_environment_.RunUntilIdle();
+ task_environment_.RunUntilIdle();
// The feature provider should know about the example.
EXPECT_EQ(fp_features_, example_.features);
EXPECT_EQ(pending_example_count(), 1u);
@@ -220,14 +220,14 @@ TEST_F(LearningTaskControllerHelperTest, DropTargetValueAfterFP) {
// callback yet; we might send a TargetValue.
example_.features[0] = FeatureValue(456);
std::move(fp_cb_).Run(example_.features);
- scoped_task_environment_.RunUntilIdle();
+ task_environment_.RunUntilIdle();
EXPECT_FALSE(most_recent_example_);
EXPECT_EQ(pending_example_count(), 1u);
// Cancel the observation, and verify that the pending example has been
// removed, and no example was sent to us.
helper_->CancelObservation(id_);
- scoped_task_environment_.RunUntilIdle();
+ task_environment_.RunUntilIdle();
EXPECT_FALSE(most_recent_example_);
EXPECT_EQ(pending_example_count(), 0u);
}
diff --git a/chromium/media/learning/impl/learning_task_controller_impl.cc b/chromium/media/learning/impl/learning_task_controller_impl.cc
index 50a89482cdb..c9bb0f35365 100644
--- a/chromium/media/learning/impl/learning_task_controller_impl.cc
+++ b/chromium/media/learning/impl/learning_task_controller_impl.cc
@@ -75,6 +75,10 @@ void LearningTaskControllerImpl::CancelObservation(base::UnguessableToken id) {
helper_->CancelObservation(id);
}
+const LearningTask& LearningTaskControllerImpl::GetLearningTask() {
+ return task_;
+}
+
void LearningTaskControllerImpl::AddFinishedExample(LabelledExample example,
ukm::SourceId source_id) {
// Verify that we have a trainer and that we got the right number of features.
diff --git a/chromium/media/learning/impl/learning_task_controller_impl.h b/chromium/media/learning/impl/learning_task_controller_impl.h
index 06df120b045..ae011fde55b 100644
--- a/chromium/media/learning/impl/learning_task_controller_impl.h
+++ b/chromium/media/learning/impl/learning_task_controller_impl.h
@@ -51,6 +51,7 @@ class COMPONENT_EXPORT(LEARNING_IMPL) LearningTaskControllerImpl
void CompleteObservation(base::UnguessableToken id,
const ObservationCompletion& completion) override;
void CancelObservation(base::UnguessableToken id) override;
+ const LearningTask& GetLearningTask() override;
private:
// Add |example| to the training data, and process it.
diff --git a/chromium/media/learning/impl/learning_task_controller_impl_unittest.cc b/chromium/media/learning/impl/learning_task_controller_impl_unittest.cc
index 9daec3aeaf1..0faa166be69 100644
--- a/chromium/media/learning/impl/learning_task_controller_impl_unittest.cc
+++ b/chromium/media/learning/impl/learning_task_controller_impl_unittest.cc
@@ -7,7 +7,7 @@
#include <utility>
#include "base/bind.h"
-#include "base/test/scoped_task_environment.h"
+#include "base/test/task_environment.h"
#include "base/threading/sequenced_task_runner_handle.h"
#include "media/learning/impl/distribution_reporter.h"
#include "testing/gtest/include/gtest/gtest.h"
@@ -112,7 +112,7 @@ class LearningTaskControllerImplTest : public testing::Test {
// To prevent a memory leak, reset the controller. This may post
// destruction of other objects, so RunUntilIdle().
controller_.reset();
- scoped_task_environment_.RunUntilIdle();
+ task_environment_.RunUntilIdle();
}
void CreateController(SequenceBoundFeatureProvider feature_provider =
@@ -137,7 +137,7 @@ class LearningTaskControllerImplTest : public testing::Test {
id, ObservationCompletion(example.target_value, example.weight));
}
- base::test::ScopedTaskEnvironment scoped_task_environment_;
+ base::test::TaskEnvironment task_environment_;
// Number of models that we trained.
int num_models_ = 0;
@@ -208,7 +208,7 @@ TEST_F(LearningTaskControllerImplTest, FeatureProviderIsUsed) {
example.features.push_back(FeatureValue(123));
example.weight = 321u;
AddExample(example);
- scoped_task_environment_.RunUntilIdle();
+ task_environment_.RunUntilIdle();
EXPECT_EQ(trainer_raw_->training_data()[0].features[0], FeatureValue(124));
EXPECT_EQ(trainer_raw_->training_data()[0].weight, example.weight);
}
diff --git a/chromium/media/learning/impl/lookup_table_trainer_unittest.cc b/chromium/media/learning/impl/lookup_table_trainer_unittest.cc
index 47618746617..fa0fc88f65c 100644
--- a/chromium/media/learning/impl/lookup_table_trainer_unittest.cc
+++ b/chromium/media/learning/impl/lookup_table_trainer_unittest.cc
@@ -6,7 +6,7 @@
#include "base/bind.h"
#include "base/run_loop.h"
-#include "base/test/scoped_task_environment.h"
+#include "base/test/task_environment.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace media {
@@ -23,11 +23,11 @@ class LookupTableTrainerTest : public testing::Test {
[](std::unique_ptr<Model>* model_out,
std::unique_ptr<Model> model) { *model_out = std::move(model); },
&model));
- scoped_task_environment_.RunUntilIdle();
+ task_environment_.RunUntilIdle();
return model;
}
- base::test::ScopedTaskEnvironment scoped_task_environment_;
+ base::test::TaskEnvironment task_environment_;
LookupTableTrainer trainer_;
LearningTask task_;
diff --git a/chromium/media/learning/impl/random_tree_trainer_unittest.cc b/chromium/media/learning/impl/random_tree_trainer_unittest.cc
index f9face03115..e289f06e07c 100644
--- a/chromium/media/learning/impl/random_tree_trainer_unittest.cc
+++ b/chromium/media/learning/impl/random_tree_trainer_unittest.cc
@@ -6,7 +6,7 @@
#include "base/bind.h"
#include "base/run_loop.h"
-#include "base/test/scoped_task_environment.h"
+#include "base/test/task_environment.h"
#include "media/learning/impl/test_random_number_generator.h"
#include "testing/gtest/include/gtest/gtest.h"
@@ -38,11 +38,11 @@ class RandomTreeTest : public testing::TestWithParam<LearningTask::Ordering> {
[](std::unique_ptr<Model>* model_out,
std::unique_ptr<Model> model) { *model_out = std::move(model); },
&model));
- scoped_task_environment_.RunUntilIdle();
+ task_environment_.RunUntilIdle();
return model;
}
- base::test::ScopedTaskEnvironment scoped_task_environment_;
+ base::test::TaskEnvironment task_environment_;
TestRandomNumberGenerator rng_;
RandomTreeTrainer trainer_;
diff --git a/chromium/media/learning/mojo/mojo_learning_task_controller_service_unittest.cc b/chromium/media/learning/mojo/mojo_learning_task_controller_service_unittest.cc
index 99e8a391269..b2a38ad25d5 100644
--- a/chromium/media/learning/mojo/mojo_learning_task_controller_service_unittest.cc
+++ b/chromium/media/learning/mojo/mojo_learning_task_controller_service_unittest.cc
@@ -8,7 +8,7 @@
#include "base/bind.h"
#include "base/macros.h"
#include "base/memory/ptr_util.h"
-#include "base/test/scoped_task_environment.h"
+#include "base/test/task_environment.h"
#include "base/threading/thread.h"
#include "media/learning/mojo/mojo_learning_task_controller_service.h"
#include "testing/gtest/include/gtest/gtest.h"
@@ -36,6 +36,10 @@ class MojoLearningTaskControllerServiceTest : public ::testing::Test {
cancel_args_.id_ = id;
}
+ const LearningTask& GetLearningTask() override {
+ return LearningTask::Empty();
+ }
+
struct {
base::UnguessableToken id_;
FeatureVector features_;
@@ -72,7 +76,7 @@ class MojoLearningTaskControllerServiceTest : public ::testing::Test {
LearningTask task_;
// Mojo stuff.
- base::test::ScopedTaskEnvironment scoped_task_environment_;
+ base::test::TaskEnvironment task_environment_;
FakeLearningTaskController* controller_raw_ = nullptr;
diff --git a/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller.cc b/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller.cc
index f4648b38bcb..7cd0cb19589 100644
--- a/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller.cc
+++ b/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller.cc
@@ -35,5 +35,9 @@ void MojoLearningTaskController::CancelObservation(base::UnguessableToken id) {
controller_ptr_->CancelObservation(id);
}
+const LearningTask& MojoLearningTaskController::GetLearningTask() {
+ return LearningTask::Empty();
+}
+
} // namespace learning
} // namespace media
diff --git a/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller.h b/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller.h
index 893d1f0890e..d6c9bed38f1 100644
--- a/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller.h
+++ b/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller.h
@@ -29,6 +29,7 @@ class COMPONENT_EXPORT(MEDIA_LEARNING_MOJO) MojoLearningTaskController
void CompleteObservation(base::UnguessableToken id,
const ObservationCompletion& completion) override;
void CancelObservation(base::UnguessableToken id) override;
+ const LearningTask& GetLearningTask() override;
private:
mojom::LearningTaskControllerPtr controller_ptr_;
diff --git a/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller_unittest.cc b/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller_unittest.cc
index 546aae171b4..b7af774dfbc 100644
--- a/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller_unittest.cc
+++ b/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller_unittest.cc
@@ -8,7 +8,7 @@
#include "base/bind.h"
#include "base/macros.h"
#include "base/memory/ptr_util.h"
-#include "base/test/scoped_task_environment.h"
+#include "base/test/task_environment.h"
#include "base/threading/thread.h"
#include "media/learning/mojo/public/cpp/mojo_learning_task_controller.h"
#include "mojo/public/cpp/bindings/binding.h"
@@ -70,7 +70,7 @@ class MojoLearningTaskControllerTest : public ::testing::Test {
}
// Mojo stuff.
- base::test::ScopedTaskEnvironment scoped_task_environment_;
+ base::test::TaskEnvironment task_environment_;
FakeMojoLearningTaskController fake_learning_controller_;
mojo::Binding<mojom::LearningTaskController> learning_controller_binding_;
@@ -83,7 +83,7 @@ TEST_F(MojoLearningTaskControllerTest, Begin) {
base::UnguessableToken id = base::UnguessableToken::Create();
FeatureVector features = {FeatureValue(123), FeatureValue(456)};
learning_controller_->BeginObservation(id, features);
- scoped_task_environment_.RunUntilIdle();
+ task_environment_.RunUntilIdle();
EXPECT_EQ(id, fake_learning_controller_.begin_args_.id_);
EXPECT_EQ(features, fake_learning_controller_.begin_args_.features_);
}
@@ -92,7 +92,7 @@ TEST_F(MojoLearningTaskControllerTest, Complete) {
base::UnguessableToken id = base::UnguessableToken::Create();
ObservationCompletion completion(TargetValue(1234));
learning_controller_->CompleteObservation(id, completion);
- scoped_task_environment_.RunUntilIdle();
+ task_environment_.RunUntilIdle();
EXPECT_EQ(id, fake_learning_controller_.complete_args_.id_);
EXPECT_EQ(completion.target_value,
fake_learning_controller_.complete_args_.completion_.target_value);
@@ -101,7 +101,7 @@ TEST_F(MojoLearningTaskControllerTest, Complete) {
TEST_F(MojoLearningTaskControllerTest, Cancel) {
base::UnguessableToken id = base::UnguessableToken::Create();
learning_controller_->CancelObservation(id);
- scoped_task_environment_.RunUntilIdle();
+ task_environment_.RunUntilIdle();
EXPECT_EQ(id, fake_learning_controller_.cancel_args_.id_);
}
diff --git a/chromium/media/learning/mojo/public/mojom/learning_types.typemap b/chromium/media/learning/mojo/public/mojom/learning_types.typemap
index beaf1467335..cb7d19c0c5c 100644
--- a/chromium/media/learning/mojo/public/mojom/learning_types.typemap
+++ b/chromium/media/learning/mojo/public/mojom/learning_types.typemap
@@ -13,8 +13,8 @@ public_deps = [
"//media/learning/common",
]
type_mappings = [
- "media.learning.mojom.LabelledExample=media::learning::LabelledExample",
- "media.learning.mojom.FeatureValue=media::learning::FeatureValue",
- "media.learning.mojom.TargetValue=media::learning::TargetValue",
- "media.learning.mojom.ObservationCompletion=media::learning::ObservationCompletion",
+ "media.learning.mojom.LabelledExample=::media::learning::LabelledExample",
+ "media.learning.mojom.FeatureValue=::media::learning::FeatureValue",
+ "media.learning.mojom.TargetValue=::media::learning::TargetValue",
+ "media.learning.mojom.ObservationCompletion=::media::learning::ObservationCompletion",
]