summaryrefslogtreecommitdiff
path: root/chromium/media/learning
diff options
context:
space:
mode:
authorAllan Sandfeld Jensen <allan.jensen@qt.io>2019-05-24 11:40:17 +0200
committerAllan Sandfeld Jensen <allan.jensen@qt.io>2019-05-24 12:42:11 +0000
commit5d87695f37678f96492b258bbab36486c59866b4 (patch)
treebe9783bbaf04fb930c4d74ca9c00b5e7954c8bc6 /chromium/media/learning
parent6c11fb357ec39bf087b8b632e2b1e375aef1b38b (diff)
downloadqtwebengine-chromium-5d87695f37678f96492b258bbab36486c59866b4.tar.gz
BASELINE: Update Chromium to 75.0.3770.56
Change-Id: I86d2007fd27a45d5797eee06f4c9369b8b50ac4f Reviewed-by: Alexandru Croitor <alexandru.croitor@qt.io>
Diffstat (limited to 'chromium/media/learning')
-rw-r--r--chromium/media/learning/common/learning_session.h12
-rw-r--r--chromium/media/learning/common/learning_task.h68
-rw-r--r--chromium/media/learning/common/learning_task_controller.h25
-rw-r--r--chromium/media/learning/common/value.cc2
-rw-r--r--chromium/media/learning/impl/BUILD.gn20
-rw-r--r--chromium/media/learning/impl/distribution_reporter.cc169
-rw-r--r--chromium/media/learning/impl/distribution_reporter.h47
-rw-r--r--chromium/media/learning/impl/distribution_reporter_unittest.cc50
-rw-r--r--chromium/media/learning/impl/extra_trees_trainer_unittest.cc15
-rw-r--r--chromium/media/learning/impl/learning_fuzzertest.cc74
-rw-r--r--chromium/media/learning/impl/learning_session_impl.cc100
-rw-r--r--chromium/media/learning/impl/learning_session_impl.h20
-rw-r--r--chromium/media/learning/impl/learning_session_impl_unittest.cc114
-rw-r--r--chromium/media/learning/impl/learning_task_controller_helper.cc18
-rw-r--r--chromium/media/learning/impl/learning_task_controller_helper.h16
-rw-r--r--chromium/media/learning/impl/learning_task_controller_helper_unittest.cc4
-rw-r--r--chromium/media/learning/impl/learning_task_controller_impl.cc101
-rw-r--r--chromium/media/learning/impl/learning_task_controller_impl.h29
-rw-r--r--chromium/media/learning/impl/learning_task_controller_impl_unittest.cc70
-rw-r--r--chromium/media/learning/impl/lookup_table_trainer.cc7
-rw-r--r--chromium/media/learning/impl/lookup_table_trainer_unittest.cc18
-rw-r--r--chromium/media/learning/impl/model.h6
-rw-r--r--chromium/media/learning/impl/one_hot.cc2
-rw-r--r--chromium/media/learning/impl/one_hot.h3
-rw-r--r--chromium/media/learning/impl/random_tree_trainer.cc69
-rw-r--r--chromium/media/learning/impl/random_tree_trainer.h19
-rw-r--r--chromium/media/learning/impl/random_tree_trainer_unittest.cc30
-rw-r--r--chromium/media/learning/impl/target_distribution.h94
-rw-r--r--chromium/media/learning/impl/target_distribution_unittest.cc164
-rw-r--r--chromium/media/learning/impl/target_histogram.cc (renamed from chromium/media/learning/impl/target_distribution.cc)51
-rw-r--r--chromium/media/learning/impl/target_histogram.h98
-rw-r--r--chromium/media/learning/impl/target_histogram_unittest.cc179
-rw-r--r--chromium/media/learning/impl/voting_ensemble.cc4
-rw-r--r--chromium/media/learning/impl/voting_ensemble.h3
-rw-r--r--chromium/media/learning/mojo/BUILD.gn6
-rw-r--r--chromium/media/learning/mojo/mojo_learning_session_impl.cc31
-rw-r--r--chromium/media/learning/mojo/mojo_learning_session_impl.h53
-rw-r--r--chromium/media/learning/mojo/mojo_learning_session_impl_unittest.cc63
-rw-r--r--chromium/media/learning/mojo/mojo_learning_task_controller_service.cc65
-rw-r--r--chromium/media/learning/mojo/mojo_learning_task_controller_service.h51
-rw-r--r--chromium/media/learning/mojo/mojo_learning_task_controller_service_unittest.cc143
-rw-r--r--chromium/media/learning/mojo/public/cpp/BUILD.gn6
-rw-r--r--chromium/media/learning/mojo/public/cpp/learning_mojom_traits.cc10
-rw-r--r--chromium/media/learning/mojo/public/cpp/learning_mojom_traits.h18
-rw-r--r--chromium/media/learning/mojo/public/cpp/mojo_learning_session.cc24
-rw-r--r--chromium/media/learning/mojo/public/cpp/mojo_learning_session.h36
-rw-r--r--chromium/media/learning/mojo/public/cpp/mojo_learning_session_unittest.cc68
-rw-r--r--chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller.cc39
-rw-r--r--chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller.h42
-rw-r--r--chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller_unittest.cc109
-rw-r--r--chromium/media/learning/mojo/public/mojom/BUILD.gn2
-rw-r--r--chromium/media/learning/mojo/public/mojom/learning_session.mojom19
-rw-r--r--chromium/media/learning/mojo/public/mojom/learning_task_controller.mojom36
-rw-r--r--chromium/media/learning/mojo/public/mojom/learning_types.mojom6
-rw-r--r--chromium/media/learning/mojo/public/mojom/learning_types.typemap2
55 files changed, 1692 insertions, 838 deletions
diff --git a/chromium/media/learning/common/learning_session.h b/chromium/media/learning/common/learning_session.h
index 22db890c2c4..f0fb9ba911e 100644
--- a/chromium/media/learning/common/learning_session.h
+++ b/chromium/media/learning/common/learning_session.h
@@ -5,6 +5,7 @@
#ifndef MEDIA_LEARNING_COMMON_LEARNING_SESSION_H_
#define MEDIA_LEARNING_COMMON_LEARNING_SESSION_H_
+#include <memory>
#include <string>
#include "base/component_export.h"
@@ -15,18 +16,17 @@
namespace media {
namespace learning {
+class LearningTaskController;
+
// Interface to provide a Learner given the task name.
class COMPONENT_EXPORT(LEARNING_COMMON) LearningSession {
public:
LearningSession();
virtual ~LearningSession();
- // Add an observed example |example| to the learning task |task_name|.
- // TODO(liberato): Consider making this an enum to match mojo.
- virtual void AddExample(const std::string& task_name,
- const LabelledExample& example) = 0;
-
- // TODO(liberato): Add prediction API.
+ // Return a LearningTaskController for the given task.
+ virtual std::unique_ptr<LearningTaskController> GetController(
+ const std::string& task_name) = 0;
private:
DISALLOW_COPY_AND_ASSIGN(LearningSession);
diff --git a/chromium/media/learning/common/learning_task.h b/chromium/media/learning/common/learning_task.h
index a7ca9b9f05b..ed5f4f76ac3 100644
--- a/chromium/media/learning/common/learning_task.h
+++ b/chromium/media/learning/common/learning_task.h
@@ -10,6 +10,7 @@
#include <vector>
#include "base/component_export.h"
+#include "base/optional.h"
#include "media/learning/common/value.h"
namespace media {
@@ -29,6 +30,9 @@ struct COMPONENT_EXPORT(LEARNING_COMMON) LearningTask {
enum class Model {
kExtraTrees,
kLookupTable,
+
+ // For the fuzzer.
+ kMaxValue = kLookupTable
};
enum class Ordering {
@@ -43,6 +47,9 @@ struct COMPONENT_EXPORT(LEARNING_COMMON) LearningTask {
// ints that represent the number of elapsed milliseconds are numerically
// ordered in a meaningful way.
kNumeric,
+
+ // For the fuzzer.
+ kMaxValue = kNumeric
};
enum class PrivacyMode {
@@ -52,6 +59,9 @@ struct COMPONENT_EXPORT(LEARNING_COMMON) LearningTask {
// Value does not represent private information, such as video width.
kPublic,
+
+ // For the fuzzer.
+ kMaxValue = kPublic
};
// Description of how a Value should be interpreted.
@@ -97,9 +107,11 @@ struct COMPONENT_EXPORT(LEARNING_COMMON) LearningTask {
// we currently have, which might be less than |max_data_set_size|.
double min_new_data_fraction = 0.1;
- // If set, then we'll record a confusion matrix hackily to UMA using this as
- // the histogram name.
- std::string uma_hacky_confusion_matrix;
+ // If provided, then we'll randomly select a |*feature_subset_size|-sized set
+ // of feature to train the model with, to allow for feature importance
+ // measurement. Note that UMA reporting only supports subsets of size one, or
+ // the whole set.
+ base::Optional<int> feature_subset_size;
// RandomForest parameters
@@ -120,7 +132,57 @@ struct COMPONENT_EXPORT(LEARNING_COMMON) LearningTask {
//
// In particular, if the percentage of dropped frames is greater than this,
// then report "false" (not smooth), else we report true.
+ //
+ // A better, non-hacky approach would be to report the predictions and
+ // observations directly, and do offline analysis with whatever threshold we
+ // like. This would remove the thresholding requirement, and also permit
+ // additional types of analysis for general regression tasks, such measuring
+ // the prediction error directly.
+ //
+ // The UKM reporter will support this.
double smoothness_threshold = 0.1;
+
+ // If set, then we'll record a confusion matrix (hackily, see
+ // |smoothness_threshold|, above, for what that means) to UMA for all
+ // predictions. Add this task's name to histograms.xml, in the histogram
+ // suffixes for "Media.Learning.BinaryThreshold.Aggregate". The threshold is
+ // chosen by |smoothness_threshold|.
+ //
+ // This option is ignored if feature subset selection is in use.
+ bool uma_hacky_aggregate_confusion_matrix = false;
+
+ // If set, then we'll record a histogram of many confusion matrices, split out
+ // by the total training data weight that was used to construct the model. Be
+ // sure to add this task's name to histograms.xml, in the histogram suffixes
+ // for "Media.Learning.BinaryThreshold.ByTrainingWeight". The threshold is
+ // chosen by |smoothness_threshold|.
+ //
+ // This option is ignored if feature subset selection is in use.
+ bool uma_hacky_by_training_weight_confusion_matrix = false;
+
+ // If set, then we'll record a histogram of many confusion matrices, split out
+ // by the (single) selected feature subset. This does nothing if we're not
+ // using feature subsets, or if the subset size isn't one. Be sure to add
+ // this tasks' name to histograms.xml, in the histogram suffixes for
+ // "Media.Learning.BinaryThreshold.ByFeature" too.
+ bool uma_hacky_by_feature_subset_confusion_matrix = false;
+
+ // Maximum training weight for UMA reporting. We'll report results offset
+ // into different confusion matrices in the same histogram, evenly spaced
+ // from 0 to |max_reporting_weight|, with one additional bucket for everything
+ // larger than that. The number of buckets is |num_reporting_weight_buckets|.
+ double max_reporting_weight = 99.;
+
+ // Number of buckets that we'll use to split out the confusion matrix by
+ // training weight. The last one is reserved for "all", while the others are
+ // split evenly from 0 to |max_reporting_weight|, inclusive. One can select
+ // up to 15 buckets. We use 11 by default, so it breaks up the default weight
+ // into buckets of size 10.
+ //
+ // In other words, the defaults will make these buckets:
+ // [0-9] [10-19] ... [90-99] [100 and up]. This makes sense if the training
+ // set maximum size is the default of 100, and each example has a weight of 1.
+ int num_reporting_weight_buckets = 11;
};
} // namespace learning
diff --git a/chromium/media/learning/common/learning_task_controller.h b/chromium/media/learning/common/learning_task_controller.h
index 65cd91f6d44..1e224bde59e 100644
--- a/chromium/media/learning/common/learning_task_controller.h
+++ b/chromium/media/learning/common/learning_task_controller.h
@@ -8,6 +8,7 @@
#include "base/callback.h"
#include "base/component_export.h"
#include "base/macros.h"
+#include "base/unguessable_token.h"
#include "media/learning/common/labelled_example.h"
#include "media/learning/common/learning_task.h"
@@ -40,31 +41,27 @@ class COMPONENT_EXPORT(LEARNING_COMMON) LearningTaskController {
LearningTaskController() = default;
virtual ~LearningTaskController() = default;
- // TODO(liberato): what is the scope of this id? can it be local to whoever
- // owns the LTC? otherwise, consider making it an unguessable token.
- // TODO(liberato): consider making a special id that means "i will not send a
- // target value", to save a call to CancelObservation.
- using ObservationId = int32_t;
-
- // new example. Call this at the time one would try to predict the
- // TargetValue. This lets the framework snapshot any framework-provided
+ // Start a new observation. Call this at the time one would try to predict
+ // the TargetValue. This lets the framework snapshot any framework-provided
// feature values at prediction time. Later, if you want to turn these
- // features into an example for training a model, then call the returned CB
- // with the TargetValue and weight. Otherwise, you may discard the CB.
+ // features into an example for training a model, then call
+ // CompleteObservation with the same id and an ObservationCompletion.
+ // Otherwise, call CancelObservation with |id|. It's also okay to destroy the
+ // controller with outstanding observations; these will be cancelled.
// TODO(liberato): This should optionally take a callback to receive a
// prediction for the FeatureVector.
// TODO(liberato): See if this ends up generating smaller code with pass-by-
// value or with |FeatureVector&&|, once we have callers that can actually
// benefit from it.
- virtual void BeginObservation(ObservationId id,
+ virtual void BeginObservation(base::UnguessableToken id,
const FeatureVector& features) = 0;
// Complete an observation by sending a completion.
- virtual void CompleteObservation(ObservationId id,
+ virtual void CompleteObservation(base::UnguessableToken id,
const ObservationCompletion& completion) = 0;
- // Notify the LTC that no completion will be sent.
- virtual void CancelObservation(ObservationId id) = 0;
+ // Notify the LearningTaskController that no completion will be sent.
+ virtual void CancelObservation(base::UnguessableToken id) = 0;
private:
DISALLOW_COPY_AND_ASSIGN(LearningTaskController);
diff --git a/chromium/media/learning/common/value.cc b/chromium/media/learning/common/value.cc
index 12ea399d24c..caf94e0eb0e 100644
--- a/chromium/media/learning/common/value.cc
+++ b/chromium/media/learning/common/value.cc
@@ -4,7 +4,7 @@
#include "media/learning/common/value.h"
-#include "base/hash.h"
+#include "base/hash/hash.h"
namespace media {
namespace learning {
diff --git a/chromium/media/learning/impl/BUILD.gn b/chromium/media/learning/impl/BUILD.gn
index 638014f0e39..27ed8c2dd8b 100644
--- a/chromium/media/learning/impl/BUILD.gn
+++ b/chromium/media/learning/impl/BUILD.gn
@@ -2,10 +2,13 @@
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
+import("//testing/libfuzzer/fuzzer_test.gni")
+
component("impl") {
output_name = "learning_impl"
visibility = [
"//media/learning/impl:unit_tests",
+ "//media/learning/impl:learning_fuzzer",
# Actual clients.
"//content/browser",
@@ -35,8 +38,8 @@ component("impl") {
"random_number_generator.h",
"random_tree_trainer.cc",
"random_tree_trainer.h",
- "target_distribution.cc",
- "target_distribution.h",
+ "target_histogram.cc",
+ "target_histogram.h",
"training_algorithm.h",
"voting_ensemble.cc",
"voting_ensemble.h",
@@ -69,7 +72,7 @@ source_set("unit_tests") {
"one_hot_unittest.cc",
"random_number_generator_unittest.cc",
"random_tree_trainer_unittest.cc",
- "target_distribution_unittest.cc",
+ "target_histogram_unittest.cc",
"test_random_number_generator.cc",
"test_random_number_generator.h",
]
@@ -82,3 +85,14 @@ source_set("unit_tests") {
"//testing/gtest",
]
}
+
+fuzzer_test("learning_fuzzer") {
+ sources = [
+ "learning_fuzzertest.cc",
+ ]
+ deps = [
+ ":impl",
+ "//base",
+ "//base/test:test_support",
+ ]
+}
diff --git a/chromium/media/learning/impl/distribution_reporter.cc b/chromium/media/learning/impl/distribution_reporter.cc
index e657ed3d9a2..a42f2c504a9 100644
--- a/chromium/media/learning/impl/distribution_reporter.cc
+++ b/chromium/media/learning/impl/distribution_reporter.cc
@@ -10,70 +10,170 @@
namespace media {
namespace learning {
-// Low order bit is "observed", second bit is "predicted", third bit is "could
-// not make a prediction".
+// UMA histogram base names.
+static const char* kAggregateBase = "Media.Learning.BinaryThreshold.Aggregate.";
+static const char* kByTrainingWeightBase =
+ "Media.Learning.BinaryThreshold.ByTrainingWeight.";
+static const char* kByFeatureBase = "Media.Learning.BinaryThreshold.ByFeature.";
+
+enum /* not class */ Bits {
+ // These are meant to be bitwise-or'd together, so both false cases just mean
+ // "don't set any bits".
+ PredictedFalse = 0x00,
+ ObservedFalse = 0x00,
+ ObservedTrue = 0x01,
+ PredictedTrue = 0x02,
+ // Special value to mean that no prediction was made.
+ PredictedNothing = 0x04,
+};
+
+// Low order bit is "observed", second bit is "predicted", third bit is
+// "could not make a prediction".
enum class ConfusionMatrix {
- TrueNegative = 0, // predicted == observed == false
- FalseNegative = 1, // predicted == false, observed == true
- FalsePositive = 2, // predicted == true, observed == false
- TruePositive = 3, // predicted == observed == true
- SkippedNegative = 4, // predicted == N/A, observed == false
- SkippedPositive = 5, // predicted == N/A, observed == true
+ TrueNegative = Bits::PredictedFalse | Bits::ObservedFalse,
+ FalseNegative = Bits::PredictedFalse | Bits::ObservedTrue,
+ FalsePositive = Bits::PredictedTrue | Bits::ObservedFalse,
+ TruePositive = Bits::PredictedTrue | Bits::ObservedTrue,
+ SkippedNegative = Bits::PredictedNothing | Bits::ObservedFalse,
+ SkippedPositive = Bits::PredictedNothing | Bits::ObservedTrue,
kMaxValue = SkippedPositive
};
// TODO(liberato): Currently, this implementation is a hack to collect some
// sanity-checking data for local learning with MediaCapabilities. We assume
// that the prediction is the "percentage of dropped frames".
-//
-// Please see https://chromium-review.googlesource.com/c/chromium/src/+/1385107
-// for an actual UKM-based implementation.
-class RegressionReporter : public DistributionReporter {
+class UmaRegressionReporter : public DistributionReporter {
public:
- RegressionReporter(const LearningTask& task) : DistributionReporter(task) {}
+ UmaRegressionReporter(const LearningTask& task)
+ : DistributionReporter(task) {}
- void OnPrediction(TargetDistribution observed,
- TargetDistribution predicted) override {
+ void OnPrediction(const PredictionInfo& info,
+ TargetHistogram predicted) override {
DCHECK_EQ(task().target_description.ordering,
LearningTask::Ordering::kNumeric);
- DCHECK(!task().uma_hacky_confusion_matrix.empty());
// As a complete hack, record accuracy with a fixed threshold. The average
// is the observed / predicted percentage of dropped frames.
- bool observed_smooth = observed.Average() <= task().smoothness_threshold;
+ bool observed_smooth = info.observed.value() <= task().smoothness_threshold;
// See if we made a prediction.
- int predicted_bits = 4; // N/A
+ int prediction_bit_mask = Bits::PredictedNothing;
if (predicted.total_counts() != 0) {
bool predicted_smooth =
predicted.Average() <= task().smoothness_threshold;
DVLOG(2) << "Learning: " << task().name
<< ": predicted: " << predicted_smooth << " ("
<< predicted.Average() << ") observed: " << observed_smooth
- << " (" << observed.Average() << ")";
- predicted_bits = predicted_smooth ? 2 : 0;
+ << " (" << info.observed.value() << ")";
+ prediction_bit_mask =
+ predicted_smooth ? Bits::PredictedTrue : Bits::PredictedFalse;
} else {
DVLOG(2) << "Learning: " << task().name
<< ": predicted: N/A observed: " << observed_smooth << " ("
- << observed.Average() << ")";
+ << info.observed.value() << ")";
}
- // Convert to a bucket from which we can get the confusion matrix.
- ConfusionMatrix uma_bucket = static_cast<ConfusionMatrix>(
- (observed_smooth ? 1 : 0) | predicted_bits);
- base::UmaHistogramEnumeration(task().uma_hacky_confusion_matrix,
- uma_bucket);
+ // Figure out the ConfusionMatrix enum value.
+ ConfusionMatrix confusion_matrix_value = static_cast<ConfusionMatrix>(
+ (observed_smooth ? Bits::ObservedTrue : Bits::ObservedFalse) |
+ prediction_bit_mask);
+
+ // |uma_bucket_number| is the bucket number that we'll fill in with this
+ // count. It ranges from 0 to |max_buckets-1|, inclusive. Each bucket is
+ // is separated from the start of the previous bucket by |uma_bucket_size|.
+ int uma_bucket_number = 0;
+ constexpr int matrix_size =
+ static_cast<int>(ConfusionMatrix::kMaxValue) + 1;
+
+ // The enum.xml entries separate the buckets by 10, to make it easy to see
+ // by inspection what bucket number we're in (e.g., x-axis position 23 is
+ // bucket 2 * 10 + PredictedTrue|ObservedTrue). The label in enum.xml for
+ // MegaConfusionMatrix also provides the bucket number for easy reading.
+ constexpr int uma_bucket_size = 10;
+ DCHECK_LE(matrix_size, uma_bucket_size);
+
+ // Maximum number of buckets defined in enums.xml, numbered from 0.
+ constexpr int max_buckets = 16;
+
+ // Sparse histograms can technically go past 100 exactly-stored elements,
+ // but we limit it anyway. Note that we don't care about |uma_bucket_size|,
+ // since it's a sparse histogram. Only |matrix_size| elements are used in
+ // each bucket.
+ DCHECK_LE(max_buckets * matrix_size, 100);
+
+ // If we're splitting by feature, then record it and stop. The others
+ // aren't meaningful to record if we're using random feature subsets.
+ if (task().uma_hacky_by_feature_subset_confusion_matrix &&
+ feature_indices() && feature_indices()->size() == 1) {
+ // The bucket number is just the feature number that was selected.
+ uma_bucket_number =
+ std::min(*feature_indices()->begin(), max_buckets - 1);
+
+ std::string base(kByFeatureBase);
+ base::UmaHistogramSparse(base + task().name,
+ static_cast<int>(confusion_matrix_value) +
+ uma_bucket_number * uma_bucket_size);
+
+ // Early return since no other measurements are meaningful when we're
+ // using feature subsets.
+ return;
+ }
+
+ // If we're selecting a feature subset that's bigger than one but smaller
+ // than all of them, then we don't know how to report that.
+ if (feature_indices() &&
+ feature_indices()->size() != task().feature_descriptions.size()) {
+ return;
+ }
+
+ // Do normal reporting.
+
+ // Record the aggregate confusion matrix.
+ if (task().uma_hacky_aggregate_confusion_matrix) {
+ std::string base(kAggregateBase);
+ base::UmaHistogramEnumeration(base + task().name, confusion_matrix_value);
+ }
+
+ if (task().uma_hacky_by_training_weight_confusion_matrix) {
+ // Adjust |uma_bucket_offset| by the training weight, and store the
+ // results in that bucket in the ByTrainingWeight histogram.
+ //
+ // This will bucket from 0 in even steps, with the last bucket holding
+ // |max_reporting_weight+1| and everything above it.
+
+ const int n_buckets = task().num_reporting_weight_buckets;
+ DCHECK_LE(n_buckets, max_buckets);
+
+ // We use one fewer buckets, to save one for the overflow. Buckets are
+ // numbered from 0 to |n_buckets-1|, inclusive. In other words, when the
+ // training weight is equal to |max_reporting_weight|, we still want to
+ // be in bucket |n_buckets - 2|. That's why we add one to the max before
+ // we divide; only things over the max go into the last bucket.
+ uma_bucket_number =
+ std::min<int>((n_buckets - 1) * info.total_training_weight /
+ (task().max_reporting_weight + 1),
+ n_buckets - 1);
+
+ std::string base(kByTrainingWeightBase);
+ base::UmaHistogramSparse(base + task().name,
+ static_cast<int>(confusion_matrix_value) +
+ uma_bucket_number * uma_bucket_size);
+ }
}
};
std::unique_ptr<DistributionReporter> DistributionReporter::Create(
const LearningTask& task) {
- // Hacky reporting is the only thing we know how to report.
- if (task.uma_hacky_confusion_matrix.empty())
+ // We only know how to report regression tasks right now.
+ if (task.target_description.ordering != LearningTask::Ordering::kNumeric)
return nullptr;
- if (task.target_description.ordering == LearningTask::Ordering::kNumeric)
- return std::make_unique<RegressionReporter>(task);
+ if (task.uma_hacky_aggregate_confusion_matrix ||
+ task.uma_hacky_by_training_weight_confusion_matrix ||
+ task.uma_hacky_by_feature_subset_confusion_matrix) {
+ return std::make_unique<UmaRegressionReporter>(task);
+ }
+
return nullptr;
}
@@ -83,9 +183,14 @@ DistributionReporter::DistributionReporter(const LearningTask& task)
DistributionReporter::~DistributionReporter() = default;
Model::PredictionCB DistributionReporter::GetPredictionCallback(
- TargetDistribution observed) {
+ const PredictionInfo& info) {
return base::BindOnce(&DistributionReporter::OnPrediction,
- weak_factory_.GetWeakPtr(), observed);
+ weak_factory_.GetWeakPtr(), info);
+}
+
+void DistributionReporter::SetFeatureSubset(
+ const std::set<int>& feature_indices) {
+ feature_indices_ = feature_indices;
}
} // namespace learning
diff --git a/chromium/media/learning/impl/distribution_reporter.h b/chromium/media/learning/impl/distribution_reporter.h
index 78b22e65c93..99310aa0ed8 100644
--- a/chromium/media/learning/impl/distribution_reporter.h
+++ b/chromium/media/learning/impl/distribution_reporter.h
@@ -5,13 +5,16 @@
#ifndef MEDIA_LEARNING_IMPL_DISTRIBUTION_REPORTER_H_
#define MEDIA_LEARNING_IMPL_DISTRIBUTION_REPORTER_H_
+#include <set>
+
#include "base/callback.h"
#include "base/component_export.h"
#include "base/macros.h"
#include "base/memory/weak_ptr.h"
+#include "base/optional.h"
#include "media/learning/common/learning_task.h"
#include "media/learning/impl/model.h"
-#include "media/learning/impl/target_distribution.h"
+#include "media/learning/impl/target_histogram.h"
namespace media {
namespace learning {
@@ -21,15 +24,39 @@ namespace learning {
// specific learning task.
class COMPONENT_EXPORT(LEARNING_IMPL) DistributionReporter {
public:
+ // Extra information provided to the reporter for each prediction.
+ struct PredictionInfo {
+ // What value was observed?
+ TargetValue observed;
+
+ // Total weight of the training data used to create this model.
+ double total_training_weight = 0.;
+
+ // Total number of examples (unweighted) in the training set.
+ size_t total_training_examples = 0u;
+
+ // TODO(liberato): Move the feature subset here.
+ };
+
// Create a DistributionReporter that's suitable for |task|.
static std::unique_ptr<DistributionReporter> Create(const LearningTask& task);
virtual ~DistributionReporter();
- // Returns a prediction CB that will be compared to |observed|. |observed| is
- // the total number of counts that we observed.
+ // Returns a prediction CB that will be compared to |prediction_info.observed|
+ // TODO(liberato): This is too complicated. Skip the callback and just call
+ // us with the predicted value.
virtual Model::PredictionCB GetPredictionCallback(
- TargetDistribution observed);
+ const PredictionInfo& prediction_info);
+
+ // Set the subset of features that is being used to train the model. This is
+ // used for feature importance measuremnts.
+ //
+ // For example, sending in the set [0, 3, 7] would indicate that the model was
+ // trained with task().feature_descriptions[0, 3, 7] only.
+ //
+ // Note that UMA reporting only supports single feature subsets.
+ void SetFeatureSubset(const std::set<int>& feature_indices);
protected:
DistributionReporter(const LearningTask& task);
@@ -37,12 +64,20 @@ class COMPONENT_EXPORT(LEARNING_IMPL) DistributionReporter {
const LearningTask& task() const { return task_; }
// Implemented by subclasses to report a prediction.
- virtual void OnPrediction(TargetDistribution observed,
- TargetDistribution predicted) = 0;
+ virtual void OnPrediction(const PredictionInfo& prediction_info,
+ TargetHistogram predicted) = 0;
+
+ const base::Optional<std::set<int>>& feature_indices() const {
+ return feature_indices_;
+ }
private:
LearningTask task_;
+ // If provided, then these are the features that are used to train the model.
+ // Otherwise, we assume that all features are used.
+ base::Optional<std::set<int>> feature_indices_;
+
base::WeakPtrFactory<DistributionReporter> weak_factory_;
DISALLOW_COPY_AND_ASSIGN(DistributionReporter);
diff --git a/chromium/media/learning/impl/distribution_reporter_unittest.cc b/chromium/media/learning/impl/distribution_reporter_unittest.cc
index 3815d5c4842..43679be9a14 100644
--- a/chromium/media/learning/impl/distribution_reporter_unittest.cc
+++ b/chromium/media/learning/impl/distribution_reporter_unittest.cc
@@ -16,6 +16,12 @@ namespace learning {
class DistributionReporterTest : public testing::Test {
public:
+ DistributionReporterTest() {
+ task_.name = "TaskName";
+ // UMA reporting requires a numeric target.
+ task_.target_description.ordering = LearningTask::Ordering::kNumeric;
+ }
+
base::test::ScopedTaskEnvironment scoped_task_environment_;
LearningTask task_;
@@ -25,35 +31,27 @@ class DistributionReporterTest : public testing::Test {
TEST_F(DistributionReporterTest, DistributionReporterDoesNotCrash) {
// Make sure that we request some sort of reporting.
- task_.target_description.ordering = LearningTask::Ordering::kNumeric;
- task_.uma_hacky_confusion_matrix = "test";
+ task_.uma_hacky_aggregate_confusion_matrix = true;
reporter_ = DistributionReporter::Create(task_);
EXPECT_NE(reporter_, nullptr);
+ // Observe an average of 2 / 3.
+ DistributionReporter::PredictionInfo info;
+ info.observed = TargetValue(2.0 / 3.0);
+ auto cb = reporter_->GetPredictionCallback(info);
+
+ TargetHistogram predicted;
const TargetValue Zero(0);
const TargetValue One(1);
- TargetDistribution observed;
- // Observe an average of 2 / 3.
- observed[Zero] = 100;
- observed[One] = 200;
- auto cb = reporter_->GetPredictionCallback(observed);
-
- TargetDistribution predicted;
// Predict an average of 5 / 9.
predicted[Zero] = 40;
predicted[One] = 50;
std::move(cb).Run(predicted);
-
- // TODO(liberato): When we switch to ukm, use a TestUkmRecorder to make sure
- // that it fills in the right stuff.
- // https://chromium-review.googlesource.com/c/chromium/src/+/1385107 .
}
-TEST_F(DistributionReporterTest, DistributionReporterNeedsUmaName) {
+TEST_F(DistributionReporterTest, DistributionReporterMustBeRequested) {
// Make sure that we don't get a reporter if we don't request any reporting.
- task_.target_description.ordering = LearningTask::Ordering::kNumeric;
- task_.uma_hacky_confusion_matrix = "";
reporter_ = DistributionReporter::Create(task_);
EXPECT_EQ(reporter_, nullptr);
}
@@ -62,10 +60,28 @@ TEST_F(DistributionReporterTest,
DistributionReporterHackyConfusionMatrixNeedsRegression) {
// Hacky confusion matrix reporting only works with regression.
task_.target_description.ordering = LearningTask::Ordering::kUnordered;
- task_.uma_hacky_confusion_matrix = "test";
+ task_.uma_hacky_aggregate_confusion_matrix = true;
reporter_ = DistributionReporter::Create(task_);
EXPECT_EQ(reporter_, nullptr);
}
+TEST_F(DistributionReporterTest, ProvidesAggregateReporter) {
+ task_.uma_hacky_aggregate_confusion_matrix = true;
+ reporter_ = DistributionReporter::Create(task_);
+ EXPECT_NE(reporter_, nullptr);
+}
+
+TEST_F(DistributionReporterTest, ProvidesByTrainingWeightReporter) {
+ task_.uma_hacky_by_training_weight_confusion_matrix = true;
+ reporter_ = DistributionReporter::Create(task_);
+ EXPECT_NE(reporter_, nullptr);
+}
+
+TEST_F(DistributionReporterTest, ProvidesByFeatureSubsetReporter) {
+ task_.uma_hacky_by_feature_subset_confusion_matrix = true;
+ reporter_ = DistributionReporter::Create(task_);
+ EXPECT_NE(reporter_, nullptr);
+}
+
} // namespace learning
} // namespace media
diff --git a/chromium/media/learning/impl/extra_trees_trainer_unittest.cc b/chromium/media/learning/impl/extra_trees_trainer_unittest.cc
index ad07000ba87..d9e18970ced 100644
--- a/chromium/media/learning/impl/extra_trees_trainer_unittest.cc
+++ b/chromium/media/learning/impl/extra_trees_trainer_unittest.cc
@@ -55,7 +55,7 @@ TEST_P(ExtraTreesTest, EmptyTrainingDataWorks) {
TrainingData empty;
auto model = Train(task_, empty);
EXPECT_NE(model.get(), nullptr);
- EXPECT_EQ(model->PredictDistribution(FeatureVector()), TargetDistribution());
+ EXPECT_EQ(model->PredictDistribution(FeatureVector()), TargetHistogram());
}
TEST_P(ExtraTreesTest, FisherIrisDataset) {
@@ -67,8 +67,7 @@ TEST_P(ExtraTreesTest, FisherIrisDataset) {
// Verify predictions on the training set, just for sanity.
size_t num_correct = 0;
for (const LabelledExample& example : training_data) {
- TargetDistribution distribution =
- model->PredictDistribution(example.features);
+ TargetHistogram distribution = model->PredictDistribution(example.features);
TargetValue predicted_value;
if (distribution.FindSingularMax(&predicted_value) &&
predicted_value == example.target_value) {
@@ -102,8 +101,7 @@ TEST_P(ExtraTreesTest, WeightedTrainingSetIsSupported) {
auto model = Train(task_, training_data);
// The singular max should be example_1.
- TargetDistribution distribution =
- model->PredictDistribution(example_1.features);
+ TargetHistogram distribution = model->PredictDistribution(example_1.features);
TargetValue predicted_value;
EXPECT_TRUE(distribution.FindSingularMax(&predicted_value));
EXPECT_EQ(predicted_value, example_1.target_value);
@@ -135,8 +133,7 @@ TEST_P(ExtraTreesTest, RegressionWorks) {
auto model = Train(task_, training_data);
// Make sure that the results are in the right range.
- TargetDistribution distribution =
- model->PredictDistribution(example_1.features);
+ TargetHistogram distribution = model->PredictDistribution(example_1.features);
EXPECT_GT(distribution.Average(), example_1.target_value.value() * 0.95);
EXPECT_LT(distribution.Average(), example_1.target_value.value() * 1.05);
distribution = model->PredictDistribution(example_2.features);
@@ -194,10 +191,10 @@ TEST_P(ExtraTreesTest, RegressionVsBinaryClassification) {
// the data is separable, it probably should be exact.
for (auto& r_example : r_examples) {
const FeatureVector& fv = r_example.features;
- TargetDistribution c_dist = c_model->PredictDistribution(fv);
+ TargetHistogram c_dist = c_model->PredictDistribution(fv);
EXPECT_LE(c_dist.Average(), r_example.target_value.value() * 1.05);
EXPECT_GE(c_dist.Average(), r_example.target_value.value() * 0.95);
- TargetDistribution r_dist = r_model->PredictDistribution(fv);
+ TargetHistogram r_dist = r_model->PredictDistribution(fv);
EXPECT_LE(r_dist.Average(), r_example.target_value.value() * 1.05);
EXPECT_GE(r_dist.Average(), r_example.target_value.value() * 0.95);
}
diff --git a/chromium/media/learning/impl/learning_fuzzertest.cc b/chromium/media/learning/impl/learning_fuzzertest.cc
new file mode 100644
index 00000000000..f886ffb9bef
--- /dev/null
+++ b/chromium/media/learning/impl/learning_fuzzertest.cc
@@ -0,0 +1,74 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "base/test/fuzzed_data_provider.h"
+#include "base/test/scoped_task_environment.h"
+#include "media/learning/impl/learning_task_controller_impl.h"
+
+using media::learning::FeatureValue;
+using media::learning::FeatureVector;
+using media::learning::LearningTask;
+using ValueDescription = media::learning::LearningTask::ValueDescription;
+using media::learning::LearningTaskControllerImpl;
+using media::learning::ObservationCompletion;
+using media::learning::TargetValue;
+
+ValueDescription ConsumeValueDescription(base::FuzzedDataProvider* provider) {
+ ValueDescription desc;
+ desc.name = provider->ConsumeRandomLengthString(100);
+ desc.ordering = provider->ConsumeEnum<LearningTask::Ordering>();
+ desc.privacy_mode = provider->ConsumeEnum<LearningTask::PrivacyMode>();
+ return desc;
+}
+
+double ConsumeDouble(base::FuzzedDataProvider* provider) {
+ std::vector<uint8_t> v = provider->ConsumeBytes(sizeof(double));
+ if (v.size() == sizeof(double))
+ return reinterpret_cast<double*>(v.data())[0];
+
+ return 0;
+}
+
+FeatureVector ConsumeFeatureVector(base::FuzzedDataProvider* provider) {
+ FeatureVector features;
+ int n = provider->ConsumeIntegralInRange(0, 100);
+ while (n-- > 0)
+ features.push_back(FeatureValue(ConsumeDouble(provider)));
+
+ return features;
+}
+
+extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
+ base::test::ScopedTaskEnvironment scoped_task_environment;
+ base::FuzzedDataProvider provider(data, size);
+
+ LearningTask task;
+ task.name = provider.ConsumeRandomLengthString(100);
+ task.model = provider.ConsumeEnum<LearningTask::Model>();
+ task.use_one_hot_conversion = provider.ConsumeBool();
+ task.uma_hacky_aggregate_confusion_matrix = provider.ConsumeBool();
+ task.uma_hacky_by_training_weight_confusion_matrix = provider.ConsumeBool();
+ task.uma_hacky_by_feature_subset_confusion_matrix = provider.ConsumeBool();
+ int n_features = provider.ConsumeIntegralInRange(0, 100);
+ int subset_size = provider.ConsumeIntegralInRange<uint8_t>(0, n_features);
+ if (subset_size)
+ task.feature_subset_size = subset_size;
+ for (int i = 0; i < n_features; i++)
+ task.feature_descriptions.push_back(ConsumeValueDescription(&provider));
+ task.target_description = ConsumeValueDescription(&provider);
+
+ LearningTaskControllerImpl controller(task);
+
+ // Build random examples.
+ while (provider.remaining_bytes() > 0) {
+ base::UnguessableToken id = base::UnguessableToken::Create();
+ controller.BeginObservation(id, ConsumeFeatureVector(&provider));
+ controller.CompleteObservation(
+ id, ObservationCompletion(TargetValue(ConsumeDouble(&provider)),
+ ConsumeDouble(&provider)));
+ scoped_task_environment.RunUntilIdle();
+ }
+
+ return 0;
+}
diff --git a/chromium/media/learning/impl/learning_session_impl.cc b/chromium/media/learning/impl/learning_session_impl.cc
index 4ad7de46f74..39d7e30fe6d 100644
--- a/chromium/media/learning/impl/learning_session_impl.cc
+++ b/chromium/media/learning/impl/learning_session_impl.cc
@@ -4,6 +4,7 @@
#include "media/learning/impl/learning_session_impl.h"
+#include <set>
#include <utility>
#include "base/bind.h"
@@ -14,15 +15,74 @@
namespace media {
namespace learning {
-LearningSessionImpl::LearningSessionImpl()
- : controller_factory_(
- base::BindRepeating([](const LearningTask& task,
- SequenceBoundFeatureProvider feature_provider)
- -> std::unique_ptr<LearningTaskController> {
- return std::make_unique<LearningTaskControllerImpl>(
- task, DistributionReporter::Create(task),
+// Allow multiple clients to own an LTC that points to the same underlying LTC.
+// Since we don't own the LTC, we also keep track of in-flight observations and
+// explicitly cancel them on destruction, since dropping an LTC implies that.
+class WeakLearningTaskController : public LearningTaskController {
+ public:
+ WeakLearningTaskController(
+ base::WeakPtr<LearningSessionImpl> weak_session,
+ base::SequenceBound<LearningTaskController>* controller)
+ : weak_session_(std::move(weak_session)), controller_(controller) {}
+
+ ~WeakLearningTaskController() override {
+ if (!weak_session_)
+ return;
+
+ // Cancel any outstanding observations.
+ for (auto& id : outstanding_ids_) {
+ controller_->Post(FROM_HERE, &LearningTaskController::CancelObservation,
+ id);
+ }
+ }
+
+ void BeginObservation(base::UnguessableToken id,
+ const FeatureVector& features) override {
+ if (!weak_session_)
+ return;
+
+ outstanding_ids_.insert(id);
+ controller_->Post(FROM_HERE, &LearningTaskController::BeginObservation, id,
+ features);
+ }
+
+ void CompleteObservation(base::UnguessableToken id,
+ const ObservationCompletion& completion) override {
+ if (!weak_session_)
+ return;
+ outstanding_ids_.erase(id);
+ controller_->Post(FROM_HERE, &LearningTaskController::CompleteObservation,
+ id, completion);
+ }
+
+ void CancelObservation(base::UnguessableToken id) override {
+ if (!weak_session_)
+ return;
+ outstanding_ids_.erase(id);
+ controller_->Post(FROM_HERE, &LearningTaskController::CancelObservation,
+ id);
+ }
+
+ base::WeakPtr<LearningSessionImpl> weak_session_;
+ base::SequenceBound<LearningTaskController>* controller_;
+
+ // Set of ids that have been started but not completed / cancelled yet.
+ std::set<base::UnguessableToken> outstanding_ids_;
+};
+
+LearningSessionImpl::LearningSessionImpl(
+ scoped_refptr<base::SequencedTaskRunner> task_runner)
+ : task_runner_(std::move(task_runner)),
+ controller_factory_(base::BindRepeating(
+ [](scoped_refptr<base::SequencedTaskRunner> task_runner,
+ const LearningTask& task,
+ SequenceBoundFeatureProvider feature_provider)
+ -> base::SequenceBound<LearningTaskController> {
+ return base::SequenceBound<LearningTaskControllerImpl>(
+ task_runner, task, DistributionReporter::Create(task),
std::move(feature_provider));
- })) {}
+ })),
+ weak_factory_(this) {}
LearningSessionImpl::~LearningSessionImpl() = default;
@@ -31,25 +91,25 @@ void LearningSessionImpl::SetTaskControllerFactoryCBForTesting(
controller_factory_ = std::move(cb);
}
-void LearningSessionImpl::AddExample(const std::string& task_name,
- const LabelledExample& example) {
+std::unique_ptr<LearningTaskController> LearningSessionImpl::GetController(
+ const std::string& task_name) {
auto iter = task_map_.find(task_name);
- if (iter != task_map_.end()) {
- // TODO(liberato): We shouldn't be adding examples. We should provide the
- // LearningTaskController instead, although ownership gets a bit weird.
- LearningTaskController::ObservationId id = 1;
- iter->second->BeginObservation(id, example.features);
- iter->second->CompleteObservation(
- id, ObservationCompletion(example.target_value, example.weight));
- }
+ if (iter == task_map_.end())
+ return nullptr;
+
+ // If there were any way to replace / destroy a controller other than when we
+ // destroy |this|, then this wouldn't be such a good idea.
+ return std::make_unique<WeakLearningTaskController>(
+ weak_factory_.GetWeakPtr(), &iter->second);
}
void LearningSessionImpl::RegisterTask(
const LearningTask& task,
SequenceBoundFeatureProvider feature_provider) {
DCHECK(task_map_.count(task.name) == 0);
- task_map_.emplace(task.name,
- controller_factory_.Run(task, std::move(feature_provider)));
+ task_map_.emplace(
+ task.name,
+ controller_factory_.Run(task_runner_, task, std::move(feature_provider)));
}
} // namespace learning
diff --git a/chromium/media/learning/impl/learning_session_impl.h b/chromium/media/learning/impl/learning_session_impl.h
index f4693dfe0a4..01402f5d163 100644
--- a/chromium/media/learning/impl/learning_session_impl.h
+++ b/chromium/media/learning/impl/learning_session_impl.h
@@ -8,6 +8,8 @@
#include <map>
#include "base/component_export.h"
+#include "base/memory/weak_ptr.h"
+#include "base/sequenced_task_runner.h"
#include "base/threading/sequence_bound.h"
#include "media/learning/common/learning_session.h"
#include "media/learning/common/learning_task_controller.h"
@@ -21,19 +23,22 @@ namespace learning {
class COMPONENT_EXPORT(LEARNING_IMPL) LearningSessionImpl
: public LearningSession {
public:
- LearningSessionImpl();
+ // We will create LearningTaskControllers that run on |task_runner|.
+ LearningSessionImpl(scoped_refptr<base::SequencedTaskRunner> task_runner);
~LearningSessionImpl() override;
+ // Create a SequenceBound controller for |task| on |task_runner|.
using CreateTaskControllerCB =
- base::RepeatingCallback<std::unique_ptr<LearningTaskController>(
+ base::RepeatingCallback<base::SequenceBound<LearningTaskController>(
+ scoped_refptr<base::SequencedTaskRunner>,
const LearningTask&,
SequenceBoundFeatureProvider)>;
void SetTaskControllerFactoryCBForTesting(CreateTaskControllerCB cb);
// LearningSession
- void AddExample(const std::string& task_name,
- const LabelledExample& example) override;
+ std::unique_ptr<LearningTaskController> GetController(
+ const std::string& task_name) override;
// Registers |task|, so that calls to AddExample with |task.name| will work.
// This will create a new controller for the task.
@@ -42,12 +47,17 @@ class COMPONENT_EXPORT(LEARNING_IMPL) LearningSessionImpl
SequenceBoundFeatureProvider());
private:
+ // Task runner on which we'll create controllers.
+ scoped_refptr<base::SequencedTaskRunner> task_runner_;
+
// [task_name] = task controller.
using LearningTaskMap =
- std::map<std::string, std::unique_ptr<LearningTaskController>>;
+ std::map<std::string, base::SequenceBound<LearningTaskController>>;
LearningTaskMap task_map_;
CreateTaskControllerCB controller_factory_;
+
+ base::WeakPtrFactory<LearningSessionImpl> weak_factory_;
};
} // namespace learning
diff --git a/chromium/media/learning/impl/learning_session_impl_unittest.cc b/chromium/media/learning/impl/learning_session_impl_unittest.cc
index 5c55d66dc6a..2cb878a19d6 100644
--- a/chromium/media/learning/impl/learning_session_impl_unittest.cc
+++ b/chromium/media/learning/impl/learning_session_impl_unittest.cc
@@ -18,11 +18,19 @@ namespace learning {
class LearningSessionImplTest : public testing::Test {
public:
+ class FakeLearningTaskController;
+ using ControllerVector = std::vector<FakeLearningTaskController*>;
+ using TaskRunnerVector = std::vector<base::SequencedTaskRunner*>;
+
class FakeLearningTaskController : public LearningTaskController {
public:
- FakeLearningTaskController(const LearningTask& task,
+ // Send ControllerVector* as void*, else it complains that args can't be
+ // forwarded. Adding base::Unretained() doesn't help.
+ FakeLearningTaskController(void* controllers,
+ const LearningTask& task,
SequenceBoundFeatureProvider feature_provider)
: feature_provider_(std::move(feature_provider)) {
+ static_cast<ControllerVector*>(controllers)->push_back(this);
// As a complete hack, call the only public method on fp so that
// we can verify that it was given to us by the session.
if (!feature_provider_.is_null()) {
@@ -32,13 +40,13 @@ class LearningSessionImplTest : public testing::Test {
}
}
- void BeginObservation(ObservationId id,
+ void BeginObservation(base::UnguessableToken id,
const FeatureVector& features) override {
id_ = id;
features_ = features;
}
- void CompleteObservation(ObservationId id,
+ void CompleteObservation(base::UnguessableToken id,
const ObservationCompletion& completion) override {
EXPECT_EQ(id_, id);
example_.features = std::move(features_);
@@ -46,12 +54,17 @@ class LearningSessionImplTest : public testing::Test {
example_.weight = completion.weight;
}
- void CancelObservation(ObservationId id) override { ASSERT_TRUE(false); }
+ void CancelObservation(base::UnguessableToken id) override {
+ cancelled_id_ = id;
+ }
SequenceBoundFeatureProvider feature_provider_;
- ObservationId id_ = 0;
+ base::UnguessableToken id_;
FeatureVector features_;
LabelledExample example_;
+
+ // Most recently cancelled id.
+ base::UnguessableToken cancelled_id_;
};
class FakeFeatureProvider : public FeatureProvider {
@@ -67,20 +80,21 @@ class LearningSessionImplTest : public testing::Test {
bool* flag_ptr_ = nullptr;
};
- using ControllerVector = std::vector<FakeLearningTaskController*>;
-
LearningSessionImplTest() {
- session_ = std::make_unique<LearningSessionImpl>();
+ task_runner_ = base::SequencedTaskRunnerHandle::Get();
+ session_ = std::make_unique<LearningSessionImpl>(task_runner_);
session_->SetTaskControllerFactoryCBForTesting(base::BindRepeating(
- [](ControllerVector* controllers, const LearningTask& task,
+ [](ControllerVector* controllers, TaskRunnerVector* task_runners,
+ scoped_refptr<base::SequencedTaskRunner> task_runner,
+ const LearningTask& task,
SequenceBoundFeatureProvider feature_provider)
- -> std::unique_ptr<LearningTaskController> {
- auto controller = std::make_unique<FakeLearningTaskController>(
- task, std::move(feature_provider));
- controllers->push_back(controller.get());
- return controller;
+ -> base::SequenceBound<LearningTaskController> {
+ task_runners->push_back(task_runner.get());
+ return base::SequenceBound<FakeLearningTaskController>(
+ task_runner, static_cast<void*>(controllers), task,
+ std::move(feature_provider));
},
- &task_controllers_));
+ &task_controllers_, &task_runners_));
task_0_.name = "task_0";
task_1_.name = "task_1";
@@ -88,48 +102,108 @@ class LearningSessionImplTest : public testing::Test {
base::test::ScopedTaskEnvironment scoped_task_environment_;
+ scoped_refptr<base::SequencedTaskRunner> task_runner_;
+
std::unique_ptr<LearningSessionImpl> session_;
LearningTask task_0_;
LearningTask task_1_;
ControllerVector task_controllers_;
+ TaskRunnerVector task_runners_;
};
TEST_F(LearningSessionImplTest, RegisteringTasksCreatesControllers) {
EXPECT_EQ(task_controllers_.size(), 0u);
+ EXPECT_EQ(task_runners_.size(), 0u);
+
session_->RegisterTask(task_0_);
+ scoped_task_environment_.RunUntilIdle();
EXPECT_EQ(task_controllers_.size(), 1u);
+ EXPECT_EQ(task_runners_.size(), 1u);
+ EXPECT_EQ(task_runners_[0], task_runner_.get());
+
session_->RegisterTask(task_1_);
+ scoped_task_environment_.RunUntilIdle();
EXPECT_EQ(task_controllers_.size(), 2u);
+ EXPECT_EQ(task_runners_.size(), 2u);
+ EXPECT_EQ(task_runners_[1], task_runner_.get());
}
TEST_F(LearningSessionImplTest, ExamplesAreForwardedToCorrectTask) {
session_->RegisterTask(task_0_);
session_->RegisterTask(task_1_);
+ base::UnguessableToken id = base::UnguessableToken::Create();
+
LabelledExample example_0({FeatureValue(123), FeatureValue(456)},
TargetValue(1234));
- session_->AddExample(task_0_.name, example_0);
+ std::unique_ptr<LearningTaskController> ltc_0 =
+ session_->GetController(task_0_.name);
+ ltc_0->BeginObservation(id, example_0.features);
+ ltc_0->CompleteObservation(
+ id, ObservationCompletion(example_0.target_value, example_0.weight));
LabelledExample example_1({FeatureValue(321), FeatureValue(654)},
TargetValue(4321));
- session_->AddExample(task_1_.name, example_1);
+
+ std::unique_ptr<LearningTaskController> ltc_1 =
+ session_->GetController(task_1_.name);
+ ltc_1->BeginObservation(id, example_1.features);
+ ltc_1->CompleteObservation(
+ id, ObservationCompletion(example_1.target_value, example_1.weight));
+
+ scoped_task_environment_.RunUntilIdle();
EXPECT_EQ(task_controllers_[0]->example_, example_0);
EXPECT_EQ(task_controllers_[1]->example_, example_1);
}
+TEST_F(LearningSessionImplTest, ControllerLifetimeScopedToSession) {
+ session_->RegisterTask(task_0_);
+
+ std::unique_ptr<LearningTaskController> controller =
+ session_->GetController(task_0_.name);
+
+ // Destroy the session. |controller| should still be usable, though it won't
+ // forward requests anymore.
+ session_.reset();
+ scoped_task_environment_.RunUntilIdle();
+
+ // Should not crash.
+ controller->BeginObservation(base::UnguessableToken::Create(),
+ FeatureVector());
+}
+
TEST_F(LearningSessionImplTest, FeatureProviderIsForwarded) {
// Verify that a FeatureProvider actually gets forwarded to the LTC.
bool flag = false;
- session_->RegisterTask(task_0_,
- base::SequenceBound<FakeFeatureProvider>(
- base::SequencedTaskRunnerHandle::Get(), &flag));
+ session_->RegisterTask(
+ task_0_, base::SequenceBound<FakeFeatureProvider>(task_runner_, &flag));
scoped_task_environment_.RunUntilIdle();
// Registering the task should create a FakeLearningTaskController, which will
// call AddFeatures on the fake FeatureProvider.
EXPECT_TRUE(flag);
}
+TEST_F(LearningSessionImplTest, DestroyingControllerCancelsObservations) {
+ session_->RegisterTask(task_0_);
+
+ std::unique_ptr<LearningTaskController> controller =
+ session_->GetController(task_0_.name);
+ scoped_task_environment_.RunUntilIdle();
+
+ // Start an observation and verify that it starts.
+ base::UnguessableToken id = base::UnguessableToken::Create();
+ controller->BeginObservation(id, FeatureVector());
+ scoped_task_environment_.RunUntilIdle();
+ EXPECT_EQ(task_controllers_[0]->id_, id);
+ EXPECT_NE(task_controllers_[0]->cancelled_id_, id);
+
+ // Should result in cancelling the observation.
+ controller.reset();
+ scoped_task_environment_.RunUntilIdle();
+ EXPECT_EQ(task_controllers_[0]->cancelled_id_, id);
+}
+
} // namespace learning
} // namespace media
diff --git a/chromium/media/learning/impl/learning_task_controller_helper.cc b/chromium/media/learning/impl/learning_task_controller_helper.cc
index d4d57b43ebe..01c8972e67b 100644
--- a/chromium/media/learning/impl/learning_task_controller_helper.cc
+++ b/chromium/media/learning/impl/learning_task_controller_helper.cc
@@ -24,7 +24,7 @@ LearningTaskControllerHelper::LearningTaskControllerHelper(
LearningTaskControllerHelper::~LearningTaskControllerHelper() = default;
-void LearningTaskControllerHelper::BeginObservation(ObservationId id,
+void LearningTaskControllerHelper::BeginObservation(base::UnguessableToken id,
FeatureVector features) {
auto& pending_example = pending_examples_[id];
@@ -41,10 +41,11 @@ void LearningTaskControllerHelper::BeginObservation(ObservationId id,
}
void LearningTaskControllerHelper::CompleteObservation(
- ObservationId id,
+ base::UnguessableToken id,
const ObservationCompletion& completion) {
auto iter = pending_examples_.find(id);
- DCHECK(iter != pending_examples_.end());
+ if (iter == pending_examples_.end())
+ return;
iter->second.example.target_value = completion.target_value;
iter->second.example.weight = completion.weight;
@@ -52,10 +53,11 @@ void LearningTaskControllerHelper::CompleteObservation(
ProcessExampleIfFinished(std::move(iter));
}
-void LearningTaskControllerHelper::CancelObservation(ObservationId id) {
+void LearningTaskControllerHelper::CancelObservation(
+ base::UnguessableToken id) {
auto iter = pending_examples_.find(id);
- // If the example has already been completed, then we shouldn't be called.
- DCHECK(iter != pending_examples_.end());
+ if (iter == pending_examples_.end())
+ return;
// This would have to check for pending predictions, if we supported them, and
// defer destruction until the features arrive.
@@ -66,7 +68,7 @@ void LearningTaskControllerHelper::CancelObservation(ObservationId id) {
void LearningTaskControllerHelper::OnFeaturesReadyTrampoline(
scoped_refptr<base::SequencedTaskRunner> task_runner,
base::WeakPtr<LearningTaskControllerHelper> weak_this,
- ObservationId id,
+ base::UnguessableToken id,
FeatureVector features) {
// TODO(liberato): this would benefit from promises / deferred data.
auto cb = base::BindOnce(&LearningTaskControllerHelper::OnFeaturesReady,
@@ -78,7 +80,7 @@ void LearningTaskControllerHelper::OnFeaturesReadyTrampoline(
}
}
-void LearningTaskControllerHelper::OnFeaturesReady(ObservationId id,
+void LearningTaskControllerHelper::OnFeaturesReady(base::UnguessableToken id,
FeatureVector features) {
PendingExampleMap::iterator iter = pending_examples_.find(id);
// It's possible that OnLabelCallbackDestroyed has already run. That's okay
diff --git a/chromium/media/learning/impl/learning_task_controller_helper.h b/chromium/media/learning/impl/learning_task_controller_helper.h
index 9370536d8ce..7e9bb6f5a7f 100644
--- a/chromium/media/learning/impl/learning_task_controller_helper.h
+++ b/chromium/media/learning/impl/learning_task_controller_helper.h
@@ -37,9 +37,6 @@ class COMPONENT_EXPORT(LEARNING_IMPL) LearningTaskControllerHelper
// Callback to add labelled examples as training data.
using AddExampleCB = base::RepeatingCallback<void(LabelledExample)>;
- // Convenience.
- using ObservationId = LearningTaskController::ObservationId;
-
// TODO(liberato): Consider making the FP not optional.
LearningTaskControllerHelper(const LearningTask& task,
AddExampleCB add_example_cb,
@@ -48,10 +45,10 @@ class COMPONENT_EXPORT(LEARNING_IMPL) LearningTaskControllerHelper
virtual ~LearningTaskControllerHelper();
// See LearningTaskController::BeginObservation.
- void BeginObservation(ObservationId id, FeatureVector features);
- void CompleteObservation(ObservationId id,
+ void BeginObservation(base::UnguessableToken id, FeatureVector features);
+ void CompleteObservation(base::UnguessableToken id,
const ObservationCompletion& completion);
- void CancelObservation(ObservationId id);
+ void CancelObservation(base::UnguessableToken id);
private:
// Record of an example that has been started by RecordObservedFeatures, but
@@ -67,19 +64,20 @@ class COMPONENT_EXPORT(LEARNING_IMPL) LearningTaskControllerHelper
};
// [non-repeating int] = example
- using PendingExampleMap = std::map<ObservationId, PendingExample>;
+ using PendingExampleMap = std::map<base::UnguessableToken, PendingExample>;
// Called on any sequence when features are ready. Will call OnFeatureReady
// if called on |task_runner|, or will post to |task_runner|.
static void OnFeaturesReadyTrampoline(
scoped_refptr<base::SequencedTaskRunner> task_runner,
base::WeakPtr<LearningTaskControllerHelper> weak_this,
- ObservationId id,
+ base::UnguessableToken id,
FeatureVector features);
// Called when a new feature vector has been finished by |feature_provider_|,
// if needed, to actually add the example.
- void OnFeaturesReady(ObservationId example_id, FeatureVector features);
+ void OnFeaturesReady(base::UnguessableToken example_id,
+ FeatureVector features);
// If |example| is finished, then send it to the LearningSession and remove it
// from the map. Otherwise, do nothing.
diff --git a/chromium/media/learning/impl/learning_task_controller_helper_unittest.cc b/chromium/media/learning/impl/learning_task_controller_helper_unittest.cc
index 806c43c05cc..de756f0cf48 100644
--- a/chromium/media/learning/impl/learning_task_controller_helper_unittest.cc
+++ b/chromium/media/learning/impl/learning_task_controller_helper_unittest.cc
@@ -43,6 +43,8 @@ class LearningTaskControllerHelperTest : public testing::Test {
example_.features.push_back(FeatureValue(3));
example_.target_value = TargetValue(123);
example_.weight = 100u;
+
+ id_ = base::UnguessableToken::Create();
}
void CreateClient(bool include_fp) {
@@ -87,7 +89,7 @@ class LearningTaskControllerHelperTest : public testing::Test {
LearningTask task_;
- LearningTaskController::ObservationId id_ = 1;
+ base::UnguessableToken id_;
LabelledExample example_;
};
diff --git a/chromium/media/learning/impl/learning_task_controller_impl.cc b/chromium/media/learning/impl/learning_task_controller_impl.cc
index 84092c939da..544218aaef6 100644
--- a/chromium/media/learning/impl/learning_task_controller_impl.cc
+++ b/chromium/media/learning/impl/learning_task_controller_impl.cc
@@ -6,8 +6,10 @@
#include <memory>
#include <utility>
+#include <vector>
#include "base/bind.h"
+#include "media/learning/impl/distribution_reporter.h"
#include "media/learning/impl/extra_trees_trainer.h"
#include "media/learning/impl/lookup_table_trainer.h"
@@ -25,7 +27,14 @@ LearningTaskControllerImpl::LearningTaskControllerImpl(
task,
base::BindRepeating(&LearningTaskControllerImpl::AddFinishedExample,
AsWeakPtr()),
- std::move(feature_provider))) {
+ std::move(feature_provider))),
+ expected_feature_count_(task_.feature_descriptions.size()) {
+ // Note that |helper_| uses the full set of features.
+
+ // TODO(liberato): Make this compositional. FeatureSubsetTaskController?
+ if (task_.feature_subset_size)
+ DoFeatureSubsetSelection();
+
switch (task_.model) {
case LearningTask::Model::kExtraTrees:
trainer_ = std::make_unique<ExtraTreesTrainer>();
@@ -39,22 +48,51 @@ LearningTaskControllerImpl::LearningTaskControllerImpl(
LearningTaskControllerImpl::~LearningTaskControllerImpl() = default;
void LearningTaskControllerImpl::BeginObservation(
- ObservationId id,
+ base::UnguessableToken id,
const FeatureVector& features) {
+ // TODO(liberato): Should we enforce that the right number of features are
+ // present here? Right now, we allow it to be shorter, so that features from
+ // a FeatureProvider may be omitted. Of course, they have to be at the end in
+ // that case. If we start enforcing it here, make sure that LearningHelper
+ // starts adding the placeholder features.
+ if (!trainer_)
+ return;
+
helper_->BeginObservation(id, features);
}
void LearningTaskControllerImpl::CompleteObservation(
- ObservationId id,
+ base::UnguessableToken id,
const ObservationCompletion& completion) {
+ if (!trainer_)
+ return;
helper_->CompleteObservation(id, completion);
}
-void LearningTaskControllerImpl::CancelObservation(ObservationId id) {
+void LearningTaskControllerImpl::CancelObservation(base::UnguessableToken id) {
+ if (!trainer_)
+ return;
helper_->CancelObservation(id);
}
void LearningTaskControllerImpl::AddFinishedExample(LabelledExample example) {
+ // Verify that we have a trainer and that we got the right number of features.
+ // We don't compare to |task_.feature_descriptions.size()| since that has been
+ // adjusted to the subset size already. We expect the original count.
+ if (!trainer_ || example.features.size() != expected_feature_count_)
+ return;
+
+ // Now that we have the whole set of features, select the subset we want.
+ FeatureVector new_features;
+ if (task_.feature_subset_size) {
+ for (auto& iter : feature_indices_)
+ new_features.push_back(example.features[iter]);
+ example.features = std::move(new_features);
+ } // else use them all.
+
+ // The features should now match the task.
+ DCHECK_EQ(example.features.size(), task_.feature_descriptions.size());
+
if (training_data_->size() >= task_.max_data_set_size) {
// Replace a random example. We don't necessarily want to replace the
// oldest, since we don't necessarily want to enforce an ad-hoc recency
@@ -68,12 +106,13 @@ void LearningTaskControllerImpl::AddFinishedExample(LabelledExample example) {
// Once we have a model, see if we'd get |example| correct.
if (model_ && reporter_) {
- TargetDistribution predicted =
- model_->PredictDistribution(example.features);
+ TargetHistogram predicted = model_->PredictDistribution(example.features);
- TargetDistribution observed;
- observed += example.target_value;
- reporter_->GetPredictionCallback(observed).Run(predicted);
+ DistributionReporter::PredictionInfo info;
+ info.observed = example.target_value;
+ info.total_training_weight = last_training_weight_;
+ info.total_training_examples = last_training_size_;
+ reporter_->GetPredictionCallback(info).Run(predicted);
}
// Can't train more than one model concurrently.
@@ -89,7 +128,8 @@ void LearningTaskControllerImpl::AddFinishedExample(LabelledExample example) {
num_untrained_examples_ = 0;
TrainedModelCB model_cb =
- base::BindOnce(&LearningTaskControllerImpl::OnModelTrained, AsWeakPtr());
+ base::BindOnce(&LearningTaskControllerImpl::OnModelTrained, AsWeakPtr(),
+ training_data_->total_weight(), training_data_->size());
training_is_in_progress_ = true;
// Note that this copies the training data, so it's okay if we add more
// examples to our copy before this returns.
@@ -99,10 +139,15 @@ void LearningTaskControllerImpl::AddFinishedExample(LabelledExample example) {
trainer_->Train(task_, *training_data_, std::move(model_cb));
}
-void LearningTaskControllerImpl::OnModelTrained(std::unique_ptr<Model> model) {
+void LearningTaskControllerImpl::OnModelTrained(double training_weight,
+ int training_size,
+ std::unique_ptr<Model> model) {
DCHECK(training_is_in_progress_);
training_is_in_progress_ = false;
model_ = std::move(model);
+ // Record these for metrics.
+ last_training_weight_ = training_weight;
+ last_training_size_ = training_size;
}
void LearningTaskControllerImpl::SetTrainerForTesting(
@@ -110,5 +155,39 @@ void LearningTaskControllerImpl::SetTrainerForTesting(
trainer_ = std::move(trainer);
}
+void LearningTaskControllerImpl::DoFeatureSubsetSelection() {
+ // Choose a random feature, and trim the descriptions to match.
+ std::vector<size_t> features;
+ for (size_t i = 0; i < task_.feature_descriptions.size(); i++)
+ features.push_back(i);
+
+ for (int i = 0; i < *task_.feature_subset_size; i++) {
+ // Pick an element from |i| to the end of the list, inclusive.
+ // TODO(liberato): For tests, this will happen before any rng is provided
+ // by the test; we'll use an actual rng.
+ int r = rng()->Generate(features.size() - i) + i;
+ // Swap them.
+ std::swap(features[i], features[r]);
+ }
+
+ // Construct the feature subset from the first few elements. Also adjust the
+ // task's descriptions to match. We do this in two steps so that the
+ // descriptions are added via iterating over |feature_indices_|, so that the
+ // enumeration order is the same as when we adjust the feature values of
+ // incoming examples. In both cases, we iterate over |feature_indicies_|,
+ // which might (will) re-order them with respect to |features|.
+ for (int i = 0; i < *task_.feature_subset_size; i++)
+ feature_indices_.insert(features[i]);
+
+ std::vector<LearningTask::ValueDescription> adjusted_descriptions;
+ for (auto& iter : feature_indices_)
+ adjusted_descriptions.push_back(task_.feature_descriptions[iter]);
+
+ task_.feature_descriptions = adjusted_descriptions;
+
+ if (reporter_)
+ reporter_->SetFeatureSubset(feature_indices_);
+}
+
} // namespace learning
} // namespace media
diff --git a/chromium/media/learning/impl/learning_task_controller_impl.h b/chromium/media/learning/impl/learning_task_controller_impl.h
index 2866da17f81..eae031be307 100644
--- a/chromium/media/learning/impl/learning_task_controller_impl.h
+++ b/chromium/media/learning/impl/learning_task_controller_impl.h
@@ -6,6 +6,7 @@
#define MEDIA_LEARNING_IMPL_LEARNING_TASK_CONTROLLER_IMPL_H_
#include <memory>
+#include <set>
#include "base/callback.h"
#include "base/component_export.h"
@@ -20,6 +21,7 @@
namespace media {
namespace learning {
+class DistributionReporter;
class LearningTaskControllerImplTest;
// Controller for a single learning task. Takes training examples, and forwards
@@ -44,21 +46,27 @@ class COMPONENT_EXPORT(LEARNING_IMPL) LearningTaskControllerImpl
~LearningTaskControllerImpl() override;
// LearningTaskController
- void BeginObservation(ObservationId id,
+ void BeginObservation(base::UnguessableToken id,
const FeatureVector& features) override;
- void CompleteObservation(ObservationId id,
+ void CompleteObservation(base::UnguessableToken id,
const ObservationCompletion& completion) override;
- void CancelObservation(ObservationId id) override;
+ void CancelObservation(base::UnguessableToken id) override;
private:
// Add |example| to the training data, and process it.
void AddFinishedExample(LabelledExample example);
- // Called by |training_cb_| when the model is trained.
- void OnModelTrained(std::unique_ptr<Model> model);
+ // Called by |training_cb_| when the model is trained. |training_weight| and
+ // |training_size| are the training set's total weight and number of examples.
+ void OnModelTrained(double training_weight,
+ int training_size,
+ std::unique_ptr<Model> model);
void SetTrainerForTesting(std::unique_ptr<TrainingAlgorithm> trainer);
+ // Update |task_| to reflect a randomly chosen subset of features.
+ void DoFeatureSubsetSelection();
+
LearningTask task_;
// Current batch of examples.
@@ -74,6 +82,10 @@ class COMPONENT_EXPORT(LEARNING_IMPL) LearningTaskControllerImpl
// This helps us decide when to train a new model.
int num_untrained_examples_ = 0;
+ // Total weight and number of examples in the most recently trained model.
+ double last_training_weight_ = 0.;
+ size_t last_training_size_ = 0u;
+
// Training algorithm that we'll use.
std::unique_ptr<TrainingAlgorithm> trainer_;
@@ -83,6 +95,13 @@ class COMPONENT_EXPORT(LEARNING_IMPL) LearningTaskControllerImpl
// Helper that we use to handle deferred examples.
std::unique_ptr<LearningTaskControllerHelper> helper_;
+ // If the task specifies feature importance measurement, then this is the
+ // randomly chosen subset of features.
+ std::set<int> feature_indices_;
+
+ // Number of features that we expect in each observation.
+ size_t expected_feature_count_;
+
friend class LearningTaskControllerImplTest;
};
diff --git a/chromium/media/learning/impl/learning_task_controller_impl_unittest.cc b/chromium/media/learning/impl/learning_task_controller_impl_unittest.cc
index 7433919a62c..e1620c0858d 100644
--- a/chromium/media/learning/impl/learning_task_controller_impl_unittest.cc
+++ b/chromium/media/learning/impl/learning_task_controller_impl_unittest.cc
@@ -22,11 +22,18 @@ class LearningTaskControllerImplTest : public testing::Test {
FakeDistributionReporter(const LearningTask& task)
: DistributionReporter(task) {}
+ // protected => public
+ const base::Optional<std::set<int>>& feature_indices() const {
+ return DistributionReporter::feature_indices();
+ }
+
protected:
- void OnPrediction(TargetDistribution observed,
- TargetDistribution predicted) override {
+ void OnPrediction(const PredictionInfo& info,
+ TargetHistogram predicted) override {
num_reported_++;
- if (observed == predicted)
+ TargetHistogram dist;
+ dist += info.observed;
+ if (dist == predicted)
num_correct_++;
}
@@ -41,9 +48,9 @@ class LearningTaskControllerImplTest : public testing::Test {
FakeModel(TargetValue target) : target_(target) {}
// Model
- TargetDistribution PredictDistribution(
+ TargetHistogram PredictDistribution(
const FeatureVector& features) override {
- TargetDistribution dist;
+ TargetHistogram dist;
dist += target_;
return dist;
}
@@ -64,14 +71,18 @@ class LearningTaskControllerImplTest : public testing::Test {
void Train(const LearningTask& task,
const TrainingData& training_data,
TrainedModelCB model_cb) override {
+ task_ = task;
(*num_models_)++;
training_data_ = training_data;
std::move(model_cb).Run(std::make_unique<FakeModel>(target_value_));
}
+ const LearningTask& task() const { return task_; }
+
const TrainingData& training_data() const { return training_data_; }
private:
+ LearningTask task_;
int* num_models_ = nullptr;
TargetValue target_value_;
@@ -113,7 +124,7 @@ class LearningTaskControllerImplTest : public testing::Test {
}
void AddExample(const LabelledExample& example) {
- LearningTaskController::ObservationId id = 1;
+ base::UnguessableToken id = base::UnguessableToken::Create();
controller_->BeginObservation(id, example.features);
controller_->CompleteObservation(
id, ObservationCompletion(example.target_value, example.weight));
@@ -123,7 +134,6 @@ class LearningTaskControllerImplTest : public testing::Test {
// Number of models that we trained.
int num_models_ = 0;
- FakeModel* last_model_ = nullptr;
// Two distinct targets.
const TargetValue predicted_target_;
@@ -182,6 +192,7 @@ TEST_F(LearningTaskControllerImplTest, AddingExamplesTrainsModelAndReports) {
TEST_F(LearningTaskControllerImplTest, FeatureProviderIsUsed) {
// If a FeatureProvider factory is provided, make sure that it's used to
// adjust new examples.
+ task_.feature_descriptions.push_back({"AddedByFeatureProvider"});
SequenceBoundFeatureProvider feature_provider =
base::SequenceBound<FakeFeatureProvider>(
base::SequencedTaskRunnerHandle::Get());
@@ -195,5 +206,50 @@ TEST_F(LearningTaskControllerImplTest, FeatureProviderIsUsed) {
EXPECT_EQ(trainer_raw_->training_data()[0].weight, example.weight);
}
+TEST_F(LearningTaskControllerImplTest, FeatureSubsetsWork) {
+ const char* feature_names[] = {
+ "feature0", "feature1", "feature2", "feature3", "feature4", "feature5",
+ "feature6", "feature7", "feature8", "feature9", "feature10", "feature11",
+ };
+ const int num_features = sizeof(feature_names) / sizeof(feature_names[0]);
+ for (int i = 0; i < num_features; i++)
+ task_.feature_descriptions.push_back({feature_names[i]});
+ const size_t subset_size = 4;
+ task_.feature_subset_size = subset_size;
+ CreateController();
+
+ // Verify that the reporter is given a subset of the features.
+ auto subset = *reporter_raw_->feature_indices();
+ EXPECT_EQ(subset.size(), subset_size);
+
+ // Train a model. Each feature will have a unique value.
+ LabelledExample example;
+ for (int i = 0; i < num_features; i++)
+ example.features.push_back(FeatureValue(i));
+ AddExample(example);
+
+ // Verify that all feature names in |subset| are present in the task.
+ FeatureVector expected_features;
+ expected_features.resize(subset_size);
+ EXPECT_EQ(trainer_raw_->task().feature_descriptions.size(), subset_size);
+ for (auto& iter : subset) {
+ bool found = false;
+ for (size_t i = 0; i < subset_size; i++) {
+ if (trainer_raw_->task().feature_descriptions[i].name ==
+ feature_names[iter]) {
+ // Also build a vector with the features in the expected order.
+ expected_features[i] = example.features[iter];
+ found = true;
+ break;
+ }
+ }
+ EXPECT_TRUE(found);
+ }
+
+ // Verify that the training data has the adjusted features.
+ EXPECT_EQ(trainer_raw_->training_data().size(), 1u);
+ EXPECT_EQ(trainer_raw_->training_data()[0].features, expected_features);
+}
+
} // namespace learning
} // namespace media
diff --git a/chromium/media/learning/impl/lookup_table_trainer.cc b/chromium/media/learning/impl/lookup_table_trainer.cc
index 4c698c71022..57ebdbbc0f6 100644
--- a/chromium/media/learning/impl/lookup_table_trainer.cc
+++ b/chromium/media/learning/impl/lookup_table_trainer.cc
@@ -19,17 +19,16 @@ class LookupTable : public Model {
}
// Model
- TargetDistribution PredictDistribution(
- const FeatureVector& instance) override {
+ TargetHistogram PredictDistribution(const FeatureVector& instance) override {
auto iter = buckets_.find(instance);
if (iter == buckets_.end())
- return TargetDistribution();
+ return TargetHistogram();
return iter->second;
}
private:
- std::map<FeatureVector, TargetDistribution> buckets_;
+ std::map<FeatureVector, TargetHistogram> buckets_;
};
LookupTableTrainer::LookupTableTrainer() = default;
diff --git a/chromium/media/learning/impl/lookup_table_trainer_unittest.cc b/chromium/media/learning/impl/lookup_table_trainer_unittest.cc
index 323d69d471e..47618746617 100644
--- a/chromium/media/learning/impl/lookup_table_trainer_unittest.cc
+++ b/chromium/media/learning/impl/lookup_table_trainer_unittest.cc
@@ -37,7 +37,7 @@ TEST_F(LookupTableTrainerTest, EmptyTrainingDataWorks) {
TrainingData empty;
std::unique_ptr<Model> model = Train(task_, empty);
EXPECT_NE(model.get(), nullptr);
- EXPECT_EQ(model->PredictDistribution(FeatureVector()), TargetDistribution());
+ EXPECT_EQ(model->PredictDistribution(FeatureVector()), TargetHistogram());
}
TEST_F(LookupTableTrainerTest, UniformTrainingDataWorks) {
@@ -51,8 +51,7 @@ TEST_F(LookupTableTrainerTest, UniformTrainingDataWorks) {
// The tree should produce a distribution for one value (our target), which
// has |n_examples| counts.
- TargetDistribution distribution =
- model->PredictDistribution(example.features);
+ TargetHistogram distribution = model->PredictDistribution(example.features);
EXPECT_EQ(distribution.size(), 1u);
EXPECT_EQ(distribution[example.target_value], n_examples);
}
@@ -66,8 +65,7 @@ TEST_F(LookupTableTrainerTest, SimpleSeparableTrainingData) {
std::unique_ptr<Model> model = Train(task_, training_data);
// Each value should have a distribution with one target value with one count.
- TargetDistribution distribution =
- model->PredictDistribution(example_1.features);
+ TargetHistogram distribution = model->PredictDistribution(example_1.features);
EXPECT_NE(model.get(), nullptr);
EXPECT_EQ(distribution.size(), 1u);
EXPECT_EQ(distribution[example_1.target_value], 1u);
@@ -104,8 +102,7 @@ TEST_F(LookupTableTrainerTest, ComplexSeparableTrainingData) {
// Each example should have a distribution that selects the right value.
for (const auto& example : training_data) {
- TargetDistribution distribution =
- model->PredictDistribution(example.features);
+ TargetHistogram distribution = model->PredictDistribution(example.features);
TargetValue singular_max;
EXPECT_TRUE(distribution.FindSingularMax(&singular_max));
EXPECT_EQ(singular_max, example.target_value);
@@ -122,8 +119,7 @@ TEST_F(LookupTableTrainerTest, UnseparableTrainingData) {
EXPECT_NE(model.get(), nullptr);
// Each value should have a distribution with two targets with one count each.
- TargetDistribution distribution =
- model->PredictDistribution(example_1.features);
+ TargetHistogram distribution = model->PredictDistribution(example_1.features);
EXPECT_EQ(distribution.size(), 2u);
EXPECT_EQ(distribution[example_1.target_value], 1u);
EXPECT_EQ(distribution[example_2.target_value], 1u);
@@ -143,7 +139,7 @@ TEST_F(LookupTableTrainerTest, UnknownFeatureValueHandling) {
training_data.push_back(example_2);
std::unique_ptr<Model> model = Train(task_, training_data);
- TargetDistribution distribution =
+ TargetHistogram distribution =
model->PredictDistribution(FeatureVector({FeatureValue(789)}));
// OOV data should return an empty distribution (nominal).
EXPECT_EQ(distribution.size(), 0u);
@@ -160,7 +156,7 @@ TEST_F(LookupTableTrainerTest, RegressionWithWeightedExamplesWorks) {
training_data.push_back(example_2);
std::unique_ptr<Model> model = Train(task_, training_data);
- TargetDistribution distribution =
+ TargetHistogram distribution =
model->PredictDistribution(FeatureVector({FeatureValue(123)}));
double avg = distribution.Average();
const double expected =
diff --git a/chromium/media/learning/impl/model.h b/chromium/media/learning/impl/model.h
index 0950b6c9713..42366875b65 100644
--- a/chromium/media/learning/impl/model.h
+++ b/chromium/media/learning/impl/model.h
@@ -8,7 +8,7 @@
#include "base/component_export.h"
#include "media/learning/common/labelled_example.h"
#include "media/learning/impl/model.h"
-#include "media/learning/impl/target_distribution.h"
+#include "media/learning/impl/target_histogram.h"
namespace media {
namespace learning {
@@ -19,11 +19,11 @@ namespace learning {
class COMPONENT_EXPORT(LEARNING_IMPL) Model {
public:
// Callback for asynchronous predictions.
- using PredictionCB = base::OnceCallback<void(TargetDistribution predicted)>;
+ using PredictionCB = base::OnceCallback<void(TargetHistogram predicted)>;
virtual ~Model() = default;
- virtual TargetDistribution PredictDistribution(
+ virtual TargetHistogram PredictDistribution(
const FeatureVector& instance) = 0;
// TODO(liberato): Consider adding an async prediction helper.
diff --git a/chromium/media/learning/impl/one_hot.cc b/chromium/media/learning/impl/one_hot.cc
index b8dab81e142..c3e82855eda 100644
--- a/chromium/media/learning/impl/one_hot.cc
+++ b/chromium/media/learning/impl/one_hot.cc
@@ -111,7 +111,7 @@ ConvertingModel::ConvertingModel(std::unique_ptr<OneHotConverter> converter,
: converter_(std::move(converter)), model_(std::move(model)) {}
ConvertingModel::~ConvertingModel() = default;
-TargetDistribution ConvertingModel::PredictDistribution(
+TargetHistogram ConvertingModel::PredictDistribution(
const FeatureVector& instance) {
FeatureVector converted_instance = converter_->Convert(instance);
return model_->PredictDistribution(converted_instance);
diff --git a/chromium/media/learning/impl/one_hot.h b/chromium/media/learning/impl/one_hot.h
index 0a3f479b721..b48b66ec924 100644
--- a/chromium/media/learning/impl/one_hot.h
+++ b/chromium/media/learning/impl/one_hot.h
@@ -67,8 +67,7 @@ class COMPONENT_EXPORT(LEARNING_IMPL) ConvertingModel : public Model {
~ConvertingModel() override;
// Model
- TargetDistribution PredictDistribution(
- const FeatureVector& instance) override;
+ TargetHistogram PredictDistribution(const FeatureVector& instance) override;
private:
std::unique_ptr<OneHotConverter> converter_;
diff --git a/chromium/media/learning/impl/random_tree_trainer.cc b/chromium/media/learning/impl/random_tree_trainer.cc
index 88e1b97dd19..c45b757716e 100644
--- a/chromium/media/learning/impl/random_tree_trainer.cc
+++ b/chromium/media/learning/impl/random_tree_trainer.cc
@@ -40,8 +40,7 @@ struct InteriorNode : public Model {
split_point_(split_point) {}
// Model
- TargetDistribution PredictDistribution(
- const FeatureVector& features) override {
+ TargetHistogram PredictDistribution(const FeatureVector& features) override {
// Figure out what feature value we should use for the split.
FeatureValue f;
switch (ordering_) {
@@ -59,16 +58,16 @@ struct InteriorNode : public Model {
// If we've never seen this feature value, then return nothing.
if (iter == children_.end())
- return TargetDistribution();
+ return TargetHistogram();
return iter->second->PredictDistribution(features);
}
- TargetDistribution PredictDistributionWithMissingValues(
+ TargetHistogram PredictDistributionWithMissingValues(
const FeatureVector& features) {
- TargetDistribution total;
+ TargetHistogram total;
for (auto& child_pair : children_) {
- TargetDistribution predicted =
+ TargetHistogram predicted =
child_pair.second->PredictDistribution(features);
// TODO(liberato): Normalize? Weight?
total += predicted;
@@ -102,19 +101,26 @@ struct LeafNode : public Model {
for (size_t idx : training_idx)
distribution_ += training_data[idx];
- // Note that we don't treat numeric targets any differently. We want to
- // weight the leaf by the number of examples, so replacing it with an
- // average would just introduce rounding errors. One might as well take the
- // average of the final distribution.
+ // Each leaf gets one vote.
+ // See https://en.wikipedia.org/wiki/Bootstrap_aggregating . TL;DR: the
+ // individual trees should average (regression) or vote (classification).
+ //
+ // TODO(liberato): It's unclear that a leaf should get to vote with an
+ // entire distribution; we might want to take the max for kUnordered here.
+ // If so, then we might also want to Average() for kNumeric targets, though
+ // in that case, the results would be the same anyway. That's not, of
+ // course, guaranteed for all methods of converting |distribution_| into a
+ // numeric prediction. In general, we should provide a single estimate.
+ distribution_.Normalize();
}
// TreeNode
- TargetDistribution PredictDistribution(const FeatureVector&) override {
+ TargetHistogram PredictDistribution(const FeatureVector&) override {
return distribution_;
}
private:
- TargetDistribution distribution_;
+ TargetHistogram distribution_;
};
RandomTreeTrainer::RandomTreeTrainer(RandomNumberGenerator* rng)
@@ -297,7 +303,8 @@ RandomTreeTrainer::Split RandomTreeTrainer::ConstructSplit(
// Find the split's feature values and construct the training set for each.
// I think we want to iterate on the underlying vector, and look up the int in
- // the training data directly.
+ // the training data directly. |total_weight| will hold the total weight of
+ // all examples that come into this node.
double total_weight = 0.;
for (size_t idx : training_idx) {
const LabelledExample& example = training_data[idx];
@@ -324,7 +331,7 @@ RandomTreeTrainer::Split RandomTreeTrainer::ConstructSplit(
Split::BranchInfo& branch_info = iter->second;
branch_info.training_idx.push_back(idx);
- branch_info.target_distribution += example;
+ branch_info.target_histogram += example;
}
// Figure out how good / bad this split is.
@@ -340,18 +347,22 @@ RandomTreeTrainer::Split RandomTreeTrainer::ConstructSplit(
return split;
}
-void RandomTreeTrainer::ComputeSplitScore_Nominal(Split* split,
- double total_weight) {
+void RandomTreeTrainer::ComputeSplitScore_Nominal(
+ Split* split,
+ double total_incoming_weight) {
// Compute the nats given that we're at this node.
split->nats_remaining = 0;
for (auto& info_iter : split->branch_infos) {
Split::BranchInfo& branch_info = info_iter.second;
- const double total_counts = branch_info.target_distribution.total_counts();
+ // |weight_along_branch| is the total weight of examples that would follow
+ // this branch in the tree.
+ const double weight_along_branch =
+ branch_info.target_histogram.total_counts();
// |p_branch| is the probability of following this branch.
- const double p_branch = total_counts / total_weight;
- for (auto& iter : branch_info.target_distribution) {
- double p = iter.second / total_counts;
+ const double p_branch = weight_along_branch / total_incoming_weight;
+ for (auto& iter : branch_info.target_histogram) {
+ double p = iter.second / total_incoming_weight;
// p*log(p) is the expected nats if the answer is |iter|. We multiply
// that by the probability of being in this bucket at all.
split->nats_remaining -= (p * log(p)) * p_branch;
@@ -359,25 +370,29 @@ void RandomTreeTrainer::ComputeSplitScore_Nominal(Split* split,
}
}
-void RandomTreeTrainer::ComputeSplitScore_Numeric(Split* split,
- double total_weight) {
+void RandomTreeTrainer::ComputeSplitScore_Numeric(
+ Split* split,
+ double total_incoming_weight) {
// Compute the nats given that we're at this node.
split->nats_remaining = 0;
for (auto& info_iter : split->branch_infos) {
Split::BranchInfo& branch_info = info_iter.second;
- const double total_counts = branch_info.target_distribution.total_counts();
+ // |weight_along_branch| is the total weight of examples that would follow
+ // this branch in the tree.
+ const double weight_along_branch =
+ branch_info.target_histogram.total_counts();
// |p_branch| is the probability of following this branch.
- const double p_branch = total_counts / total_weight;
+ const double p_branch = weight_along_branch / total_incoming_weight;
// Compute the average at this node. Note that we have no idea if the leaf
// node would actually use an average, but really it should match. It would
- // be really nice if we could compute the value (or TargetDistribution) as
+ // be really nice if we could compute the value (or TargetHistogram) as
// part of computing the split, and have somebody just hand that target
// distribution to the leaf if it ends up as one.
- double average = branch_info.target_distribution.Average();
+ double average = branch_info.target_histogram.Average();
- for (auto& iter : branch_info.target_distribution) {
+ for (auto& iter : branch_info.target_histogram) {
// Compute the squared error for all |iter.second| counts that each have a
// value of |iter.first|, when this leaf approximates them as |average|.
double sq_err = (iter.first.value() - average) *
diff --git a/chromium/media/learning/impl/random_tree_trainer.h b/chromium/media/learning/impl/random_tree_trainer.h
index 3383e55fa11..816469832ea 100644
--- a/chromium/media/learning/impl/random_tree_trainer.h
+++ b/chromium/media/learning/impl/random_tree_trainer.h
@@ -20,7 +20,7 @@
namespace media {
namespace learning {
-// Trains RandomTree decision tree classifier (doesn't handle regression).
+// Trains RandomTree decision tree classifier / regressor.
//
// Decision trees, including RandomTree, classify instances as follows. Each
// non-leaf node is marked with a feature number |i|. The value of the |i|-th
@@ -71,9 +71,9 @@ namespace learning {
// See https://en.wikipedia.org/wiki/Random_forest for information. Note that
// this is just a single tree, not the whole forest.
//
-// TODO(liberato): Right now, it not-so-randomly selects from the entire set.
-// TODO(liberato): consider PRF or other simplified approximations.
-// TODO(liberato): separate Model and TrainingAlgorithm. This is the latter.
+// Note that this variant chooses split points randomly, as described by the
+// ExtraTrees algorithm. This is slightly different than RandomForest, which
+// chooses split points to improve the split's score.
class COMPONENT_EXPORT(LEARNING_IMPL) RandomTreeTrainer
: public TrainingAlgorithm,
public HasRandomNumberGenerator {
@@ -135,7 +135,7 @@ class COMPONENT_EXPORT(LEARNING_IMPL) RandomTreeTrainer
// branch of the split.
// This is a flat_map since we're likely to have a very small (e.g.,
// "true / "false") number of targets.
- TargetDistribution target_distribution;
+ TargetHistogram target_histogram;
};
// [feature value at this split] = info about which examples take this
@@ -158,12 +158,13 @@ class COMPONENT_EXPORT(LEARNING_IMPL) RandomTreeTrainer
const std::vector<size_t>& training_idx,
int index);
- // Fill in |nats_remaining| for |split| for a nominal target. |total_weight|
- // is the total weight of all instances coming into this split.
- void ComputeSplitScore_Nominal(Split* split, double total_weight);
+ // Fill in |nats_remaining| for |split| for a nominal target.
+ // |total_incoming_weight| is the total weight of all instances coming into
+ // the node that we're splitting.
+ void ComputeSplitScore_Nominal(Split* split, double total_incoming_weight);
// Fill in |nats_remaining| for |split| for a numeric target.
- void ComputeSplitScore_Numeric(Split* split, double total_weight);
+ void ComputeSplitScore_Numeric(Split* split, double total_incoming_weight);
// Compute the split point for |training_data| for a nominal feature.
FeatureValue FindSplitPoint_Nominal(size_t index,
diff --git a/chromium/media/learning/impl/random_tree_trainer_unittest.cc b/chromium/media/learning/impl/random_tree_trainer_unittest.cc
index 6af9e94b49c..f9face03115 100644
--- a/chromium/media/learning/impl/random_tree_trainer_unittest.cc
+++ b/chromium/media/learning/impl/random_tree_trainer_unittest.cc
@@ -55,7 +55,7 @@ TEST_P(RandomTreeTest, EmptyTrainingDataWorks) {
TrainingData empty;
std::unique_ptr<Model> model = Train(task_, empty);
EXPECT_NE(model.get(), nullptr);
- EXPECT_EQ(model->PredictDistribution(FeatureVector()), TargetDistribution());
+ EXPECT_EQ(model->PredictDistribution(FeatureVector()), TargetHistogram());
}
TEST_P(RandomTreeTest, UniformTrainingDataWorks) {
@@ -69,11 +69,10 @@ TEST_P(RandomTreeTest, UniformTrainingDataWorks) {
std::unique_ptr<Model> model = Train(task_, training_data);
// The tree should produce a distribution for one value (our target), which
- // has |n_examples| counts.
- TargetDistribution distribution =
- model->PredictDistribution(example.features);
+ // has one count.
+ TargetHistogram distribution = model->PredictDistribution(example.features);
EXPECT_EQ(distribution.size(), 1u);
- EXPECT_EQ(distribution[example.target_value], n_examples);
+ EXPECT_EQ(distribution[example.target_value], 1.0);
}
TEST_P(RandomTreeTest, SimpleSeparableTrainingData) {
@@ -86,8 +85,7 @@ TEST_P(RandomTreeTest, SimpleSeparableTrainingData) {
std::unique_ptr<Model> model = Train(task_, training_data);
// Each value should have a distribution with one target value with one count.
- TargetDistribution distribution =
- model->PredictDistribution(example_1.features);
+ TargetHistogram distribution = model->PredictDistribution(example_1.features);
EXPECT_NE(model.get(), nullptr);
EXPECT_EQ(distribution.size(), 1u);
EXPECT_EQ(distribution[example_1.target_value], 1u);
@@ -129,8 +127,7 @@ TEST_P(RandomTreeTest, ComplexSeparableTrainingData) {
// Each example should have a distribution that selects the right value.
for (const LabelledExample& example : training_data) {
- TargetDistribution distribution =
- model->PredictDistribution(example.features);
+ TargetHistogram distribution = model->PredictDistribution(example.features);
TargetValue singular_max;
EXPECT_TRUE(distribution.FindSingularMax(&singular_max));
EXPECT_EQ(singular_max, example.target_value);
@@ -147,17 +144,16 @@ TEST_P(RandomTreeTest, UnseparableTrainingData) {
std::unique_ptr<Model> model = Train(task_, training_data);
EXPECT_NE(model.get(), nullptr);
- // Each value should have a distribution with two targets with one count each.
- TargetDistribution distribution =
- model->PredictDistribution(example_1.features);
+ // Each value should have a distribution with two targets with equal counts.
+ TargetHistogram distribution = model->PredictDistribution(example_1.features);
EXPECT_EQ(distribution.size(), 2u);
- EXPECT_EQ(distribution[example_1.target_value], 1u);
- EXPECT_EQ(distribution[example_2.target_value], 1u);
+ EXPECT_EQ(distribution[example_1.target_value], 0.5);
+ EXPECT_EQ(distribution[example_2.target_value], 0.5);
distribution = model->PredictDistribution(example_2.features);
EXPECT_EQ(distribution.size(), 2u);
- EXPECT_EQ(distribution[example_1.target_value], 1u);
- EXPECT_EQ(distribution[example_2.target_value], 1u);
+ EXPECT_EQ(distribution[example_1.target_value], 0.5);
+ EXPECT_EQ(distribution[example_2.target_value], 0.5);
}
TEST_P(RandomTreeTest, UnknownFeatureValueHandling) {
@@ -202,7 +198,7 @@ TEST_P(RandomTreeTest, NumericFeaturesSplitMultipleTimes) {
std::unique_ptr<Model> model = Train(task_, training_data);
for (size_t i = 0; i < 4; i++) {
// Get a prediction for the |i|-th feature value.
- TargetDistribution distribution = model->PredictDistribution(
+ TargetHistogram distribution = model->PredictDistribution(
FeatureVector({FeatureValue(i * feature_mult)}));
// The distribution should have one count that should be correct. If
// the feature isn't split four times, then some feature value will have too
diff --git a/chromium/media/learning/impl/target_distribution.h b/chromium/media/learning/impl/target_distribution.h
deleted file mode 100644
index 3f4a9cd7274..00000000000
--- a/chromium/media/learning/impl/target_distribution.h
+++ /dev/null
@@ -1,94 +0,0 @@
-// Copyright 2018 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#ifndef MEDIA_LEARNING_IMPL_TARGET_DISTRIBUTION_H_
-#define MEDIA_LEARNING_IMPL_TARGET_DISTRIBUTION_H_
-
-#include <ostream>
-#include <string>
-
-#include "base/component_export.h"
-#include "base/containers/flat_map.h"
-#include "base/macros.h"
-#include "media/learning/common/labelled_example.h"
-#include "media/learning/common/value.h"
-
-namespace media {
-namespace learning {
-
-// TargetDistribution of target values.
-class COMPONENT_EXPORT(LEARNING_IMPL) TargetDistribution {
- private:
- // We use a flat_map since this will often have only one or two TargetValues,
- // such as "true" or "false".
- using DistributionMap = base::flat_map<TargetValue, size_t>;
-
- public:
- TargetDistribution();
- TargetDistribution(const TargetDistribution& rhs);
- TargetDistribution(TargetDistribution&& rhs);
- ~TargetDistribution();
-
- TargetDistribution& operator=(const TargetDistribution& rhs);
- TargetDistribution& operator=(TargetDistribution&& rhs);
-
- bool operator==(const TargetDistribution& rhs) const;
-
- // Add |rhs| to our counts.
- TargetDistribution& operator+=(const TargetDistribution& rhs);
-
- // Increment |rhs| by one.
- TargetDistribution& operator+=(const TargetValue& rhs);
-
- // Increment the distribution by |example|'s target value and weight.
- TargetDistribution& operator+=(const LabelledExample& example);
-
- // Return the number of counts for |value|.
- size_t operator[](const TargetValue& value) const;
- size_t& operator[](const TargetValue& value);
-
- // Return the total counts in the map.
- size_t total_counts() const {
- size_t total = 0.;
- for (auto& entry : counts_)
- total += entry.second;
- return total;
- }
-
- DistributionMap::const_iterator begin() const { return counts_.begin(); }
-
- DistributionMap::const_iterator end() const { return counts_.end(); }
-
- // Return the number of buckets in the distribution.
- // TODO(liberato): Do we want this?
- size_t size() const { return counts_.size(); }
-
- // Find the singular value with the highest counts, and copy it into
- // |value_out| and (optionally) |counts_out|. Returns true if there is a
- // singular maximum, else returns false with the out params undefined.
- bool FindSingularMax(TargetValue* value_out,
- size_t* counts_out = nullptr) const;
-
- // Return the average value of the entries in this distribution. Of course,
- // this only makes sense if the TargetValues can be interpreted as numeric.
- double Average() const;
-
- std::string ToString() const;
-
- private:
- const DistributionMap& counts() const { return counts_; }
-
- // [value] == counts
- DistributionMap counts_;
-
- // Allow copy and assign.
-};
-
-COMPONENT_EXPORT(LEARNING_IMPL)
-std::ostream& operator<<(std::ostream& out, const TargetDistribution& dist);
-
-} // namespace learning
-} // namespace media
-
-#endif // MEDIA_LEARNING_IMPL_TARGET_DISTRIBUTION_H_
diff --git a/chromium/media/learning/impl/target_distribution_unittest.cc b/chromium/media/learning/impl/target_distribution_unittest.cc
deleted file mode 100644
index 1c7564aa5e3..00000000000
--- a/chromium/media/learning/impl/target_distribution_unittest.cc
+++ /dev/null
@@ -1,164 +0,0 @@
-// Copyright 2018 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#include "media/learning/impl/target_distribution.h"
-
-#include "testing/gtest/include/gtest/gtest.h"
-
-namespace media {
-namespace learning {
-
-class TargetDistributionTest : public testing::Test {
- public:
- TargetDistributionTest() : value_1(123), value_2(456), value_3(789) {}
-
- TargetDistribution distribution_;
-
- TargetValue value_1;
- const size_t counts_1 = 100;
-
- TargetValue value_2;
- const size_t counts_2 = 10;
-
- TargetValue value_3;
-};
-
-TEST_F(TargetDistributionTest, EmptyTargetDistributionHasZeroCounts) {
- EXPECT_EQ(distribution_.total_counts(), 0u);
-}
-
-TEST_F(TargetDistributionTest, AddingCountsWorks) {
- distribution_[value_1] = counts_1;
- EXPECT_EQ(distribution_.total_counts(), counts_1);
- EXPECT_EQ(distribution_[value_1], counts_1);
- distribution_[value_1] += counts_1;
- EXPECT_EQ(distribution_.total_counts(), counts_1 * 2u);
- EXPECT_EQ(distribution_[value_1], counts_1 * 2u);
-}
-
-TEST_F(TargetDistributionTest, MultipleValuesAreSeparate) {
- distribution_[value_1] = counts_1;
- distribution_[value_2] = counts_2;
- EXPECT_EQ(distribution_.total_counts(), counts_1 + counts_2);
- EXPECT_EQ(distribution_[value_1], counts_1);
- EXPECT_EQ(distribution_[value_2], counts_2);
-}
-
-TEST_F(TargetDistributionTest, AddingTargetValues) {
- distribution_ += value_1;
- EXPECT_EQ(distribution_.total_counts(), 1u);
- EXPECT_EQ(distribution_[value_1], 1u);
- EXPECT_EQ(distribution_[value_2], 0u);
-
- distribution_ += value_1;
- EXPECT_EQ(distribution_.total_counts(), 2u);
- EXPECT_EQ(distribution_[value_1], 2u);
- EXPECT_EQ(distribution_[value_2], 0u);
-
- distribution_ += value_2;
- EXPECT_EQ(distribution_.total_counts(), 3u);
- EXPECT_EQ(distribution_[value_1], 2u);
- EXPECT_EQ(distribution_[value_2], 1u);
-}
-
-TEST_F(TargetDistributionTest, AddingTargetDistributions) {
- distribution_[value_1] = counts_1;
-
- TargetDistribution rhs;
- rhs[value_2] = counts_2;
-
- distribution_ += rhs;
-
- EXPECT_EQ(distribution_.total_counts(), counts_1 + counts_2);
- EXPECT_EQ(distribution_[value_1], counts_1);
- EXPECT_EQ(distribution_[value_2], counts_2);
-}
-
-TEST_F(TargetDistributionTest, FindSingularMaxFindsTheSingularMax) {
- distribution_[value_1] = counts_1;
- distribution_[value_2] = counts_2;
- ASSERT_TRUE(counts_1 > counts_2);
-
- TargetValue max_value(0);
- size_t max_counts = 0;
- EXPECT_TRUE(distribution_.FindSingularMax(&max_value, &max_counts));
- EXPECT_EQ(max_value, value_1);
- EXPECT_EQ(max_counts, counts_1);
-}
-
-TEST_F(TargetDistributionTest,
- FindSingularMaxFindsTheSingularMaxAlternateOrder) {
- // Switch the order, to handle sorting in different directions.
- distribution_[value_1] = counts_2;
- distribution_[value_2] = counts_1;
- ASSERT_TRUE(counts_1 > counts_2);
-
- TargetValue max_value(0);
- size_t max_counts = 0;
- EXPECT_TRUE(distribution_.FindSingularMax(&max_value, &max_counts));
- EXPECT_EQ(max_value, value_2);
- EXPECT_EQ(max_counts, counts_1);
-}
-
-TEST_F(TargetDistributionTest, FindSingularMaxReturnsFalsForNonSingularMax) {
- distribution_[value_1] = counts_1;
- distribution_[value_2] = counts_1;
-
- TargetValue max_value(0);
- size_t max_counts = 0;
- EXPECT_FALSE(distribution_.FindSingularMax(&max_value, &max_counts));
-}
-
-TEST_F(TargetDistributionTest, FindSingularMaxIgnoresNonSingularNonMax) {
- distribution_[value_1] = counts_1;
- // |value_2| and |value_3| are tied, but not the max.
- distribution_[value_2] = counts_2;
- distribution_[value_3] = counts_2;
- ASSERT_TRUE(counts_1 > counts_2);
-
- TargetValue max_value(0);
- size_t max_counts = 0;
- EXPECT_TRUE(distribution_.FindSingularMax(&max_value, &max_counts));
- EXPECT_EQ(max_value, value_1);
- EXPECT_EQ(max_counts, counts_1);
-}
-
-TEST_F(TargetDistributionTest, FindSingularMaxDoesntRequireCounts) {
- distribution_[value_1] = counts_1;
-
- TargetValue max_value(0);
- EXPECT_TRUE(distribution_.FindSingularMax(&max_value));
- EXPECT_EQ(max_value, value_1);
-}
-
-TEST_F(TargetDistributionTest, EqualDistributionsCompareAsEqual) {
- distribution_[value_1] = counts_1;
- TargetDistribution distribution_2;
- distribution_2[value_1] = counts_1;
-
- EXPECT_TRUE(distribution_ == distribution_2);
-}
-
-TEST_F(TargetDistributionTest, UnequalDistributionsCompareAsNotEqual) {
- distribution_[value_1] = counts_1;
- TargetDistribution distribution_2;
- distribution_2[value_2] = counts_2;
-
- EXPECT_FALSE(distribution_ == distribution_2);
-}
-
-TEST_F(TargetDistributionTest, WeightedLabelledExamplesCountCorrectly) {
- LabelledExample example = {{}, value_1};
- example.weight = counts_1;
- distribution_ += example;
-
- TargetDistribution distribution_2;
- for (size_t i = 0; i < counts_1; i++)
- distribution_2 += value_1;
-
- EXPECT_EQ(distribution_, distribution_2);
-}
-
-} // namespace learning
-} // namespace media
diff --git a/chromium/media/learning/impl/target_distribution.cc b/chromium/media/learning/impl/target_histogram.cc
index 2fe271733bc..ad1a1f2112a 100644
--- a/chromium/media/learning/impl/target_distribution.cc
+++ b/chromium/media/learning/impl/target_histogram.cc
@@ -2,51 +2,48 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
-#include "media/learning/impl/target_distribution.h"
+#include "media/learning/impl/target_histogram.h"
#include <sstream>
namespace media {
namespace learning {
-TargetDistribution::TargetDistribution() = default;
+TargetHistogram::TargetHistogram() = default;
-TargetDistribution::TargetDistribution(const TargetDistribution& rhs) = default;
+TargetHistogram::TargetHistogram(const TargetHistogram& rhs) = default;
-TargetDistribution::TargetDistribution(TargetDistribution&& rhs) = default;
+TargetHistogram::TargetHistogram(TargetHistogram&& rhs) = default;
-TargetDistribution::~TargetDistribution() = default;
+TargetHistogram::~TargetHistogram() = default;
-TargetDistribution& TargetDistribution::operator=(
- const TargetDistribution& rhs) = default;
-
-TargetDistribution& TargetDistribution::operator=(TargetDistribution&& rhs) =
+TargetHistogram& TargetHistogram::operator=(const TargetHistogram& rhs) =
default;
-bool TargetDistribution::operator==(const TargetDistribution& rhs) const {
+TargetHistogram& TargetHistogram::operator=(TargetHistogram&& rhs) = default;
+
+bool TargetHistogram::operator==(const TargetHistogram& rhs) const {
return rhs.total_counts() == total_counts() && rhs.counts_ == counts_;
}
-TargetDistribution& TargetDistribution::operator+=(
- const TargetDistribution& rhs) {
+TargetHistogram& TargetHistogram::operator+=(const TargetHistogram& rhs) {
for (auto& rhs_pair : rhs.counts())
counts_[rhs_pair.first] += rhs_pair.second;
return *this;
}
-TargetDistribution& TargetDistribution::operator+=(const TargetValue& rhs) {
+TargetHistogram& TargetHistogram::operator+=(const TargetValue& rhs) {
counts_[rhs]++;
return *this;
}
-TargetDistribution& TargetDistribution::operator+=(
- const LabelledExample& example) {
+TargetHistogram& TargetHistogram::operator+=(const LabelledExample& example) {
counts_[example.target_value] += example.weight;
return *this;
}
-size_t TargetDistribution::operator[](const TargetValue& value) const {
+double TargetHistogram::operator[](const TargetValue& value) const {
auto iter = counts_.find(value);
if (iter == counts_.end())
return 0;
@@ -54,16 +51,16 @@ size_t TargetDistribution::operator[](const TargetValue& value) const {
return iter->second;
}
-size_t& TargetDistribution::operator[](const TargetValue& value) {
+double& TargetHistogram::operator[](const TargetValue& value) {
return counts_[value];
}
-bool TargetDistribution::FindSingularMax(TargetValue* value_out,
- size_t* counts_out) const {
+bool TargetHistogram::FindSingularMax(TargetValue* value_out,
+ double* counts_out) const {
if (!counts_.size())
return false;
- size_t unused_counts;
+ double unused_counts;
if (!counts_out)
counts_out = &unused_counts;
@@ -85,9 +82,9 @@ bool TargetDistribution::FindSingularMax(TargetValue* value_out,
return singular_max;
}
-double TargetDistribution::Average() const {
+double TargetHistogram::Average() const {
double total_value = 0.;
- size_t total_counts = 0;
+ double total_counts = 0;
for (auto& iter : counts_) {
total_value += iter.first.value() * iter.second;
total_counts += iter.second;
@@ -99,7 +96,13 @@ double TargetDistribution::Average() const {
return total_value / total_counts;
}
-std::string TargetDistribution::ToString() const {
+void TargetHistogram::Normalize() {
+ double total = total_counts();
+ for (auto& iter : counts_)
+ iter.second /= total;
+}
+
+std::string TargetHistogram::ToString() const {
std::ostringstream ss;
ss << "[";
for (auto& entry : counts_)
@@ -110,7 +113,7 @@ std::string TargetDistribution::ToString() const {
}
std::ostream& operator<<(std::ostream& out,
- const media::learning::TargetDistribution& dist) {
+ const media::learning::TargetHistogram& dist) {
return out << dist.ToString();
}
diff --git a/chromium/media/learning/impl/target_histogram.h b/chromium/media/learning/impl/target_histogram.h
new file mode 100644
index 00000000000..cb8de2b625e
--- /dev/null
+++ b/chromium/media/learning/impl/target_histogram.h
@@ -0,0 +1,98 @@
+// Copyright 2018 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef MEDIA_LEARNING_IMPL_TARGET_HISTOGRAM_H_
+#define MEDIA_LEARNING_IMPL_TARGET_HISTOGRAM_H_
+
+#include <ostream>
+#include <string>
+
+#include "base/component_export.h"
+#include "base/containers/flat_map.h"
+#include "base/macros.h"
+#include "media/learning/common/labelled_example.h"
+#include "media/learning/common/value.h"
+
+namespace media {
+namespace learning {
+
+// Histogram of target values that allows fractional counts.
+class COMPONENT_EXPORT(LEARNING_IMPL) TargetHistogram {
+ private:
+ // We use a flat_map since this will often have only one or two TargetValues,
+ // such as "true" or "false".
+ using CountMap = base::flat_map<TargetValue, double>;
+
+ public:
+ TargetHistogram();
+ TargetHistogram(const TargetHistogram& rhs);
+ TargetHistogram(TargetHistogram&& rhs);
+ ~TargetHistogram();
+
+ TargetHistogram& operator=(const TargetHistogram& rhs);
+ TargetHistogram& operator=(TargetHistogram&& rhs);
+
+ bool operator==(const TargetHistogram& rhs) const;
+
+ // Add |rhs| to our counts.
+ TargetHistogram& operator+=(const TargetHistogram& rhs);
+
+ // Increment |rhs| by one.
+ TargetHistogram& operator+=(const TargetValue& rhs);
+
+ // Increment the histogram by |example|'s target value and weight.
+ TargetHistogram& operator+=(const LabelledExample& example);
+
+ // Return the number of counts for |value|.
+ double operator[](const TargetValue& value) const;
+ double& operator[](const TargetValue& value);
+
+ // Return the total counts in the map.
+ double total_counts() const {
+ double total = 0.;
+ for (auto& entry : counts_)
+ total += entry.second;
+ return total;
+ }
+
+ CountMap::const_iterator begin() const { return counts_.begin(); }
+
+ CountMap::const_iterator end() const { return counts_.end(); }
+
+ // Return the number of buckets in the histogram.
+ // TODO(liberato): Do we want this?
+ size_t size() const { return counts_.size(); }
+
+ // Find the singular value with the highest counts, and copy it into
+ // |value_out| and (optionally) |counts_out|. Returns true if there is a
+ // singular maximum, else returns false with the out params undefined.
+ bool FindSingularMax(TargetValue* value_out,
+ double* counts_out = nullptr) const;
+
+ // Return the average value of the entries in this histogram. Of course,
+ // this only makes sense if the TargetValues can be interpreted as numeric.
+ double Average() const;
+
+ // Normalize the histogram so that it has one total count, unless it's
+ // empty. It will continue to have zero in that case.
+ void Normalize();
+
+ std::string ToString() const;
+
+ private:
+ const CountMap& counts() const { return counts_; }
+
+ // [value] == counts
+ CountMap counts_;
+
+ // Allow copy and assign.
+};
+
+COMPONENT_EXPORT(LEARNING_IMPL)
+std::ostream& operator<<(std::ostream& out, const TargetHistogram& dist);
+
+} // namespace learning
+} // namespace media
+
+#endif // MEDIA_LEARNING_IMPL_TARGET_HISTOGRAM_H_
diff --git a/chromium/media/learning/impl/target_histogram_unittest.cc b/chromium/media/learning/impl/target_histogram_unittest.cc
new file mode 100644
index 00000000000..5ba36e0ed78
--- /dev/null
+++ b/chromium/media/learning/impl/target_histogram_unittest.cc
@@ -0,0 +1,179 @@
+// Copyright 2018 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "media/learning/impl/target_histogram.h"
+
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace media {
+namespace learning {
+
+class TargetHistogramTest : public testing::Test {
+ public:
+ TargetHistogramTest() : value_1(123), value_2(456), value_3(789) {}
+
+ TargetHistogram histogram_;
+
+ TargetValue value_1;
+ const size_t counts_1 = 100;
+
+ TargetValue value_2;
+ const size_t counts_2 = 10;
+
+ TargetValue value_3;
+};
+
+TEST_F(TargetHistogramTest, EmptyTargetHistogramHasZeroCounts) {
+ EXPECT_EQ(histogram_.total_counts(), 0u);
+}
+
+TEST_F(TargetHistogramTest, AddingCountsWorks) {
+ histogram_[value_1] = counts_1;
+ EXPECT_EQ(histogram_.total_counts(), counts_1);
+ EXPECT_EQ(histogram_[value_1], counts_1);
+ histogram_[value_1] += counts_1;
+ EXPECT_EQ(histogram_.total_counts(), counts_1 * 2u);
+ EXPECT_EQ(histogram_[value_1], counts_1 * 2u);
+}
+
+TEST_F(TargetHistogramTest, MultipleValuesAreSeparate) {
+ histogram_[value_1] = counts_1;
+ histogram_[value_2] = counts_2;
+ EXPECT_EQ(histogram_.total_counts(), counts_1 + counts_2);
+ EXPECT_EQ(histogram_[value_1], counts_1);
+ EXPECT_EQ(histogram_[value_2], counts_2);
+}
+
+TEST_F(TargetHistogramTest, AddingTargetValues) {
+ histogram_ += value_1;
+ EXPECT_EQ(histogram_.total_counts(), 1u);
+ EXPECT_EQ(histogram_[value_1], 1u);
+ EXPECT_EQ(histogram_[value_2], 0u);
+
+ histogram_ += value_1;
+ EXPECT_EQ(histogram_.total_counts(), 2u);
+ EXPECT_EQ(histogram_[value_1], 2u);
+ EXPECT_EQ(histogram_[value_2], 0u);
+
+ histogram_ += value_2;
+ EXPECT_EQ(histogram_.total_counts(), 3u);
+ EXPECT_EQ(histogram_[value_1], 2u);
+ EXPECT_EQ(histogram_[value_2], 1u);
+}
+
+TEST_F(TargetHistogramTest, AddingTargetHistograms) {
+ histogram_[value_1] = counts_1;
+
+ TargetHistogram rhs;
+ rhs[value_2] = counts_2;
+
+ histogram_ += rhs;
+
+ EXPECT_EQ(histogram_.total_counts(), counts_1 + counts_2);
+ EXPECT_EQ(histogram_[value_1], counts_1);
+ EXPECT_EQ(histogram_[value_2], counts_2);
+}
+
+TEST_F(TargetHistogramTest, FindSingularMaxFindsTheSingularMax) {
+ histogram_[value_1] = counts_1;
+ histogram_[value_2] = counts_2;
+ ASSERT_TRUE(counts_1 > counts_2);
+
+ TargetValue max_value(0);
+ double max_counts = 0;
+ EXPECT_TRUE(histogram_.FindSingularMax(&max_value, &max_counts));
+ EXPECT_EQ(max_value, value_1);
+ EXPECT_EQ(max_counts, counts_1);
+}
+
+TEST_F(TargetHistogramTest, FindSingularMaxFindsTheSingularMaxAlternateOrder) {
+ // Switch the order, to handle sorting in different directions.
+ histogram_[value_1] = counts_2;
+ histogram_[value_2] = counts_1;
+ ASSERT_TRUE(counts_1 > counts_2);
+
+ TargetValue max_value(0);
+ double max_counts = 0;
+ EXPECT_TRUE(histogram_.FindSingularMax(&max_value, &max_counts));
+ EXPECT_EQ(max_value, value_2);
+ EXPECT_EQ(max_counts, counts_1);
+}
+
+TEST_F(TargetHistogramTest, FindSingularMaxReturnsFalsForNonSingularMax) {
+ histogram_[value_1] = counts_1;
+ histogram_[value_2] = counts_1;
+
+ TargetValue max_value(0);
+ double max_counts = 0;
+ EXPECT_FALSE(histogram_.FindSingularMax(&max_value, &max_counts));
+}
+
+TEST_F(TargetHistogramTest, FindSingularMaxIgnoresNonSingularNonMax) {
+ histogram_[value_1] = counts_1;
+ // |value_2| and |value_3| are tied, but not the max.
+ histogram_[value_2] = counts_2;
+ histogram_[value_3] = counts_2;
+ ASSERT_TRUE(counts_1 > counts_2);
+
+ TargetValue max_value(0);
+ double max_counts = 0;
+ EXPECT_TRUE(histogram_.FindSingularMax(&max_value, &max_counts));
+ EXPECT_EQ(max_value, value_1);
+ EXPECT_EQ(max_counts, counts_1);
+}
+
+TEST_F(TargetHistogramTest, FindSingularMaxDoesntRequireCounts) {
+ histogram_[value_1] = counts_1;
+
+ TargetValue max_value(0);
+ EXPECT_TRUE(histogram_.FindSingularMax(&max_value));
+ EXPECT_EQ(max_value, value_1);
+}
+
+TEST_F(TargetHistogramTest, EqualDistributionsCompareAsEqual) {
+ histogram_[value_1] = counts_1;
+ TargetHistogram histogram_2;
+ histogram_2[value_1] = counts_1;
+
+ EXPECT_TRUE(histogram_ == histogram_2);
+}
+
+TEST_F(TargetHistogramTest, UnequalDistributionsCompareAsNotEqual) {
+ histogram_[value_1] = counts_1;
+ TargetHistogram histogram_2;
+ histogram_2[value_2] = counts_2;
+
+ EXPECT_FALSE(histogram_ == histogram_2);
+}
+
+TEST_F(TargetHistogramTest, WeightedLabelledExamplesCountCorrectly) {
+ LabelledExample example = {{}, value_1};
+ example.weight = counts_1;
+ histogram_ += example;
+
+ TargetHistogram histogram_2;
+ for (size_t i = 0; i < counts_1; i++)
+ histogram_2 += value_1;
+
+ EXPECT_EQ(histogram_, histogram_2);
+}
+
+TEST_F(TargetHistogramTest, Normalize) {
+ histogram_[value_1] = counts_1;
+ histogram_[value_2] = counts_2;
+ histogram_.Normalize();
+ EXPECT_EQ(histogram_[value_1],
+ counts_1 / static_cast<double>(counts_1 + counts_2));
+ EXPECT_EQ(histogram_[value_2],
+ counts_2 / static_cast<double>(counts_1 + counts_2));
+}
+
+TEST_F(TargetHistogramTest, NormalizeEmptyDistribution) {
+ // Normalizing an empty distribution should result in an empty distribution.
+ histogram_.Normalize();
+ EXPECT_EQ(histogram_.total_counts(), 0);
+}
+
+} // namespace learning
+} // namespace media
diff --git a/chromium/media/learning/impl/voting_ensemble.cc b/chromium/media/learning/impl/voting_ensemble.cc
index 667e739975d..ef07b1cb731 100644
--- a/chromium/media/learning/impl/voting_ensemble.cc
+++ b/chromium/media/learning/impl/voting_ensemble.cc
@@ -12,9 +12,9 @@ VotingEnsemble::VotingEnsemble(std::vector<std::unique_ptr<Model>> models)
VotingEnsemble::~VotingEnsemble() = default;
-TargetDistribution VotingEnsemble::PredictDistribution(
+TargetHistogram VotingEnsemble::PredictDistribution(
const FeatureVector& instance) {
- TargetDistribution distribution;
+ TargetHistogram distribution;
for (auto iter = models_.begin(); iter != models_.end(); iter++)
distribution += (*iter)->PredictDistribution(instance);
diff --git a/chromium/media/learning/impl/voting_ensemble.h b/chromium/media/learning/impl/voting_ensemble.h
index 2b4bf11a3fa..7e0cba7b59a 100644
--- a/chromium/media/learning/impl/voting_ensemble.h
+++ b/chromium/media/learning/impl/voting_ensemble.h
@@ -23,8 +23,7 @@ class COMPONENT_EXPORT(LEARNING_IMPL) VotingEnsemble : public Model {
~VotingEnsemble() override;
// Model
- TargetDistribution PredictDistribution(
- const FeatureVector& instance) override;
+ TargetHistogram PredictDistribution(const FeatureVector& instance) override;
private:
std::vector<std::unique_ptr<Model>> models_;
diff --git a/chromium/media/learning/mojo/BUILD.gn b/chromium/media/learning/mojo/BUILD.gn
index bfe45d11359..d8c92955fdc 100644
--- a/chromium/media/learning/mojo/BUILD.gn
+++ b/chromium/media/learning/mojo/BUILD.gn
@@ -8,8 +8,8 @@ import("//testing/test.gni")
component("impl") {
output_name = "media_learning_mojo_impl"
sources = [
- "mojo_learning_session_impl.cc",
- "mojo_learning_session_impl.h",
+ "mojo_learning_task_controller_service.cc",
+ "mojo_learning_task_controller_service.h",
]
defines = [ "IS_MEDIA_LEARNING_MOJO_IMPL" ]
@@ -35,7 +35,7 @@ source_set("unit_tests") {
testonly = true
sources = [
- "mojo_learning_session_impl_unittest.cc",
+ "mojo_learning_task_controller_service_unittest.cc",
]
deps = [
diff --git a/chromium/media/learning/mojo/mojo_learning_session_impl.cc b/chromium/media/learning/mojo/mojo_learning_session_impl.cc
deleted file mode 100644
index 3cfb9789ff6..00000000000
--- a/chromium/media/learning/mojo/mojo_learning_session_impl.cc
+++ /dev/null
@@ -1,31 +0,0 @@
-// Copyright 2018 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#include "media/learning/mojo/mojo_learning_session_impl.h"
-
-#include "media/learning/common/learning_session.h"
-
-namespace media {
-namespace learning {
-
-MojoLearningSessionImpl::MojoLearningSessionImpl(
- std::unique_ptr<::media::learning::LearningSession> impl)
- : impl_(std::move(impl)) {}
-
-MojoLearningSessionImpl::~MojoLearningSessionImpl() = default;
-
-void MojoLearningSessionImpl::Bind(mojom::LearningSessionRequest request) {
- self_bindings_.AddBinding(this, std::move(request));
-}
-
-void MojoLearningSessionImpl::AddExample(mojom::LearningTaskType task_type,
- const LabelledExample& example) {
- // TODO(liberato): Convert |task_type| into a task name.
- std::string task_name("no_task");
-
- impl_->AddExample(task_name, example);
-}
-
-} // namespace learning
-} // namespace media
diff --git a/chromium/media/learning/mojo/mojo_learning_session_impl.h b/chromium/media/learning/mojo/mojo_learning_session_impl.h
deleted file mode 100644
index 83d0eb72018..00000000000
--- a/chromium/media/learning/mojo/mojo_learning_session_impl.h
+++ /dev/null
@@ -1,53 +0,0 @@
-// Copyright 2018 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#ifndef MEDIA_LEARNING_MOJO_MOJO_LEARNING_SESSION_IMPL_H_
-#define MEDIA_LEARNING_MOJO_MOJO_LEARNING_SESSION_IMPL_H_
-
-#include <memory>
-
-#include "base/component_export.h"
-#include "base/macros.h"
-#include "media/learning/mojo/public/mojom/learning_session.mojom.h"
-#include "mojo/public/cpp/bindings/binding_set.h"
-
-namespace media {
-namespace learning {
-
-class LearningSession;
-class MojoLearningSessionImplTest;
-
-// Mojo service that talks to a local LearningSession.
-class COMPONENT_EXPORT(MEDIA_LEARNING_MOJO) MojoLearningSessionImpl
- : public mojom::LearningSession {
- public:
- ~MojoLearningSessionImpl() override;
-
- // Bind |request| to this instance.
- void Bind(mojom::LearningSessionRequest request);
-
- // mojom::LearningSession
- void AddExample(mojom::LearningTaskType task_type,
- const LabelledExample& example) override;
-
- protected:
- explicit MojoLearningSessionImpl(
- std::unique_ptr<::media::learning::LearningSession> impl);
-
- // Underlying session to which we proxy calls.
- std::unique_ptr<::media::learning::LearningSession> impl_;
-
- // We own our own bindings. If we're ever not a singleton, then this should
- // move to our owner.
- mojo::BindingSet<mojom::LearningSession> self_bindings_;
-
- friend class MojoLearningSessionImplTest;
-
- DISALLOW_COPY_AND_ASSIGN(MojoLearningSessionImpl);
-};
-
-} // namespace learning
-} // namespace media
-
-#endif // MEDIA_LEARNING_MOJO_MOJO_LEARNING_SESSION_IMPL_H_
diff --git a/chromium/media/learning/mojo/mojo_learning_session_impl_unittest.cc b/chromium/media/learning/mojo/mojo_learning_session_impl_unittest.cc
deleted file mode 100644
index 5399d20ca16..00000000000
--- a/chromium/media/learning/mojo/mojo_learning_session_impl_unittest.cc
+++ /dev/null
@@ -1,63 +0,0 @@
-// Copyright 2018 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#include <memory>
-
-#include "base/macros.h"
-#include "base/memory/ptr_util.h"
-#include "media/learning/common/learning_session.h"
-#include "media/learning/mojo/mojo_learning_session_impl.h"
-#include "media/learning/mojo/public/mojom/learning_types.mojom.h"
-#include "testing/gtest/include/gtest/gtest.h"
-
-namespace media {
-namespace learning {
-
-class MojoLearningSessionImplTest : public ::testing::Test {
- public:
- class FakeLearningSession : public ::media::learning::LearningSession {
- public:
- void AddExample(const std::string& task_name,
- const LabelledExample& example) override {
- most_recent_task_name_ = task_name;
- most_recent_example_ = example;
- }
-
- std::string most_recent_task_name_;
- LabelledExample most_recent_example_;
- };
-
- public:
- MojoLearningSessionImplTest() = default;
- ~MojoLearningSessionImplTest() override = default;
-
- void SetUp() override {
- // Create a mojo learner that forwards to a fake learner.
- std::unique_ptr<FakeLearningSession> provider =
- std::make_unique<FakeLearningSession>();
- fake_learning_session_ = provider.get();
- learning_session_impl_ = base::WrapUnique<MojoLearningSessionImpl>(
- new MojoLearningSessionImpl(std::move(provider)));
- }
-
- FakeLearningSession* fake_learning_session_ = nullptr;
-
- const mojom::LearningTaskType task_type_ =
- mojom::LearningTaskType::kPlaceHolderTask;
-
- // The learner provider under test.
- std::unique_ptr<MojoLearningSessionImpl> learning_session_impl_;
-};
-
-TEST_F(MojoLearningSessionImplTest, FeaturesAndTargetValueAreCopied) {
- mojom::LabelledExamplePtr example_ptr = mojom::LabelledExample::New();
- const LabelledExample example = {{Value(123), Value(456), Value(890)},
- TargetValue(1234)};
-
- learning_session_impl_->AddExample(task_type_, example);
- EXPECT_EQ(example, fake_learning_session_->most_recent_example_);
-}
-
-} // namespace learning
-} // namespace media
diff --git a/chromium/media/learning/mojo/mojo_learning_task_controller_service.cc b/chromium/media/learning/mojo/mojo_learning_task_controller_service.cc
new file mode 100644
index 00000000000..acd3f319447
--- /dev/null
+++ b/chromium/media/learning/mojo/mojo_learning_task_controller_service.cc
@@ -0,0 +1,65 @@
+// Copyright 2018 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "media/learning/mojo/mojo_learning_task_controller_service.h"
+
+#include <utility>
+
+#include "media/learning/common/learning_task_controller.h"
+
+namespace media {
+namespace learning {
+
+// Somewhat arbitrary upper limit on the number of in-flight observations that
+// we'll allow a client to have.
+static const size_t kMaxInFlightObservations = 16;
+
+MojoLearningTaskControllerService::MojoLearningTaskControllerService(
+ const LearningTask& task,
+ std::unique_ptr<::media::learning::LearningTaskController> impl)
+ : task_(task), impl_(std::move(impl)) {}
+
+MojoLearningTaskControllerService::~MojoLearningTaskControllerService() =
+ default;
+
+void MojoLearningTaskControllerService::BeginObservation(
+ const base::UnguessableToken& id,
+ const FeatureVector& features) {
+ // Drop the observation if it doesn't match the feature description size.
+ if (features.size() != task_.feature_descriptions.size())
+ return;
+
+ // Don't allow the client to send too many in-flight observations.
+ if (in_flight_observations_.size() >= kMaxInFlightObservations)
+ return;
+ in_flight_observations_.insert(id);
+
+ // Since we own |impl_|, we don't need to keep track of in-flight
+ // observations. We'll release |impl_| on destruction, which cancels them.
+ impl_->BeginObservation(id, features);
+}
+
+void MojoLearningTaskControllerService::CompleteObservation(
+ const base::UnguessableToken& id,
+ const ObservationCompletion& completion) {
+ auto iter = in_flight_observations_.find(id);
+ if (iter == in_flight_observations_.end())
+ return;
+ in_flight_observations_.erase(iter);
+
+ impl_->CompleteObservation(id, completion);
+}
+
+void MojoLearningTaskControllerService::CancelObservation(
+ const base::UnguessableToken& id) {
+ auto iter = in_flight_observations_.find(id);
+ if (iter == in_flight_observations_.end())
+ return;
+ in_flight_observations_.erase(iter);
+
+ impl_->CancelObservation(id);
+}
+
+} // namespace learning
+} // namespace media
diff --git a/chromium/media/learning/mojo/mojo_learning_task_controller_service.h b/chromium/media/learning/mojo/mojo_learning_task_controller_service.h
new file mode 100644
index 00000000000..cdf8721adca
--- /dev/null
+++ b/chromium/media/learning/mojo/mojo_learning_task_controller_service.h
@@ -0,0 +1,51 @@
+// Copyright 2018 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef MEDIA_LEARNING_MOJO_MOJO_LEARNING_TASK_CONTROLLER_SERVICE_H_
+#define MEDIA_LEARNING_MOJO_MOJO_LEARNING_TASK_CONTROLLER_SERVICE_H_
+
+#include <memory>
+#include <set>
+
+#include "base/component_export.h"
+#include "base/macros.h"
+#include "media/learning/mojo/public/mojom/learning_task_controller.mojom.h"
+
+namespace media {
+namespace learning {
+
+class LearningTaskController;
+
+// Mojo service that talks to a local LearningTaskController.
+class COMPONENT_EXPORT(MEDIA_LEARNING_MOJO) MojoLearningTaskControllerService
+ : public mojom::LearningTaskController {
+ public:
+ // |impl| is the underlying controller that we'll send requests to.
+ explicit MojoLearningTaskControllerService(
+ const LearningTask& task,
+ std::unique_ptr<::media::learning::LearningTaskController> impl);
+ ~MojoLearningTaskControllerService() override;
+
+ // mojom::LearningTaskController
+ void BeginObservation(const base::UnguessableToken& id,
+ const FeatureVector& features) override;
+ void CompleteObservation(const base::UnguessableToken& id,
+ const ObservationCompletion& completion) override;
+ void CancelObservation(const base::UnguessableToken& id) override;
+
+ protected:
+ const LearningTask task_;
+
+ // Underlying controller to which we proxy calls.
+ std::unique_ptr<::media::learning::LearningTaskController> impl_;
+
+ std::set<base::UnguessableToken> in_flight_observations_;
+
+ DISALLOW_COPY_AND_ASSIGN(MojoLearningTaskControllerService);
+};
+
+} // namespace learning
+} // namespace media
+
+#endif // MEDIA_LEARNING_MOJO_MOJO_LEARNING_TASK_CONTROLLER_SERVICE_H_
diff --git a/chromium/media/learning/mojo/mojo_learning_task_controller_service_unittest.cc b/chromium/media/learning/mojo/mojo_learning_task_controller_service_unittest.cc
new file mode 100644
index 00000000000..99e8a391269
--- /dev/null
+++ b/chromium/media/learning/mojo/mojo_learning_task_controller_service_unittest.cc
@@ -0,0 +1,143 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include <memory>
+#include <utility>
+
+#include "base/bind.h"
+#include "base/macros.h"
+#include "base/memory/ptr_util.h"
+#include "base/test/scoped_task_environment.h"
+#include "base/threading/thread.h"
+#include "media/learning/mojo/mojo_learning_task_controller_service.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace media {
+namespace learning {
+
+class MojoLearningTaskControllerServiceTest : public ::testing::Test {
+ public:
+ class FakeLearningTaskController : public LearningTaskController {
+ public:
+ void BeginObservation(base::UnguessableToken id,
+ const FeatureVector& features) override {
+ begin_args_.id_ = id;
+ begin_args_.features_ = features;
+ }
+
+ void CompleteObservation(base::UnguessableToken id,
+ const ObservationCompletion& completion) override {
+ complete_args_.id_ = id;
+ complete_args_.completion_ = completion;
+ }
+
+ void CancelObservation(base::UnguessableToken id) override {
+ cancel_args_.id_ = id;
+ }
+
+ struct {
+ base::UnguessableToken id_;
+ FeatureVector features_;
+ } begin_args_;
+
+ struct {
+ base::UnguessableToken id_;
+ ObservationCompletion completion_;
+ } complete_args_;
+
+ struct {
+ base::UnguessableToken id_;
+ } cancel_args_;
+ };
+
+ public:
+ MojoLearningTaskControllerServiceTest() = default;
+ ~MojoLearningTaskControllerServiceTest() override = default;
+
+ void SetUp() override {
+ std::unique_ptr<FakeLearningTaskController> controller =
+ std::make_unique<FakeLearningTaskController>();
+ controller_raw_ = controller.get();
+
+ // Add two features.
+ task_.feature_descriptions.push_back({});
+ task_.feature_descriptions.push_back({});
+
+ // Tell |learning_controller_| to forward to the fake learner impl.
+ service_ = std::make_unique<MojoLearningTaskControllerService>(
+ task_, std::move(controller));
+ }
+
+ LearningTask task_;
+
+ // Mojo stuff.
+ base::test::ScopedTaskEnvironment scoped_task_environment_;
+
+ FakeLearningTaskController* controller_raw_ = nullptr;
+
+ // The learner under test.
+ std::unique_ptr<MojoLearningTaskControllerService> service_;
+};
+
+TEST_F(MojoLearningTaskControllerServiceTest, BeginComplete) {
+ base::UnguessableToken id = base::UnguessableToken::Create();
+ FeatureVector features = {FeatureValue(123), FeatureValue(456)};
+ service_->BeginObservation(id, features);
+ EXPECT_EQ(id, controller_raw_->begin_args_.id_);
+ EXPECT_EQ(features, controller_raw_->begin_args_.features_);
+
+ ObservationCompletion completion(TargetValue(1234));
+ service_->CompleteObservation(id, completion);
+
+ EXPECT_EQ(id, controller_raw_->complete_args_.id_);
+ EXPECT_EQ(completion.target_value,
+ controller_raw_->complete_args_.completion_.target_value);
+}
+
+TEST_F(MojoLearningTaskControllerServiceTest, BeginCancel) {
+ base::UnguessableToken id = base::UnguessableToken::Create();
+ FeatureVector features = {FeatureValue(123), FeatureValue(456)};
+ service_->BeginObservation(id, features);
+ EXPECT_EQ(id, controller_raw_->begin_args_.id_);
+ EXPECT_EQ(features, controller_raw_->begin_args_.features_);
+
+ service_->CancelObservation(id);
+
+ EXPECT_EQ(id, controller_raw_->cancel_args_.id_);
+}
+
+TEST_F(MojoLearningTaskControllerServiceTest, TooFewFeaturesIsIgnored) {
+ // A FeatureVector with too few elements should be ignored.
+ base::UnguessableToken id = base::UnguessableToken::Create();
+ FeatureVector short_features = {FeatureValue(123)};
+ service_->BeginObservation(id, short_features);
+ EXPECT_NE(id, controller_raw_->begin_args_.id_);
+ EXPECT_EQ(controller_raw_->begin_args_.features_.size(), 0u);
+}
+
+TEST_F(MojoLearningTaskControllerServiceTest, TooManyFeaturesIsIgnored) {
+ // A FeatureVector with too many elements should be ignored.
+ base::UnguessableToken id = base::UnguessableToken::Create();
+ FeatureVector long_features = {FeatureValue(123), FeatureValue(456),
+ FeatureValue(789)};
+ service_->BeginObservation(id, long_features);
+ EXPECT_NE(id, controller_raw_->begin_args_.id_);
+ EXPECT_EQ(controller_raw_->begin_args_.features_.size(), 0u);
+}
+
+TEST_F(MojoLearningTaskControllerServiceTest, CompleteWithoutBeginFails) {
+ base::UnguessableToken id = base::UnguessableToken::Create();
+ ObservationCompletion completion(TargetValue(1234));
+ service_->CompleteObservation(id, completion);
+ EXPECT_NE(id, controller_raw_->complete_args_.id_);
+}
+
+TEST_F(MojoLearningTaskControllerServiceTest, CancelWithoutBeginFails) {
+ base::UnguessableToken id = base::UnguessableToken::Create();
+ service_->CancelObservation(id);
+ EXPECT_NE(id, controller_raw_->cancel_args_.id_);
+}
+
+} // namespace learning
+} // namespace media
diff --git a/chromium/media/learning/mojo/public/cpp/BUILD.gn b/chromium/media/learning/mojo/public/cpp/BUILD.gn
index 2157282389e..ce6b135639f 100644
--- a/chromium/media/learning/mojo/public/cpp/BUILD.gn
+++ b/chromium/media/learning/mojo/public/cpp/BUILD.gn
@@ -9,8 +9,8 @@ source_set("cpp") {
]
sources = [
- "mojo_learning_session.cc",
- "mojo_learning_session.h",
+ "mojo_learning_task_controller.cc",
+ "mojo_learning_task_controller.h",
]
defines = [ "IS_MEDIA_LEARNING_MOJO_IMPL" ]
@@ -26,7 +26,7 @@ source_set("unit_tests") {
testonly = true
sources = [
- "mojo_learning_session_unittest.cc",
+ "mojo_learning_task_controller_unittest.cc",
]
deps = [
diff --git a/chromium/media/learning/mojo/public/cpp/learning_mojom_traits.cc b/chromium/media/learning/mojo/public/cpp/learning_mojom_traits.cc
index aa308d9de9f..12841f9df32 100644
--- a/chromium/media/learning/mojo/public/cpp/learning_mojom_traits.cc
+++ b/chromium/media/learning/mojo/public/cpp/learning_mojom_traits.cc
@@ -39,4 +39,14 @@ bool StructTraits<media::learning::mojom::TargetValueDataView,
return true;
}
+// static
+bool StructTraits<media::learning::mojom::ObservationCompletionDataView,
+ media::learning::ObservationCompletion>::
+ Read(media::learning::mojom::ObservationCompletionDataView data,
+ media::learning::ObservationCompletion* out_observation_completion) {
+ if (!data.ReadTargetValue(&out_observation_completion->target_value))
+ return false;
+ out_observation_completion->weight = data.weight();
+ return true;
+}
} // namespace mojo
diff --git a/chromium/media/learning/mojo/public/cpp/learning_mojom_traits.h b/chromium/media/learning/mojo/public/cpp/learning_mojom_traits.h
index 932a5cb7d4a..52f8ed5a86c 100644
--- a/chromium/media/learning/mojo/public/cpp/learning_mojom_traits.h
+++ b/chromium/media/learning/mojo/public/cpp/learning_mojom_traits.h
@@ -7,6 +7,7 @@
#include <vector>
+#include "media/learning/common/learning_task_controller.h"
#include "media/learning/common/value.h"
#include "media/learning/mojo/public/mojom/learning_types.mojom.h"
#include "mojo/public/cpp/bindings/struct_traits.h"
@@ -52,6 +53,23 @@ class StructTraits<media::learning::mojom::TargetValueDataView,
media::learning::TargetValue* out_target_value);
};
+template <>
+class StructTraits<media::learning::mojom::ObservationCompletionDataView,
+ media::learning::ObservationCompletion> {
+ public:
+ static media::learning::TargetValue target_value(
+ const media::learning::ObservationCompletion& e) {
+ return e.target_value;
+ }
+ static media::learning::WeightType weight(
+ const media::learning::ObservationCompletion& e) {
+ return e.weight;
+ }
+ static bool Read(
+ media::learning::mojom::ObservationCompletionDataView data,
+ media::learning::ObservationCompletion* out_observation_completion);
+};
+
} // namespace mojo
#endif // MEDIA_LEARNING_MOJO_PUBLIC_CPP_LEARNING_MOJOM_TRAITS_H_
diff --git a/chromium/media/learning/mojo/public/cpp/mojo_learning_session.cc b/chromium/media/learning/mojo/public/cpp/mojo_learning_session.cc
deleted file mode 100644
index c67f642842a..00000000000
--- a/chromium/media/learning/mojo/public/cpp/mojo_learning_session.cc
+++ /dev/null
@@ -1,24 +0,0 @@
-// Copyright 2018 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#include "media/learning/mojo/public/cpp/mojo_learning_session.h"
-
-#include "mojo/public/cpp/bindings/binding.h"
-
-namespace media {
-namespace learning {
-
-MojoLearningSession::MojoLearningSession(mojom::LearningSessionPtr session_ptr)
- : session_ptr_(std::move(session_ptr)) {}
-
-MojoLearningSession::~MojoLearningSession() = default;
-
-void MojoLearningSession::AddExample(const std::string& task_name,
- const LabelledExample& example) {
- // TODO(liberato): Convert from |task_name| to a task type.
- session_ptr_->AddExample(mojom::LearningTaskType::kPlaceHolderTask, example);
-}
-
-} // namespace learning
-} // namespace media
diff --git a/chromium/media/learning/mojo/public/cpp/mojo_learning_session.h b/chromium/media/learning/mojo/public/cpp/mojo_learning_session.h
deleted file mode 100644
index 0e8af2b6aca..00000000000
--- a/chromium/media/learning/mojo/public/cpp/mojo_learning_session.h
+++ /dev/null
@@ -1,36 +0,0 @@
-// Copyright 2018 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#ifndef MEDIA_LEARNING_MOJO_PUBLIC_CPP_MOJO_LEARNING_SESSION_H_
-#define MEDIA_LEARNING_MOJO_PUBLIC_CPP_MOJO_LEARNING_SESSION_H_
-
-#include "base/component_export.h"
-#include "base/macros.h"
-#include "media/learning/common/learning_session.h"
-#include "media/learning/mojo/public/mojom/learning_session.mojom.h"
-
-namespace media {
-namespace learning {
-
-// LearningSession implementation to forward to a remote impl via mojo.
-class COMPONENT_EXPORT(MEDIA_LEARNING_MOJO) MojoLearningSession
- : public LearningSession {
- public:
- explicit MojoLearningSession(mojom::LearningSessionPtr session_ptr);
- ~MojoLearningSession() override;
-
- // LearningSession
- void AddExample(const std::string& task_name,
- const LabelledExample& example) override;
-
- private:
- mojom::LearningSessionPtr session_ptr_;
-
- DISALLOW_COPY_AND_ASSIGN(MojoLearningSession);
-};
-
-} // namespace learning
-} // namespace media
-
-#endif // MEDIA_LEARNING_MOJO_PUBLIC_CPP_MOJO_LEARNING_SESSION_H_
diff --git a/chromium/media/learning/mojo/public/cpp/mojo_learning_session_unittest.cc b/chromium/media/learning/mojo/public/cpp/mojo_learning_session_unittest.cc
deleted file mode 100644
index 37cecb4db21..00000000000
--- a/chromium/media/learning/mojo/public/cpp/mojo_learning_session_unittest.cc
+++ /dev/null
@@ -1,68 +0,0 @@
-// Copyright 2018 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#include <memory>
-
-#include "base/bind.h"
-#include "base/macros.h"
-#include "base/memory/ptr_util.h"
-#include "base/test/scoped_task_environment.h"
-#include "base/threading/thread.h"
-#include "media/learning/mojo/public/cpp/mojo_learning_session.h"
-#include "mojo/public/cpp/bindings/binding.h"
-#include "testing/gtest/include/gtest/gtest.h"
-
-namespace media {
-namespace learning {
-
-class MojoLearningSessionTest : public ::testing::Test {
- public:
- // Impl of a mojom::LearningSession that remembers call arguments.
- class FakeMojoLearningSessionImpl : public mojom::LearningSession {
- public:
- void AddExample(mojom::LearningTaskType task_type,
- const LabelledExample& example) override {
- task_type_ = std::move(task_type);
- example_ = example;
- }
-
- mojom::LearningTaskType task_type_;
- LabelledExample example_;
- };
-
- public:
- MojoLearningSessionTest()
- : learning_session_binding_(&fake_learning_session_impl_) {}
- ~MojoLearningSessionTest() override = default;
-
- void SetUp() override {
- // Create a fake learner provider mojo impl.
- mojom::LearningSessionPtr learning_session_ptr;
- learning_session_binding_.Bind(mojo::MakeRequest(&learning_session_ptr));
-
- // Tell |learning_session_| to forward to the fake learner impl.
- learning_session_ =
- std::make_unique<MojoLearningSession>(std::move(learning_session_ptr));
- }
-
- // Mojo stuff.
- base::test::ScopedTaskEnvironment scoped_task_environment_;
-
- FakeMojoLearningSessionImpl fake_learning_session_impl_;
- mojo::Binding<mojom::LearningSession> learning_session_binding_;
-
- // The learner under test.
- std::unique_ptr<MojoLearningSession> learning_session_;
-};
-
-TEST_F(MojoLearningSessionTest, ExampleIsCopied) {
- LabelledExample example({FeatureValue(123), FeatureValue(456)},
- TargetValue(1234));
- learning_session_->AddExample("unused task id", example);
- learning_session_binding_.FlushForTesting();
- EXPECT_EQ(fake_learning_session_impl_.example_, example);
-}
-
-} // namespace learning
-} // namespace media
diff --git a/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller.cc b/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller.cc
new file mode 100644
index 00000000000..f4648b38bcb
--- /dev/null
+++ b/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller.cc
@@ -0,0 +1,39 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "media/learning/mojo/public/cpp/mojo_learning_task_controller.h"
+
+#include <utility>
+
+#include "mojo/public/cpp/bindings/binding.h"
+
+namespace media {
+namespace learning {
+
+MojoLearningTaskController::MojoLearningTaskController(
+ mojom::LearningTaskControllerPtr controller_ptr)
+ : controller_ptr_(std::move(controller_ptr)) {}
+
+MojoLearningTaskController::~MojoLearningTaskController() = default;
+
+void MojoLearningTaskController::BeginObservation(
+ base::UnguessableToken id,
+ const FeatureVector& features) {
+ // We don't need to keep track of in-flight observations, since the service
+ // side handles it for us.
+ controller_ptr_->BeginObservation(id, features);
+}
+
+void MojoLearningTaskController::CompleteObservation(
+ base::UnguessableToken id,
+ const ObservationCompletion& completion) {
+ controller_ptr_->CompleteObservation(id, completion);
+}
+
+void MojoLearningTaskController::CancelObservation(base::UnguessableToken id) {
+ controller_ptr_->CancelObservation(id);
+}
+
+} // namespace learning
+} // namespace media
diff --git a/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller.h b/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller.h
new file mode 100644
index 00000000000..893d1f0890e
--- /dev/null
+++ b/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller.h
@@ -0,0 +1,42 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef MEDIA_LEARNING_MOJO_PUBLIC_CPP_MOJO_LEARNING_TASK_CONTROLLER_H_
+#define MEDIA_LEARNING_MOJO_PUBLIC_CPP_MOJO_LEARNING_TASK_CONTROLLER_H_
+
+#include <utility>
+
+#include "base/component_export.h"
+#include "base/macros.h"
+#include "media/learning/common/learning_task_controller.h"
+#include "media/learning/mojo/public/mojom/learning_task_controller.mojom.h"
+
+namespace media {
+namespace learning {
+
+// LearningTaskController implementation to forward to a remote impl via mojo.
+class COMPONENT_EXPORT(MEDIA_LEARNING_MOJO) MojoLearningTaskController
+ : public LearningTaskController {
+ public:
+ explicit MojoLearningTaskController(
+ mojom::LearningTaskControllerPtr controller_ptr);
+ ~MojoLearningTaskController() override;
+
+ // LearningTaskController
+ void BeginObservation(base::UnguessableToken id,
+ const FeatureVector& features) override;
+ void CompleteObservation(base::UnguessableToken id,
+ const ObservationCompletion& completion) override;
+ void CancelObservation(base::UnguessableToken id) override;
+
+ private:
+ mojom::LearningTaskControllerPtr controller_ptr_;
+
+ DISALLOW_COPY_AND_ASSIGN(MojoLearningTaskController);
+};
+
+} // namespace learning
+} // namespace media
+
+#endif // MEDIA_LEARNING_MOJO_PUBLIC_CPP_MOJO_LEARNING_TASK_CONTROLLER_H_
diff --git a/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller_unittest.cc b/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller_unittest.cc
new file mode 100644
index 00000000000..546aae171b4
--- /dev/null
+++ b/chromium/media/learning/mojo/public/cpp/mojo_learning_task_controller_unittest.cc
@@ -0,0 +1,109 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include <memory>
+#include <utility>
+
+#include "base/bind.h"
+#include "base/macros.h"
+#include "base/memory/ptr_util.h"
+#include "base/test/scoped_task_environment.h"
+#include "base/threading/thread.h"
+#include "media/learning/mojo/public/cpp/mojo_learning_task_controller.h"
+#include "mojo/public/cpp/bindings/binding.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace media {
+namespace learning {
+
+class MojoLearningTaskControllerTest : public ::testing::Test {
+ public:
+ // Impl of a mojom::LearningTaskController that remembers call arguments.
+ class FakeMojoLearningTaskController : public mojom::LearningTaskController {
+ public:
+ void BeginObservation(const base::UnguessableToken& id,
+ const FeatureVector& features) override {
+ begin_args_.id_ = id;
+ begin_args_.features_ = features;
+ }
+
+ void CompleteObservation(const base::UnguessableToken& id,
+ const ObservationCompletion& completion) override {
+ complete_args_.id_ = id;
+ complete_args_.completion_ = completion;
+ }
+
+ void CancelObservation(const base::UnguessableToken& id) override {
+ cancel_args_.id_ = id;
+ }
+
+ struct {
+ base::UnguessableToken id_;
+ FeatureVector features_;
+ } begin_args_;
+
+ struct {
+ base::UnguessableToken id_;
+ ObservationCompletion completion_;
+ } complete_args_;
+
+ struct {
+ base::UnguessableToken id_;
+ } cancel_args_;
+ };
+
+ public:
+ MojoLearningTaskControllerTest()
+ : learning_controller_binding_(&fake_learning_controller_) {}
+ ~MojoLearningTaskControllerTest() override = default;
+
+ void SetUp() override {
+ // Create a fake learner provider mojo impl.
+ mojom::LearningTaskControllerPtr learning_controller_ptr;
+ learning_controller_binding_.Bind(
+ mojo::MakeRequest(&learning_controller_ptr));
+
+ // Tell |learning_controller_| to forward to the fake learner impl.
+ learning_controller_ = std::make_unique<MojoLearningTaskController>(
+ std::move(learning_controller_ptr));
+ }
+
+ // Mojo stuff.
+ base::test::ScopedTaskEnvironment scoped_task_environment_;
+
+ FakeMojoLearningTaskController fake_learning_controller_;
+ mojo::Binding<mojom::LearningTaskController> learning_controller_binding_;
+
+ // The learner under test.
+ std::unique_ptr<MojoLearningTaskController> learning_controller_;
+};
+
+TEST_F(MojoLearningTaskControllerTest, Begin) {
+ base::UnguessableToken id = base::UnguessableToken::Create();
+ FeatureVector features = {FeatureValue(123), FeatureValue(456)};
+ learning_controller_->BeginObservation(id, features);
+ scoped_task_environment_.RunUntilIdle();
+ EXPECT_EQ(id, fake_learning_controller_.begin_args_.id_);
+ EXPECT_EQ(features, fake_learning_controller_.begin_args_.features_);
+}
+
+TEST_F(MojoLearningTaskControllerTest, Complete) {
+ base::UnguessableToken id = base::UnguessableToken::Create();
+ ObservationCompletion completion(TargetValue(1234));
+ learning_controller_->CompleteObservation(id, completion);
+ scoped_task_environment_.RunUntilIdle();
+ EXPECT_EQ(id, fake_learning_controller_.complete_args_.id_);
+ EXPECT_EQ(completion.target_value,
+ fake_learning_controller_.complete_args_.completion_.target_value);
+}
+
+TEST_F(MojoLearningTaskControllerTest, Cancel) {
+ base::UnguessableToken id = base::UnguessableToken::Create();
+ learning_controller_->CancelObservation(id);
+ scoped_task_environment_.RunUntilIdle();
+ EXPECT_EQ(id, fake_learning_controller_.cancel_args_.id_);
+}
+
+} // namespace learning
+} // namespace media
diff --git a/chromium/media/learning/mojo/public/mojom/BUILD.gn b/chromium/media/learning/mojo/public/mojom/BUILD.gn
index e0292358725..3d0a1f565b1 100644
--- a/chromium/media/learning/mojo/public/mojom/BUILD.gn
+++ b/chromium/media/learning/mojo/public/mojom/BUILD.gn
@@ -7,7 +7,7 @@ import("//mojo/public/tools/bindings/mojom.gni")
mojom("mojom") {
sources = [
- "learning_session.mojom",
+ "learning_task_controller.mojom",
"learning_types.mojom",
]
diff --git a/chromium/media/learning/mojo/public/mojom/learning_session.mojom b/chromium/media/learning/mojo/public/mojom/learning_session.mojom
deleted file mode 100644
index f7a2b1d7b3f..00000000000
--- a/chromium/media/learning/mojo/public/mojom/learning_session.mojom
+++ /dev/null
@@ -1,19 +0,0 @@
-// Copyright 2018 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-module media.learning.mojom;
-
-import "media/learning/mojo/public/mojom/learning_types.mojom";
-
-// Learning tasks, to prevent sending the task name string in AcquireLearner.
-enum LearningTaskType {
- // There are no tasks yet.
- kPlaceHolderTask,
-};
-
-// media/learning/public/learning_session.h
-interface LearningSession {
- // Add |example| to |task_type|.
- AddExample(LearningTaskType task_type, LabelledExample example);
-};
diff --git a/chromium/media/learning/mojo/public/mojom/learning_task_controller.mojom b/chromium/media/learning/mojo/public/mojom/learning_task_controller.mojom
new file mode 100644
index 00000000000..8fb3135d438
--- /dev/null
+++ b/chromium/media/learning/mojo/public/mojom/learning_task_controller.mojom
@@ -0,0 +1,36 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+module media.learning.mojom;
+
+import "mojo/public/mojom/base/unguessable_token.mojom";
+import "media/learning/mojo/public/mojom/learning_types.mojom";
+
+// Client for a single learning task. Intended to be the primary API for client
+// code that generates FeatureVectors / requests predictions for a single task.
+// The API supports sending in an observed FeatureVector without a target value,
+// so that framework-provided features (FeatureProvider) can be snapshotted at
+// the right time. One doesn't generally want to wait until the TargetValue is
+// observed to do that.
+//
+// Typically, this interface will allow non-browser processes to communicate
+// with the learning framework in the browser.
+interface LearningTaskController {
+ // Start a new observation. Call this at the time one would try to predict
+ // the TargetValue. This lets the framework snapshot any framework-provided
+ // feature values at prediction time. Later, if you want to turn these
+ // features into an example for training a model, then call
+ // CompleteObservation with the same id and an ObservationCompletion.
+ // Otherwise, call CancelObservation with |id|. It's also okay to destroy the
+ // controller with outstanding observations; these will be cancelled.
+ BeginObservation(mojo_base.mojom.UnguessableToken id,
+ array<FeatureValue> features);
+
+ // Complete observation |id| by providing |completion|.
+ CompleteObservation(mojo_base.mojom.UnguessableToken id,
+ ObservationCompletion completion);
+
+ // Cancel observation |id|. Deleting |this| will do the same.
+ CancelObservation(mojo_base.mojom.UnguessableToken id);
+};
diff --git a/chromium/media/learning/mojo/public/mojom/learning_types.mojom b/chromium/media/learning/mojo/public/mojom/learning_types.mojom
index 9a51bd970c5..cc469185203 100644
--- a/chromium/media/learning/mojo/public/mojom/learning_types.mojom
+++ b/chromium/media/learning/mojo/public/mojom/learning_types.mojom
@@ -19,3 +19,9 @@ struct LabelledExample {
array<FeatureValue> features;
TargetValue target_value;
};
+
+// learning::ObservationCompletion (common/learning_task_controller.h)
+struct ObservationCompletion {
+ TargetValue target_value;
+ uint64 weight = 1;
+};
diff --git a/chromium/media/learning/mojo/public/mojom/learning_types.typemap b/chromium/media/learning/mojo/public/mojom/learning_types.typemap
index 4e6a27b67dd..beaf1467335 100644
--- a/chromium/media/learning/mojo/public/mojom/learning_types.typemap
+++ b/chromium/media/learning/mojo/public/mojom/learning_types.typemap
@@ -1,6 +1,7 @@
mojom = "//media/learning/mojo/public/mojom/learning_types.mojom"
public_headers = [
"//media/learning/common/labelled_example.h",
+ "//media/learning/common/learning_task_controller.h",
"//media/learning/common/value.h",
]
traits_headers = [ "//media/learning/mojo/public/cpp/learning_mojom_traits.h" ]
@@ -15,4 +16,5 @@ type_mappings = [
"media.learning.mojom.LabelledExample=media::learning::LabelledExample",
"media.learning.mojom.FeatureValue=media::learning::FeatureValue",
"media.learning.mojom.TargetValue=media::learning::TargetValue",
+ "media.learning.mojom.ObservationCompletion=media::learning::ObservationCompletion",
]