summaryrefslogtreecommitdiff
path: root/chromium/components/assist_ranker
diff options
context:
space:
mode:
authorAllan Sandfeld Jensen <allan.jensen@qt.io>2019-02-13 15:05:36 +0100
committerAllan Sandfeld Jensen <allan.jensen@qt.io>2019-02-14 10:33:47 +0000
commite684a3455bcc29a6e3e66a004e352dea4e1141e7 (patch)
treed55b4003bde34d7d05f558f02cfd82b2a66a7aac /chromium/components/assist_ranker
parent2b94bfe47ccb6c08047959d1c26e392919550e86 (diff)
downloadqtwebengine-chromium-e684a3455bcc29a6e3e66a004e352dea4e1141e7.tar.gz
BASELINE: Update Chromium to 72.0.3626.110 and Ninja to 1.9.0
Change-Id: Ic57220b00ecc929a893c91f5cc552f5d3e99e922 Reviewed-by: Michael BrĂ¼ning <michael.bruning@qt.io>
Diffstat (limited to 'chromium/components/assist_ranker')
-rw-r--r--chromium/components/assist_ranker/base_predictor.cc4
-rw-r--r--chromium/components/assist_ranker/base_predictor.h9
-rw-r--r--chromium/components/assist_ranker/base_predictor_unittest.cc46
-rw-r--r--chromium/components/assist_ranker/binary_classifier_predictor.cc10
-rw-r--r--chromium/components/assist_ranker/binary_classifier_predictor_unittest.cc34
-rw-r--r--chromium/components/assist_ranker/generic_logistic_regression_inference.h2
-rw-r--r--chromium/components/assist_ranker/predictor_config.h14
-rw-r--r--chromium/components/assist_ranker/predictor_config_definitions.cc13
8 files changed, 106 insertions, 26 deletions
diff --git a/chromium/components/assist_ranker/base_predictor.cc b/chromium/components/assist_ranker/base_predictor.cc
index 9060cd8de64..07e585aaa25 100644
--- a/chromium/components/assist_ranker/base_predictor.cc
+++ b/chromium/components/assist_ranker/base_predictor.cc
@@ -124,6 +124,10 @@ GURL BasePredictor::GetModelUrl() const {
return GURL(config_.field_trial_url_param->Get());
}
+float BasePredictor::GetPredictThresholdReplacement() const {
+ return config_.field_trial_threshold_replacement_param;
+}
+
RankerExample BasePredictor::PreprocessExample(const RankerExample& example) {
if (ranker_model_->proto().has_metadata() &&
ranker_model_->proto().metadata().input_features_names_are_hex_hashes()) {
diff --git a/chromium/components/assist_ranker/base_predictor.h b/chromium/components/assist_ranker/base_predictor.h
index 6904ad2e561..c89a4ae69c7 100644
--- a/chromium/components/assist_ranker/base_predictor.h
+++ b/chromium/components/assist_ranker/base_predictor.h
@@ -22,6 +22,10 @@ class UkmEntryBuilder;
namespace assist_ranker {
+// Value to use for when no prediction threshold replacement should be applied.
+// See |GetPredictThresholdReplacement| method.
+const float kNoPredictThresholdReplacement = 0.0;
+
class Feature;
class RankerExample;
class RankerModel;
@@ -29,7 +33,7 @@ class RankerModel;
// Predictors are objects that provide an interface for prediction, as well as
// encapsulate the logic for loading the model and logging. Sub-classes of
// BasePredictor implement an interface that depends on the nature of the
-// suported model. Subclasses of BasePredictor will also need to implement an
+// supported model. Subclasses of BasePredictor will also need to implement an
// Initialize method that will be called once the model is available, and a
// static validation function with the following signature:
//
@@ -49,6 +53,9 @@ class BasePredictor : public base::SupportsWeakPtr<BasePredictor> {
// Returns the model URL.
GURL GetModelUrl() const;
+ // Returns the threshold to use for prediction, or
+ // kNoPredictThresholdReplacement to leave it unchanged.
+ float GetPredictThresholdReplacement() const;
// Returns the model name.
std::string GetModelName() const;
diff --git a/chromium/components/assist_ranker/base_predictor_unittest.cc b/chromium/components/assist_ranker/base_predictor_unittest.cc
index 5b770768e98..6b330bae386 100644
--- a/chromium/components/assist_ranker/base_predictor_unittest.cc
+++ b/chromium/components/assist_ranker/base_predictor_unittest.cc
@@ -54,37 +54,41 @@ const base::Feature kTestRankerQuery{"TestRankerQuery",
const base::FeatureParam<std::string> kTestRankerUrl{
&kTestRankerQuery, kTestUrlParamName, kTestDefaultModelUrl};
-const PredictorConfig kTestPredictorConfig = PredictorConfig{
- kTestModelName, kTestLoggingName, kTestUmaPrefixName, LOG_UKM,
- &kFeatureWhitelist, &kTestRankerQuery, &kTestRankerUrl};
+const PredictorConfig kTestPredictorConfig =
+ PredictorConfig{kTestModelName, kTestLoggingName,
+ kTestUmaPrefixName, LOG_UKM,
+ &kFeatureWhitelist, &kTestRankerQuery,
+ &kTestRankerUrl, kNoPredictThresholdReplacement};
// Class that implements virtual functions of the base class.
class FakePredictor : public BasePredictor {
public:
- static std::unique_ptr<FakePredictor> Create();
+ // Creates a |FakePredictor| using the default config (from this file).
+ static std::unique_ptr<FakePredictor> Create() {
+ return Create(kTestPredictorConfig);
+ }
+ // Creates a |FakePredictor| using the |PredictorConfig| passed in
+ // |predictor_config|.
+ static std::unique_ptr<FakePredictor> Create(
+ PredictorConfig predictor_config);
~FakePredictor() override{};
// Validation will always succeed.
- static RankerModelStatus ValidateModel(const RankerModel& model);
+ static RankerModelStatus ValidateModel(const RankerModel& model) {
+ return RankerModelStatus::OK;
+ }
protected:
// Not implementing any inference logic.
bool Initialize() override { return true; };
private:
- FakePredictor(const PredictorConfig& config);
+ FakePredictor(const PredictorConfig& config) : BasePredictor(config) {}
DISALLOW_COPY_AND_ASSIGN(FakePredictor);
};
-FakePredictor::FakePredictor(const PredictorConfig& config)
- : BasePredictor(config) {}
-
-RankerModelStatus FakePredictor::ValidateModel(const RankerModel& model) {
- return RankerModelStatus::OK;
-}
-
-std::unique_ptr<FakePredictor> FakePredictor::Create() {
- std::unique_ptr<FakePredictor> predictor(
- new FakePredictor(kTestPredictorConfig));
+std::unique_ptr<FakePredictor> FakePredictor::Create(
+ PredictorConfig predictor_config) {
+ std::unique_ptr<FakePredictor> predictor(new FakePredictor(predictor_config));
auto ranker_model = std::make_unique<RankerModel>();
auto fake_model_loader = std::make_unique<FakeRankerModelLoader>(
base::BindRepeating(&FakePredictor::ValidateModel),
@@ -184,4 +188,14 @@ TEST_F(BasePredictorTest, LogExampleToUkm) {
GetTestUkmRecorder()->EntryHasMetric(entries[0], kFeatureNotWhitelisted));
}
+TEST_F(BasePredictorTest, GetPredictThresholdReplacement) {
+ float altered_threshold = 0.78f; // Arbitrary value.
+ const PredictorConfig altered_threshold_config{
+ kTestModelName, kTestLoggingName, kTestUmaPrefixName,
+ LOG_UKM, &kFeatureWhitelist, &kTestRankerQuery,
+ &kTestRankerUrl, altered_threshold};
+ auto predictor = FakePredictor::Create(altered_threshold_config);
+ EXPECT_EQ(altered_threshold, predictor->GetPredictThresholdReplacement());
+}
+
} // namespace assist_ranker
diff --git a/chromium/components/assist_ranker/binary_classifier_predictor.cc b/chromium/components/assist_ranker/binary_classifier_predictor.cc
index 651eaab8b40..402aa5931c6 100644
--- a/chromium/components/assist_ranker/binary_classifier_predictor.cc
+++ b/chromium/components/assist_ranker/binary_classifier_predictor.cc
@@ -36,6 +36,8 @@ std::unique_ptr<BinaryClassifierPredictor> BinaryClassifierPredictor::Create(
const GURL& model_url = predictor->GetModelUrl();
DVLOG(1) << "Creating predictor instance for " << predictor->GetModelName();
DVLOG(1) << "Model URL: " << model_url;
+ DVLOG(1) << "Using predict threshold replacement: "
+ << predictor->GetPredictThresholdReplacement();
auto model_loader = std::make_unique<RankerModelLoaderImpl>(
base::BindRepeating(&BinaryClassifierPredictor::ValidateModel),
base::BindRepeating(&BinaryClassifierPredictor::OnModelAvailable,
@@ -52,7 +54,13 @@ bool BinaryClassifierPredictor::Predict(const RankerExample& example,
return false;
}
- *prediction = inference_module_->Predict(PreprocessExample(example));
+ float predict_threshold_replacement = GetPredictThresholdReplacement();
+ if (predict_threshold_replacement != kNoPredictThresholdReplacement) {
+ *prediction = inference_module_->PredictScore(PreprocessExample(example)) >=
+ predict_threshold_replacement;
+ } else {
+ *prediction = inference_module_->Predict(PreprocessExample(example));
+ }
DVLOG(1) << "Predictor " << GetModelName() << " predicted: " << *prediction;
return true;
}
diff --git a/chromium/components/assist_ranker/binary_classifier_predictor_unittest.cc b/chromium/components/assist_ranker/binary_classifier_predictor_unittest.cc
index 672c80fadf0..03dfa537750 100644
--- a/chromium/components/assist_ranker/binary_classifier_predictor_unittest.cc
+++ b/chromium/components/assist_ranker/binary_classifier_predictor_unittest.cc
@@ -31,6 +31,7 @@ class BinaryClassifierPredictorTest : public ::testing::Test {
GenericLogisticRegressionModel GetSimpleLogisticRegressionModel();
PredictorConfig GetConfig();
+ PredictorConfig GetConfig(float predictor_threshold_replacement);
protected:
const std::string feature_ = "feature";
@@ -66,9 +67,14 @@ const base::FeatureParam<std::string> kTestRankerUrl{
&kTestRankerQuery, "url-param-name", "https://default.model.url"};
PredictorConfig BinaryClassifierPredictorTest::GetConfig() {
+ return GetConfig(kNoPredictThresholdReplacement);
+}
+
+PredictorConfig BinaryClassifierPredictorTest::GetConfig(
+ float predictor_threshold_replacement) {
PredictorConfig config("model_name", "logging_name", "uma_prefix", LOG_NONE,
GetEmptyWhitelist(), &kTestRankerQuery,
- &kTestRankerUrl);
+ &kTestRankerUrl, predictor_threshold_replacement);
return config;
}
@@ -171,4 +177,30 @@ TEST_F(BinaryClassifierPredictorTest,
EXPECT_LT(float_response, threshold_);
}
+TEST_F(BinaryClassifierPredictorTest,
+ GenericLogisticRegressionPreprocessedModelReplacedThreshold) {
+ auto ranker_model = std::make_unique<RankerModel>();
+ auto& glr = *ranker_model->mutable_proto()->mutable_logistic_regression();
+ glr = GetSimpleLogisticRegressionModel();
+ glr.clear_weights();
+ glr.set_is_preprocessed_model(true);
+ (*glr.mutable_fullname_weights())[feature_] = weight_;
+
+ float high_threshold = 0.9; // Some high threshold.
+ auto predictor =
+ InitPredictor(std::move(ranker_model), GetConfig(high_threshold));
+ EXPECT_TRUE(predictor->IsReady());
+
+ RankerExample ranker_example;
+ auto& features = *ranker_example.mutable_features();
+ features[feature_].set_bool_value(true);
+ bool bool_response;
+ EXPECT_TRUE(predictor->Predict(ranker_example, &bool_response));
+ EXPECT_FALSE(bool_response);
+ float float_response;
+ EXPECT_TRUE(predictor->PredictScore(ranker_example, &float_response));
+ EXPECT_GT(float_response, threshold_);
+ EXPECT_LT(float_response, high_threshold);
+}
+
} // namespace assist_ranker
diff --git a/chromium/components/assist_ranker/generic_logistic_regression_inference.h b/chromium/components/assist_ranker/generic_logistic_regression_inference.h
index ad4b1dbf5ad..eb34804226b 100644
--- a/chromium/components/assist_ranker/generic_logistic_regression_inference.h
+++ b/chromium/components/assist_ranker/generic_logistic_regression_inference.h
@@ -23,7 +23,7 @@ class GenericLogisticRegressionInference {
// Returns a boolean decision given a RankerExample. Uses the same logic as
// PredictScore, and then applies the model decision threshold.
bool Predict(const RankerExample& example);
- // Returns a score between 0 and 1 give a RankerExample.
+ // Returns a score between 0 and 1 given a RankerExample.
float PredictScore(const RankerExample& example);
private:
diff --git a/chromium/components/assist_ranker/predictor_config.h b/chromium/components/assist_ranker/predictor_config.h
index 6a545889f4f..944164b4aad 100644
--- a/chromium/components/assist_ranker/predictor_config.h
+++ b/chromium/components/assist_ranker/predictor_config.h
@@ -30,21 +30,25 @@ struct PredictorConfig {
const LogType log_type,
const base::flat_set<std::string>* feature_whitelist,
const base::Feature* field_trial,
- const base::FeatureParam<std::string>* field_trial_url_param)
+ const base::FeatureParam<std::string>* field_trial_url_param,
+ float field_trial_threshold_replacement_param)
: model_name(model_name),
logging_name(logging_name),
uma_prefix(uma_prefix),
log_type(log_type),
feature_whitelist(feature_whitelist),
field_trial(field_trial),
- field_trial_url_param(field_trial_url_param) {}
- const char* model_name;
- const char* logging_name;
- const char* uma_prefix;
+ field_trial_url_param(field_trial_url_param),
+ field_trial_threshold_replacement_param(
+ field_trial_threshold_replacement_param) {}
+ const char* const model_name;
+ const char* const logging_name;
+ const char* const uma_prefix;
const LogType log_type;
const base::flat_set<std::string>* feature_whitelist;
const base::Feature* field_trial;
const base::FeatureParam<std::string>* field_trial_url_param;
+ const float field_trial_threshold_replacement_param;
};
} // namespace assist_ranker
diff --git a/chromium/components/assist_ranker/predictor_config_definitions.cc b/chromium/components/assist_ranker/predictor_config_definitions.cc
index e967f682f36..7bcfc641a95 100644
--- a/chromium/components/assist_ranker/predictor_config_definitions.cc
+++ b/chromium/components/assist_ranker/predictor_config_definitions.cc
@@ -3,6 +3,7 @@
// found in the LICENSE file.
#include "components/assist_ranker/predictor_config_definitions.h"
+#include "components/assist_ranker/base_predictor.h"
namespace assist_ranker {
@@ -28,6 +29,15 @@ GetContextualSearchRankerUrlFeatureParam() {
return kContextualSearchRankerUrl;
}
+float GetContextualSearchRankerThresholdFeatureParam() {
+ static auto* kContextualSearchRankerThreshold =
+ new base::FeatureParam<double>(
+ &kContextualSearchRankerQuery,
+ "contextual-search-ranker-predict-threshold",
+ kNoPredictThresholdReplacement);
+ return static_cast<float>(kContextualSearchRankerThreshold->Get());
+}
+
// NOTE: This list needs to be kept in sync with tools/metrics/ukm/ukm.xml!
// Only features within this list will be logged to UKM.
// TODO(chrome-ranker-team) Deprecate the whitelist once it is available through
@@ -77,7 +87,8 @@ const PredictorConfig GetContextualSearchPredictorConfig() {
kContextualSearchModelName, kContextualSearchLoggingName,
kContextualSearchUmaPrefixName, LOG_UKM,
GetContextualSearchFeatureWhitelist(), &kContextualSearchRankerQuery,
- GetContextualSearchRankerUrlFeatureParam()));
+ GetContextualSearchRankerUrlFeatureParam(),
+ GetContextualSearchRankerThresholdFeatureParam()));
return kContextualSearchPredictorConfig;
}
#endif // OS_ANDROID