diff options
author | Allan Sandfeld Jensen <allan.jensen@qt.io> | 2019-02-13 15:05:36 +0100 |
---|---|---|
committer | Allan Sandfeld Jensen <allan.jensen@qt.io> | 2019-02-14 10:33:47 +0000 |
commit | e684a3455bcc29a6e3e66a004e352dea4e1141e7 (patch) | |
tree | d55b4003bde34d7d05f558f02cfd82b2a66a7aac /chromium/components/assist_ranker | |
parent | 2b94bfe47ccb6c08047959d1c26e392919550e86 (diff) | |
download | qtwebengine-chromium-e684a3455bcc29a6e3e66a004e352dea4e1141e7.tar.gz |
BASELINE: Update Chromium to 72.0.3626.110 and Ninja to 1.9.0
Change-Id: Ic57220b00ecc929a893c91f5cc552f5d3e99e922
Reviewed-by: Michael BrĂ¼ning <michael.bruning@qt.io>
Diffstat (limited to 'chromium/components/assist_ranker')
8 files changed, 106 insertions, 26 deletions
diff --git a/chromium/components/assist_ranker/base_predictor.cc b/chromium/components/assist_ranker/base_predictor.cc index 9060cd8de64..07e585aaa25 100644 --- a/chromium/components/assist_ranker/base_predictor.cc +++ b/chromium/components/assist_ranker/base_predictor.cc @@ -124,6 +124,10 @@ GURL BasePredictor::GetModelUrl() const { return GURL(config_.field_trial_url_param->Get()); } +float BasePredictor::GetPredictThresholdReplacement() const { + return config_.field_trial_threshold_replacement_param; +} + RankerExample BasePredictor::PreprocessExample(const RankerExample& example) { if (ranker_model_->proto().has_metadata() && ranker_model_->proto().metadata().input_features_names_are_hex_hashes()) { diff --git a/chromium/components/assist_ranker/base_predictor.h b/chromium/components/assist_ranker/base_predictor.h index 6904ad2e561..c89a4ae69c7 100644 --- a/chromium/components/assist_ranker/base_predictor.h +++ b/chromium/components/assist_ranker/base_predictor.h @@ -22,6 +22,10 @@ class UkmEntryBuilder; namespace assist_ranker { +// Value to use for when no prediction threshold replacement should be applied. +// See |GetPredictThresholdReplacement| method. +const float kNoPredictThresholdReplacement = 0.0; + class Feature; class RankerExample; class RankerModel; @@ -29,7 +33,7 @@ class RankerModel; // Predictors are objects that provide an interface for prediction, as well as // encapsulate the logic for loading the model and logging. Sub-classes of // BasePredictor implement an interface that depends on the nature of the -// suported model. Subclasses of BasePredictor will also need to implement an +// supported model. Subclasses of BasePredictor will also need to implement an // Initialize method that will be called once the model is available, and a // static validation function with the following signature: // @@ -49,6 +53,9 @@ class BasePredictor : public base::SupportsWeakPtr<BasePredictor> { // Returns the model URL. GURL GetModelUrl() const; + // Returns the threshold to use for prediction, or + // kNoPredictThresholdReplacement to leave it unchanged. + float GetPredictThresholdReplacement() const; // Returns the model name. std::string GetModelName() const; diff --git a/chromium/components/assist_ranker/base_predictor_unittest.cc b/chromium/components/assist_ranker/base_predictor_unittest.cc index 5b770768e98..6b330bae386 100644 --- a/chromium/components/assist_ranker/base_predictor_unittest.cc +++ b/chromium/components/assist_ranker/base_predictor_unittest.cc @@ -54,37 +54,41 @@ const base::Feature kTestRankerQuery{"TestRankerQuery", const base::FeatureParam<std::string> kTestRankerUrl{ &kTestRankerQuery, kTestUrlParamName, kTestDefaultModelUrl}; -const PredictorConfig kTestPredictorConfig = PredictorConfig{ - kTestModelName, kTestLoggingName, kTestUmaPrefixName, LOG_UKM, - &kFeatureWhitelist, &kTestRankerQuery, &kTestRankerUrl}; +const PredictorConfig kTestPredictorConfig = + PredictorConfig{kTestModelName, kTestLoggingName, + kTestUmaPrefixName, LOG_UKM, + &kFeatureWhitelist, &kTestRankerQuery, + &kTestRankerUrl, kNoPredictThresholdReplacement}; // Class that implements virtual functions of the base class. class FakePredictor : public BasePredictor { public: - static std::unique_ptr<FakePredictor> Create(); + // Creates a |FakePredictor| using the default config (from this file). + static std::unique_ptr<FakePredictor> Create() { + return Create(kTestPredictorConfig); + } + // Creates a |FakePredictor| using the |PredictorConfig| passed in + // |predictor_config|. + static std::unique_ptr<FakePredictor> Create( + PredictorConfig predictor_config); ~FakePredictor() override{}; // Validation will always succeed. - static RankerModelStatus ValidateModel(const RankerModel& model); + static RankerModelStatus ValidateModel(const RankerModel& model) { + return RankerModelStatus::OK; + } protected: // Not implementing any inference logic. bool Initialize() override { return true; }; private: - FakePredictor(const PredictorConfig& config); + FakePredictor(const PredictorConfig& config) : BasePredictor(config) {} DISALLOW_COPY_AND_ASSIGN(FakePredictor); }; -FakePredictor::FakePredictor(const PredictorConfig& config) - : BasePredictor(config) {} - -RankerModelStatus FakePredictor::ValidateModel(const RankerModel& model) { - return RankerModelStatus::OK; -} - -std::unique_ptr<FakePredictor> FakePredictor::Create() { - std::unique_ptr<FakePredictor> predictor( - new FakePredictor(kTestPredictorConfig)); +std::unique_ptr<FakePredictor> FakePredictor::Create( + PredictorConfig predictor_config) { + std::unique_ptr<FakePredictor> predictor(new FakePredictor(predictor_config)); auto ranker_model = std::make_unique<RankerModel>(); auto fake_model_loader = std::make_unique<FakeRankerModelLoader>( base::BindRepeating(&FakePredictor::ValidateModel), @@ -184,4 +188,14 @@ TEST_F(BasePredictorTest, LogExampleToUkm) { GetTestUkmRecorder()->EntryHasMetric(entries[0], kFeatureNotWhitelisted)); } +TEST_F(BasePredictorTest, GetPredictThresholdReplacement) { + float altered_threshold = 0.78f; // Arbitrary value. + const PredictorConfig altered_threshold_config{ + kTestModelName, kTestLoggingName, kTestUmaPrefixName, + LOG_UKM, &kFeatureWhitelist, &kTestRankerQuery, + &kTestRankerUrl, altered_threshold}; + auto predictor = FakePredictor::Create(altered_threshold_config); + EXPECT_EQ(altered_threshold, predictor->GetPredictThresholdReplacement()); +} + } // namespace assist_ranker diff --git a/chromium/components/assist_ranker/binary_classifier_predictor.cc b/chromium/components/assist_ranker/binary_classifier_predictor.cc index 651eaab8b40..402aa5931c6 100644 --- a/chromium/components/assist_ranker/binary_classifier_predictor.cc +++ b/chromium/components/assist_ranker/binary_classifier_predictor.cc @@ -36,6 +36,8 @@ std::unique_ptr<BinaryClassifierPredictor> BinaryClassifierPredictor::Create( const GURL& model_url = predictor->GetModelUrl(); DVLOG(1) << "Creating predictor instance for " << predictor->GetModelName(); DVLOG(1) << "Model URL: " << model_url; + DVLOG(1) << "Using predict threshold replacement: " + << predictor->GetPredictThresholdReplacement(); auto model_loader = std::make_unique<RankerModelLoaderImpl>( base::BindRepeating(&BinaryClassifierPredictor::ValidateModel), base::BindRepeating(&BinaryClassifierPredictor::OnModelAvailable, @@ -52,7 +54,13 @@ bool BinaryClassifierPredictor::Predict(const RankerExample& example, return false; } - *prediction = inference_module_->Predict(PreprocessExample(example)); + float predict_threshold_replacement = GetPredictThresholdReplacement(); + if (predict_threshold_replacement != kNoPredictThresholdReplacement) { + *prediction = inference_module_->PredictScore(PreprocessExample(example)) >= + predict_threshold_replacement; + } else { + *prediction = inference_module_->Predict(PreprocessExample(example)); + } DVLOG(1) << "Predictor " << GetModelName() << " predicted: " << *prediction; return true; } diff --git a/chromium/components/assist_ranker/binary_classifier_predictor_unittest.cc b/chromium/components/assist_ranker/binary_classifier_predictor_unittest.cc index 672c80fadf0..03dfa537750 100644 --- a/chromium/components/assist_ranker/binary_classifier_predictor_unittest.cc +++ b/chromium/components/assist_ranker/binary_classifier_predictor_unittest.cc @@ -31,6 +31,7 @@ class BinaryClassifierPredictorTest : public ::testing::Test { GenericLogisticRegressionModel GetSimpleLogisticRegressionModel(); PredictorConfig GetConfig(); + PredictorConfig GetConfig(float predictor_threshold_replacement); protected: const std::string feature_ = "feature"; @@ -66,9 +67,14 @@ const base::FeatureParam<std::string> kTestRankerUrl{ &kTestRankerQuery, "url-param-name", "https://default.model.url"}; PredictorConfig BinaryClassifierPredictorTest::GetConfig() { + return GetConfig(kNoPredictThresholdReplacement); +} + +PredictorConfig BinaryClassifierPredictorTest::GetConfig( + float predictor_threshold_replacement) { PredictorConfig config("model_name", "logging_name", "uma_prefix", LOG_NONE, GetEmptyWhitelist(), &kTestRankerQuery, - &kTestRankerUrl); + &kTestRankerUrl, predictor_threshold_replacement); return config; } @@ -171,4 +177,30 @@ TEST_F(BinaryClassifierPredictorTest, EXPECT_LT(float_response, threshold_); } +TEST_F(BinaryClassifierPredictorTest, + GenericLogisticRegressionPreprocessedModelReplacedThreshold) { + auto ranker_model = std::make_unique<RankerModel>(); + auto& glr = *ranker_model->mutable_proto()->mutable_logistic_regression(); + glr = GetSimpleLogisticRegressionModel(); + glr.clear_weights(); + glr.set_is_preprocessed_model(true); + (*glr.mutable_fullname_weights())[feature_] = weight_; + + float high_threshold = 0.9; // Some high threshold. + auto predictor = + InitPredictor(std::move(ranker_model), GetConfig(high_threshold)); + EXPECT_TRUE(predictor->IsReady()); + + RankerExample ranker_example; + auto& features = *ranker_example.mutable_features(); + features[feature_].set_bool_value(true); + bool bool_response; + EXPECT_TRUE(predictor->Predict(ranker_example, &bool_response)); + EXPECT_FALSE(bool_response); + float float_response; + EXPECT_TRUE(predictor->PredictScore(ranker_example, &float_response)); + EXPECT_GT(float_response, threshold_); + EXPECT_LT(float_response, high_threshold); +} + } // namespace assist_ranker diff --git a/chromium/components/assist_ranker/generic_logistic_regression_inference.h b/chromium/components/assist_ranker/generic_logistic_regression_inference.h index ad4b1dbf5ad..eb34804226b 100644 --- a/chromium/components/assist_ranker/generic_logistic_regression_inference.h +++ b/chromium/components/assist_ranker/generic_logistic_regression_inference.h @@ -23,7 +23,7 @@ class GenericLogisticRegressionInference { // Returns a boolean decision given a RankerExample. Uses the same logic as // PredictScore, and then applies the model decision threshold. bool Predict(const RankerExample& example); - // Returns a score between 0 and 1 give a RankerExample. + // Returns a score between 0 and 1 given a RankerExample. float PredictScore(const RankerExample& example); private: diff --git a/chromium/components/assist_ranker/predictor_config.h b/chromium/components/assist_ranker/predictor_config.h index 6a545889f4f..944164b4aad 100644 --- a/chromium/components/assist_ranker/predictor_config.h +++ b/chromium/components/assist_ranker/predictor_config.h @@ -30,21 +30,25 @@ struct PredictorConfig { const LogType log_type, const base::flat_set<std::string>* feature_whitelist, const base::Feature* field_trial, - const base::FeatureParam<std::string>* field_trial_url_param) + const base::FeatureParam<std::string>* field_trial_url_param, + float field_trial_threshold_replacement_param) : model_name(model_name), logging_name(logging_name), uma_prefix(uma_prefix), log_type(log_type), feature_whitelist(feature_whitelist), field_trial(field_trial), - field_trial_url_param(field_trial_url_param) {} - const char* model_name; - const char* logging_name; - const char* uma_prefix; + field_trial_url_param(field_trial_url_param), + field_trial_threshold_replacement_param( + field_trial_threshold_replacement_param) {} + const char* const model_name; + const char* const logging_name; + const char* const uma_prefix; const LogType log_type; const base::flat_set<std::string>* feature_whitelist; const base::Feature* field_trial; const base::FeatureParam<std::string>* field_trial_url_param; + const float field_trial_threshold_replacement_param; }; } // namespace assist_ranker diff --git a/chromium/components/assist_ranker/predictor_config_definitions.cc b/chromium/components/assist_ranker/predictor_config_definitions.cc index e967f682f36..7bcfc641a95 100644 --- a/chromium/components/assist_ranker/predictor_config_definitions.cc +++ b/chromium/components/assist_ranker/predictor_config_definitions.cc @@ -3,6 +3,7 @@ // found in the LICENSE file. #include "components/assist_ranker/predictor_config_definitions.h" +#include "components/assist_ranker/base_predictor.h" namespace assist_ranker { @@ -28,6 +29,15 @@ GetContextualSearchRankerUrlFeatureParam() { return kContextualSearchRankerUrl; } +float GetContextualSearchRankerThresholdFeatureParam() { + static auto* kContextualSearchRankerThreshold = + new base::FeatureParam<double>( + &kContextualSearchRankerQuery, + "contextual-search-ranker-predict-threshold", + kNoPredictThresholdReplacement); + return static_cast<float>(kContextualSearchRankerThreshold->Get()); +} + // NOTE: This list needs to be kept in sync with tools/metrics/ukm/ukm.xml! // Only features within this list will be logged to UKM. // TODO(chrome-ranker-team) Deprecate the whitelist once it is available through @@ -77,7 +87,8 @@ const PredictorConfig GetContextualSearchPredictorConfig() { kContextualSearchModelName, kContextualSearchLoggingName, kContextualSearchUmaPrefixName, LOG_UKM, GetContextualSearchFeatureWhitelist(), &kContextualSearchRankerQuery, - GetContextualSearchRankerUrlFeatureParam())); + GetContextualSearchRankerUrlFeatureParam(), + GetContextualSearchRankerThresholdFeatureParam())); return kContextualSearchPredictorConfig; } #endif // OS_ANDROID |