diff options
author | Allan Sandfeld Jensen <allan.jensen@qt.io> | 2019-05-24 11:40:17 +0200 |
---|---|---|
committer | Allan Sandfeld Jensen <allan.jensen@qt.io> | 2019-05-24 12:42:11 +0000 |
commit | 5d87695f37678f96492b258bbab36486c59866b4 (patch) | |
tree | be9783bbaf04fb930c4d74ca9c00b5e7954c8bc6 /chromium/media/learning | |
parent | 6c11fb357ec39bf087b8b632e2b1e375aef1b38b (diff) | |
download | qtwebengine-chromium-5d87695f37678f96492b258bbab36486c59866b4.tar.gz |
BASELINE: Update Chromium to 75.0.3770.56
Change-Id: I86d2007fd27a45d5797eee06f4c9369b8b50ac4f
Reviewed-by: Alexandru Croitor <alexandru.croitor@qt.io>
Diffstat (limited to 'chromium/media/learning')
55 files changed, 1692 insertions, 838 deletions
diff --git a/chromium/media/learning/common/learning_session.h b/chromium/media/learning/common/learning_session.h index 22db890c2c4..f0fb9ba911e 100644 --- a/chromium/media/learning/common/learning_session.h +++ b/chromium/media/learning/common/learning_session.h @@ -5,6 +5,7 @@ #ifndef MEDIA_LEARNING_COMMON_LEARNING_SESSION_H_ #define MEDIA_LEARNING_COMMON_LEARNING_SESSION_H_ +#include <memory> #include <string> #include "base/component_export.h" @@ -15,18 +16,17 @@ namespace media { namespace learning { +class LearningTaskController; + // Interface to provide a Learner given the task name. class COMPONENT_EXPORT(LEARNING_COMMON) LearningSession { public: LearningSession(); virtual ~LearningSession(); - // Add an observed example |example| to the learning task |task_name|. - // TODO(liberato): Consider making this an enum to match mojo. - virtual void AddExample(const std::string& task_name, - const LabelledExample& example) = 0; - - // TODO(liberato): Add prediction API. + // Return a LearningTaskController for the given task. + virtual std::unique_ptr<LearningTaskController> GetController( + const std::string& task_name) = 0; private: DISALLOW_COPY_AND_ASSIGN(LearningSession); diff --git a/chromium/media/learning/common/learning_task.h b/chromium/media/learning/common/learning_task.h index a7ca9b9f05b..ed5f4f76ac3 100644 --- a/chromium/media/learning/common/learning_task.h +++ b/chromium/media/learning/common/learning_task.h @@ -10,6 +10,7 @@ #include <vector> #include "base/component_export.h" +#include "base/optional.h" #include "media/learning/common/value.h" namespace media { @@ -29,6 +30,9 @@ struct COMPONENT_EXPORT(LEARNING_COMMON) LearningTask { enum class Model { kExtraTrees, kLookupTable, + + // For the fuzzer. + kMaxValue = kLookupTable }; enum class Ordering { @@ -43,6 +47,9 @@ struct COMPONENT_EXPORT(LEARNING_COMMON) LearningTask { // ints that represent the number of elapsed milliseconds are numerically // ordered in a meaningful way. kNumeric, + + // For the fuzzer. + kMaxValue = kNumeric }; enum class PrivacyMode { @@ -52,6 +59,9 @@ struct COMPONENT_EXPORT(LEARNING_COMMON) LearningTask { // Value does not represent private information, such as video width. kPublic, + + // For the fuzzer. + kMaxValue = kPublic }; // Description of how a Value should be interpreted. @@ -97,9 +107,11 @@ struct COMPONENT_EXPORT(LEARNING_COMMON) LearningTask { // we currently have, which might be less than |max_data_set_size|. double min_new_data_fraction = 0.1; - // If set, then we'll record a confusion matrix hackily to UMA using this as - // the histogram name. - std::string uma_hacky_confusion_matrix; + // If provided, then we'll randomly select a |*feature_subset_size|-sized set + // of feature to train the model with, to allow for feature importance + // measurement. Note that UMA reporting only supports subsets of size one, or + // the whole set. + base::Optional<int> feature_subset_size; // RandomForest parameters @@ -120,7 +132,57 @@ struct COMPONENT_EXPORT(LEARNING_COMMON) LearningTask { // // In particular, if the percentage of dropped frames is greater than this, // then report "false" (not smooth), else we report true. + // + // A better, non-hacky approach would be to report the predictions and + // observations directly, and do offline analysis with whatever threshold we + // like. This would remove the thresholding requirement, and also permit + // additional types of analysis for general regression tasks, such measuring + // the prediction error directly. + // + // The UKM reporter will support this. double smoothness_threshold = 0.1; + + // If set, then we'll record a confusion matrix (hackily, see + // |smoothness_threshold|, above, for what that means) to UMA for all + // predictions. Add this task's name to histograms.xml, in the histogram + // suffixes for "Media.Learning.BinaryThreshold.Aggregate". The threshold is + // chosen by |smoothness_threshold|. + // + // This option is ignored if feature subset selection is in use. + bool uma_hacky_aggregate_confusion_matrix = false; + + // If set, then we'll record a histogram of many confusion matrices, split out + // by the total training data weight that was used to construct the model. Be + // sure to add this task's name to histograms.xml, in the histogram suffixes + // for "Media.Learning.BinaryThreshold.ByTrainingWeight". The threshold is + // chosen by |smoothness_threshold|. + // + // This option is ignored if feature subset selection is in use. + bool uma_hacky_by_training_weight_confusion_matrix = false; + + // If set, then we'll record a histogram of many confusion matrices, split out + // by the (single) selected feature subset. This does nothing if we're not + // using feature subsets, or if the subset size isn't one. Be sure to add + // this tasks' name to histograms.xml, in the histogram suffixes for + // "Media.Learning.BinaryThreshold.ByFeature" too. + bool uma_hacky_by_feature_subset_confusion_matrix = false; + + // Maximum training weight for UMA reporting. We'll report results offset + // into different confusion matrices in the same histogram, evenly spaced + // from 0 to |max_reporting_weight|, with one additional bucket for everything + // larger than that. The number of buckets is |num_reporting_weight_buckets|. + double max_reporting_weight = 99.; + + // Number of buckets that we'll use to split out the confusion matrix by + // training weight. The last one is reserved for "all", while the others are + // split evenly from 0 to |max_reporting_weight|, inclusive. One can select + // up to 15 buckets. We use 11 by default, so it breaks up the default weight + // into buckets of size 10. + // + // In other words, the defaults will make these buckets: + // [0-9] [10-19] ... [90-99] [100 and up]. This makes sense if the training + // set maximum size is the default of 100, and each example has a weight of 1. + int num_reporting_weight_buckets = 11; }; } // namespace learning diff --git a/chromium/media/learning/common/learning_task_controller.h b/chromium/media/learning/common/learning_task_controller.h index 65cd91f6d44..1e224bde59e 100644 --- a/chromium/media/learning/common/learning_task_controller.h +++ b/chromium/media/learning/common/learning_task_controller.h @@ -8,6 +8,7 @@ #include "base/callback.h" #include "base/component_export.h" #include "base/macros.h" +#include "base/unguessable_token.h" #include "media/learning/common/labelled_example.h" #include "media/learning/common/learning_task.h" @@ -40,31 +41,27 @@ class COMPONENT_EXPORT(LEARNING_COMMON) LearningTaskController { LearningTaskController() = default; virtual ~LearningTaskController() = default; - // TODO(liberato): what is the scope of this id? can it be local to whoever - // owns the LTC? otherwise, consider making it an unguessable token. - // TODO(liberato): consider making a special id that means "i will not send a - // target value", to save a call to CancelObservation. - using ObservationId = int32_t; - - // new example. Call this at the time one would try to predict the - // TargetValue. This lets the framework snapshot any framework-provided + // Start a new observation. Call this at the time one would try to predict + // the TargetValue. This lets the framework snapshot any framework-provided // feature values at prediction time. Later, if you want to turn these - // features into an example for training a model, then call the returned CB - // with the TargetValue and weight. Otherwise, you may discard the CB. + // features into an example for training a model, then call + // CompleteObservation with the same id and an ObservationCompletion. + // Otherwise, call CancelObservation with |id|. It's also okay to destroy the + // controller with outstanding observations; these will be cancelled. // TODO(liberato): This should optionally take a callback to receive a // prediction for the FeatureVector. // TODO(liberato): See if this ends up generating smaller code with pass-by- // value or with |FeatureVector&&|, once we have callers that can actually // benefit from it. - virtual void BeginObservation(ObservationId id, + virtual void BeginObservation(base::UnguessableToken id, const FeatureVector& features) = 0; // Complete an observation by sending a completion. - virtual void CompleteObservation(ObservationId id, + virtual void CompleteObservation(base::UnguessableToken id, const ObservationCompletion& completion) = 0; - // Notify the LTC that no completion will be sent. - virtual void CancelObservation(ObservationId id) = 0; + // Notify the LearningTaskController that no completion will be sent. + virtual void CancelObservation(base::UnguessableToken id) = 0; private: DISALLOW_COPY_AND_ASSIGN(LearningTaskController); diff --git a/chromium/media/learning/common/value.cc b/chromium/media/learning/common/value.cc index 12ea399d24c..caf94e0eb0e 100644 --- a/chromium/media/learning/common/value.cc +++ b/chromium/media/learning/common/value.cc @@ -4,7 +4,7 @@ #include "media/learning/common/value.h" -#include "base/hash.h" +#include "base/hash/hash.h" namespace media { namespace learning { diff --git a/chromium/media/learning/impl/BUILD.gn b/chromium/media/learning/impl/BUILD.gn index 638014f0e39..27ed8c2dd8b 100644 --- a/chromium/media/learning/impl/BUILD.gn +++ b/chromium/media/learning/impl/BUILD.gn @@ -2,10 +2,13 @@ # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. +import("//testing/libfuzzer/fuzzer_test.gni") + component("impl") { output_name = "learning_impl" visibility = [ "//media/learning/impl:unit_tests", + "//media/learning/impl:learning_fuzzer", # Actual clients. "//content/browser", @@ -35,8 +38,8 @@ component("impl") { "random_number_generator.h", "random_tree_trainer.cc", "random_tree_trainer.h", - "target_distribution.cc", - "target_distribution.h", + "target_histogram.cc", + "target_histogram.h", "training_algorithm.h", "voting_ensemble.cc", "voting_ensemble.h", @@ -69,7 +72,7 @@ source_set("unit_tests") { "one_hot_unittest.cc", "random_number_generator_unittest.cc", "random_tree_trainer_unittest.cc", - "target_distribution_unittest.cc", + "target_histogram_unittest.cc", "test_random_number_generator.cc", "test_random_number_generator.h", ] @@ -82,3 +85,14 @@ source_set("unit_tests") { "//testing/gtest", ] } + +fuzzer_test("learning_fuzzer") { + sources = [ + "learning_fuzzertest.cc", + ] + deps = [ + ":impl", + "//base", + "//base/test:test_support", + ] +} diff --git a/chromium/media/learning/impl/distribution_reporter.cc b/chromium/media/learning/impl/distribution_reporter.cc index e657ed3d9a2..a42f2c504a9 100644 --- a/chromium/media/learning/impl/distribution_reporter.cc +++ b/chromium/media/learning/impl/distribution_reporter.cc @@ -10,70 +10,170 @@ namespace media { namespace learning { -// Low order bit is "observed", second bit is "predicted", third bit is "could -// not make a prediction". +// UMA histogram base names. +static const char* kAggregateBase = "Media.Learning.BinaryThreshold.Aggregate."; +static const char* kByTrainingWeightBase = + "Media.Learning.BinaryThreshold.ByTrainingWeight."; +static const char* kByFeatureBase = "Media.Learning.BinaryThreshold.ByFeature."; + +enum /* not class */ Bits { + // These are meant to be bitwise-or'd together, so both false cases just mean + // "don't set any bits". + PredictedFalse = 0x00, + ObservedFalse = 0x00, + ObservedTrue = 0x01, + PredictedTrue = 0x02, + // Special value to mean that no prediction was made. + PredictedNothing = 0x04, +}; + +// Low order bit is "observed", second bit is "predicted", third bit is +// "could not make a prediction". enum class ConfusionMatrix { - TrueNegative = 0, // predicted == observed == false - FalseNegative = 1, // predicted == false, observed == true - FalsePositive = 2, // predicted == true, observed == false - TruePositive = 3, // predicted == observed == true - SkippedNegative = 4, // predicted == N/A, observed == false - SkippedPositive = 5, // predicted == N/A, observed == true + TrueNegative = Bits::PredictedFalse | Bits::ObservedFalse, + FalseNegative = Bits::PredictedFalse | Bits::ObservedTrue, + FalsePositive = Bits::PredictedTrue | Bits::ObservedFalse, + TruePositive = Bits::PredictedTrue | Bits::ObservedTrue, + SkippedNegative = Bits::PredictedNothing | Bits::ObservedFalse, + SkippedPositive = Bits::PredictedNothing | Bits::ObservedTrue, kMaxValue = SkippedPositive }; // TODO(liberato): Currently, this implementation is a hack to collect some // sanity-checking data for local learning with MediaCapabilities. We assume // that the prediction is the "percentage of dropped frames". -// -// Please see https://chromium-review.googlesource.com/c/chromium/src/+/1385107 -// for an actual UKM-based implementation. -class RegressionReporter : public DistributionReporter { +class UmaRegressionReporter : public DistributionReporter { public: - RegressionReporter(const LearningTask& task) : DistributionReporter(task) {} + UmaRegressionReporter(const LearningTask& task) + : DistributionReporter(task) {} - void OnPrediction(TargetDistribution observed, - TargetDistribution predicted) override { + void OnPrediction(const PredictionInfo& info, + TargetHistogram predicted) override { DCHECK_EQ(task().target_description.ordering, LearningTask::Ordering::kNumeric); - DCHECK(!task().uma_hacky_confusion_matrix.empty()); // As a complete hack, record accuracy with a fixed threshold. The average // is the observed / predicted percentage of dropped frames. - bool observed_smooth = observed.Average() <= task().smoothness_threshold; + bool observed_smooth = info.observed.value() <= task().smoothness_threshold; // See if we made a prediction. - int predicted_bits = 4; // N/A + int prediction_bit_mask = Bits::PredictedNothing; if (predicted.total_counts() != 0) { bool predicted_smooth = predicted.Average() <= task().smoothness_threshold; DVLOG(2) << "Learning: " << task().name << ": predicted: " << predicted_smooth << " (" << predicted.Average() << ") observed: " << observed_smooth - << " (" << observed.Average() << ")"; - predicted_bits = predicted_smooth ? 2 : 0; + << " (" << info.observed.value() << ")"; + prediction_bit_mask = + predicted_smooth ? Bits::PredictedTrue : Bits::PredictedFalse; } else { DVLOG(2) << "Learning: " << task().name << ": predicted: N/A observed: " << observed_smooth << " (" - << observed.Average() << ")"; + << info.observed.value() << ")"; } - // Convert to a bucket from which we can get the confusion matrix. - ConfusionMatrix uma_bucket = static_cast<ConfusionMatrix>( - (observed_smooth ? 1 : 0) | predicted_bits); - base::UmaHistogramEnumeration(task().uma_hacky_confusion_matrix, - uma_bucket); + // Figure out the ConfusionMatrix enum value. + ConfusionMatrix confusion_matrix_value = static_cast<ConfusionMatrix>( + (observed_smooth ? Bits::ObservedTrue : Bits::ObservedFalse) | + prediction_bit_mask); + + // |uma_bucket_number| is the bucket number that we'll fill in with this + // count. It ranges from 0 to |max_buckets-1|, inclusive. Each bucket is + // is separated from the start of the previous bucket by |uma_bucket_size|. + int uma_bucket_number = 0; + constexpr int matrix_size = + static_cast<int>(ConfusionMatrix::kMaxValue) + 1; + + // The enum.xml entries separate the buckets by 10, to make it easy to see + // by inspection what bucket number we're in (e.g., x-axis position 23 is + // bucket 2 * 10 + PredictedTrue|ObservedTrue). The label in enum.xml for + // MegaConfusionMatrix also provides the bucket number for easy reading. + constexpr int uma_bucket_size = 10; + DCHECK_LE(matrix_size, uma_bucket_size); + + // Maximum number of buckets defined in enums.xml, numbered from 0. + constexpr int max_buckets = 16; + + // Sparse histograms can technically go past 100 exactly-stored elements, + // but we limit it anyway. Note that we don't care about |uma_bucket_size|, + // since it's a sparse histogram. Only |matrix_size| elements are used in + // each bucket. + DCHECK_LE(max_buckets * matrix_size, 100); + + // If we're splitting by feature, then record it and stop. The others + // aren't meaningful to record if we're using random feature subsets. + if (task().uma_hacky_by_feature_subset_confusion_matrix && + feature_indices() && feature_indices()->size() == 1) { + // The bucket number is just the feature number that was selected. + uma_bucket_number = + std::min(*feature_indices()->begin(), max_buckets - 1); + + std::string base(kByFeatureBase); + base::UmaHistogramSparse(base + task().name, + static_cast<int>(confusion_matrix_value) + + uma_bucket_number * uma_bucket_size); + + // Early return since no other measurements are meaningful when we're + // using feature subsets. + return; + } + + // If we're selecting a feature subset that's bigger than one but smaller + // than all of them, then we don't know how to report that. + if (feature_indices() && + feature_indices()->size() != task().feature_descriptions.size()) { + return; + } + + // Do normal reporting. + + // Record the aggregate confusion matrix. + if (task().uma_hacky_aggregate_confusion_matrix) { + std::string base(kAggregateBase); + base::UmaHistogramEnumeration(base + task().name, confusion_matrix_value); + } + + if (task().uma_hacky_by_training_weight_confusion_matrix) { + // Adjust |uma_bucket_offset| by the training weight, and store the + // results in that bucket in the ByTrainingWeight histogram. + // + // This will bucket from 0 in even steps, with the last bucket holding + // |max_reporting_weight+1| and everything above it. + + const int n_buckets = task().num_reporting_weight_buckets; + DCHECK_LE(n_buckets, max_buckets); + + // We use one fewer buckets, to save one for the overflow. Buckets are + // numbered from 0 to |n_buckets-1|, inclusive. In other words, when the + // training weight is equal to |max_reporting_weight|, we still want to + // be in bucket |n_buckets - 2|. That's why we add one to the max before + // we divide; only things over the max go into the last bucket. + uma_bucket_number = + std::min<int>((n_buckets - 1) * info.total_training_weight / + (task().max_reporting_weight + 1), + n_buckets - 1); + + std::string base(kByTrainingWeightBase); + base::UmaHistogramSparse(base + task().name, + static_cast<int>(confusion_matrix_value) + + uma_bucket_number * uma_bucket_size); + } } }; std::unique_ptr<DistributionReporter> DistributionReporter::Create( const LearningTask& task) { - // Hacky reporting is the only thing we know how to report. - if (task.uma_hacky_confusion_matrix.empty()) + // We only know how to report regression tasks right now. + if (task.target_description.ordering != LearningTask::Ordering::kNumeric) return nullptr; - if (task.target_description.ordering == LearningTask::Ordering::kNumeric) - return std::make_unique<RegressionReporter>(task); + if (task.uma_hacky_aggregate_confusion_matrix || + task.uma_hacky_by_training_weight_confusion_matrix || + task.uma_hacky_by_feature_subset_confusion_matrix) { + return std::make_unique<UmaRegressionReporter>(task); + } + return nullptr; } @@ -83,9 +183,14 @@ DistributionReporter::DistributionReporter(const LearningTask& task) DistributionReporter::~DistributionReporter() = default; Model::PredictionCB DistributionReporter::GetPredictionCallback( - TargetDistribution observed) { + const PredictionInfo& info) { return base::BindOnce(&DistributionReporter::OnPrediction, - weak_factory_.GetWeakPtr(), observed); + weak_factory_.GetWeakPtr(), info); +} + +void DistributionReporter::SetFeatureSubset( + const std::set<int>& feature_indices) { + feature_indices_ = feature_indices; } } // namespace learning diff --git a/chromium/media/learning/impl/distribution_reporter.h b/chromium/media/learning/impl/distribution_reporter.h index 78b22e65c93..99310aa0ed8 100644 --- a/chromium/media/learning/impl/distribution_reporter.h +++ b/chromium/media/learning/impl/distribution_reporter.h @@ -5,13 +5,16 @@ #ifndef MEDIA_LEARNING_IMPL_DISTRIBUTION_REPORTER_H_ #define MEDIA_LEARNING_IMPL_DISTRIBUTION_REPORTER_H_ +#include <set> + #include "base/callback.h" #include "base/component_export.h" #include "base/macros.h" #include "base/memory/weak_ptr.h" +#include "base/optional.h" #include "media/learning/common/learning_task.h" #include "media/learning/impl/model.h" -#include "media/learning/impl/target_distribution.h" +#include "media/learning/impl/target_histogram.h" namespace media { namespace learning { @@ -21,15 +24,39 @@ namespace learning { // specific learning task. class COMPONENT_EXPORT(LEARNING_IMPL) DistributionReporter { public: + // Extra information provided to the reporter for each prediction. + struct PredictionInfo { + // What value was observed? + TargetValue observed; + + // Total weight of the training data used to create this model. + double total_training_weight = 0.; + + // Total number of examples (unweighted) in the training set. + size_t total_training_examples = 0u; + + // TODO(liberato): Move the feature subset here. + }; + // Create a DistributionReporter that's suitable for |task|. static std::unique_ptr<DistributionReporter> Create(const LearningTask& task); virtual ~DistributionReporter(); - // Returns a prediction CB that will be compared to |observed|. |observed| is - // the total number of counts that we observed. + // Returns a prediction CB that will be compared to |prediction_info.observed| + // TODO(liberato): This is too complicated. Skip the callback and just call + // us with the predicted value. virtual Model::PredictionCB GetPredictionCallback( - TargetDistribution observed); + const PredictionInfo& prediction_info); + + // Set the subset of features that is being used to train the model. This is + // used for feature importance measuremnts. + // + // For example, sending in the set [0, 3, 7] would indicate that the model was + // trained with task().feature_descriptions[0, 3, 7] only. + // + // Note that UMA reporting only supports single feature subsets. + void SetFeatureSubset(const std::set<int>& feature_indices); protected: DistributionReporter(const LearningTask& task); @@ -37,12 +64,20 @@ class COMPONENT_EXPORT(LEARNING_IMPL) DistributionReporter { const LearningTask& task() const { return task_; } // Implemented by subclasses to report a prediction. - virtual void OnPrediction(TargetDistribution observed, - TargetDistribution predicted) = 0; + virtual void OnPrediction(const PredictionInfo& prediction_info, + TargetHistogram predicted) = 0; + + const base::Optional<std::set<int>>& feature_indices() const { + return feature_indices_; + } private: LearningTask task_; + // If provided, then these are the features that are used to train the model. + // Otherwise, we assume that all features are used. + base::Optional<std::set<int>> feature_indices_; + base::WeakPtrFactory<DistributionReporter> weak_factory_; DISALLOW_COPY_AND_ASSIGN(DistributionReporter); diff --git a/chromium/media/learning/impl/distribution_reporter_unittest.cc b/chromium/media/learning/impl/distribution_reporter_unittest.cc index 3815d5c4842..43679be9a14 100644 --- a/chromium/media/learning/impl/distribution_reporter_unittest.cc +++ b/chromium/media/learning/impl/distribution_reporter_unittest.cc @@ -16,6 +16,12 @@ namespace learning { class DistributionReporterTest : public testing::Test { public: + DistributionReporterTest() { + task_.name = "TaskName"; + // UMA reporting requires a numeric target. + task_.target_description.ordering = LearningTask::Ordering::kNumeric; + } + base::test::ScopedTaskEnvironment scoped_task_environment_; LearningTask task_; @@ -25,35 +31,27 @@ class DistributionReporterTest : public testing::Test { TEST_F(DistributionReporterTest, DistributionReporterDoesNotCrash) { // Make sure that we request some sort of reporting. - task_.target_description.ordering = LearningTask::Ordering::kNumeric; - task_.uma_hacky_confusion_matrix = "test"; + task_.uma_hacky_aggregate_confusion_matrix = true; reporter_ = DistributionReporter::Create(task_); EXPECT_NE(reporter_, nullptr); + // Observe an average of 2 / 3. + DistributionReporter::PredictionInfo info; + info.observed = TargetValue(2.0 / 3.0); + auto cb = reporter_->GetPredictionCallback(info); + + TargetHistogram predicted; const TargetValue Zero(0); const TargetValue One(1); - TargetDistribution observed; - // Observe an average of 2 / 3. - observed[Zero] = 100; - observed[One] = 200; - auto cb = reporter_->GetPredictionCallback(observed); - - TargetDistribution predicted; // Predict an average of 5 / 9. predicted[Zero] = 40; predicted[One] = 50; std::move(cb).Run(predicted); - - // TODO(liberato): When we switch to ukm, use a TestUkmRecorder to make sure - // that it fills in the right stuff. - // https://chromium-review.googlesource.com/c/chromium/src/+/1385107 . } -TEST_F(DistributionReporterTest, DistributionReporterNeedsUmaName) { +TEST_F(DistributionReporterTest, DistributionReporterMustBeRequested) { // Make sure that we don't get a reporter if we don't request any reporting. - task_.target_description.ordering = LearningTask::Ordering::kNumeric; - task_.uma_hacky_confusion_matrix = ""; reporter_ = DistributionReporter::Create(task_); EXPECT_EQ(reporter_, nullptr); } @@ -62,10 +60,28 @@ TEST_F(DistributionReporterTest, DistributionReporterHackyConfusionMatrixNeedsRegression) { // Hacky confusion matrix reporting only works with regression. task_.target_description.ordering = LearningTask::Ordering::kUnordered; - task_.uma_hacky_confusion_matrix = "test"; + task_.uma_hacky_aggregate_confusion_matrix = true; reporter_ = DistributionReporter::Create(task_); EXPECT_EQ(reporter_, nullptr); } +TEST_F(DistributionReporterTest, ProvidesAggregateReporter) { + task_.uma_hacky_aggregate_confusion_matrix = true; + reporter_ = DistributionReporter::Create(task_); + EXPECT_NE(reporter_, nullptr); +} + +TEST_F(DistributionReporterTest, ProvidesByTrainingWeightReporter) { + task_.uma_hacky_by_training_weight_confusion_matrix = true; + reporter_ = DistributionReporter::Create(task_); + EXPECT_NE(reporter_, nullptr); +} + +TEST_F(DistributionReporterTest, ProvidesByFeatureSubsetReporter) { + task_.uma_hacky_by_feature_subset_confusion_matrix = true; + reporter_ = DistributionReporter::Create(task_); + EXPECT_NE(reporter_, nullptr); +} + } // namespace learning } // namespace media diff --git a/chromium/media/learning/impl/extra_trees_trainer_unittest.cc b/chromium/media/learning/impl/extra_trees_trainer_unittest.cc index ad07000ba87..d9e18970ced 100644 --- a/chromium/media/learning/impl/extra_trees_trainer_unittest.cc +++ b/chromium/media/learning/impl/extra_trees_trainer_unittest.cc @@ -55,7 +55,7 @@ TEST_P(ExtraTreesTest, EmptyTrainingDataWorks) { TrainingData empty; auto model = Train(task_, empty); EXPECT_NE(model.get(), nullptr); - EXPECT_EQ(model->PredictDistribution(FeatureVector()), TargetDistribution()); + EXPECT_EQ(model->PredictDistribution(FeatureVector()), TargetHistogram()); } TEST_P(ExtraTreesTest, FisherIrisDataset) { @@ -67,8 +67,7 @@ TEST_P(ExtraTreesTest, FisherIrisDataset) { // Verify predictions on the training set, just for sanity. size_t num_correct = 0; for (const LabelledExample& example : training_data) { - TargetDistribution distribution = - model->PredictDistribution(example.features); + TargetHistogram distribution = model->PredictDistribution(example.features); TargetValue predicted_value; if (distribution.FindSingularMax(&predicted_value) && predicted_value == example.target_value) { @@ -102,8 +101,7 @@ TEST_P(ExtraTreesTest, WeightedTrainingSetIsSupported) { auto model = Train(task_, training_data); // The singular max should be example_1. - TargetDistribution distribution = - model->PredictDistribution(example_1.features); + TargetHistogram distribution = model->PredictDistribution(example_1.features); TargetValue predicted_value; EXPECT_TRUE(distribution.FindSingularMax(&predicted_value)); EXPECT_EQ(predicted_value, example_1.target_value); @@ -135,8 +133,7 @@ TEST_P(ExtraTreesTest, RegressionWorks) { auto model = Train(task_, training_data); // Make sure that the results are in the right range. - TargetDistribution distribution = - model->PredictDistribution(example_1.features); + TargetHistogram distribution = model->PredictDistribution(example_1.features); EXPECT_GT(distribution.Average(), example_1.target_value.value() * 0.95); EXPECT_LT(distribution.Average(), example_1.target_value.value() * 1.05); distribution = model->PredictDistribution(example_2.features); @@ -194,10 +191,10 @@ TEST_P(ExtraTreesTest, RegressionVsBinaryClassification) { // the data is separable, it probably should be exact. for (auto& r_example : r_examples) { const FeatureVector& fv = r_example.features; - TargetDistribution c_dist = c_model->PredictDistribution(fv); + TargetHistogram c_dist = c_model->PredictDistribution(fv); EXPECT_LE(c_dist.Average(), r_example.target_value.value() * 1.05); EXPECT_GE(c_dist.Average(), r_example.target_value.value() * 0.95); - TargetDistribution r_dist = r_model->PredictDistribution(fv); + TargetHistogram r_dist = r_model->PredictDistribution(fv); EXPECT_LE(r_dist.Average(), r_example.target_value.value() * 1.05); EXPECT_GE(r_dist.Average(), r_example.target_value.value() * 0.95); } diff --git a/chromium/media/learning/impl/learning_fuzzertest.cc b/chromium/media/learning/impl/learning_fuzzertest.cc new file mode 100644 index 00000000000..f886ffb9bef --- /dev/null +++ b/chromium/media/learning/impl/learning_fuzzertest.cc @@ -0,0 +1,74 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "base/test/fuzzed_data_provider.h" +#include "base/test/scoped_task_environment.h" +#include "media/learning/impl/learning_task_controller_impl.h" + +using media::learning::FeatureValue; +using media::learning::FeatureVector; +using media::learning::LearningTask; +using ValueDescription = media::learning::LearningTask::ValueDescription; +using media::learning::LearningTaskControllerImpl; +using media::learning::ObservationCompletion; +using media::learning::TargetValue; + +ValueDescription ConsumeValueDescription(base::FuzzedDataProvider* provider) { + ValueDescription desc; + desc.name = provider->ConsumeRandomLengthString(100); + desc.ordering = provider->ConsumeEnum<LearningTask::Ordering>(); + desc.privacy_mode = provider->ConsumeEnum<LearningTask::PrivacyMode>(); + return desc; +} + +double ConsumeDouble(base::FuzzedDataProvider* provider) { + std::vector<uint8_t> v = provider->ConsumeBytes(sizeof(double)); + if (v.size() == sizeof(double)) + return reinterpret_cast<double*>(v.data())[0]; + + return 0; +} + +FeatureVector ConsumeFeatureVector(base::FuzzedDataProvider* provider) { + FeatureVector features; + int n = provider->ConsumeIntegralInRange(0, 100); + while (n-- > 0) + features.push_back(FeatureValue(ConsumeDouble(provider))); + + return features; +} + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + base::test::ScopedTaskEnvironment scoped_task_environment; + base::FuzzedDataProvider provider(data, size); + + LearningTask task; + task.name = provider.ConsumeRandomLengthString(100); + task.model = provider.ConsumeEnum<LearningTask::Model>(); + task.use_one_hot_conversion = provider.ConsumeBool(); + task.uma_hacky_aggregate_confusion_matrix = provider.ConsumeBool(); + task.uma_hacky_by_training_weight_confusion_matrix = provider.ConsumeBool(); + task.uma_hacky_by_feature_subset_confusion_matrix = provider.ConsumeBool(); + int n_features = provider.ConsumeIntegralInRange(0, 100); + int subset_size = provider.ConsumeIntegralInRange<uint8_t>(0, n_features); + if (subset_size) + task.feature_subset_size = subset_size; + for (int i = 0; i < n_features; i++) + task.feature_descriptions.push_back(ConsumeValueDescription(&provider)); + task.target_description = ConsumeValueDescription(&provider); + + LearningTaskControllerImpl controller(task); + + // Build random examples. + while (provider.remaining_bytes() > 0) { + base::UnguessableToken id = base::UnguessableToken::Create(); + controller.BeginObservation(id, ConsumeFeatureVector(&provider)); + controller.CompleteObservation( + id, ObservationCompletion(TargetValue(ConsumeDouble(&provider)), + ConsumeDouble(&provider))); + scoped_task_environment.RunUntilIdle(); + } + + return 0; +} diff --git a/chromium/media/learning/impl/learning_session_impl.cc b/chromium/media/learning/impl/learning_session_impl.cc index 4ad7de46f74..39d7e30fe6d 100644 --- a/chromium/media/learning/impl/learning_session_impl.cc +++ b/chromium/media/learning/impl/learning_session_impl.cc @@ -4,6 +4,7 @@ #include "media/learning/impl/learning_session_impl.h" +#include <set> #include <utility> #include "base/bind.h" @@ -14,15 +15,74 @@ namespace media { namespace learning { -LearningSessionImpl::LearningSessionImpl() - : controller_factory_( - base::BindRepeating([](const LearningTask& task, - SequenceBoundFeatureProvider feature_provider) - -> std::unique_ptr<LearningTaskController> { - return std::make_unique<LearningTaskControllerImpl>( - task, DistributionReporter::Create(task), +// Allow multiple clients to own an LTC that points to the same underlying LTC. +// Since we don't own the LTC, we also keep track of in-flight observations and +// explicitly cancel them on destruction, since dropping an LTC implies that. +class WeakLearningTaskController : public LearningTaskController { + public: + WeakLearningTaskController( + base::WeakPtr<LearningSessionImpl> weak_session, + base::SequenceBound<LearningTaskController>* controller) + : weak_session_(std::move(weak_session)), controller_(controller) {} + + ~WeakLearningTaskController() override { + if (!weak_session_) + return; + + // Cancel any outstanding observations. + for (auto& id : outstanding_ids_) { + controller_->Post(FROM_HERE, &LearningTaskController::CancelObservation, + id); + } + } + + void BeginObservation(base::UnguessableToken id, + const FeatureVector& features) override { + if (!weak_session_) + return; + + outstanding_ids_.insert(id); + controller_->Post(FROM_HERE, &LearningTaskController::BeginObservation, id, + features); + } + + void CompleteObservation(base::UnguessableToken id, + const ObservationCompletion& completion) override { + if (!weak_session_) + return; + outstanding_ids_.erase(id); + controller_->Post(FROM_HERE, &LearningTaskController::CompleteObservation, + id, completion); + } + + void CancelObservation(base::UnguessableToken id) override { + if (!weak_session_) + return; + outstanding_ids_.erase(id); + controller_->Post(FROM_HERE, &LearningTaskController::CancelObservation, + id); + } + + base::WeakPtr<LearningSessionImpl> weak_session_; + base::SequenceBound<LearningTaskController>* controller_; + + // Set of ids that have been started but not completed / cancelled yet. + std::set<base::UnguessableToken> outstanding_ids_; +}; + +LearningSessionImpl::LearningSessionImpl( + scoped_refptr<base::SequencedTaskRunner> task_runner) + : task_runner_(std::move(task_runner)), + controller_factory_(base::BindRepeating( + [](scoped_refptr<base::SequencedTaskRunner> task_runner, + const LearningTask& task, + SequenceBoundFeatureProvider feature_provider) + -> base::SequenceBound<LearningTaskController> { + return base::SequenceBound<LearningTaskControllerImpl>( + task_runner, task, DistributionReporter::Create(task), std::move(feature_provider)); - })) {} + })), + weak_factory_(this) {} LearningSessionImpl::~LearningSessionImpl() = default; @@ -31,25 +91,25 @@ void LearningSessionImpl::SetTaskControllerFactoryCBForTesting( controller_factory_ = std::move(cb); } -void LearningSessionImpl::AddExample(const std::string& task_name, - const LabelledExample& example) { +std::unique_ptr<LearningTaskController> LearningSessionImpl::GetController( + const std::string& task_name) { auto iter = task_map_.find(task_name); - if (iter != task_map_.end()) { - // TODO(liberato): We shouldn't be adding examples. We should provide the - // LearningTaskController instead, although ownership gets a bit weird. - LearningTaskController::ObservationId id = 1; - iter->second->BeginObservation(id, example.features); - iter->second->CompleteObservation( - id, ObservationCompletion(example.target_value, example.weight)); - } + if (iter == task_map_.end()) + return nullptr; + + // If there were any way to replace / destroy a controller other than when we + // destroy |this|, then this wouldn't be such a good idea. + return std::make_unique<WeakLearningTaskController>( + weak_factory_.GetWeakPtr(), &iter->second); } void LearningSessionImpl::RegisterTask( const LearningTask& task, SequenceBoundFeatureProvider feature_provider) { DCHECK(task_map_.count(task.name) == 0); - task_map_.emplace(task.name, - controller_factory_.Run(task, std::move(feature_provider))); + task_map_.emplace( + task.name, + controller_factory_.Run(task_runner_, task, std::move(feature_provider))); } } // namespace learning diff --git a/chromium/media/learning/impl/learning_session_impl.h b/chromium/media/learning/impl/learning_session_impl.h index f4693dfe0a4..01402f5d163 100644 --- a/chromium/media/learning/impl/learning_session_impl.h +++ b/chromium/media/learning/impl/learning_session_impl.h @@ -8,6 +8,8 @@ #include <map> #include "base/component_export.h" +#include "base/memory/weak_ptr.h" +#include "base/sequenced_task_runner.h" #include "base/threading/sequence_bound.h" #include "media/learning/common/learning_session.h" #include "media/learning/common/learning_task_controller.h" @@ -21,19 +23,22 @@ namespace learning { class COMPONENT_EXPORT(LEARNING_IMPL) LearningSessionImpl : public LearningSession { public: - LearningSessionImpl(); + // We will create LearningTaskControllers that run on |task_runner|. + LearningSessionImpl(scoped_refptr<base::SequencedTaskRunner> task_runner); ~LearningSessionImpl() override; + // Create a SequenceBound controller for |task| on |task_runner|. using CreateTaskControllerCB = - base::RepeatingCallback<std::unique_ptr<LearningTaskController>( + base::RepeatingCallback<base::SequenceBound<LearningTaskController>( + scoped_refptr<base::SequencedTaskRunner>, const LearningTask&, SequenceBoundFeatureProvider)>; void SetTaskControllerFactoryCBForTesting(CreateTaskControllerCB cb); // LearningSession - void AddExample(const std::string& task_name, - const LabelledExample& example) override; + std::unique_ptr<LearningTaskController> GetController( + const std::string& task_name) override; // Registers |task|, so that calls to AddExample with |task.name| will work. // This will create a new controller for the task. @@ -42,12 +47,17 @@ class COMPONENT_EXPORT(LEARNING_IMPL) LearningSessionImpl SequenceBoundFeatureProvider()); private: + // Task runner on which we'll create controllers. + scoped_refptr<base::SequencedTaskRunner> task_runner_; + // [task_name] = task controller. using LearningTaskMap = - std::map<std::string, std::unique_ptr<LearningTaskController>>; + std::map<std::string, base::SequenceBound<LearningTaskController>>; LearningTaskMap task_map_; CreateTaskControllerCB controller_factory_; + + base::WeakPtrFactory<LearningSessionImpl> weak_factory_; }; } // namespace learning diff --git a/chromium/media/learning/impl/learning_session_impl_unittest.cc b/chromium/media/learning/impl/learning_session_impl_unittest.cc index 5c55d66dc6a..2cb878a19d6 100644 --- a/chromium/media/learning/impl/learning_session_impl_unittest.cc +++ b/chromium/media/learning/impl/learning_session_impl_unittest.cc @@ -18,11 +18,19 @@ namespace learning { class LearningSessionImplTest : public testing::Test { public: + class FakeLearningTaskController; + using ControllerVector = std::vector<FakeLearningTaskController*>; + using TaskRunnerVector = std::vector<base::SequencedTaskRunner*>; + class FakeLearningTaskController : public LearningTaskController { public: - FakeLearningTaskController(const LearningTask& task, + // Send ControllerVector* as void*, else it complains that args can't be + // forwarded. Adding base::Unretained() doesn't help. + FakeLearningTaskController(void* controllers, + const LearningTask& task, SequenceBoundFeatureProvider feature_provider) : feature_provider_(std::move(feature_provider)) { + static_cast<ControllerVector*>(controllers)->push_back(this); // As a complete hack, call the only public method on fp so that // we can verify that it was given to us by the session. if (!feature_provider_.is_null()) { @@ -32,13 +40,13 @@ class LearningSessionImplTest : public testing::Test { } } - void BeginObservation(ObservationId id, + void BeginObservation(base::UnguessableToken id, const FeatureVector& features) override { id_ = id; features_ = features; } - void CompleteObservation(ObservationId id, + void CompleteObservation(base::UnguessableToken id, const ObservationCompletion& completion) override { EXPECT_EQ(id_, id); example_.features = std::move(features_); @@ -46,12 +54,17 @@ class LearningSessionImplTest : public testing::Test { example_.weight = completion.weight; } - void CancelObservation(ObservationId id) override { ASSERT_TRUE(false); } + void CancelObservation(base::UnguessableToken id) override { + cancelled_id_ = id; + } SequenceBoundFeatureProvider feature_provider_; - ObservationId id_ = 0; + base::UnguessableToken id_; FeatureVector features_; LabelledExample example_; + + // Most recently cancelled id. + base::UnguessableToken cancelled_id_; }; class FakeFeatureProvider : public FeatureProvider { @@ -67,20 +80,21 @@ class LearningSessionImplTest : public testing::Test { bool* flag_ptr_ = nullptr; }; - using ControllerVector = std::vector<FakeLearningTaskController*>; - LearningSessionImplTest() { - session_ = std::make_unique<LearningSessionImpl>(); + task_runner_ = base::SequencedTaskRunnerHandle::Get(); + session_ = std::make_unique<LearningSessionImpl>(task_runner_); session_->SetTaskControllerFactoryCBForTesting(base::BindRepeating( - [](ControllerVector* controllers, const LearningTask& task, + [](ControllerVector* controllers, TaskRunnerVector* task_runners, + scoped_refptr<base::SequencedTaskRunner> task_runner, + const LearningTask& task, SequenceBoundFeatureProvider feature_provider) - -> std::unique_ptr<LearningTaskController> { - auto controller = std::make_unique<FakeLearningTaskController>( - task, std::move(feature_provider)); - controllers->push_back(controller.get()); - return controller; + -> base::SequenceBound<LearningTaskController> { + task_runners->push_back(task_runner.get()); + return base::SequenceBound<FakeLearningTaskController>( + task_runner, static_cast<void*>(controllers), task, + std::move(feature_provider)); }, - &task_controllers_)); + &task_controllers_, &task_runners_)); task_0_.name = "task_0"; task_1_.name = "task_1"; @@ -88,48 +102,108 @@ class LearningSessionImplTest : public testing::Test { base::test::ScopedTaskEnvironment scoped_task_environment_; + scoped_refptr<base::SequencedTaskRunner> task_runner_; + std::unique_ptr<LearningSessionImpl> session_; LearningTask task_0_; LearningTask task_1_; ControllerVector task_controllers_; + TaskRunnerVector task_runners_; }; TEST_F(LearningSessionImplTest, RegisteringTasksCreatesControllers) { EXPECT_EQ(task_controllers_.size(), 0u); + EXPECT_EQ(task_runners_.size(), 0u); + session_->RegisterTask(task_0_); + scoped_task_environment_.RunUntilIdle(); EXPECT_EQ(task_controllers_.size(), 1u); + EXPECT_EQ(task_runners_.size(), 1u); + EXPECT_EQ(task_runners_[0], task_runner_.get()); + session_->RegisterTask(task_1_); + scoped_task_environment_.RunUntilIdle(); EXPECT_EQ(task_controllers_.size(), 2u); + EXPECT_EQ(task_runners_.size(), 2u); + EXPECT_EQ(task_runners_[1], task_runner_.get()); } TEST_F(LearningSessionImplTest, ExamplesAreForwardedToCorrectTask) { session_->RegisterTask(task_0_); session_->RegisterTask(task_1_); + base::UnguessableToken id = base::UnguessableToken::Create(); + LabelledExample example_0({FeatureValue(123), FeatureValue(456)}, TargetValue(1234)); - session_->AddExample(task_0_.name, example_0); + std::unique_ptr<LearningTaskController> ltc_0 = + session_->GetController(task_0_.name); + ltc_0->BeginObservation(id, example_0.features); + ltc_0->CompleteObservation( + id, ObservationCompletion(example_0.target_value, example_0.weight)); LabelledExample example_1({FeatureValue(321), FeatureValue(654)}, TargetValue(4321)); - session_->AddExample(task_1_.name, example_1); + + std::unique_ptr<LearningTaskController> ltc_1 = + session_->GetController(task_1_.name); + ltc_1->BeginObservation(id, example_1.features); + ltc_1->CompleteObservation( + id, ObservationCompletion(example_1.target_value, example_1.weight)); + + scoped_task_environment_.RunUntilIdle(); EXPECT_EQ(task_controllers_[0]->example_, example_0); EXPECT_EQ(task_controllers_[1]->example_, example_1); } +TEST_F(LearningSessionImplTest, ControllerLifetimeScopedToSession) { + session_->RegisterTask(task_0_); + + std::unique_ptr<LearningTaskController> controller = + session_->GetController(task_0_.name); + + // Destroy the session. |controller| should still be usable, though it won't + // forward requests anymore. + session_.reset(); + scoped_task_environment_.RunUntilIdle(); + + // Should not crash. + controller->BeginObservation(base::UnguessableToken::Create(), + FeatureVector()); +} + TEST_F(LearningSessionImplTest, FeatureProviderIsForwarded) { // Verify that a FeatureProvider actually gets forwarded to the LTC. bool flag = false; - session_->RegisterTask(task_0_, - base::SequenceBound<FakeFeatureProvider>( - base::SequencedTaskRunnerHandle::Get(), &flag)); + session_->RegisterTask( + task_0_, base::SequenceBound<FakeFeatureProvider>(task_runner_, &flag)); scoped_task_environment_.RunUntilIdle(); // Registering the task should create a FakeLearningTaskController, which will // call AddFeatures on the fake FeatureProvider. EXPECT_TRUE(flag); } +TEST_F(LearningSessionImplTest, DestroyingControllerCancelsObservations) { + session_->RegisterTask(task_0_); + + std::unique_ptr<LearningTaskController> controller = + session_->GetController(task_0_.name); + scoped_task_environment_.RunUntilIdle(); + + // Start an observation and verify that it starts. + base::UnguessableToken id = base::UnguessableToken::Create(); + controller->BeginObservation(id, FeatureVector()); + scoped_task_environment_.RunUntilIdle(); + EXPECT_EQ(task_controllers_[0]->id_, id); + EXPECT_NE(task_controllers_[0]->cancelled_id_, id); + + // Should result in cancelling the observation. + controller.reset(); + scoped_task_environment_.RunUntilIdle(); + EXPECT_EQ(task_controllers_[0]->cancelled_id_, id); +} + } // namespace learning } // namespace media diff --git a/chromium/media/learning/impl/learning_task_controller_helper.cc b/chromium/media/learning/impl/learning_task_controller_helper.cc index d4d57b43ebe..01c8972e67b 100644 --- a/chromium/media/learning/impl/learning_task_controller_helper.cc +++ b/chromium/media/learning/impl/learning_task_controller_helper.cc @@ -24,7 +24,7 @@ LearningTaskControllerHelper::LearningTaskControllerHelper( LearningTaskControllerHelper::~LearningTaskControllerHelper() = default; -void LearningTaskControllerHelper::BeginObservation(ObservationId id, +void LearningTaskControllerHelper::BeginObservation(base::UnguessableToken id, FeatureVector features) { auto& pending_example = pending_examples_[id]; @@ -41,10 +41,11 @@ void LearningTaskControllerHelper::BeginObservation(ObservationId id, } void LearningTaskControllerHelper::CompleteObservation( - ObservationId id, + base::UnguessableToken id, const ObservationCompletion& completion) { auto iter = pending_examples_.find(id); - DCHECK(iter != pending_examples_.end()); + if (iter == pending_examples_.end()) + return; iter->second.example.target_value = completion.target_value; iter->second.example.weight = completion.weight; @@ -52,10 +53,11 @@ void LearningTaskControllerHelper::CompleteObservation( ProcessExampleIfFinished(std::move(iter)); } -void LearningTaskControllerHelper::CancelObservation(ObservationId id) { +void LearningTaskControllerHelper::CancelObservation( + base::UnguessableToken id) { auto iter = pending_examples_.find(id); - // If the example has already been completed, then we shouldn't be called. - DCHECK(iter != pending_examples_.end()); + if (iter == pending_examples_.end()) + return; // This would have to check for pending predictions, if we supported them, and // defer destruction until the features arrive. @@ -66,7 +68,7 @@ void LearningTaskControllerHelper::CancelObservation(ObservationId id) { void LearningTaskControllerHelper::OnFeaturesReadyTrampoline( scoped_refptr<base::SequencedTaskRunner> task_runner, base::WeakPtr<LearningTaskControllerHelper> weak_this, - ObservationId id, + base::UnguessableToken id, FeatureVector features) { // TODO(liberato): this would benefit from promises / deferred data. auto cb = base::BindOnce(&LearningTaskControllerHelper::OnFeaturesReady, @@ -78,7 +80,7 @@ void LearningTaskControllerHelper::OnFeaturesReadyTrampoline( } } -void LearningTaskControllerHelper::OnFeaturesReady(ObservationId id, +void LearningTaskControllerHelper::OnFeaturesReady(base::UnguessableToken id, FeatureVector features) { PendingExampleMap::iterator iter = pending_examples_.find(id); // It's possible that OnLabelCallbackDestroyed has already run. That's okay diff --git a/chromium/media/learning/impl/learning_task_controller_helper.h b/chromium/media/learning/impl/learning_task_controller_helper.h index 9370536d8ce..7e9bb6f5a7f 100644 --- a/chromium/media/learning/impl/learning_task_controller_helper.h +++ b/chromium/media/learning/impl/learning_task_controller_helper.h @@ -37,9 +37,6 @@ class COMPONENT_EXPORT(LEARNING_IMPL) LearningTaskControllerHelper // Callback to add labelled examples as training data. using AddExampleCB = base::RepeatingCallback<void(LabelledExample)>; - // Convenience. - using ObservationId = LearningTaskController::ObservationId; - // TODO(liberato): Consider making the FP not optional. LearningTaskControllerHelper(const LearningTask& task, AddExampleCB add_example_cb, @@ -48,10 +45,10 @@ class COMPONENT_EXPORT(LEARNING_IMPL) LearningTaskControllerHelper virtual ~LearningTaskControllerHelper(); // See LearningTaskController::BeginObservation. - void BeginObservation(ObservationId id, FeatureVector features); - void CompleteObservation(ObservationId id, + void BeginObservation(base::UnguessableToken id, FeatureVector features); + void CompleteObservation(base::UnguessableToken id, const ObservationCompletion& completion); - void CancelObservation(ObservationId id); + void CancelObservation(base::UnguessableToken id); private: // Record of an example that has been started by RecordObservedFeatures, but @@ -67,19 +64,20 @@ class COMPONENT_EXPORT(LEARNING_IMPL) LearningTaskControllerHelper }; // [non-repeating int] = example - using PendingExampleMap = std::map<ObservationId, PendingExample>; + using PendingExampleMap = std::map<base::UnguessableToken, PendingExample>; // Called on any sequence when features are ready. Will call OnFeatureReady // if called on |task_runner|, or will post to |task_runner|. static void OnFeaturesReadyTrampoline( scoped_refptr<base::SequencedTaskRunner> task_runner, base::WeakPtr<LearningTaskControllerHelper> weak_this, - ObservationId id, + base::UnguessableToken id, FeatureVector features); // Called when a new feature vector has been finished by |feature_provider_|, // if needed, to actually add the example. - void OnFeaturesReady(ObservationId example_id, FeatureVector features); + void OnFeaturesReady(base::UnguessableToken example_id, + FeatureVector features); // If |example| is finished, then send it to the LearningSession and remove it // from the map. Otherwise, do nothing. diff --git a/chromium/media/learning/impl/learning_task_controller_helper_unittest.cc b/chromium/media/learning/impl/learning_task_controller_helper_unittest.cc index 806c43c05cc..de756f0cf48 100644 --- a/chromium/media/learning/impl/learning_task_controller_helper_unittest.cc +++ b/chromium/media/learning/impl/learning_task_controller_helper_unittest.cc @@ -43,6 +43,8 @@ class LearningTaskControllerHelperTest : public testing::Test { example_.features.push_back(FeatureValue(3)); example_.target_value = TargetValue(123); example_.weight = 100u; + + id_ = base::UnguessableToken::Create(); } void CreateClient(bool include_fp) { @@ -87,7 +89,7 @@ class LearningTaskControllerHelperTest : public testing::Test { LearningTask task_; - LearningTaskController::ObservationId id_ = 1; + base::UnguessableToken id_; LabelledExample example_; }; diff --git a/chromium/media/learning/impl/learning_task_controller_impl.cc b/chromium/media/learning/impl/learning_task_controller_impl.cc index 84092c939da..544218aaef6 100644 --- a/chromium/media/learning/impl/learning_task_controller_impl.cc +++ b/chromium/media/learning/impl/learning_task_controller_impl.cc @@ -6,8 +6,10 @@ #include <memory> #include <utility> +#include <vector> #include "base/bind.h" +#include "media/learning/impl/distribution_reporter.h" #include "media/learning/impl/extra_trees_trainer.h" #include "media/learning/impl/lookup_table_trainer.h" @@ -25,7 +27,14 @@ LearningTaskControllerImpl::LearningTaskControllerImpl( task, base::BindRepeating(&LearningTaskControllerImpl::AddFinishedExample, AsWeakPtr()), - std::move(feature_provider))) { + std::move(feature_provider))), + expected_feature_count_(task_.feature_descriptions.size()) { + // Note that |helper_| uses the full set of features. + + // TODO(liberato): Make this compositional. FeatureSubsetTaskController? + if (task_.feature_subset_size) + DoFeatureSubsetSelection(); + switch (task_.model) { case LearningTask::Model::kExtraTrees: trainer_ = std::make_unique<ExtraTreesTrainer>(); @@ -39,22 +48,51 @@ LearningTaskControllerImpl::LearningTaskControllerImpl( LearningTaskControllerImpl::~LearningTaskControllerImpl() = default; void LearningTaskControllerImpl::BeginObservation( - ObservationId id, + base::UnguessableToken id, const FeatureVector& features) { + // TODO(liberato): Should we enforce that the right number of features are + // present here? Right now, we allow it to be shorter, so that features from + // a FeatureProvider may be omitted. Of course, they have to be at the end in + // that case. If we start enforcing it here, make sure that LearningHelper + // starts adding the placeholder features. + if (!trainer_) + return; + helper_->BeginObservation(id, features); } void LearningTaskControllerImpl::CompleteObservation( - ObservationId id, + base::UnguessableToken id, const ObservationCompletion& completion) { + if (!trainer_) + return; helper_->CompleteObservation(id, completion); } -void LearningTaskControllerImpl::CancelObservation(ObservationId id) { +void LearningTaskControllerImpl::CancelObservation(base::UnguessableToken id) { + if (!trainer_) + return; helper_->CancelObservation(id); } void LearningTaskControllerImpl::AddFinishedExample(LabelledExample example) { + // Verify that we have a trainer and that we got the right number of features. + // We don't compare to |task_.feature_descriptions.size()| since that has been + // adjusted to the subset size already. We expect the original count. + if (!trainer_ || example.features.size() != expected_feature_count_) + return; + + // Now that we have the whole set of features, select the subset we want. + FeatureVector new_features; + if (task_.feature_subset_size) { + for (auto& iter : feature_indices_) + new_features.push_back(example.features[iter]); + example.features = std::move(new_features); + } // else use them all. + + // The features should now match the task. + DCHECK_EQ(example.features.size(), task_.feature_descriptions.size()); + if (training_data_->size() >= task_.max_data_set_size) { // Replace a random example. We don't necessarily want to replace the // oldest, since we don't necessarily want to enforce an ad-hoc recency @@ -68,12 +106,13 @@ void LearningTaskControllerImpl::AddFinishedExample(LabelledExample example) { // Once we have a model, see if we'd get |example| correct. if (model_ && reporter_) { - TargetDistribution predicted = - model_->PredictDistribution(example.features); + TargetHistogram predicted = model_->PredictDistribution(example.features); - TargetDistribution observed; - observed += example.target_value; - reporter_->GetPredictionCallback(observed).Run(predicted); + DistributionReporter::PredictionInfo info; + info.observed = example.target_value; + info.total_training_weight = last_training_weight_; + info.total_training_examples = last_training_size_; + reporter_->GetPredictionCallback(info).Run(predicted); } // Can't train more than one model concurrently. @@ -89,7 +128,8 @@ void LearningTaskControllerImpl::AddFinishedExample(LabelledExample example) { num_untrained_examples_ = 0; TrainedModelCB model_cb = - base::BindOnce(&LearningTaskControllerImpl::OnModelTrained, AsWeakPtr()); + base::BindOnce(&LearningTaskControllerImpl::OnModelTrained, AsWeakPtr(), + training_data_->total_weight(), training_data_->size()); training_is_in_progress_ = true; // Note that this copies the training data, so it's okay if we add more // examples to our copy before this returns. @@ -99,10 +139,15 @@ void LearningTaskControllerImpl::AddFinishedExample(LabelledExample example) { trainer_->Train(task_, *training_data_, std::move(model_cb)); } -void LearningTaskControllerImpl::OnModelTrained(std::unique_ptr<Model> model) { +void LearningTaskControllerImpl::OnModelTrained(double training_weight, + int training_size, + std::unique_ptr<Model> model) { DCHECK(training_is_in_progress_); training_is_in_progress_ = false; model_ = std::move(model); + // Record these for metrics. + last_training_weight_ = training_weight; + last_training_size_ = training_size; } void LearningTaskControllerImpl::SetTrainerForTesting( @@ -110,5 +155,39 @@ void LearningTaskControllerImpl::SetTrainerForTesting( trainer_ = std::move(trainer); } +void LearningTaskControllerImpl::DoFeatureSubsetSelection() { + // Choose a random feature, and trim the descriptions to match. + std::vector<size_t> features; + for (size_t i = 0; i < task_.feature_descriptions.size(); i++) + features.push_back(i); + + for (int i = 0; i < *task_.feature_subset_size; i++) { + // Pick an element from |i| to the end of the list, inclusive. + // TODO(liberato): For tests, this will happen before any rng is provided + // by the test; we'll use an actual rng. + int r = rng()->Generate(features.size() - i) + i; + // Swap them. + std::swap(features[i], features[r]); + } + + // Construct the feature subset from the first few elements. Also adjust the + // task's descriptions to match. We do this in two steps so that the + // descriptions are added via iterating over |feature_indices_|, so that the + // enumeration order is the same as when we adjust the feature values of + // incoming examples. In both cases, we iterate over |feature_indicies_|, + // which might (will) re-order them with respect to |features|. + for (int i = 0; i < *task_.feature_subset_size; i++) + feature_indices_.insert(features[i]); + + std::vector<LearningTask::ValueDescription> adjusted_descriptions; + for (auto& iter : feature_indices_) + adjusted_descriptions.push_back(task_.feature_descriptions[iter]); + + task_.feature_descriptions = adjusted_descriptions; + + if (reporter_) + reporter_->SetFeatureSubset(feature_indices_); +} + } // namespace learning } // namespace media diff --git a/chromium/media/learning/impl/learning_task_controller_impl.h b/chromium/media/learning/impl/learning_task_controller_impl.h index 2866da17f81..eae031be307 100644 --- a/chromium/media/learning/impl/learning_task_controller_impl.h +++ b/chromium/media/learning/impl/learning_task_controller_impl.h @@ -6,6 +6,7 @@ #define MEDIA_LEARNING_IMPL_LEARNING_TASK_CONTROLLER_IMPL_H_ #include <memory> +#include <set> #include "base/callback.h" #include "base/component_export.h" @@ -20,6 +21,7 @@ namespace media { namespace learning { +class DistributionReporter; class LearningTaskControllerImplTest; // Controller for a single learning task. Takes training examples, and forwards @@ -44,21 +46,27 @@ class COMPONENT_EXPORT(LEARNING_IMPL) LearningTaskControllerImpl ~LearningTaskControllerImpl() override; // LearningTaskController - void BeginObservation(ObservationId id, + void BeginObservation(base::UnguessableToken id, const FeatureVector& features) override; - void CompleteObservation(ObservationId id, + void CompleteObservation(base::UnguessableToken id, const ObservationCompletion& completion) override; - void CancelObservation(ObservationId id) override; + void CancelObservation(base::UnguessableToken id) override; private: // Add |example| to the training data, and process it. void AddFinishedExample(LabelledExample example); - // Called by |training_cb_| when the model is trained. - void OnModelTrained(std::unique_ptr<Model> model); + // Called by |training_cb_| when the model is trained. |training_weight| and + // |training_size| are the training set's total weight and number of examples. + void OnModelTrained(double training_weight, + int training_size, + std::unique_ptr<Model> model); void SetTrainerForTesting(std::unique_ptr<TrainingAlgorithm> trainer); + // Update |task_| to reflect a randomly chosen subset of features. + void DoFeatureSubsetSelection(); + LearningTask task_; // Current batch of examples. @@ -74,6 +82,10 @@ class COMPONENT_EXPORT(LEARNING_IMPL) LearningTaskControllerImpl // This helps us decide when to train a new model. int num_untrained_examples_ = 0; + // Total weight and number of examples in the most recently trained model. + double last_training_weight_ = 0.; + size_t last_training_size_ = 0u; + // Training algorithm that we'll use. std::unique_ptr<TrainingAlgorithm> trainer_; @@ -83,6 +95,13 @@ class COMPONENT_EXPORT(LEARNING_IMPL) LearningTaskControllerImpl // Helper that we use to handle deferred examples. std::unique_ptr<LearningTaskControllerHelper> helper_; + // If the task specifies feature importance measurement, then this is the + // randomly chosen subset of features. + std::set<int> feature_indices_; + + // Number of features that we expect in each observation. + size_t expected_feature_count_; + friend class LearningTaskControllerImplTest; }; diff --git a/chromium/media/learning/impl/learning_task_controller_impl_unittest.cc b/chromium/media/learning/impl/learning_task_controller_impl_unittest.cc index 7433919a62c..e1620c0858d 100644 --- a/chromium/media/learning/impl/learning_task_controller_impl_unittest.cc +++ b/chromium/media/learning/impl/learning_task_controller_impl_unittest.cc @@ -22,11 +22,18 @@ class LearningTaskControllerImplTest : public testing::Test { FakeDistributionReporter(const LearningTask& task) : DistributionReporter(task) {} + // protected => public + const base::Optional<std::set<int>>& feature_indices() const { + return DistributionReporter::feature_indices(); + } + protected: - void OnPrediction(TargetDistribution observed, - TargetDistribution predicted) override { + void OnPrediction(const PredictionInfo& info, + TargetHistogram predicted) override { num_reported_++; - if (observed == predicted) + TargetHistogram dist; + dist += info.observed; + if (dist == predicted) num_correct_++; } @@ -41,9 +48,9 @@ class LearningTaskControllerImplTest : public testing::Test { FakeModel(TargetValue target) : target_(target) {} // Model - TargetDistribution PredictDistribution( + TargetHistogram PredictDistribution( const FeatureVector& features) override { - TargetDistribution dist; + TargetHistogram dist; dist += target_; return dist; } @@ -64,14 +71,18 @@ class LearningTaskControllerImplTest : public testing::Test { void Train(const LearningTask& task, const TrainingData& training_data, TrainedModelCB model_cb) override { + task_ = task; (*num_models_)++; training_data_ = training_data; std::move(model_cb).Run(std::make_unique<FakeModel>(target_value_)); } + const LearningTask& task() const { return task_; } + const TrainingData& training_data() const { return training_data_; } private: + LearningTask task_; int* num_models_ = nullptr; TargetValue target_value_; @@ -113,7 +124,7 @@ class LearningTaskControllerImplTest : public testing::Test { } void AddExample(const LabelledExample& example) { - LearningTaskController::ObservationId id = 1; + base::UnguessableToken id = base::UnguessableToken::Create(); controller_->BeginObservation(id, example.features); controller_->CompleteObservation( id, ObservationCompletion(example.target_value, example.weight)); @@ -123,7 +134,6 @@ class LearningTaskControllerImplTest : public testing::Test { // Number of models that we trained. int num_models_ = 0; - FakeModel* last_model_ = nullptr; // Two distinct targets. const TargetValue predicted_target_; @@ -182,6 +192,7 @@ TEST_F(LearningTaskControllerImplTest, AddingExamplesTrainsModelAndReports) { TEST_F(LearningTaskControllerImplTest, FeatureProviderIsUsed) { // If a FeatureProvider factory is provided, make sure that it's used to // adjust new examples. + task_.feature_descriptions.push_back({"AddedByFeatureProvider"}); SequenceBoundFeatureProvider feature_provider = base::SequenceBound<FakeFeatureProvider>( base::SequencedTaskRunnerHandle::Get()); @@ -195,5 +206,50 @@ TEST_F(LearningTaskControllerImplTest, FeatureProviderIsUsed) { EXPECT_EQ(trainer_raw_->training_data()[0].weight, example.weight); } +TEST_F(LearningTaskControllerImplTest, FeatureSubsetsWork) { + const char* feature_names[] = { + "feature0", "feature1", "feature2", "feature3", "feature4", "feature5", + "feature6", "feature7", "feature8", "feature9", "feature10", "feature11", + }; + const int num_features = sizeof(feature_names) / sizeof(feature_names[0]); + for (int i = 0; i < num_features; i++) + task_.feature_descriptions.push_back({feature_names[i]}); + const size_t subset_size = 4; + task_.feature_subset_size = subset_size; + CreateController(); + + // Verify that the reporter is given a subset of the features. + auto subset = *reporter_raw_->feature_indices(); + EXPECT_EQ(subset.size(), subset_size); + + // Train a model. Each feature will have a unique value. + LabelledExample example; + for (int i = 0; i < num_features; i++) + example.features.push_back(FeatureValue(i)); + AddExample(example); + + // Verify that all feature names in |subset| are present in the task. + FeatureVector expected_features; + expected_features.resize(subset_size); + EXPECT_EQ(trainer_raw_->task().feature_descriptions.size(), subset_size); + for (auto& iter : subset) { + bool found = false; + for (size_t i = 0; i < subset_size; i++) { + if (trainer_raw_->task().feature_descriptions[i].name == + feature_names[iter]) { + // Also build a vector with the features in the expected order. + expected_features[i] = example.features[iter]; + found = true; + break; + } + } + EXPECT_TRUE(found); + } + + // Verify that the training data has the adjusted features. + EXPECT_EQ(trainer_raw_->training_data().size(), 1u); + EXPECT_EQ(trainer_raw_->training_data()[0].features, expected_features); +} + } // namespace learning } // namespace media diff --git a/chromium/media/learning/impl/lookup_table_trainer.cc b/chromium/media/learning/impl/lookup_table_trainer.cc index 4c698c71022..57ebdbbc0f6 100644 --- a/chromium/media/learning/impl/lookup_table_trainer.cc +++ b/chromium/media/learning/impl/lookup_table_trainer.cc @@ -19,17 +19,16 @@ class LookupTable : public Model { } // Model - TargetDistribution PredictDistribution( - const FeatureVector& instance) override { + TargetHistogram PredictDistribution(const FeatureVector& instance) override { auto iter = buckets_.find(instance); if (iter == buckets_.end()) - return TargetDistribution(); + return TargetHistogram(); return iter->second; } private: - std::map<FeatureVector, TargetDistribution> buckets_; + std::map<FeatureVector, TargetHistogram> buckets_; }; LookupTableTrainer::LookupTableTrainer() = default; diff --git a/chromium/media/learning/impl/lookup_table_trainer_unittest.cc b/chromium/media/learning/impl/lookup_table_trainer_unittest.cc index 323d69d471e..47618746617 100644 --- a/chromium/media/learning/impl/lookup_table_trainer_unittest.cc +++ b/chromium/media/learning/impl/lookup_table_trainer_unittest.cc @@ -37,7 +37,7 @@ TEST_F(LookupTableTrainerTest, EmptyTrainingDataWorks) { TrainingData empty; std::unique_ptr<Model> model = Train(task_, empty); EXPECT_NE(model.get(), nullptr); - EXPECT_EQ(model->PredictDistribution(FeatureVector()), TargetDistribution()); + EXPECT_EQ(model->PredictDistribution(FeatureVector()), TargetHistogram()); } TEST_F(LookupTableTrainerTest, UniformTrainingDataWorks) { @@ -51,8 +51,7 @@ TEST_F(LookupTableTrainerTest, UniformTrainingDataWorks) { // The tree should produce a distribution for one value (our target), which // has |n_examples| counts. - TargetDistribution distribution = - model->PredictDistribution(example.features); + TargetHistogram distribution = model->PredictDistribution(example.features); EXPECT_EQ(distribution.size(), 1u); EXPECT_EQ(distribution[example.target_value], n_examples); } @@ -66,8 +65,7 @@ TEST_F(LookupTableTrainerTest, SimpleSeparableTrainingData) { std::unique_ptr<Model> model = Train(task_, training_data); // Each value should have a distribution with one target value with one count. - TargetDistribution distribution = - model->PredictDistribution(example_1.features); + TargetHistogram distribution = model->PredictDistribution(example_1.features); EXPECT_NE(model.get(), nullptr); EXPECT_EQ(distribution.size(), 1u); EXPECT_EQ(distribution[example_1.target_value], 1u); @@ -104,8 +102,7 @@ TEST_F(LookupTableTrainerTest, ComplexSeparableTrainingData) { // Each example should have a distribution that selects the right value. for (const auto& example : training_data) { - TargetDistribution distribution = - model->PredictDistribution(example.features); + TargetHistogram distribution = model->PredictDistribution(example.features); TargetValue singular_max; EXPECT_TRUE(distribution.FindSingularMax(&singular_max)); EXPECT_EQ(singular_max, example.target_value); @@ -122,8 +119,7 @@ TEST_F(LookupTableTrainerTest, UnseparableTrainingData) { EXPECT_NE(model.get(), nullptr); // Each value should have a distribution with two targets with one count each. - TargetDistribution distribution = - model->PredictDistribution(example_1.features); + TargetHistogram distribution = model->PredictDistribution(example_1.features); EXPECT_EQ(distribution.size(), 2u); EXPECT_EQ(distribution[example_1.target_value], 1u); EXPECT_EQ(distribution[example_2.target_value], 1u); @@ -143,7 +139,7 @@ TEST_F(LookupTableTrainerTest, UnknownFeatureValueHandling) { training_data.push_back(example_2); std::unique_ptr<Model> model = Train(task_, training_data); - TargetDistribution distribution = + TargetHistogram distribution = model->PredictDistribution(FeatureVector({FeatureValue(789)})); // OOV data should return an empty distribution (nominal). EXPECT_EQ(distribution.size(), 0u); @@ -160,7 +156,7 @@ TEST_F(LookupTableTrainerTest, RegressionWithWeightedExamplesWorks) { training_data.push_back(example_2); std::unique_ptr<Model> model = Train(task_, training_data); - TargetDistribution distribution = + TargetHistogram distribution = model->PredictDistribution(FeatureVector({FeatureValue(123)})); double avg = distribution.Average(); const double expected = diff --git a/chromium/media/learning/impl/model.h b/chromium/media/learning/impl/model.h index 0950b6c9713..42366875b65 100644 --- a/chromium/media/learning/impl/model.h +++ b/chromium/media/learning/impl/model.h @@ -8,7 +8,7 @@ #include "base/component_export.h" #include "media/learning/common/labelled_example.h" #include "media/learning/impl/model.h" -#include "media/learning/impl/target_distribution.h" +#include "media/learning/impl/target_histogram.h" namespace media { namespace learning { @@ -19,11 +19,11 @@ namespace learning { class COMPONENT_EXPORT(LEARNING_IMPL) Model { public: // Callback for asynchronous predictions. - using PredictionCB = base::OnceCallback<void(TargetDistribution predicted)>; + using PredictionCB = base::OnceCallback<void(TargetHistogram predicted)>; virtual ~Model() = default; - virtual TargetDistribution PredictDistribution( + virtual TargetHistogram PredictDistribution( const FeatureVector& instance) = 0; // TODO(liberato): Consider adding an async prediction helper. diff --git a/chromium/media/learning/impl/one_hot.cc b/chromium/media/learning/impl/one_hot.cc index b8dab81e142..c3e82855eda 100644 --- a/chromium/media/learning/impl/one_hot.cc +++ b/chromium/media/learning/impl/one_hot.cc @@ -111,7 +111,7 @@ ConvertingModel::ConvertingModel(std::unique_ptr<OneHotConverter> converter, : converter_(std::move(converter)), model_(std::move(model)) {} ConvertingModel::~ConvertingModel() = default; -TargetDistribution ConvertingModel::PredictDistribution( +TargetHistogram ConvertingModel::PredictDistribution( const FeatureVector& instance) { FeatureVector converted_instance = converter_->Convert(instance); return model_->PredictDistribution(converted_instance); diff --git a/chromium/media/learning/impl/one_hot.h b/chromium/media/learning/impl/one_hot.h index 0a3f479b721..b48b66ec924 100644 --- a/chromium/media/learning/impl/one_hot.h +++ b/chromium/media/learning/impl/one_hot.h @@ -67,8 +67,7 @@ class COMPONENT_EXPORT(LEARNING_IMPL) ConvertingModel : public Model { ~ConvertingModel() override; // Model - TargetDistribution PredictDistribution( - const FeatureVector& instance) override; + TargetHistogram PredictDistribution(const FeatureVector& instance) override; private: std::unique_ptr<OneHotConverter> converter_; diff --git a/chromium/media/learning/impl/random_tree_trainer.cc b/chromium/media/learning/impl/random_tree_trainer.cc index 88e1b97dd19..c45b757716e 100644 --- a/chromium/media/learning/impl/random_tree_trainer.cc +++ b/chromium/media/learning/impl/random_tree_trainer.cc @@ -40,8 +40,7 @@ struct InteriorNode : public Model { split_point_(split_point) {} // Model - TargetDistribution PredictDistribution( - const FeatureVector& features) override { + TargetHistogram PredictDistribution(const FeatureVector& features) override { // Figure out what feature value we should use for the split. FeatureValue f; switch (ordering_) { @@ -59,16 +58,16 @@ struct InteriorNode : public Model { // If we've never seen this feature value, then return nothing. if (iter == children_.end()) - return TargetDistribution(); + return TargetHistogram(); return iter->second->PredictDistribution(features); } - TargetDistribution PredictDistributionWithMissingValues( + TargetHistogram PredictDistributionWithMissingValues( const FeatureVector& features) { - TargetDistribution total; + TargetHistogram total; for (auto& child_pair : children_) { - TargetDistribution predicted = + TargetHistogram predicted = child_pair.second->PredictDistribution(features); // TODO(liberato): Normalize? Weight? total += predicted; @@ -102,19 +101,26 @@ struct LeafNode : public Model { for (size_t idx : training_idx) distribution_ += training_data[idx]; - // Note that we don't treat numeric targets any differently. We want to - // weight the leaf by the number of examples, so replacing it with an - // average would just introduce rounding errors. One might as well take the - // average of the final distribution. + // Each leaf gets one vote. + // See https://en.wikipedia.org/wiki/Bootstrap_aggregating . TL;DR: the + // individual trees should average (regression) or vote (classification). + // + // TODO(liberato): It's unclear that a leaf should get to vote with an + // entire distribution; we might want to take the max for kUnordered here. + // If so, then we might also want to Average() for kNumeric targets, though + // in that case, the results would be the same anyway. That's not, of + // course, guaranteed for all methods of converting |distribution_| into a + // numeric prediction. In general, we should provide a single estimate. + distribution_.Normalize(); } // TreeNode - TargetDistribution PredictDistribution(const FeatureVector&) override { + TargetHistogram PredictDistribution(const FeatureVector&) override { return distribution_; } private: - TargetDistribution distribution_; + TargetHistogram distribution_; }; RandomTreeTrainer::RandomTreeTrainer(RandomNumberGenerator* rng) @@ -297,7 +303,8 @@ RandomTreeTrainer::Split RandomTreeTrainer::ConstructSplit( // Find the split's feature values and construct the training set for each. // I think we want to iterate on the underlying vector, and look up the int in - // the training data directly. + // the training data directly. |total_weight| will hold the total weight of + // all examples that come into this node. double total_weight = 0.; for (size_t idx : training_idx) { const LabelledExample& example = training_data[idx]; @@ -324,7 +331,7 @@ RandomTreeTrainer::Split RandomTreeTrainer::ConstructSplit( Split::BranchInfo& branch_info = iter->second; branch_info.training_idx.push_back(idx); - branch_info.target_distribution += example; + branch_info.target_histogram += example; } // Figure out how good / bad this split is. @@ -340,18 +347,22 @@ RandomTreeTrainer::Split RandomTreeTrainer::ConstructSplit( return split; } -void RandomTreeTrainer::ComputeSplitScore_Nominal(Split* split, - double total_weight) { +void RandomTreeTrainer::ComputeSplitScore_Nominal( + Split* split, + double total_incoming_weight) { // Compute the nats given that we're at this node. split->nats_remaining = 0; for (auto& info_iter : split->branch_infos) { Split::BranchInfo& branch_info = info_iter.second; - const double total_counts = branch_info.target_distribution.total_counts(); + // |weight_along_branch| is the total weight of examples that would follow + // this branch in the tree. + const double weight_along_branch = + branch_info.target_histogram.total_counts(); // |p_branch| is the probability of following this branch. - const double p_branch = total_counts / total_weight; - for (auto& iter : branch_info.target_distribution) { - double p = iter.second / total_counts; + const double p_branch = weight_along_branch / total_incoming_weight; + for (auto& iter : branch_info.target_histogram) { + double p = iter.second / total_incoming_weight; // p*log(p) is the expected nats if the answer is |iter|. We multiply // that by the probability of being in this bucket at all. split->nats_remaining -= (p * log(p)) * p_branch; @@ -359,25 +370,29 @@ void RandomTreeTrainer::ComputeSplitScore_Nominal(Split* split, } } -void RandomTreeTrainer::ComputeSplitScore_Numeric(Split* split, - double total_weight) { +void RandomTreeTrainer::ComputeSplitScore_Numeric( + Split* split, + double total_incoming_weight) { // Compute the nats given that we're at this node. split->nats_remaining = 0; for (auto& info_iter : split->branch_infos) { Split::BranchInfo& branch_info = info_iter.second; - const double total_counts = branch_info.target_distribution.total_counts(); + // |weight_along_branch| is the total weight of examples that would follow + // this branch in the tree. + const double weight_along_branch = + branch_info.target_histogram.total_counts(); // |p_branch| is the probability of following this branch. - const double p_branch = total_counts / total_weight; + const double p_branch = weight_along_branch / total_incoming_weight; // Compute the average at this node. Note that we have no idea if the leaf // node would actually use an average, but really it should match. It would - // be really nice if we could compute the value (or TargetDistribution) as + // be really nice if we could compute the value (or TargetHistogram) as // part of computing the split, and have somebody just hand that target // distribution to the leaf if it ends up as one. - double average = branch_info.target_distribution.Average(); + double average = branch_info.target_histogram.Average(); - for (auto& iter : branch_info.target_distribution) { + for (auto& iter : branch_info.target_histogram) { // Compute the squared error for all |iter.second| counts that each have a // value of |iter.first|, when this leaf approximates them as |average|. double sq_err = (iter.first.value() - average) * diff --git a/chromium/media/learning/impl/random_tree_trainer.h b/chromium/media/learning/impl/random_tree_trainer.h index 3383e55fa11..816469832ea 100644 --- a/chromium/media/learning/impl/random_tree_trainer.h +++ b/chromium/media/learning/impl/random_tree_trainer.h @@ -20,7 +20,7 @@ namespace media { namespace learning { -// Trains RandomTree decision tree classifier (doesn't handle regression). +// Trains RandomTree decision tree classifier / regressor. // // Decision trees, including RandomTree, classify instances as follows. Each // non-leaf node is marked with a feature number |i|. The value of the |i|-th @@ -71,9 +71,9 @@ namespace learning { // See https://en.wikipedia.org/wiki/Random_forest for information. Note that // this is just a single tree, not the whole forest. // -// TODO(liberato): Right now, it not-so-randomly selects from the entire set. -// TODO(liberato): consider PRF or other simplified approximations. -// TODO(liberato): separate Model and TrainingAlgorithm. This is the latter. +// Note that this variant chooses split points randomly, as described by the +// ExtraTrees algorithm. This is slightly different than RandomForest, which +// chooses split points to improve the split's score. class COMPONENT_EXPORT(LEARNING_IMPL) RandomTreeTrainer : public TrainingAlgorithm, public HasRandomNumberGenerator { @@ -135,7 +135,7 @@ class COMPONENT_EXPORT(LEARNING_IMPL) RandomTreeTrainer // branch of the split. // This is a flat_map since we're likely to have a very small (e.g., // "true / "false") number of targets. - TargetDistribution target_distribution; + TargetHistogram target_histogram; }; // [feature value at this split] = info about which examples take this @@ -158,12 +158,13 @@ class COMPONENT_EXPORT(LEARNING_IMPL) RandomTreeTrainer const std::vector<size_t>& training_idx, int index); - // Fill in |nats_remaining| for |split| for a nominal target. |total_weight| - // is the total weight of all instances coming into this split. - void ComputeSplitScore_Nominal(Split* split, double total_weight); + // Fill in |nats_remaining| for |split| for a nominal target. + // |total_incoming_weight| is the total weight of all instances coming into + // the node that we're splitting. + void ComputeSplitScore_Nominal(Split* split, double total_incoming_weight); // Fill in |nats_remaining| for |split| for a numeric target. - void ComputeSplitScore_Numeric(Split* split, double total_weight); + void ComputeSplitScore_Numeric(Split* split, double total_incoming_weight); // Compute the split point for |training_data| for a nominal feature. FeatureValue FindSplitPoint_Nominal(size_t index, diff --git a/chromium/media/learning/impl/random_tree_trainer_unittest.cc b/chromium/media/learning/impl/random_tree_trainer_unittest.cc index 6af9e94b49c..f9face03115 100644 --- a/chromium/media/learning/impl/random_tree_trainer_unittest.cc +++ b/chromium/media/learning/impl/random_tree_trainer_unittest.cc @@ -55,7 +55,7 @@ TEST_P(RandomTreeTest, EmptyTrainingDataWorks) { TrainingData empty; std::unique_ptr<Model> model = Train(task_, empty); EXPECT_NE(model.get(), nullptr); - EXPECT_EQ(model->PredictDistribution(FeatureVector()), TargetDistribution()); + EXPECT_EQ(model->PredictDistribution(FeatureVector()), TargetHistogram()); } TEST_P(RandomTreeTest, UniformTrainingDataWorks) { @@ -69,11 +69,10 @@ TEST_P(RandomTreeTest, UniformTrainingDataWorks) { std::unique_ptr<Model> model = Train(task_, training_data); // The tree should produce a distribution for one value (our target), which - // has |n_examples| counts. - TargetDistribution distribution = - model->PredictDistribution(example.features); + // has one count. + TargetHistogram distribution = model->PredictDistribution(example.features); EXPECT_EQ(distribution.size(), 1u); - EXPECT_EQ(distribution[example.target_value], n_examples); + EXPECT_EQ(distribution[example.target_value], 1.0); } TEST_P(RandomTreeTest, SimpleSeparableTrainingData) { @@ -86,8 +85,7 @@ TEST_P(RandomTreeTest, SimpleSeparableTrainingData) { std::unique_ptr<Model> model = Train(task_, training_data); // Each value should have a distribution with one target value with one count. - TargetDistribution distribution = - model->PredictDistribution(example_1.features); + TargetHistogram distribution = model->PredictDistribution(example_1.features); EXPECT_NE(model.get(), nullptr); EXPECT_EQ(distribution.size(), 1u); EXPECT_EQ(distribution[example_1.target_value], 1u); @@ -129,8 +127,7 @@ TEST_P(RandomTreeTest, ComplexSeparableTrainingData) { // Each example should have a distribution that selects the right value. for (const LabelledExample& example : training_data) { - TargetDistribution distribution = - model->PredictDistribution(example.features); + TargetHistogram distribution = model->PredictDistribution(example.features); TargetValue singular_max; EXPECT_TRUE(distribution.FindSingularMax(&singular_max)); EXPECT_EQ(singular_max, example.target_value); @@ -147,17 +144,16 @@ TEST_P(RandomTreeTest, UnseparableTrainingData) { std::unique_ptr<Model> model = Train(task_, training_data); EXPECT_NE(model.get(), nullptr); - // Each value should have a distribution with two targets with one count each. - TargetDistribution distribution = - model->PredictDistribution(example_1.features); + // Each value should have a distribution with two targets with equal counts. + TargetHistogram distribution = model->PredictDistribution(example_1.features); EXPECT_EQ(distribution.size(), 2u); - EXPECT_EQ(distribution[example_1.target_value], 1u); - EXPECT_EQ(distribution[example_2.target_value], 1u); + EXPECT_EQ(distribution[example_1.target_value], 0.5); + EXPECT_EQ(distribution[example_2.target_value], 0.5); distribution = model->PredictDistribution(example_2.features); EXPECT_EQ(distribution.size(), 2u); - EXPECT_EQ(distribution[example_1.target_value], 1u); - EXPECT_EQ(distribution[example_2.target_value], 1u); + EXPECT_EQ(distribution[example_1.target_value], 0.5); + EXPECT_EQ(distribution[example_2.target_value], 0.5); } TEST_P(RandomTreeTest, UnknownFeatureValueHandling) { @@ -202,7 +198,7 @@ TEST_P(RandomTreeTest, NumericFeaturesSplitMultipleTimes) { std::unique_ptr<Model> model = Train(task_, training_data); for (size_t i = 0; i < 4; i++) { // Get a prediction for the |i|-th feature value. - TargetDistribution distribution = model->PredictDistribution( + TargetHistogram distribution = model->PredictDistribution( FeatureVector({FeatureValue(i * feature_mult)})); // The distribution should have one count that should be correct. If // the feature isn't split four times, then some feature value will have too diff --git a/chromium/media/learning/impl/target_distribution.h b/chromium/media/learning/impl/target_distribution.h deleted file mode 100644 index 3f4a9cd7274..00000000000 --- a/chromium/media/learning/impl/target_distribution.h +++ /dev/null @@ -1,94 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef MEDIA_LEARNING_IMPL_TARGET_DISTRIBUTION_H_ -#define MEDIA_LEARNING_IMPL_TARGET_DISTRIBUTION_H_ - -#include <ostream> -#include <string> - -#include "base/component_export.h" -#include "base/containers/flat_map.h" -#include "base/macros.h" -#include "media/learning/common/labelled_example.h" -#include "media/learning/common/value.h" - -namespace media { -namespace learning { - -// TargetDistribution of target values. -class COMPONENT_EXPORT(LEARNING_IMPL) TargetDistribution { - private: - // We use a flat_map since this will often have only one or two TargetValues, - // such as "true" or "false". - using DistributionMap = base::flat_map<TargetValue, size_t>; - - public: - TargetDistribution(); - TargetDistribution(const TargetDistribution& rhs); - TargetDistribution(TargetDistribution&& rhs); - ~TargetDistribution(); - - TargetDistribution& operator=(const TargetDistribution& rhs); - TargetDistribution& operator=(TargetDistribution&& rhs); - - bool operator==(const TargetDistribution& rhs) const; - - // Add |rhs| to our counts. - TargetDistribution& operator+=(const TargetDistribution& rhs); - - // Increment |rhs| by one. - TargetDistribution& operator+=(const TargetValue& rhs); - - // Increment the distribution by |example|'s target value and weight. - TargetDistribution& operator+=(const LabelledExample& example); - - // Return the number of counts for |value|. - size_t operator[](const TargetValue& value) const; - size_t& operator[](const TargetValue& value); - - // Return the total counts in the map. - size_t total_counts() const { - size_t total = 0.; - for (auto& entry : counts_) - total += entry.second; - return total; - } - - DistributionMap::const_iterator begin() const { return counts_.begin(); } - - DistributionMap::const_iterator end() const { return counts_.end(); } - - // Return the number of buckets in the distribution. - // TODO(liberato): Do we want this? - size_t size() const { return counts_.size(); } - - // Find the singular value with the highest counts, and copy it into - // |value_out| and (optionally) |counts_out|. Returns true if there is a - // singular maximum, else returns false with the out params undefined. - bool FindSingularMax(TargetValue* value_out, - size_t* counts_out = nullptr) const; - - // Return the average value of the entries in this distribution. Of course, - // this only makes sense if the TargetValues can be interpreted as numeric. - double Average() const; - - std::string ToString() const; - - private: - const DistributionMap& counts() const { return counts_; } - - // [value] == counts - DistributionMap counts_; - - // Allow copy and assign. -}; - -COMPONENT_EXPORT(LEARNING_IMPL) -std::ostream& operator<<(std::ostream& out, const TargetDistribution& dist); - -} // namespace learning -} // namespace media - -#endif // MEDIA_LEARNING_IMPL_TARGET_DISTRIBUTION_H_ diff --git a/chromium/media/learning/impl/target_distribution_unittest.cc b/chromium/media/learning/impl/target_distribution_unittest.cc deleted file mode 100644 index 1c7564aa5e3..00000000000 --- a/chromium/media/learning/impl/target_distribution_unittest.cc +++ /dev/null @@ -1,164 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "media/learning/impl/target_distribution.h" - -#include "testing/gtest/include/gtest/gtest.h" - -namespace media { -namespace learning { - -class TargetDistributionTest : public testing::Test { - public: - TargetDistributionTest() : value_1(123), value_2(456), value_3(789) {} - - TargetDistribution distribution_; - - TargetValue value_1; - const size_t counts_1 = 100; - - TargetValue value_2; - const size_t counts_2 = 10; - - TargetValue value_3; -}; - -TEST_F(TargetDistributionTest, EmptyTargetDistributionHasZeroCounts) { - EXPECT_EQ(distribution_.total_counts(), 0u); -} - -TEST_F(TargetDistributionTest, AddingCountsWorks) { - distribution_[value_1] = counts_1; - EXPECT_EQ(distribution_.total_counts(), counts_1); - EXPECT_EQ(distribution_[value_1], counts_1); - distribution_[value_1] += counts_1; - EXPECT_EQ(distribution_.total_counts(), counts_1 * 2u); - EXPECT_EQ(distribution_[value_1], counts_1 * 2u); -} - -TEST_F(TargetDistributionTest, MultipleValuesAreSeparate) { - distribution_[value_1] = counts_1; - distribution_[value_2] = counts_2; - EXPECT_EQ(distribution_.total_counts(), counts_1 + counts_2); - EXPECT_EQ(distribution_[value_1], counts_1); - EXPECT_EQ(distribution_[value_2], counts_2); -} - -TEST_F(TargetDistributionTest, AddingTargetValues) { - distribution_ += value_1; - EXPECT_EQ(distribution_.total_counts(), 1u); - EXPECT_EQ(distribution_[value_1], 1u); - EXPECT_EQ(distribution_[value_2], 0u); - - distribution_ += value_1; - EXPECT_EQ(distribution_.total_counts(), 2u); - EXPECT_EQ(distribution_[value_1], 2u); - EXPECT_EQ(distribution_[value_2], 0u); - - distribution_ += value_2; - EXPECT_EQ(distribution_.total_counts(), 3u); - EXPECT_EQ(distribution_[value_1], 2u); - EXPECT_EQ(distribution_[value_2], 1u); -} - -TEST_F(TargetDistributionTest, AddingTargetDistributions) { - distribution_[value_1] = counts_1; - - TargetDistribution rhs; - rhs[value_2] = counts_2; - - distribution_ += rhs; - - EXPECT_EQ(distribution_.total_counts(), counts_1 + counts_2); - EXPECT_EQ(distribution_[value_1], counts_1); - EXPECT_EQ(distribution_[value_2], counts_2); -} - -TEST_F(TargetDistributionTest, FindSingularMaxFindsTheSingularMax) { - distribution_[value_1] = counts_1; - distribution_[value_2] = counts_2; - ASSERT_TRUE(counts_1 > counts_2); - - TargetValue max_value(0); - size_t max_counts = 0; - EXPECT_TRUE(distribution_.FindSingularMax(&max_value, &max_counts)); - EXPECT_EQ(max_value, value_1); - EXPECT_EQ(max_counts, counts_1); -} - -TEST_F(TargetDistributionTest, - FindSingularMaxFindsTheSingularMaxAlternateOrder) { - // Switch the order, to handle sorting in different directions. - distribution_[value_1] = counts_2; - distribution_[value_2] = counts_1; - ASSERT_TRUE(counts_1 > counts_2); - - TargetValue max_value(0); - size_t max_counts = 0; - EXPECT_TRUE(distribution_.FindSingularMax(&max_value, &max_counts)); - EXPECT_EQ(max_value, value_2); - EXPECT_EQ(max_counts, counts_1); -} - -TEST_F(TargetDistributionTest, FindSingularMaxReturnsFalsForNonSingularMax) { - distribution_[value_1] = counts_1; - distribution_[value_2] = counts_1; - - TargetValue max_value(0); - size_t max_counts = 0; - EXPECT_FALSE(distribution_.FindSingularMax(&max_value, &max_counts)); -} - -TEST_F(TargetDistributionTest, FindSingularMaxIgnoresNonSingularNonMax) { - distribution_[value_1] = counts_1; - // |value_2| and |value_3| are tied, but not the max. - distribution_[value_2] = counts_2; - distribution_[value_3] = counts_2; - ASSERT_TRUE(counts_1 > counts_2); - - TargetValue max_value(0); - size_t max_counts = 0; - EXPECT_TRUE(distribution_.FindSingularMax(&max_value, &max_counts)); - EXPECT_EQ(max_value, value_1); - EXPECT_EQ(max_counts, counts_1); -} - -TEST_F(TargetDistributionTest, FindSingularMaxDoesntRequireCounts) { - distribution_[value_1] = counts_1; - - TargetValue max_value(0); - EXPECT_TRUE(distribution_.FindSingularMax(&max_value)); - EXPECT_EQ(max_value, value_1); -} - -TEST_F(TargetDistributionTest, EqualDistributionsCompareAsEqual) { - distribution_[value_1] = counts_1; - TargetDistribution distribution_2; - distribution_2[value_1] = counts_1; - - EXPECT_TRUE(distribution_ == distribution_2); -} - -TEST_F(TargetDistributionTest, UnequalDistributionsCompareAsNotEqual) { - distribution_[value_1] = counts_1; - TargetDistribution distribution_2; - distribution_2[value_2] = counts_2; - - EXPECT_FALSE(distribution_ == distribution_2); -} - -TEST_F(TargetDistributionTest, WeightedLabelledExamplesCountCorrectly) { - LabelledExample example = {{}, value_1}; - example.weight = counts_1; - distribution_ += example; - - TargetDistribution distribution_2; - for (size_t i = 0; i < counts_1; i++) - distribution_2 += value_1; - - EXPECT_EQ(distribution_, distribution_2); -} - -} // namespace learning -} // namespace media diff --git a/chromium/media/learning/impl/target_distribution.cc b/chromium/media/learning/impl/target_histogram.cc index 2fe271733bc..ad1a1f2112a 100644 --- a/chromium/media/learning/impl/target_distribution.cc +++ b/chromium/media/learning/impl/target_histogram.cc @@ -2,51 +2,48 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "media/learning/impl/target_distribution.h" +#include "media/learning/impl/target_histogram.h" #include <sstream> namespace media { namespace learning { -TargetDistribution::TargetDistribution() = default; +TargetHistogram::TargetHistogram() = default; -TargetDistribution::TargetDistribution(const TargetDistribution& rhs) = default; +TargetHistogram::TargetHistogram(const TargetHistogram& rhs) = default; -TargetDistribution::TargetDistribution(TargetDistribution&& rhs) = default; +TargetHistogram::TargetHistogram(TargetHistogram&& rhs) = default; -TargetDistribution::~TargetDistribution() = default; +TargetHistogram::~TargetHistogram() = default; -TargetDistribution& TargetDistribution::operator=( - const TargetDistribution& rhs) = default; - -TargetDistribution& TargetDistribution::operator=(TargetDistribution&& rhs) = +TargetHistogram& TargetHistogram::operator=(const TargetHistogram& rhs) = default; -bool TargetDistribution::operator==(const TargetDistribution& rhs) const { +TargetHistogram& TargetHistogram::operator=(TargetHistogram&& rhs) = default; + +bool TargetHistogram::operator==(const TargetHistogram& rhs) const { return rhs.total_counts() == total_counts() && rhs.counts_ == counts_; } -TargetDistribution& TargetDistribution::operator+=( - const TargetDistribution& rhs) { +TargetHistogram& TargetHistogram::operator+=(const TargetHistogram& rhs) { for (auto& rhs_pair : rhs.counts()) counts_[rhs_pair.first] += rhs_pair.second; return *this; } -TargetDistribution& TargetDistribution::operator+=(const TargetValue& rhs) { +TargetHistogram& TargetHistogram::operator+=(const TargetValue& rhs) { counts_[rhs]++; return *this; } -TargetDistribution& TargetDistribution::operator+=( - const LabelledExample& example) { +TargetHistogram& TargetHistogram::operator+=(const LabelledExample& example) { counts_[example.target_value] += example.weight; return *this; } -size_t TargetDistribution::operator[](const TargetValue& value) const { +double TargetHistogram::operator[](const TargetValue& value) const { auto iter = counts_.find(value); if (iter == counts_.end()) return 0; @@ -54,16 +51,16 @@ size_t TargetDistribution::operator[](const TargetValue& value) const { return iter->second; } -size_t& TargetDistribution::operator[](const TargetValue& value) { +double& TargetHistogram::operator[](const TargetValue& value) { return counts_[value]; } -bool TargetDistribution::FindSingularMax(TargetValue* value_out, - size_t* counts_out) const { +bool TargetHistogram::FindSingularMax(TargetValue* value_out, + double* counts_out) const { if (!counts_.size()) return false; - size_t unused_counts; + double unused_counts; if (!counts_out) counts_out = &unused_counts; @@ -85,9 +82,9 @@ bool TargetDistribution::FindSingularMax(TargetValue* value_out, return singular_max; } -double TargetDistribution::Average() const { +double TargetHistogram::Average() const { double total_value = 0.; - size_t total_counts = 0; + double total_counts = 0; for (auto& iter : counts_) { total_value += iter.first.value() * iter.second; total_counts += iter.second; @@ -99,7 +96,13 @@ double TargetDistribution::Average() const { return total_value / total_counts; } -std::string TargetDistribution::ToString() const { +void TargetHistogram::Normalize() { + double total = total_counts(); + for (auto& iter : counts_) + iter.second /= total; +} + +std::string TargetHistogram::ToString() const { std::ostringstream ss; ss << "["; for (auto& entry : counts_) @@ -110,7 +113,7 @@ std::string TargetDistribution::ToString() const { } std::ostream& operator<<(std::ostream& out, - const media::learning::TargetDistribution& dist) { + const media::learning::TargetHistogram& dist) { return out << dist.ToString(); } diff --git a/chromium/media/learning/impl/target_histogram.h b/chromium/media/learning/impl/target_histogram.h new file mode 100644 index 00000000000..cb8de2b625e --- /dev/null +++ b/chromium/media/learning/impl/target_histogram.h @@ -0,0 +1,98 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MEDIA_LEARNING_IMPL_TARGET_HISTOGRAM_H_ +#define MEDIA_LEARNING_IMPL_TARGET_HISTOGRAM_H_ + +#include <ostream> +#include <string> + +#include "base/component_export.h" +#include "base/containers/flat_map.h" +#include "base/macros.h" +#include "media/learning/common/labelled_example.h" +#include "media/learning/common/value.h" + +namespace media { +namespace learning { + +// Histogram of target values that allows fractional counts. +class COMPONENT_EXPORT(LEARNING_IMPL) TargetHistogram { + private: + // We use a flat_map since this will often have only one or two TargetValues, + // such as "true" or "false". + using CountMap = base::flat_map<TargetValue, double>; + + public: + TargetHistogram(); + TargetHistogram(const TargetHistogram& rhs); + TargetHistogram(TargetHistogram&& rhs); + ~TargetHistogram(); + + TargetHistogram& operator=(const TargetHistogram& rhs); + TargetHistogram& operator=(TargetHistogram&& rhs); + + bool operator==(const TargetHistogram& rhs) const; + + // Add |rhs| to our counts. + TargetHistogram& operator+=(const TargetHistogram& rhs); + + // Increment |rhs| by one. + TargetHistogram& operator+=(const TargetValue& rhs); + + // Increment the histogram by |example|'s target value and weight. + TargetHistogram& operator+=(const LabelledExample& example); + + // Return the number of counts for |value|. + double operator[](const TargetValue& value) const; + double& operator[](const TargetValue& value); + + // Return the total counts in the map. + double total_counts() const { + double total = 0.; + for (auto& entry : counts_) + total += entry.second; + return total; + } + + CountMap::const_iterator begin() const { return counts_.begin(); } + + CountMap::const_iterator end() const { return counts_.end(); } + + // Return the number of buckets in the histogram. + // TODO(liberato): Do we want this? + size_t size() const { return counts_.size(); } + + // Find the singular value with the highest counts, and copy it into + // |value_out| and (optionally) |counts_out|. Returns true if there is a + // singular maximum, else returns false with the out params undefined. + bool FindSingularMax(TargetValue* value_out, + double* counts_out = nullptr) const; + + // Return the average value of the entries in this histogram. Of course, + // this only makes sense if the TargetValues can be interpreted as numeric. + double Average() const; + + // Normalize the histogram so that it has one total count, unless it's + // empty. It will continue to have zero in that case. + void Normalize(); + + std::string ToString() const; + + private: + const CountMap& counts() const { return counts_; } + + // [value] == counts + CountMap counts_; + + // Allow copy and assign. +}; + +COMPONENT_EXPORT(LEARNING_IMPL) +std::ostream& operator<<(std::ostream& out, const TargetHistogram& dist); + +} // namespace learning +} // namespace media + +#endif // MEDIA_LEARNING_IMPL_TARGET_HISTOGRAM_H_ diff --git a/chromium/media/learning/impl/target_histogram_unittest.cc b/chromium/media/learning/impl/target_histogram_unittest.cc new file mode 100644 index 00000000000..5ba36e0ed78 --- /dev/null +++ b/chromium/media/learning/impl/target_histogram_unittest.cc @@ -0,0 +1,179 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "media/learning/impl/target_histogram.h" + +#include "testing/gtest/include/gtest/gtest.h" + +namespace media { +namespace learning { + +class TargetHistogramTest : public testing::Test { + public: + TargetHistogramTest() : value_1(123), value_2(456), value_3(789) {} + + TargetHistogram histogram_; + + TargetValue value_1; + const size_t counts_1 = 100; + + TargetValue value_2; + const size_t counts_2 = 10; + + TargetValue value_3; +}; + +TEST_F(TargetHistogramTest, EmptyTargetHistogramHasZeroCounts) { + EXPECT_EQ(histogram_.total_counts(), 0u); +} + +TEST_F(TargetHistogramTest, AddingCountsWorks) { + histogram_[value_1] = counts_1; + EXPECT_EQ(histogram_.total_counts(), counts_1); + EXPECT_EQ(histogram_[value_1], counts_1); + histogram_[value_1] += counts_1; + EXPECT_EQ(histogram_.total_counts(), counts_1 * 2u); + EXPECT_EQ(histogram_[value_1], counts_1 * 2u); +} + +TEST_F(TargetHistogramTest, MultipleValuesAreSeparate) { + histogram_[value_1] = counts_1; + histogram_[value_2] = counts_2; + EXPECT_EQ(histogram_.total_counts(), counts_1 + counts_2); + EXPECT_EQ(histogram_[value_1], counts_1); + EXPECT_EQ(histogram_[value_2], counts_2); +} + +TEST_F(TargetHistogramTest, AddingTargetValues) { + histogram_ += value_1; + EXPECT_EQ(histogram_.total_counts(), 1u); + EXPECT_EQ(histogram_[value_1], 1u); + EXPECT_EQ(histogram_[value_2], 0u); + + histogram_ += value_1; + EXPECT_EQ(histogram_.total_counts(), 2u); + EXPECT_EQ(histogram_[value_1], 2u); + EXPECT_EQ(histogram_[value_2], 0u); + + histogram_ += value_2; + EXPECT_EQ(histogram_.total_counts(), 3u); + EXPECT_EQ(histogram_[value_1], 2u); + EXPECT_EQ(histogram_[value_2], 1u); +} + +TEST_F(TargetHistogramTest, AddingTargetHistograms) { + histogram_[value_1] = counts_1; + + TargetHistogram rhs; + rhs[value_2] = counts_2; + + histogram_ += rhs; + + EXPECT_EQ(histogram_.total_counts(), counts_1 + counts_2); + EXPECT_EQ(histogram_[value_1], counts_1); + EXPECT_EQ(histogram_[value_2], counts_2); +} + +TEST_F(TargetHistogramTest, FindSingularMaxFindsTheSingularMax) { + histogram_[value_1] = counts_1; + histogram_[value_2] = counts_2; + ASSERT_TRUE(counts_1 > counts_2); + + TargetValue max_value(0); + double max_counts = 0; + EXPECT_TRUE(histogram_.FindSingularMax(&max_value, &max_counts)); + EXPECT_EQ(max_value, value_1); + EXPECT_EQ(max_counts, counts_1); +} + +TEST_F(TargetHistogramTest, FindSingularMaxFindsTheSingularMaxAlternateOrder) { + // Switch the order, to handle sorting in different directions. + histogram_[value_1] = counts_2; + histogram_[value_2] = counts_1; + ASSERT_TRUE(counts_1 > counts_2); + + TargetValue max_value(0); + double max_counts = 0; + EXPECT_TRUE(histogram_.FindSingularMax(&max_value, &max_counts)); + EXPECT_EQ(max_value, value_2); + EXPECT_EQ(max_counts, counts_1); +} + +TEST_F(TargetHistogramTest, FindSingularMaxReturnsFalsForNonSingularMax) { + histogram_[value_1] = counts_1; + histogram_[value_2] = counts_1; + + TargetValue max_value(0); + double max_counts = 0; + EXPECT_FALSE(histogram_.FindSingularMax(&max_value, &max_counts)); +} + +TEST_F(TargetHistogramTest, FindSingularMaxIgnoresNonSingularNonMax) { + histogram_[value_1] = counts_1; + // |value_2| and |value_3| are tied, but not the max. + histogram_[value_2] = counts_2; + histogram_[value_3] = counts_2; + ASSERT_TRUE(counts_1 > counts_2); + + TargetValue max_value(0); + double max_counts = 0; + EXPECT_TRUE(histogram_.FindSingularMax(&max_value, &max_counts)); + EXPECT_EQ(max_value, value_1); + EXPECT_EQ(max_counts, counts_1); +} + +TEST_F(TargetHistogramTest, FindSingularMaxDoesntRequireCounts) { + histogram_[value_1] = counts_1; + + TargetValue max_value(0); + EXPECT_TRUE(histogram_.FindSingularMax(&max_value)); + EXPECT_EQ(max_value, value_1); +} + +TEST_F(TargetHistogramTest, EqualDistributionsCompareAsEqual) { + histogram_[value_1] = counts_1; + TargetHistogram histogram_2; + histogram_2[value_1] = counts_1; + + EXPECT_TRUE(histogram_ == histogram_2); +} + +TEST_F(TargetHistogramTest, UnequalDistributionsCompareAsNotEqual) { + histogram_[value_1] = counts_1; + TargetHistogram histogram_2; + histogram_2[value_2] = counts_2; + + EXPECT_FALSE(histogram_ == histogram_2); +} + +TEST_F(TargetHistogramTest, WeightedLabelledExamplesCountCorrectly) { + LabelledExample example = {{}, value_1}; + example.weight = counts_1; + histogram_ += example; + + TargetHistogram histogram_2; + for (size_t i = 0; i < counts_1; i++) + histogram_2 += value_1; + + EXPECT_EQ(histogram_, histogram_2); +} + +TEST_F(TargetHistogramTest, Normalize) { + histogram_[value_1] = counts_1; + histogram_[value_2] = counts_2; + histogram_.Normalize(); + EXPECT_EQ(histogram_[value_1], + counts_1 / static_cast<double>(counts_1 + counts_2)); + EXPECT_EQ(histogram_[value_2], + counts_2 / static_cast<double>(counts_1 + counts_2)); +} + +TEST_F(TargetHistogramTest, NormalizeEmptyDistribution) { + // Normalizing an empty distribution should result in an empty distribution. + histogram_.Normalize(); + EXPECT_EQ(histogram_.total_counts(), 0); +} + +} // namespace learning +} // namespace media diff --git a/chromium/media/learning/impl/voting_ensemble.cc b/chromium/media/learning/impl/voting_ensemble.cc index 667e739975d..ef07b1cb731 100644 --- a/chromium/media/learning/impl/voting_ensemble.cc +++ b/chromium/media/learning/impl/voting_ensemble.cc @@ -12,9 +12,9 @@ VotingEnsemble::VotingEnsemble(std::vector<std::unique_ptr<Model>> models) VotingEnsemble::~VotingEnsemble() = default; -TargetDistribution VotingEnsemble::PredictDistribution( +TargetHistogram VotingEnsemble::PredictDistribution( const FeatureVector& instance) { - TargetDistribution distribution; + TargetHistogram distribution; for (auto iter = models_.begin(); iter != models_.end(); iter++) distribution += (*iter)->PredictDistribution(instance); diff --git a/chromium/media/learning/impl/voting_ensemble.h b/chromium/media/learning/impl/voting_ensemble.h index 2b4bf11a3fa..7e0cba7b59a 100644 --- a/chromium/media/learning/impl/voting_ensemble.h +++ b/chromium/media/learning/impl/voting_ensemble.h @@ -23,8 +23,7 @@ class COMPONENT_EXPORT(LEARNING_IMPL) VotingEnsemble : public Model { ~VotingEnsemble() override; // Model - TargetDistribution PredictDistribution( - const FeatureVector& instance) override; + TargetHistogram PredictDistribution(const FeatureVector& instance) override; private: std::vector<std::unique_ptr<Model>> models_; diff --git a/chromium/media/learning/mojo/BUILD.gn b/chromium/media/learning/mojo/BUILD.gn index bfe45d11359..d8c92955fdc 100644 --- a/chromium/media/learning/mojo/BUILD.gn +++ b/chromium/media/learning/mojo/BUILD.gn @@ -8,8 +8,8 @@ import("//testing/test.gni") component("impl") { output_name = "media_learning_mojo_impl" sources = [ - "mojo_learning_session_impl.cc", - "mojo_learning_session_impl.h", + "mojo_learning_task_controller_service.cc", + "mojo_learning_task_controller_service.h", ] defines = [ "IS_MEDIA_LEARNING_MOJO_IMPL" ] @@ -35,7 +35,7 @@ source_set("unit_tests") { testonly = true sources = [ - "mojo_learning_session_impl_unittest.cc", + "mojo_learning_task_controller_service_unittest.cc", ] deps = [ diff --git a/chromium/media/learning/mojo/mojo_learning_session_impl.cc b/chromium/media/learning/mojo/mojo_learning_session_impl.cc deleted file mode 100644 index 3cfb9789ff6..00000000000 --- a/chromium/media/learning/mojo/mojo_learning_session_impl.cc +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "media/learning/mojo/mojo_learning_session_impl.h" - -#include "media/learning/common/learning_session.h" - -namespace media { -namespace learning { - -MojoLearningSessionImpl::MojoLearningSessionImpl( - std::unique_ptr<::media::learning::LearningSession> impl) - : impl_(std::move(impl)) {} - -MojoLearningSessionImpl::~MojoLearningSessionImpl() = default; - -void MojoLearningSessionImpl::Bind(mojom::LearningSessionRequest request) { - self_bindings_.AddBinding(this, std::move(request)); -} - -void MojoLearningSessionImpl::AddExample(mojom::LearningTaskType task_type, - const LabelledExample& example) { - // TODO(liberato): Convert |task_type| into a task name. - std::string task_name("no_task"); - - impl_->AddExample(task_name, example); -} - -} // namespace learning -} // namespace media diff --git a/chromium/media/learning/mojo/mojo_learning_session_impl.h b/chromium/media/learning/mojo/mojo_learning_session_impl.h deleted file mode 100644 index 83d0eb72018..00000000000 --- a/chromium/media/learning/mojo/mojo_learning_session_impl.h +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef MEDIA_LEARNING_MOJO_MOJO_LEARNING_SESSION_IMPL_H_ -#define MEDIA_LEARNING_MOJO_MOJO_LEARNING_SESSION_IMPL_H_ - -#include <memory> - -#include "base/component_export.h" -#include "base/macros.h" -#include "media/learning/mojo/public/mojom/learning_session.mojom.h" -#include "mojo/public/cpp/bindings/binding_set.h" - -namespace media { -namespace learning { - -class LearningSession; -class MojoLearningSessionImplTest; - -// Mojo service that talks to a local LearningSession. -class COMPONENT_EXPORT(MEDIA_LEARNING_MOJO) MojoLearningSessionImpl - : public mojom::LearningSession { - public: - ~MojoLearningSessionImpl() override; - - // Bind |request| to this instance. - void Bind(mojom::LearningSessionRequest request); - - // mojom::LearningSession - void AddExample(mojom::LearningTaskType task_type, - const LabelledExample& example) override; - - protected: - explicit MojoLearningSessionImpl( - std::unique_ptr<::media::learning::LearningSession> impl); - - // Underlying session to which we proxy calls. - std::unique_ptr<::media::learning::LearningSession> impl_; - - // We own our own bindings. If we're ever not a singleton, then this should - // move to our owner. - mojo::BindingSet<mojom::LearningSession> self_bindings_; - - friend class MojoLearningSessionImplTest; - - DISALLOW_COPY_AND_ASSIGN(MojoLearningSessionImpl); -}; - -} // namespace learning -} // namespace media - -#endif // MEDIA_LEARNING_MOJO_MOJO_LEARNING_SESSION_IMPL_H_ diff --git a/chromium/media/learning/mojo/mojo_learning_session_impl_unittest.cc b/chromium/media/learning/mojo/mojo_learning_session_impl_unittest.cc deleted file mode 100644 index 5399d20ca16..00000000000 --- a/chromium/media/learning/mojo/mojo_learning_session_impl_unittest.cc +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include <memory> - -#include "base/macros.h" -#include "base/memory/ptr_util.h" -#include "media/learning/common/learning_session.h" -#include "media/learning/mojo/mojo_learning_session_impl.h" -#include "media/learning/mojo/public/mojom/learning_types.mojom.h" -#include "testing/gtest/include/gtest/gtest.h" - -namespace media { -namespace learning { - -class MojoLearningSessionImplTest : public ::testing::Test { - public: - class FakeLearningSession : public ::media::learning::LearningSession { - public: - void AddExample(const std::string& task_name, - const LabelledExample& example) override { - most_recent_task_name_ = task_name; - most_recent_example_ = example; - } - - std::string most_recent_task_name_; - LabelledExample most_recent_example_; - }; - - public: - MojoLearningSessionImplTest() = default; - ~MojoLearningSessionImplTest() override = default; - - void SetUp() override { - // Create a mojo learner that forwards to a fake learner. - std::unique_ptr<FakeLearningSession> provider = - std::make_unique<FakeLearningSession>(); - fake_learning_session_ = provider.get(); - learning_session_impl_ = base::WrapUnique<MojoLearningSessionImpl>( - new MojoLearningSessionImpl(std::move(provider))); - } - - FakeLearningSession* fake_learning_session_ = nullptr; - - const mojom::LearningTaskType task_type_ = - mojom::LearningTaskType::kPlaceHolderTask; - - // The learner provider under test. - std::unique_ptr<MojoLearningSessionImpl> learning_session_impl_; -}; - -TEST_F(MojoLearningSessionImplTest, FeaturesAndTargetValueAreCopied) { - mojom::LabelledExamplePtr example_ptr = mojom::LabelledExample::New(); - const LabelledExample example = {{Value(123), Value(456), Value(890)}, - TargetValue(1234)}; - - learning_session_impl_->AddExample(task_type_, example); - EXPECT_EQ(example, fake_learning_session_->most_recent_example_); -} - -} // namespace learning -} // namespace media diff --git a/chromium/media/learning/mojo/mojo_learning_task_controller_service.cc b/chromium/media/learning/mojo/mojo_learning_task_controller_service.cc new file mode 100644 index 00000000000..acd3f319447 --- /dev/null +++ b/chromium/media/learning/mojo/mojo_learning_task_controller_service.cc @@ -0,0 +1,65 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "media/learning/mojo/mojo_learning_task_controller_service.h" + +#include <utility> + +#include "media/learning/common/learning_task_controller.h" + +namespace media { +namespace learning { + +// Somewhat arbitrary upper limit on the number of in-flight observations that +// we'll allow a client to have. +static const size_t kMaxInFlightObservations = 16; + +MojoLearningTaskControllerService::MojoLearningTaskControllerService( + const LearningTask& task, + std::unique_ptr<::media::learning::LearningTaskController> impl) + : task_(task), impl_(std::move(impl)) {} + +MojoLearningTaskControllerService::~MojoLearningTaskControllerService() = + default; + +void MojoLearningTaskControllerService::BeginObservation( + const base::UnguessableToken& id, + const FeatureVector& features) { + // Drop the observation if it doesn't match the feature description size. + if (features.size() != task_.feature_descriptions.size()) + return; + + // Don't allow the client to send too many in-flight observations. + if (in_flight_observations_.size() >= kMaxInFlightObservations) + return; + in_flight_observations_.insert(id); + + // Since we own |impl_|, we don't need to keep track of in-flight + // observations. We'll release |impl_| on destruction, which cancels them. + impl_->BeginObservation(id, features); +} + +void MojoLearningTaskControllerService::CompleteObservation( + const base::UnguessableToken& id, + const ObservationCompletion& completion) { + auto iter = in_flight_observations_.find(id); + if (iter == in_flight_observations_.end()) + return; + in_flight_observations_.erase(iter); + + impl_->CompleteObservation(id, completion); +} + +void MojoLearningTaskControllerService::CancelObservation( + const base::UnguessableToken& id) { + auto iter = in_flight_observations_.find(id); + if (iter == in_flight_observations_.end()) + return; + in_flight_observations_.erase(iter); + + impl_->CancelObservation(id); +} + +} // namespace learning +} // namespace media diff --git a/chromium/media/learning/mojo/mojo_learning_task_controller_service.h b/chromium/media/learning/mojo/mojo_learning_task_controller_service.h new file mode 100644 index 00000000000..cdf8721adca --- /dev/null +++ b/chromium/media/learning/mojo/mojo_learning_task_controller_service.h @@ -0,0 +1,51 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MEDIA_LEARNING_MOJO_MOJO_LEARNING_TASK_CONTROLLER_SERVICE_H_ +#define MEDIA_LEARNING_MOJO_MOJO_LEARNING_TASK_CONTROLLER_SERVICE_H_ + +#include <memory> +#include <set> + +#include "base/component_export.h" +#include "base/macros.h" +#include "media/learning/mojo/public/mojom/learning_task_controller.mojom.h" + +namespace media { +namespace learning { + +class LearningTaskController; + +// Mojo service that talks to a local LearningTaskController. +class COMPONENT_EXPORT(MEDIA_LEARNING_MOJO) MojoLearningTaskControllerService + : public mojom::LearningTaskController { + public: + // |impl| is the underlying controller that we'll send requests to. + explicit MojoLearningTaskControllerService( + const LearningTask& task, + std::unique_ptr<::media::learning::LearningTaskController> impl); + ~MojoLearningTaskControllerService() override; + + // mojom::LearningTaskController + void BeginObservation(const base::UnguessableToken& id, + const FeatureVector& features) override; + void CompleteObservation(const base::UnguessableToken& id, + const ObservationCompletion& completion) override; + void CancelObservation(const base::UnguessableToken& id) override; + + protected: + const LearningTask task_; + + // Underlying controller to which we proxy calls. + std::unique_ptr<::media::learning::LearningTaskController> impl_; + + std::set<base::UnguessableToken> in_flight_observations_; + + DISALLOW_COPY_AND_ASSIGN(MojoLearningTaskControllerService); +}; + +} // namespace learning +} // namespace media + +#endif // MEDIA_LEARNING_MOJO_MOJO_LEARNING_TASK_CONTROLLER_SERVICE_H_ diff --git a/chromium/media/learning/mojo/mojo_learning_task_controller_service_unittest.cc b/chromium/media/learning/mojo/mojo_learning_task_controller_service_unittest.cc new file mode 100644 index 00000000000..99e8a391269 --- /dev/null +++ b/chromium/media/learning/mojo/mojo_learning_task_controller_service_unittest.cc @@ -0,0 +1,143 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <memory> +#include <utility> + +#include "base/bind.h" +#include "base/macros.h" +#include "base/memory/ptr_util.h" +#include "base/test/scoped_task_environment.h" +#include "base/threading/thread.h" +#include "media/learning/mojo/mojo_learning_task_controller_service.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace media { +namespace learning { + +class MojoLearningTaskControllerServiceTest : public ::testing::Test { + public: + class FakeLearningTaskController : public LearningTaskController { + public: + void BeginObservation(base::UnguessableToken id, + const FeatureVector& features) override { + begin_args_.id_ = id; + begin_args_.features_ = features; + } + + void CompleteObservation(base::UnguessableToken id, + const ObservationCompletion& completion) override { + complete_args_.id_ = id; + complete_args_.completion_ = completion; + } + + void CancelObservation(base::UnguessableToken id) override { + cancel_args_.id_ = id; + } + + struct { + base::UnguessableToken id_; + FeatureVector features_; + } begin_args_; + + struct { + base::UnguessableToken id_; + ObservationCompletion completion_; + } complete_args_; + + struct { + base::UnguessableToken id_; + } cancel_args_; + }; + + public: + MojoLearningTaskControllerServiceTest() = default; + ~MojoLearningTaskControllerServiceTest() override = default; + + void SetUp() override { + std::unique_ptr<FakeLearningTaskController> controller = + std::make_unique<FakeLearningTaskController>(); + controller_raw_ = controller.get(); + + // Add two features. + task_.feature_descriptions.push_back({}); + task_.feature_descriptions.push_back({}); + + // Tell |learning_controller_| to forward to the fake learner impl. + service_ = std::make_unique<MojoLearningTaskControllerService>( + task_, std::move(controller)); + } + + LearningTask task_; + + // Mojo stuff. + base::test::ScopedTaskEnvironment scoped_task_environment_; + + FakeLearningTaskController* controller_raw_ = nullptr; + + // The learner under test. + std::unique_ptr<MojoLearningTaskControllerService> service_; +}; + +TEST_F(MojoLearningTaskControllerServiceTest, BeginComplete) { + base::UnguessableToken id = base::UnguessableToken::Create(); + FeatureVector features = {FeatureValue(123), FeatureValue(456)}; + service_->BeginObservation(id, features); + EXPECT_EQ(id, controller_raw_->begin_args_.id_); + EXPECT_EQ(features, controller_raw_->begin_args_.features_); + + ObservationCompletion completion(TargetValue(1234)); + service_->CompleteObservation(id, completion); + + EXPECT_EQ(id, controller_raw_->complete_args_.id_); + EXPECT_EQ(completion.target_value, + controller_raw_->complete_args_.completion_.target_value); +} + +TEST_F(MojoLearningTaskControllerServiceTest, BeginCancel) { + base::UnguessableToken id = base::UnguessableToken::Create(); + FeatureVector features = {FeatureValue(123), FeatureValue(456)}; + service_->BeginObservation(id, features); + EXPECT_EQ(id, controller_raw_->begin_args_.id_); + EXPECT_EQ(features, controller_raw_->begin_args_.features_); + + service_->CancelObservation(id); + + EXPECT_EQ(id, controller_raw_->cancel_args_.id_); +} + +TEST_F(MojoLearningTaskControllerServiceTest, TooFewFeaturesIsIgnored) { + // A FeatureVector with too few elements should be ignored. + base::UnguessableToken id = base::UnguessableToken::Create(); + FeatureVector short_features = {FeatureValue(123)}; + service_->BeginObservation(id, short_features); + EXPECT_NE(id, controller_raw_->begin_args_.id_); + EXPECT_EQ(controller_raw_->begin_args_.features_.size(), 0u); +} + +TEST_F(MojoLearningTaskControllerServiceTest, TooManyFeaturesIsIgnored) { + // A FeatureVector with too many elements should be ignored. + base::UnguessableToken id = base::UnguessableToken::Create(); + FeatureVector long_features = {FeatureValue(123), FeatureValue(456), + FeatureValue(789)}; + service_->BeginObservation(id, long_features); + EXPECT_NE(id, controller_raw_->begin_args_.id_); + EXPECT_EQ(controller_raw_->begin_args_.features_.size(), 0u); +} + +TEST_F(MojoLearningTaskControllerServiceTest, CompleteWithoutBeginFails) { + base::UnguessableToken id = base::UnguessableToken::Create(); + ObservationCompletion completion(TargetValue(1234)); + service_->CompleteObservation(id, completion); + EXPECT_NE(id, controller_raw_->complete_args_.id_); +} + +TEST_F(MojoLearningTaskControllerServiceTest, CancelWithoutBeginFails) { + base::UnguessableToken id = base::UnguessableToken::Create(); + service_->CancelObservation(id); + EXPECT_NE(id, controller_raw_->cancel_args_.id_); +} + +} // namespace learning +} // namespace media diff --git a/chromium/media/learning/mojo/public/cpp/BUILD.gn b/chromium/media/learning/mojo/public/cpp/BUILD.gn index 2157282389e..ce6b135639f 100644 --- a/chromium/media/learning/mojo/public/cpp/BUILD.gn +++ b/chromium/media/learning/mojo/public/cpp/BUILD.gn @@ -9,8 +9,8 @@ source_set("cpp") { ] sources = [ - "mojo_learning_session.cc", - "mojo_learning_session.h", + "mojo_learning_task_controller.cc", + "mojo_learning_task_controller.h", ] defines = [ "IS_MEDIA_LEARNING_MOJO_IMPL" ] @@ -26,7 +26,7 @@ source_set("unit_tests") { testonly = true sources = [ - "mojo_learning_session_unittest.cc", + "mojo_learning_task_controller_unittest.cc", ] deps = [ diff --git a/chromium/media/learning/mojo/public/cpp/learning_mojom_traits.cc b/chromium/media/learning/mojo/public/cpp/learning_mojom_traits.cc index aa308d9de9f..12841f9df32 100644 --- a/chromium/media/learning/mojo/public/cpp/learning_mojom_traits.cc +++ b/chromium/media/learning/mojo/public/cpp/learning_mojom_traits.cc @@ -39,4 +39,14 @@ bool StructTraits<media::learning::mojom::TargetValueDataView, return true; } +// static +bool StructTraits<media::learning::mojom::ObservationCompletionDataView, + media::learning::ObservationCompletion>:: + Read(media::learning::mojom::ObservationCompletionDataView data, + media::learning::ObservationCompletion* out_observation_completion) { + if (!data.ReadTargetValue(&out_observation_completion->target_value)) + return false; + out_observation_completion->weight = data.weight(); + return true; +} } // namespace mojo diff --git a/chromium/media/learning/mojo/public/cpp/learning_mojom_traits.h b/chromium/media/learning/mojo/public/cpp/learning_mojom_traits.h index 932a5cb7d4a..52f8ed5a86c 100644 --- a/chromium/media/learning/mojo/public/cpp/learning_mojom_traits.h +++ b/chromium/media/learning/mojo/public/cpp/learning_mojom_traits.h @@ -7,6 +7,7 @@ #include <vector> +#include "media/learning/common/learning_task_controller.h" #include "media/learning/common/value.h" #include "media/learning/mojo/public/mojom/learning_types.mojom.h" #include "mojo/public/cpp/bindings/struct_traits.h" @@ -52,6 +53,23 @@ class StructTraits<media::learning::mojom::TargetValueDataView, media::learning::TargetValue* out_target_value); }; +template <> +class StructTraits<media::learning::mojom::ObservationCompletionDataView, + media::learning::ObservationCompletion> { + public: + static media::learning::TargetValue target_value( + const media::learning::ObservationCompletion& e) { + return e.target_value; + } + static media::learning::WeightType weight( + const media::learning::ObservationCompletion& e) { + return e.weight; + } + static bool Read( + media::learning::mojom::ObservationCompletionDataView data, + media::learning::ObservationCompletion* out_observation_completion); +}; + } // namespace mojo #endif // MEDIA_LEARNING_MOJO_PUBLIC_CPP_LEARNING_MOJOM_TRAITS_H_ diff --git a/chromium/media/learning/mojo/public/cpp/mojo_learning_session.cc b/chromium/media/learning/mojo/public/cpp/mojo_learning_session.cc deleted file mode 100644 index c67f642842a..00000000000 --- a/chromium/media/learning/mojo/public/cpp/mojo_learning_session.cc +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "media/learning/mojo/public/cpp/mojo_learning_session.h" - -#include "mojo/public/cpp/bindings/binding.h" - -namespace media { -namespace learning { - -MojoLearningSession::MojoLearningSession(mojom::LearningSessionPtr session_ptr) - : session_ptr_(std::move(session_ptr)) {} - -MojoLearningSession::~MojoLearningSession() = default; - -void MojoLearningSession::AddExample(const std::string& task_name, - const LabelledExample& example) { - // TODO(liberato): Convert from |task_name| to a task type. - session_ptr_->AddExample(mojom::LearningTaskType::kPlaceHolderTask, example); -} - -} // namespace learning -} // namespace media diff --git a/chromium/media/learning/mojo/public/cpp/mojo_learning_session.h b/chromium/media/learning/mojo/public/cpp/mojo_learning_session.h deleted file mode 100644 index 0e8af2b6aca..00000000000 --- a/chromium/media/learning/mojo/public/cpp/mojo_learning_session.h +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef MEDIA_LEARNING_MOJO_PUBLIC_CPP_MOJO_LEARNING_SESSION_H_ -#define MEDIA_LEARNING_MOJO_PUBLIC_CPP_MOJO_LEARNING_SESSION_H_ - -#include "base/component_export.h" -#include "base/macros.h" -#include "media/learning/common/learning_session.h" -#include "media/learning/mojo/public/mojom/learning_session.mojom.h" - -namespace media { -namespace learning { - -// LearningSession implementation to forward to a remote impl via mojo. -class COMPONENT_EXPORT(MEDIA_LEARNING_MOJO) MojoLearningSession - : public LearningSession { - public: - explicit MojoLearningSession(mojom::LearningSessionPtr session_ptr); - ~MojoLearningSession() override; - - // LearningSession - void AddExample(const std::string& task_name, - const LabelledExample& example) override; - - private: - mojom::LearningSessionPtr session_ptr_; - - DISALLOW_COPY_AND_ASSIGN(MojoLearningSession); -}; - -} // namespace learning -} // namespace media - -#endif // MEDIA_LEARNING_MOJO_PUBLIC_CPP_MOJO_LEARNING_SESSION_H_ diff --git a/chromium/media/learning/mojo/public/cpp/mojo_learning_session_unittest.cc b/chromium/media/learning/mojo/public/cpp/mojo_learning_session_unittest.cc deleted file mode 100644 index 37cecb4db21..00000000000 --- a/chromium/media/learning/mojo/public/cpp/mojo_learning_session_unittest.cc +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include <memory> - -#include "base/bind.h" -#include "base/macros.h" -#include "base/memory/ptr_util.h" -#include "base/test/scoped_task_environment.h" -#include "base/threading/thread.h" -#include "media/learning/mojo/public/cpp/mojo_learning_session.h" -#include "mojo/public/cpp/bindings/binding.h" -#include "testing/gtest/include/gtest/gtest.h" - -namespace media { -namespace learning { - -class MojoLearningSessionTest : public ::testing::Test { - public: - // Impl of a mojom::LearningSession that remembers call arguments. - class FakeMojoLearningSessionImpl : public mojom::LearningSession { - public: - void AddExample(mojom::LearningTaskType task_type, - const LabelledExample& example) override { - task_type_ = std::move(task_type); - example_ = example; - } - - mojom::LearningTaskType task_type_; - LabelledExample example_; - }; - - public: - MojoLearningSessionTest() - : learning_session_binding_(&fake_learning_session_impl_) {} - ~MojoLearningSessionTest() override = default; - - void SetUp() override { - // Create a fake learner provider mojo impl. - mojom::LearningSessionPtr learning_session_ptr; - learning_session_binding_.Bind(mojo::MakeRequest(&learning_session_ptr)); - - // Tell |learning_session_| to forward to the fake learner impl. - learning_session_ = - std::make_unique<MojoLearningSession>(std::move(learning_session_ptr)); - } - - // Mojo stuff. - base::test::ScopedTaskEnvironment scoped_task_environment_; - - FakeMojoLearningSessionImpl fake_learning_session_impl_; - mojo::Binding<mojom::LearningSession> learning_session_binding_; - - // The learner under test. - std::unique_ptr<MojoLearningSession> learning_session_; -}; - -TEST_F(MojoLearningSessionTest, ExampleIsCopied) { - LabelledExample example({FeatureValue(123), FeatureValue(456)}, - TargetValue(1234)); - learning_session_->AddExample("unused task id", example); - learning_session_binding_.FlushForTesting(); - EXPECT_EQ(fake_learning_session_impl_.example_, example); -} - -} // namespace learning -} // namespace media diff --git a/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller.cc b/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller.cc new file mode 100644 index 00000000000..f4648b38bcb --- /dev/null +++ b/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller.cc @@ -0,0 +1,39 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "media/learning/mojo/public/cpp/mojo_learning_task_controller.h" + +#include <utility> + +#include "mojo/public/cpp/bindings/binding.h" + +namespace media { +namespace learning { + +MojoLearningTaskController::MojoLearningTaskController( + mojom::LearningTaskControllerPtr controller_ptr) + : controller_ptr_(std::move(controller_ptr)) {} + +MojoLearningTaskController::~MojoLearningTaskController() = default; + +void MojoLearningTaskController::BeginObservation( + base::UnguessableToken id, + const FeatureVector& features) { + // We don't need to keep track of in-flight observations, since the service + // side handles it for us. + controller_ptr_->BeginObservation(id, features); +} + +void MojoLearningTaskController::CompleteObservation( + base::UnguessableToken id, + const ObservationCompletion& completion) { + controller_ptr_->CompleteObservation(id, completion); +} + +void MojoLearningTaskController::CancelObservation(base::UnguessableToken id) { + controller_ptr_->CancelObservation(id); +} + +} // namespace learning +} // namespace media diff --git a/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller.h b/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller.h new file mode 100644 index 00000000000..893d1f0890e --- /dev/null +++ b/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller.h @@ -0,0 +1,42 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MEDIA_LEARNING_MOJO_PUBLIC_CPP_MOJO_LEARNING_TASK_CONTROLLER_H_ +#define MEDIA_LEARNING_MOJO_PUBLIC_CPP_MOJO_LEARNING_TASK_CONTROLLER_H_ + +#include <utility> + +#include "base/component_export.h" +#include "base/macros.h" +#include "media/learning/common/learning_task_controller.h" +#include "media/learning/mojo/public/mojom/learning_task_controller.mojom.h" + +namespace media { +namespace learning { + +// LearningTaskController implementation to forward to a remote impl via mojo. +class COMPONENT_EXPORT(MEDIA_LEARNING_MOJO) MojoLearningTaskController + : public LearningTaskController { + public: + explicit MojoLearningTaskController( + mojom::LearningTaskControllerPtr controller_ptr); + ~MojoLearningTaskController() override; + + // LearningTaskController + void BeginObservation(base::UnguessableToken id, + const FeatureVector& features) override; + void CompleteObservation(base::UnguessableToken id, + const ObservationCompletion& completion) override; + void CancelObservation(base::UnguessableToken id) override; + + private: + mojom::LearningTaskControllerPtr controller_ptr_; + + DISALLOW_COPY_AND_ASSIGN(MojoLearningTaskController); +}; + +} // namespace learning +} // namespace media + +#endif // MEDIA_LEARNING_MOJO_PUBLIC_CPP_MOJO_LEARNING_TASK_CONTROLLER_H_ diff --git a/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller_unittest.cc b/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller_unittest.cc new file mode 100644 index 00000000000..546aae171b4 --- /dev/null +++ b/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller_unittest.cc @@ -0,0 +1,109 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <memory> +#include <utility> + +#include "base/bind.h" +#include "base/macros.h" +#include "base/memory/ptr_util.h" +#include "base/test/scoped_task_environment.h" +#include "base/threading/thread.h" +#include "media/learning/mojo/public/cpp/mojo_learning_task_controller.h" +#include "mojo/public/cpp/bindings/binding.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace media { +namespace learning { + +class MojoLearningTaskControllerTest : public ::testing::Test { + public: + // Impl of a mojom::LearningTaskController that remembers call arguments. + class FakeMojoLearningTaskController : public mojom::LearningTaskController { + public: + void BeginObservation(const base::UnguessableToken& id, + const FeatureVector& features) override { + begin_args_.id_ = id; + begin_args_.features_ = features; + } + + void CompleteObservation(const base::UnguessableToken& id, + const ObservationCompletion& completion) override { + complete_args_.id_ = id; + complete_args_.completion_ = completion; + } + + void CancelObservation(const base::UnguessableToken& id) override { + cancel_args_.id_ = id; + } + + struct { + base::UnguessableToken id_; + FeatureVector features_; + } begin_args_; + + struct { + base::UnguessableToken id_; + ObservationCompletion completion_; + } complete_args_; + + struct { + base::UnguessableToken id_; + } cancel_args_; + }; + + public: + MojoLearningTaskControllerTest() + : learning_controller_binding_(&fake_learning_controller_) {} + ~MojoLearningTaskControllerTest() override = default; + + void SetUp() override { + // Create a fake learner provider mojo impl. + mojom::LearningTaskControllerPtr learning_controller_ptr; + learning_controller_binding_.Bind( + mojo::MakeRequest(&learning_controller_ptr)); + + // Tell |learning_controller_| to forward to the fake learner impl. + learning_controller_ = std::make_unique<MojoLearningTaskController>( + std::move(learning_controller_ptr)); + } + + // Mojo stuff. + base::test::ScopedTaskEnvironment scoped_task_environment_; + + FakeMojoLearningTaskController fake_learning_controller_; + mojo::Binding<mojom::LearningTaskController> learning_controller_binding_; + + // The learner under test. + std::unique_ptr<MojoLearningTaskController> learning_controller_; +}; + +TEST_F(MojoLearningTaskControllerTest, Begin) { + base::UnguessableToken id = base::UnguessableToken::Create(); + FeatureVector features = {FeatureValue(123), FeatureValue(456)}; + learning_controller_->BeginObservation(id, features); + scoped_task_environment_.RunUntilIdle(); + EXPECT_EQ(id, fake_learning_controller_.begin_args_.id_); + EXPECT_EQ(features, fake_learning_controller_.begin_args_.features_); +} + +TEST_F(MojoLearningTaskControllerTest, Complete) { + base::UnguessableToken id = base::UnguessableToken::Create(); + ObservationCompletion completion(TargetValue(1234)); + learning_controller_->CompleteObservation(id, completion); + scoped_task_environment_.RunUntilIdle(); + EXPECT_EQ(id, fake_learning_controller_.complete_args_.id_); + EXPECT_EQ(completion.target_value, + fake_learning_controller_.complete_args_.completion_.target_value); +} + +TEST_F(MojoLearningTaskControllerTest, Cancel) { + base::UnguessableToken id = base::UnguessableToken::Create(); + learning_controller_->CancelObservation(id); + scoped_task_environment_.RunUntilIdle(); + EXPECT_EQ(id, fake_learning_controller_.cancel_args_.id_); +} + +} // namespace learning +} // namespace media diff --git a/chromium/media/learning/mojo/public/mojom/BUILD.gn b/chromium/media/learning/mojo/public/mojom/BUILD.gn index e0292358725..3d0a1f565b1 100644 --- a/chromium/media/learning/mojo/public/mojom/BUILD.gn +++ b/chromium/media/learning/mojo/public/mojom/BUILD.gn @@ -7,7 +7,7 @@ import("//mojo/public/tools/bindings/mojom.gni") mojom("mojom") { sources = [ - "learning_session.mojom", + "learning_task_controller.mojom", "learning_types.mojom", ] diff --git a/chromium/media/learning/mojo/public/mojom/learning_session.mojom b/chromium/media/learning/mojo/public/mojom/learning_session.mojom deleted file mode 100644 index f7a2b1d7b3f..00000000000 --- a/chromium/media/learning/mojo/public/mojom/learning_session.mojom +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -module media.learning.mojom; - -import "media/learning/mojo/public/mojom/learning_types.mojom"; - -// Learning tasks, to prevent sending the task name string in AcquireLearner. -enum LearningTaskType { - // There are no tasks yet. - kPlaceHolderTask, -}; - -// media/learning/public/learning_session.h -interface LearningSession { - // Add |example| to |task_type|. - AddExample(LearningTaskType task_type, LabelledExample example); -}; diff --git a/chromium/media/learning/mojo/public/mojom/learning_task_controller.mojom b/chromium/media/learning/mojo/public/mojom/learning_task_controller.mojom new file mode 100644 index 00000000000..8fb3135d438 --- /dev/null +++ b/chromium/media/learning/mojo/public/mojom/learning_task_controller.mojom @@ -0,0 +1,36 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +module media.learning.mojom; + +import "mojo/public/mojom/base/unguessable_token.mojom"; +import "media/learning/mojo/public/mojom/learning_types.mojom"; + +// Client for a single learning task. Intended to be the primary API for client +// code that generates FeatureVectors / requests predictions for a single task. +// The API supports sending in an observed FeatureVector without a target value, +// so that framework-provided features (FeatureProvider) can be snapshotted at +// the right time. One doesn't generally want to wait until the TargetValue is +// observed to do that. +// +// Typically, this interface will allow non-browser processes to communicate +// with the learning framework in the browser. +interface LearningTaskController { + // Start a new observation. Call this at the time one would try to predict + // the TargetValue. This lets the framework snapshot any framework-provided + // feature values at prediction time. Later, if you want to turn these + // features into an example for training a model, then call + // CompleteObservation with the same id and an ObservationCompletion. + // Otherwise, call CancelObservation with |id|. It's also okay to destroy the + // controller with outstanding observations; these will be cancelled. + BeginObservation(mojo_base.mojom.UnguessableToken id, + array<FeatureValue> features); + + // Complete observation |id| by providing |completion|. + CompleteObservation(mojo_base.mojom.UnguessableToken id, + ObservationCompletion completion); + + // Cancel observation |id|. Deleting |this| will do the same. + CancelObservation(mojo_base.mojom.UnguessableToken id); +}; diff --git a/chromium/media/learning/mojo/public/mojom/learning_types.mojom b/chromium/media/learning/mojo/public/mojom/learning_types.mojom index 9a51bd970c5..cc469185203 100644 --- a/chromium/media/learning/mojo/public/mojom/learning_types.mojom +++ b/chromium/media/learning/mojo/public/mojom/learning_types.mojom @@ -19,3 +19,9 @@ struct LabelledExample { array<FeatureValue> features; TargetValue target_value; }; + +// learning::ObservationCompletion (common/learning_task_controller.h) +struct ObservationCompletion { + TargetValue target_value; + uint64 weight = 1; +}; diff --git a/chromium/media/learning/mojo/public/mojom/learning_types.typemap b/chromium/media/learning/mojo/public/mojom/learning_types.typemap index 4e6a27b67dd..beaf1467335 100644 --- a/chromium/media/learning/mojo/public/mojom/learning_types.typemap +++ b/chromium/media/learning/mojo/public/mojom/learning_types.typemap @@ -1,6 +1,7 @@ mojom = "//media/learning/mojo/public/mojom/learning_types.mojom" public_headers = [ "//media/learning/common/labelled_example.h", + "//media/learning/common/learning_task_controller.h", "//media/learning/common/value.h", ] traits_headers = [ "//media/learning/mojo/public/cpp/learning_mojom_traits.h" ] @@ -15,4 +16,5 @@ type_mappings = [ "media.learning.mojom.LabelledExample=media::learning::LabelledExample", "media.learning.mojom.FeatureValue=media::learning::FeatureValue", "media.learning.mojom.TargetValue=media::learning::TargetValue", + "media.learning.mojom.ObservationCompletion=media::learning::ObservationCompletion", ] |