summaryrefslogtreecommitdiff
path: root/chromium/components/assist_ranker
diff options
context:
space:
mode:
authorAllan Sandfeld Jensen <allan.jensen@qt.io>2018-01-31 16:33:43 +0100
committerAllan Sandfeld Jensen <allan.jensen@qt.io>2018-02-06 16:33:22 +0000
commitda51f56cc21233c2d30f0fe0d171727c3102b2e0 (patch)
tree4e579ab70ce4b19bee7984237f3ce05a96d59d83 /chromium/components/assist_ranker
parentc8c2d1901aec01e934adf561a9fdf0cc776cdef8 (diff)
downloadqtwebengine-chromium-da51f56cc21233c2d30f0fe0d171727c3102b2e0.tar.gz
BASELINE: Update Chromium to 65.0.3525.40
Also imports missing submodules Change-Id: I36901b7c6a325cda3d2c10cedb2186c25af3b79b Reviewed-by: Alexandru Croitor <alexandru.croitor@qt.io>
Diffstat (limited to 'chromium/components/assist_ranker')
-rw-r--r--chromium/components/assist_ranker/BUILD.gn7
-rw-r--r--chromium/components/assist_ranker/DEPS2
-rw-r--r--chromium/components/assist_ranker/assist_ranker_service.h23
-rw-r--r--chromium/components/assist_ranker/assist_ranker_service_impl.cc28
-rw-r--r--chromium/components/assist_ranker/assist_ranker_service_impl.h14
-rw-r--r--chromium/components/assist_ranker/base_predictor.cc110
-rw-r--r--chromium/components/assist_ranker/base_predictor.h49
-rw-r--r--chromium/components/assist_ranker/base_predictor_unittest.cc183
-rw-r--r--chromium/components/assist_ranker/binary_classifier_predictor.cc51
-rw-r--r--chromium/components/assist_ranker/binary_classifier_predictor.h12
-rw-r--r--chromium/components/assist_ranker/binary_classifier_predictor_unittest.cc33
-rw-r--r--chromium/components/assist_ranker/generic_logistic_regression_inference.cc11
-rw-r--r--chromium/components/assist_ranker/predictor_config.cc14
-rw-r--r--chromium/components/assist_ranker/predictor_config.h52
-rw-r--r--chromium/components/assist_ranker/predictor_config_definitions.cc75
-rw-r--r--chromium/components/assist_ranker/predictor_config_definitions.h26
-rw-r--r--chromium/components/assist_ranker/proto/ranker_example.proto12
-rw-r--r--chromium/components/assist_ranker/ranker_example_util.cc97
-rw-r--r--chromium/components/assist_ranker/ranker_example_util.h23
-rw-r--r--chromium/components/assist_ranker/ranker_example_util_unittest.cc77
20 files changed, 823 insertions, 76 deletions
diff --git a/chromium/components/assist_ranker/BUILD.gn b/chromium/components/assist_ranker/BUILD.gn
index 992fe8b8117..1c07991b5dd 100644
--- a/chromium/components/assist_ranker/BUILD.gn
+++ b/chromium/components/assist_ranker/BUILD.gn
@@ -15,6 +15,10 @@ static_library("assist_ranker") {
"fake_ranker_model_loader.h",
"generic_logistic_regression_inference.cc",
"generic_logistic_regression_inference.h",
+ "predictor_config.cc",
+ "predictor_config.h",
+ "predictor_config_definitions.cc",
+ "predictor_config_definitions.h",
"ranker_example_util.cc",
"ranker_example_util.h",
"ranker_model.cc",
@@ -32,6 +36,7 @@ static_library("assist_ranker") {
"//components/data_use_measurement/core",
"//components/keyed_service/core",
"//net",
+ "//services/metrics/public/cpp:metrics_cpp",
"//url",
]
}
@@ -40,6 +45,7 @@ source_set("unit_tests") {
testonly = true
sources = [
+ "base_predictor_unittest.cc",
"binary_classifier_predictor_unittest.cc",
"generic_logistic_regression_inference_unittest.cc",
"ranker_example_util_unittest.cc",
@@ -51,6 +57,7 @@ source_set("unit_tests") {
":assist_ranker",
"//base",
"//components/assist_ranker/proto",
+ "//components/ukm:test_support",
"//net:test_support",
"//testing/gtest",
]
diff --git a/chromium/components/assist_ranker/DEPS b/chromium/components/assist_ranker/DEPS
index f66492a2bcd..d8355c686f2 100644
--- a/chromium/components/assist_ranker/DEPS
+++ b/chromium/components/assist_ranker/DEPS
@@ -2,5 +2,7 @@ include_rules = [
"+components/data_use_measurement/core",
"+components/keyed_service/core",
"+components/metrics",
+ "+components/ukm",
"+net",
+ "+services/metrics/public",
] \ No newline at end of file
diff --git a/chromium/components/assist_ranker/assist_ranker_service.h b/chromium/components/assist_ranker/assist_ranker_service.h
index d2015f07668..bb11a4789b2 100644
--- a/chromium/components/assist_ranker/assist_ranker_service.h
+++ b/chromium/components/assist_ranker/assist_ranker_service.h
@@ -9,32 +9,25 @@
#include <string>
#include "base/macros.h"
+#include "base/memory/weak_ptr.h"
#include "components/keyed_service/core/keyed_service.h"
-class GURL;
-
namespace assist_ranker {
class BinaryClassifierPredictor;
+struct PredictorConfig;
-// TODO(crbug.com/778468) : Refactor this so that the service owns the predictor
-// objects and enforce model uniqueness through internal registration in order
-// to avoid cache collisions.
-//
// Service that provides Predictor objects.
class AssistRankerService : public KeyedService {
public:
AssistRankerService() = default;
- // Returns a binary classification model. |model_filename| is the filename of
- // the cached model. It should be unique to this predictor to avoid cache
- // collision. |model_url| represents a unique ID for the desired model (see
- // ranker_model_loader.h for more details). |uma_prefix| is used to log
- // histograms related to the loading of the model.
- virtual std::unique_ptr<BinaryClassifierPredictor>
- FetchBinaryClassifierPredictor(GURL model_url,
- const std::string& model_filename,
- const std::string& uma_prefix) = 0;
+ // Returns a binary classification model given a PredictorConfig.
+ // The predictor is instantiated the first time a predictor is fetched. The
+ // next calls to fetch will return a pointer to the already instantiated
+ // predictor.
+ virtual base::WeakPtr<BinaryClassifierPredictor>
+ FetchBinaryClassifierPredictor(const PredictorConfig& config) = 0;
private:
DISALLOW_COPY_AND_ASSIGN(AssistRankerService);
diff --git a/chromium/components/assist_ranker/assist_ranker_service_impl.cc b/chromium/components/assist_ranker/assist_ranker_service_impl.cc
index 255183a6e8c..5b870b8f578 100644
--- a/chromium/components/assist_ranker/assist_ranker_service_impl.cc
+++ b/chromium/components/assist_ranker/assist_ranker_service_impl.cc
@@ -3,7 +3,7 @@
// found in the LICENSE file.
#include "components/assist_ranker/assist_ranker_service_impl.h"
-
+#include "base/memory/weak_ptr.h"
#include "components/assist_ranker/binary_classifier_predictor.h"
#include "components/assist_ranker/ranker_model_loader_impl.h"
#include "net/url_request/url_request_context_getter.h"
@@ -19,15 +19,27 @@ AssistRankerServiceImpl::AssistRankerServiceImpl(
AssistRankerServiceImpl::~AssistRankerServiceImpl() {}
-std::unique_ptr<BinaryClassifierPredictor>
+base::WeakPtr<BinaryClassifierPredictor>
AssistRankerServiceImpl::FetchBinaryClassifierPredictor(
- GURL model_url,
- const std::string& model_filename,
- const std::string& uma_prefix) {
+ const PredictorConfig& config) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
- return BinaryClassifierPredictor::Create(url_request_context_getter_.get(),
- GetModelPath(model_filename),
- model_url, uma_prefix);
+ const std::string& model_name = config.model_name;
+ auto predictor_it = predictor_map_.find(model_name);
+ if (predictor_it != predictor_map_.end()) {
+ DVLOG(1) << "Predictor " << model_name << " already initialized.";
+ return base::AsWeakPtr(
+ static_cast<BinaryClassifierPredictor*>(predictor_it->second.get()));
+ }
+
+ // The predictor does not exist yet, so we create one.
+ DVLOG(1) << "Initializing predictor: " << model_name;
+ std::unique_ptr<BinaryClassifierPredictor> predictor =
+ BinaryClassifierPredictor::Create(config, GetModelPath(model_name),
+ url_request_context_getter_.get());
+ base::WeakPtr<BinaryClassifierPredictor> weak_ptr =
+ base::AsWeakPtr(predictor.get());
+ predictor_map_[model_name] = std::move(predictor);
+ return weak_ptr;
}
base::FilePath AssistRankerServiceImpl::GetModelPath(
diff --git a/chromium/components/assist_ranker/assist_ranker_service_impl.h b/chromium/components/assist_ranker/assist_ranker_service_impl.h
index e4e967e5981..1b75ff4017f 100644
--- a/chromium/components/assist_ranker/assist_ranker_service_impl.h
+++ b/chromium/components/assist_ranker/assist_ranker_service_impl.h
@@ -7,13 +7,13 @@
#include <memory>
#include <string>
+#include <unordered_map>
#include "base/files/file_path.h"
#include "base/memory/ref_counted.h"
#include "base/sequence_checker.h"
#include "components/assist_ranker/assist_ranker_service.h"
-
-class GURL;
+#include "components/assist_ranker/predictor_config.h"
namespace net {
class URLRequestContextGetter;
@@ -21,6 +21,7 @@ class URLRequestContextGetter;
namespace assist_ranker {
+class BasePredictor;
class BinaryClassifierPredictor;
class AssistRankerServiceImpl : public AssistRankerService {
@@ -31,10 +32,8 @@ class AssistRankerServiceImpl : public AssistRankerService {
~AssistRankerServiceImpl() override;
// AssistRankerService...
- std::unique_ptr<BinaryClassifierPredictor> FetchBinaryClassifierPredictor(
- GURL model_url,
- const std::string& model_filename,
- const std::string& uma_prefix) override;
+ base::WeakPtr<BinaryClassifierPredictor> FetchBinaryClassifierPredictor(
+ const PredictorConfig& config) override;
private:
// Returns the full path to the model cache.
@@ -46,6 +45,9 @@ class AssistRankerServiceImpl : public AssistRankerService {
// Base path where models are stored.
const base::FilePath base_path_;
+ std::unordered_map<std::string, std::unique_ptr<BasePredictor>>
+ predictor_map_;
+
SEQUENCE_CHECKER(sequence_checker_);
DISALLOW_COPY_AND_ASSIGN(AssistRankerServiceImpl);
diff --git a/chromium/components/assist_ranker/base_predictor.cc b/chromium/components/assist_ranker/base_predictor.cc
index b890f998ec0..13bfe0c5e15 100644
--- a/chromium/components/assist_ranker/base_predictor.cc
+++ b/chromium/components/assist_ranker/base_predictor.cc
@@ -4,23 +4,40 @@
#include "components/assist_ranker/base_predictor.h"
+#include "base/feature_list.h"
#include "base/memory/ptr_util.h"
+#include "components/assist_ranker/proto/ranker_example.pb.h"
#include "components/assist_ranker/proto/ranker_model.pb.h"
+#include "components/assist_ranker/ranker_example_util.h"
#include "components/assist_ranker/ranker_model.h"
+#include "services/metrics/public/cpp/ukm_entry_builder.h"
+#include "services/metrics/public/cpp/ukm_recorder.h"
+#include "url/gurl.h"
namespace assist_ranker {
-BasePredictor::BasePredictor() {}
+BasePredictor::BasePredictor(const PredictorConfig& config) : config_(config) {
+ // TODO(chrome-ranker-team): validate config.
+ if (config_.field_trial) {
+ is_query_enabled_ = base::FeatureList::IsEnabled(*config_.field_trial);
+ } else {
+ DVLOG(0) << "No field trial specified";
+ }
+}
+
BasePredictor::~BasePredictor() {}
void BasePredictor::LoadModel(std::unique_ptr<RankerModelLoader> model_loader) {
+ if (!is_query_enabled_)
+ return;
+
if (model_loader_) {
- DLOG(ERROR) << "This predictor already has a model loader.";
+ DVLOG(0) << "This predictor already has a model loader.";
return;
}
// Take ownership of the model loader.
model_loader_ = std::move(model_loader);
- // Kick off the initial load from cache.
+ // Kick off the initial model load.
model_loader_->NotifyOfRankerActivity();
}
@@ -31,10 +48,95 @@ void BasePredictor::OnModelAvailable(
}
bool BasePredictor::IsReady() {
- if (!is_ready_)
+ if (!is_ready_ && model_loader_)
model_loader_->NotifyOfRankerActivity();
return is_ready_;
}
+void BasePredictor::LogFeatureToUkm(const std::string& feature_name,
+ const Feature& feature,
+ ukm::UkmEntryBuilder* ukm_builder) {
+ if (!ukm_builder)
+ return;
+
+ if (!base::ContainsKey(*config_.feature_whitelist, feature_name)) {
+ DVLOG(1) << "Feature not whitelisted: " << feature_name;
+ return;
+ }
+
+ switch (feature.feature_type_case()) {
+ case Feature::kBoolValue:
+ case Feature::kFloatValue:
+ case Feature::kInt32Value:
+ case Feature::kStringValue: {
+ int64_t feature_int64_value = -1;
+ FeatureToInt64(feature, &feature_int64_value);
+ DVLOG(3) << "Logging: " << feature_name << ": " << feature_int64_value;
+ ukm_builder->AddMetric(feature_name.c_str(), feature_int64_value);
+ break;
+ }
+ case Feature::kStringList: {
+ for (int i = 0; i < feature.string_list().string_value_size(); ++i) {
+ int64_t feature_int64_value = -1;
+ FeatureToInt64(feature, &feature_int64_value, i);
+ DVLOG(3) << "Logging: " << feature_name << ": " << feature_int64_value;
+ ukm_builder->AddMetric(feature_name.c_str(), feature_int64_value);
+ }
+ break;
+ }
+ default:
+ DVLOG(0) << "Could not convert feature to int: " << feature_name;
+ }
+}
+
+void BasePredictor::LogExampleToUkm(const RankerExample& example,
+ ukm::SourceId source_id) {
+ if (config_.log_type != LOG_UKM) {
+ DVLOG(0) << "Wrong log type in predictor config: " << config_.log_type;
+ return;
+ }
+
+ if (!config_.feature_whitelist) {
+ DVLOG(0) << "No whitelist specified.";
+ return;
+ }
+ if (config_.feature_whitelist->empty()) {
+ DVLOG(0) << "Empty whitelist, examples will not be logged.";
+ return;
+ }
+
+ // Releasing the builder will trigger logging.
+ std::unique_ptr<ukm::UkmEntryBuilder> builder =
+ ukm::UkmRecorder::Get()->GetEntryBuilder(source_id, config_.logging_name);
+ if (builder) {
+ for (const auto& feature_kv : example.features()) {
+ LogFeatureToUkm(feature_kv.first, feature_kv.second, builder.get());
+ }
+ } else {
+ DVLOG(0) << "Could not get UKM Entry Builder.";
+ }
+}
+
+std::string BasePredictor::GetModelName() const {
+ return config_.model_name;
+}
+
+GURL BasePredictor::GetModelUrl() const {
+ if (!config_.field_trial_url_param) {
+ DVLOG(1) << "No URL specified.";
+ return GURL();
+ }
+
+ return GURL(config_.field_trial_url_param->Get());
+}
+
+RankerExample BasePredictor::PreprocessExample(const RankerExample& example) {
+ if (ranker_model_->proto().has_metadata() &&
+ ranker_model_->proto().metadata().input_features_names_are_hex_hashes()) {
+ return HashExampleFeatureNames(example);
+ }
+ return example;
+}
+
} // namespace assist_ranker
diff --git a/chromium/components/assist_ranker/base_predictor.h b/chromium/components/assist_ranker/base_predictor.h
index 5b69e8584e8..6904ad2e561 100644
--- a/chromium/components/assist_ranker/base_predictor.h
+++ b/chromium/components/assist_ranker/base_predictor.h
@@ -9,30 +9,53 @@
#include <string>
#include "base/files/file_path.h"
+#include "base/memory/weak_ptr.h"
+#include "components/assist_ranker/predictor_config.h"
#include "components/assist_ranker/ranker_model_loader.h"
+#include "services/metrics/public/cpp/ukm_source_id.h"
+
+class GURL;
+
+namespace ukm {
+class UkmEntryBuilder;
+}
namespace assist_ranker {
+class Feature;
+class RankerExample;
class RankerModel;
// Predictors are objects that provide an interface for prediction, as well as
-// encapsulate the logic for loading the model. Sub-classes of BasePredictor
-// implement an interface that depends on the nature of the suported model.
-// Subclasses of BasePredictor will also need to implement an Initialize method
-// that will be called once the model is available, and a static validation
-// function with the following signature:
+// encapsulate the logic for loading the model and logging. Sub-classes of
+// BasePredictor implement an interface that depends on the nature of the
+// suported model. Subclasses of BasePredictor will also need to implement an
+// Initialize method that will be called once the model is available, and a
+// static validation function with the following signature:
//
// static RankerModelStatus ValidateModel(const RankerModel& model);
-class BasePredictor {
+class BasePredictor : public base::SupportsWeakPtr<BasePredictor> {
public:
- BasePredictor();
+ BasePredictor(const PredictorConfig& config);
virtual ~BasePredictor();
+ // Returns true if the predictor is ready to make predictions.
bool IsReady();
+ // Returns true if the base::Feature associated with this model is enabled.
+ bool is_query_enabled() const { return is_query_enabled_; }
+
+ // Logs the features of |example| to UKM using the given source_id.
+ void LogExampleToUkm(const RankerExample& example, ukm::SourceId source_id);
+
+ // Returns the model URL.
+ GURL GetModelUrl() const;
+ // Returns the model name.
+ std::string GetModelName() const;
protected:
- // The model used for prediction.
- std::unique_ptr<RankerModel> ranker_model_;
+ // Preprocessing applied to an example before prediction. The original
+ // RankerExample is not modified, so it is safe to use it later for logging.
+ RankerExample PreprocessExample(const RankerExample& example);
// Called when the RankerModelLoader has finished loading the model. Returns
// true only if the model was succesfully loaded and is ready to predict.
@@ -43,9 +66,17 @@ class BasePredictor {
// Called once the model loader as succesfully loaded the model.
void OnModelAvailable(std::unique_ptr<RankerModel> model);
std::unique_ptr<RankerModelLoader> model_loader_;
+ // The model used for prediction.
+ std::unique_ptr<RankerModel> ranker_model_;
private:
+ void LogFeatureToUkm(const std::string& feature_name,
+ const Feature& feature,
+ ukm::UkmEntryBuilder* ukm_builder);
+
bool is_ready_ = false;
+ bool is_query_enabled_ = false;
+ PredictorConfig config_;
DISALLOW_COPY_AND_ASSIGN(BasePredictor);
};
diff --git a/chromium/components/assist_ranker/base_predictor_unittest.cc b/chromium/components/assist_ranker/base_predictor_unittest.cc
new file mode 100644
index 00000000000..263f27cdb75
--- /dev/null
+++ b/chromium/components/assist_ranker/base_predictor_unittest.cc
@@ -0,0 +1,183 @@
+// Copyright 2017 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "components/assist_ranker/base_predictor.h"
+
+#include <memory>
+
+#include "base/bind.h"
+#include "base/bind_helpers.h"
+#include "base/memory/ptr_util.h"
+#include "base/test/scoped_feature_list.h"
+#include "base/test/scoped_task_environment.h"
+#include "components/assist_ranker/fake_ranker_model_loader.h"
+#include "components/assist_ranker/predictor_config.h"
+#include "components/assist_ranker/proto/ranker_example.pb.h"
+#include "components/assist_ranker/ranker_model.h"
+#include "components/ukm/test_ukm_recorder.h"
+#include "testing/gtest/include/gtest/gtest.h"
+#include "url/gurl.h"
+
+namespace assist_ranker {
+
+using ::assist_ranker::testing::FakeRankerModelLoader;
+
+namespace {
+
+// Predictor config for testing.
+const char kTestModelName[] = "test_model";
+const char kTestLoggingName[] = "TestLoggingName";
+const char kTestUmaPrefixName[] = "Test.Ranker";
+const char kTestUrlParamName[] = "ranker-model-url";
+const char kTestDefaultModelUrl[] = "https://foo.bar/model.bin";
+
+const char kBoolFeature[] = "bool_feature";
+const char kIntFeature[] = "int_feature";
+const char kFloatFeature[] = "float_feature";
+const char kStringFeature[] = "string_feature";
+const char kStringListFeature[] = "string_list_feature";
+const char kFeatureNotWhitelisted[] = "not_whitelisted";
+
+const char kTestNavigationUrl[] = "https://foo.com";
+
+const base::flat_set<std::string> kFeatureWhitelist({kBoolFeature, kIntFeature,
+ kFloatFeature,
+ kStringFeature,
+ kStringListFeature});
+
+const base::Feature kTestRankerQuery{"TestRankerQuery",
+ base::FEATURE_ENABLED_BY_DEFAULT};
+
+const base::FeatureParam<std::string> kTestRankerUrl{
+ &kTestRankerQuery, kTestUrlParamName, kTestDefaultModelUrl};
+
+const PredictorConfig kTestPredictorConfig = PredictorConfig{
+ kTestModelName, kTestLoggingName, kTestUmaPrefixName, LOG_UKM,
+ &kFeatureWhitelist, &kTestRankerQuery, &kTestRankerUrl};
+
+// Class that implements virtual functions of the base class.
+class FakePredictor : public BasePredictor {
+ public:
+ static std::unique_ptr<FakePredictor> Create();
+ ~FakePredictor() override{};
+ // Validation will always succeed.
+ static RankerModelStatus ValidateModel(const RankerModel& model);
+
+ protected:
+ // Not implementing any inference logic.
+ bool Initialize() override { return true; };
+
+ private:
+ FakePredictor(const PredictorConfig& config);
+ DISALLOW_COPY_AND_ASSIGN(FakePredictor);
+};
+
+FakePredictor::FakePredictor(const PredictorConfig& config)
+ : BasePredictor(config) {}
+
+RankerModelStatus FakePredictor::ValidateModel(const RankerModel& model) {
+ return RankerModelStatus::OK;
+}
+
+std::unique_ptr<FakePredictor> FakePredictor::Create() {
+ std::unique_ptr<FakePredictor> predictor(
+ new FakePredictor(kTestPredictorConfig));
+ auto ranker_model = base::MakeUnique<RankerModel>();
+ auto fake_model_loader = base::MakeUnique<FakeRankerModelLoader>(
+ base::BindRepeating(&FakePredictor::ValidateModel),
+ base::BindRepeating(&FakePredictor::OnModelAvailable,
+ base::Unretained(predictor.get())),
+ std::move(ranker_model));
+ predictor->LoadModel(std::move(fake_model_loader));
+ return predictor;
+}
+
+} // namespace
+
+class BasePredictorTest : public ::testing::Test {
+ protected:
+ BasePredictorTest() = default;
+ // Disables Query for the test predictor.
+ void DisableQuery();
+
+ ukm::SourceId GetSourceId();
+
+ ukm::TestUkmRecorder* GetTestUkmRecorder() { return &test_ukm_recorder_; }
+
+ private:
+ // Sets up the task scheduling/task-runner environment for each test.
+ base::test::ScopedTaskEnvironment scoped_task_environment_;
+
+ // Sets itself as the global UkmRecorder on construction.
+ ukm::TestAutoSetUkmRecorder test_ukm_recorder_;
+
+ // Manages the enabling/disabling of features within the scope of a test.
+ base::test::ScopedFeatureList scoped_feature_list_;
+
+ DISALLOW_COPY_AND_ASSIGN(BasePredictorTest);
+};
+
+ukm::SourceId BasePredictorTest::GetSourceId() {
+ ukm::SourceId source_id = ukm::UkmRecorder::GetNewSourceID();
+ test_ukm_recorder_.UpdateSourceURL(source_id, GURL(kTestNavigationUrl));
+ return source_id;
+}
+
+void BasePredictorTest::DisableQuery() {
+ scoped_feature_list_.InitWithFeatures({}, {kTestRankerQuery});
+}
+
+TEST_F(BasePredictorTest, BaseTest) {
+ auto predictor = FakePredictor::Create();
+ EXPECT_EQ(kTestModelName, predictor->GetModelName());
+ EXPECT_EQ(kTestDefaultModelUrl, predictor->GetModelUrl());
+ EXPECT_TRUE(predictor->is_query_enabled());
+ EXPECT_TRUE(predictor->IsReady());
+}
+
+TEST_F(BasePredictorTest, QueryDisabled) {
+ DisableQuery();
+ auto predictor = FakePredictor::Create();
+ EXPECT_EQ(kTestModelName, predictor->GetModelName());
+ EXPECT_EQ(kTestDefaultModelUrl, predictor->GetModelUrl());
+ EXPECT_FALSE(predictor->is_query_enabled());
+ EXPECT_FALSE(predictor->IsReady());
+}
+
+TEST_F(BasePredictorTest, LogExampleToUkm) {
+ auto predictor = FakePredictor::Create();
+ RankerExample example;
+ auto& features = *example.mutable_features();
+ features[kBoolFeature].set_bool_value(true);
+ features[kIntFeature].set_int32_value(42);
+ features[kFloatFeature].set_float_value(42.0f);
+ features[kStringFeature].set_string_value("42");
+ features[kStringListFeature].mutable_string_list()->add_string_value("42");
+
+ // This feature will not be logged.
+ features[kFeatureNotWhitelisted].set_bool_value(false);
+
+ predictor->LogExampleToUkm(example, GetSourceId());
+
+ EXPECT_EQ(1U, GetTestUkmRecorder()->sources_count());
+ EXPECT_EQ(1U, GetTestUkmRecorder()->entries_count());
+ std::vector<const ukm::mojom::UkmEntry*> entries =
+ GetTestUkmRecorder()->GetEntriesByName(kTestLoggingName);
+ EXPECT_EQ(1U, entries.size());
+ GetTestUkmRecorder()->ExpectEntryMetric(entries[0], kBoolFeature,
+ 72057594037927937);
+ GetTestUkmRecorder()->ExpectEntryMetric(entries[0], kIntFeature,
+ 216172782113783850);
+ GetTestUkmRecorder()->ExpectEntryMetric(entries[0], kFloatFeature,
+ 144115189185773568);
+ GetTestUkmRecorder()->ExpectEntryMetric(entries[0], kStringFeature,
+ 288230377208836903);
+ GetTestUkmRecorder()->ExpectEntryMetric(entries[0], kStringListFeature,
+ 360287971246764839);
+
+ EXPECT_FALSE(
+ GetTestUkmRecorder()->EntryHasMetric(entries[0], kFeatureNotWhitelisted));
+}
+
+} // namespace assist_ranker
diff --git a/chromium/components/assist_ranker/binary_classifier_predictor.cc b/chromium/components/assist_ranker/binary_classifier_predictor.cc
index b37c616fa18..cc595bf0a62 100644
--- a/chromium/components/assist_ranker/binary_classifier_predictor.cc
+++ b/chromium/components/assist_ranker/binary_classifier_predictor.cc
@@ -15,26 +15,33 @@
#include "components/assist_ranker/ranker_model.h"
#include "components/assist_ranker/ranker_model_loader_impl.h"
#include "net/url_request/url_request_context_getter.h"
-#include "url/gurl.h"
namespace assist_ranker {
-BinaryClassifierPredictor::BinaryClassifierPredictor(){};
+BinaryClassifierPredictor::BinaryClassifierPredictor(
+ const PredictorConfig& config)
+ : BasePredictor(config){};
BinaryClassifierPredictor::~BinaryClassifierPredictor(){};
// static
std::unique_ptr<BinaryClassifierPredictor> BinaryClassifierPredictor::Create(
- net::URLRequestContextGetter* request_context_getter,
+ const PredictorConfig& config,
const base::FilePath& model_path,
- GURL model_url,
- const std::string& uma_prefix) {
+ net::URLRequestContextGetter* request_context_getter) {
std::unique_ptr<BinaryClassifierPredictor> predictor(
- new BinaryClassifierPredictor());
+ new BinaryClassifierPredictor(config));
+ if (!predictor->is_query_enabled()) {
+ DVLOG(1) << "Query disabled, bypassing model loading.";
+ return predictor;
+ }
+ const GURL& model_url = predictor->GetModelUrl();
+ DVLOG(1) << "Creating predictor instance for " << predictor->GetModelName();
+ DVLOG(1) << "Model URL: " << model_url;
auto model_loader = base::MakeUnique<RankerModelLoaderImpl>(
- base::Bind(&BinaryClassifierPredictor::ValidateModel),
- base::Bind(&BinaryClassifierPredictor::OnModelAvailable,
- base::Unretained(predictor.get())),
- request_context_getter, model_path, model_url, uma_prefix);
+ base::BindRepeating(&BinaryClassifierPredictor::ValidateModel),
+ base::BindRepeating(&BinaryClassifierPredictor::OnModelAvailable,
+ base::Unretained(predictor.get())),
+ request_context_getter, model_path, model_url, config.uma_prefix);
predictor->LoadModel(std::move(model_loader));
return predictor;
}
@@ -42,18 +49,23 @@ std::unique_ptr<BinaryClassifierPredictor> BinaryClassifierPredictor::Create(
bool BinaryClassifierPredictor::Predict(const RankerExample& example,
bool* prediction) {
if (!IsReady()) {
+ DVLOG(1) << "Predictor " << GetModelName() << " not ready for prediction.";
return false;
}
- *prediction = inference_module_->Predict(example);
+
+ *prediction = inference_module_->Predict(PreprocessExample(example));
+ DVLOG(1) << "Predictor " << GetModelName() << " predicted: " << *prediction;
return true;
}
bool BinaryClassifierPredictor::PredictScore(const RankerExample& example,
float* prediction) {
if (!IsReady()) {
+ DVLOG(1) << "Predictor " << GetModelName() << " not ready for prediction.";
return false;
}
- *prediction = inference_module_->PredictScore(example);
+ *prediction = inference_module_->PredictScore(PreprocessExample(example));
+ DVLOG(1) << "Predictor " << GetModelName() << " predicted: " << prediction;
return true;
}
@@ -61,17 +73,22 @@ bool BinaryClassifierPredictor::PredictScore(const RankerExample& example,
RankerModelStatus BinaryClassifierPredictor::ValidateModel(
const RankerModel& model) {
if (model.proto().model_case() != RankerModelProto::kLogisticRegression) {
+ DVLOG(0) << "Model is incompatible.";
return RankerModelStatus::INCOMPATIBLE;
}
return RankerModelStatus::OK;
}
bool BinaryClassifierPredictor::Initialize() {
- // TODO(hamelphi): move the GLRM proto up one layer in the proto in order to
- // be independent of the client feature.
- inference_module_.reset(new GenericLogisticRegressionInference(
- ranker_model_->proto().logistic_regression()));
- return true;
+ if (ranker_model_->proto().model_case() ==
+ RankerModelProto::kLogisticRegression) {
+ inference_module_ = std::make_unique<GenericLogisticRegressionInference>(
+ ranker_model_->proto().logistic_regression());
+ return true;
+ }
+
+ DVLOG(0) << "Could not initialize inference module.";
+ return false;
}
} // namespace assist_ranker
diff --git a/chromium/components/assist_ranker/binary_classifier_predictor.h b/chromium/components/assist_ranker/binary_classifier_predictor.h
index 1f960b01f4e..e932c91388f 100644
--- a/chromium/components/assist_ranker/binary_classifier_predictor.h
+++ b/chromium/components/assist_ranker/binary_classifier_predictor.h
@@ -9,8 +9,6 @@
#include "components/assist_ranker/base_predictor.h"
#include "components/assist_ranker/proto/ranker_example.pb.h"
-class GURL;
-
namespace base {
class FilePath;
}
@@ -28,11 +26,13 @@ class BinaryClassifierPredictor : public BasePredictor {
public:
~BinaryClassifierPredictor() override;
+ // Returns an new predictor instance with the given |config| and initialize
+ // its model loader. The |request_context getter| is passed to the
+ // predictor's model_loader which holds it as scoped_refptr.
static std::unique_ptr<BinaryClassifierPredictor> Create(
- net::URLRequestContextGetter* request_context_getter,
+ const PredictorConfig& config,
const base::FilePath& model_path,
- GURL model_url,
- const std::string& uma_prefix);
+ net::URLRequestContextGetter* request_context_getter) WARN_UNUSED_RESULT;
// Fills in a boolean decision given a RankerExample. Returns false if a
// prediction could not be made (e.g. the model is not loaded yet).
@@ -53,7 +53,7 @@ class BinaryClassifierPredictor : public BasePredictor {
private:
friend class BinaryClassifierPredictorTest;
- BinaryClassifierPredictor();
+ BinaryClassifierPredictor(const PredictorConfig& config);
// TODO(hamelphi): Use an abstract BinaryClassifierInferenceModule in order to
// generalize to other models.
diff --git a/chromium/components/assist_ranker/binary_classifier_predictor_unittest.cc b/chromium/components/assist_ranker/binary_classifier_predictor_unittest.cc
index fbb2c622d59..9249301d472 100644
--- a/chromium/components/assist_ranker/binary_classifier_predictor_unittest.cc
+++ b/chromium/components/assist_ranker/binary_classifier_predictor_unittest.cc
@@ -8,6 +8,7 @@
#include "base/bind.h"
#include "base/bind_helpers.h"
+#include "base/feature_list.h"
#include "base/memory/ptr_util.h"
#include "components/assist_ranker/fake_ranker_model_loader.h"
#include "components/assist_ranker/proto/ranker_model.pb.h"
@@ -21,11 +22,14 @@ using ::assist_ranker::testing::FakeRankerModelLoader;
class BinaryClassifierPredictorTest : public ::testing::Test {
public:
std::unique_ptr<BinaryClassifierPredictor> InitPredictor(
- std::unique_ptr<RankerModel> ranker_model);
+ std::unique_ptr<RankerModel> ranker_model,
+ const PredictorConfig& config);
// This model will return the value of |feature| as a prediction.
GenericLogisticRegressionModel GetSimpleLogisticRegressionModel();
+ PredictorConfig GetConfig();
+
protected:
const std::string feature_ = "feature";
const float threshold_ = 0.5;
@@ -33,10 +37,11 @@ class BinaryClassifierPredictorTest : public ::testing::Test {
std::unique_ptr<BinaryClassifierPredictor>
BinaryClassifierPredictorTest::InitPredictor(
- std::unique_ptr<RankerModel> ranker_model) {
+ std::unique_ptr<RankerModel> ranker_model,
+ const PredictorConfig& config) {
std::unique_ptr<BinaryClassifierPredictor> predictor(
- new BinaryClassifierPredictor());
- auto fake_model_loader = base::MakeUnique<FakeRankerModelLoader>(
+ new BinaryClassifierPredictor(config));
+ auto fake_model_loader = std::make_unique<FakeRankerModelLoader>(
base::Bind(&BinaryClassifierPredictor::ValidateModel),
base::Bind(&BinaryClassifierPredictor::OnModelAvailable,
base::Unretained(predictor.get())),
@@ -45,6 +50,20 @@ BinaryClassifierPredictorTest::InitPredictor(
return predictor;
}
+const base::Feature kTestRankerQuery{"TestRankerQuery",
+ base::FEATURE_ENABLED_BY_DEFAULT};
+
+const base::FeatureParam<std::string> kTestRankerUrl{
+ &kTestRankerQuery, "url-param-name", "https://default.model.url"};
+
+PredictorConfig BinaryClassifierPredictorTest::GetConfig() {
+ PredictorConfig config("model_name", "logging_name", "uma_prefix", LOG_NONE,
+ GetEmptyWhitelist(), &kTestRankerQuery,
+ &kTestRankerUrl);
+
+ return config;
+}
+
GenericLogisticRegressionModel
BinaryClassifierPredictorTest::GetSimpleLogisticRegressionModel() {
GenericLogisticRegressionModel lr_model;
@@ -58,7 +77,7 @@ BinaryClassifierPredictorTest::GetSimpleLogisticRegressionModel() {
TEST_F(BinaryClassifierPredictorTest, EmptyRankerModel) {
auto ranker_model = base::MakeUnique<RankerModel>();
- auto predictor = InitPredictor(std::move(ranker_model));
+ auto predictor = InitPredictor(std::move(ranker_model), GetConfig());
EXPECT_FALSE(predictor->IsReady());
RankerExample ranker_example;
@@ -78,7 +97,7 @@ TEST_F(BinaryClassifierPredictorTest, NoInferenceModuleForModel) {
->mutable_translate()
->mutable_translate_logistic_regression_model()
->set_bias(1);
- auto predictor = InitPredictor(std::move(ranker_model));
+ auto predictor = InitPredictor(std::move(ranker_model), GetConfig());
EXPECT_FALSE(predictor->IsReady());
RankerExample ranker_example;
@@ -94,7 +113,7 @@ TEST_F(BinaryClassifierPredictorTest, GenericLogisticRegressionModel) {
auto ranker_model = base::MakeUnique<RankerModel>();
*ranker_model->mutable_proto()->mutable_logistic_regression() =
GetSimpleLogisticRegressionModel();
- auto predictor = InitPredictor(std::move(ranker_model));
+ auto predictor = InitPredictor(std::move(ranker_model), GetConfig());
EXPECT_TRUE(predictor->IsReady());
RankerExample ranker_example;
diff --git a/chromium/components/assist_ranker/generic_logistic_regression_inference.cc b/chromium/components/assist_ranker/generic_logistic_regression_inference.cc
index f2da478a433..1248b69772c 100644
--- a/chromium/components/assist_ranker/generic_logistic_regression_inference.cc
+++ b/chromium/components/assist_ranker/generic_logistic_regression_inference.cc
@@ -29,7 +29,7 @@ float GenericLogisticRegressionInference::PredictScore(
const FeatureWeight& feature_weight = weight_it.second;
switch (feature_weight.feature_type_case()) {
case FeatureWeight::FEATURE_TYPE_NOT_SET: {
- DVLOG(1) << "Feature type not set for " << feature_name;
+ DVLOG(0) << "Feature type not set for " << feature_name;
break;
}
case FeatureWeight::kScalar: {
@@ -37,6 +37,8 @@ float GenericLogisticRegressionInference::PredictScore(
if (GetFeatureValueAsFloat(feature_name, example, &value)) {
const float weight = feature_weight.scalar();
activation += value * weight;
+ } else {
+ DVLOG(1) << "Feature not in example: " << feature_name;
}
break;
}
@@ -50,19 +52,22 @@ float GenericLogisticRegressionInference::PredictScore(
} else {
// If the category is not found, use the default weight.
activation += feature_weight.one_hot().default_weight();
+ DVLOG(1) << "Unknown feature value for " << feature_name << ": "
+ << value;
}
} else {
// If the feature is missing, use the default weight.
activation += feature_weight.one_hot().default_weight();
+ DVLOG(1) << "Feature not in example: " << feature_name;
}
break;
}
case FeatureWeight::kSparse: {
- DVLOG(1) << "Sparse features not implemented yet.";
+ DVLOG(0) << "Sparse features not implemented yet.";
break;
}
case FeatureWeight::kBucketized: {
- DVLOG(1) << "Bucketized features not implemented yet.";
+ DVLOG(0) << "Bucketized features not implemented yet.";
break;
}
}
diff --git a/chromium/components/assist_ranker/predictor_config.cc b/chromium/components/assist_ranker/predictor_config.cc
new file mode 100644
index 00000000000..57434e52390
--- /dev/null
+++ b/chromium/components/assist_ranker/predictor_config.cc
@@ -0,0 +1,14 @@
+// Copyright 2017 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "components/assist_ranker/predictor_config.h"
+
+namespace assist_ranker {
+
+const base::flat_set<std::string>* GetEmptyWhitelist() {
+ static auto* whitelist = new base::flat_set<std::string>();
+ return whitelist;
+}
+
+} // namespace assist_ranker
diff --git a/chromium/components/assist_ranker/predictor_config.h b/chromium/components/assist_ranker/predictor_config.h
new file mode 100644
index 00000000000..6a545889f4f
--- /dev/null
+++ b/chromium/components/assist_ranker/predictor_config.h
@@ -0,0 +1,52 @@
+// Copyright 2017 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef COMPONENTS_ASSIST_RANKER_PREDICTOR_CONFIG_H_
+#define COMPONENTS_ASSIST_RANKER_PREDICTOR_CONFIG_H_
+
+#include <string>
+
+#include "base/containers/flat_set.h"
+#include "base/metrics/field_trial_params.h"
+
+namespace assist_ranker {
+
+// TODO(chrome-ranker-team): Implement other logging types.
+enum LogType {
+ LOG_NONE = 0,
+ LOG_UKM = 1,
+};
+
+// Empty feature whitelist used for testing.
+const base::flat_set<std::string>* GetEmptyWhitelist();
+
+// This struct holds the config options for logging, loading and field trial
+// for a predictor.
+struct PredictorConfig {
+ PredictorConfig(const char* model_name,
+ const char* logging_name,
+ const char* uma_prefix,
+ const LogType log_type,
+ const base::flat_set<std::string>* feature_whitelist,
+ const base::Feature* field_trial,
+ const base::FeatureParam<std::string>* field_trial_url_param)
+ : model_name(model_name),
+ logging_name(logging_name),
+ uma_prefix(uma_prefix),
+ log_type(log_type),
+ feature_whitelist(feature_whitelist),
+ field_trial(field_trial),
+ field_trial_url_param(field_trial_url_param) {}
+ const char* model_name;
+ const char* logging_name;
+ const char* uma_prefix;
+ const LogType log_type;
+ const base::flat_set<std::string>* feature_whitelist;
+ const base::Feature* field_trial;
+ const base::FeatureParam<std::string>* field_trial_url_param;
+};
+
+} // namespace assist_ranker
+
+#endif // COMPONENTS_ASSIST_RANKER_PREDICTOR_CONFIG_H_
diff --git a/chromium/components/assist_ranker/predictor_config_definitions.cc b/chromium/components/assist_ranker/predictor_config_definitions.cc
new file mode 100644
index 00000000000..771593c3d26
--- /dev/null
+++ b/chromium/components/assist_ranker/predictor_config_definitions.cc
@@ -0,0 +1,75 @@
+// Copyright 2017 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "components/assist_ranker/predictor_config_definitions.h"
+
+namespace assist_ranker {
+
+#if defined(OS_ANDROID)
+const base::Feature kContextualSearchRankerQuery{
+ "ContextualSearchRankerQuery", base::FEATURE_DISABLED_BY_DEFAULT};
+
+namespace {
+
+const char kContextualSearchModelName[] = "contextual_search_model";
+const char kContextualSearchLoggingName[] = "ContextualSearch";
+const char kContextualSearchUmaPrefixName[] = "Search.ContextualSearch.Ranker";
+
+const char kContextualSearchDefaultModelUrl[] =
+ "https://www.gstatic.com/chrome/intelligence/assist/ranker/models/"
+ "contextual_search/test_ranker_model_20171109_short_words_v2.pb.bin";
+
+const base::FeatureParam<std::string>*
+GetContextualSearchRankerUrlFeatureParam() {
+ static auto* kContextualSearchRankerUrl = new base::FeatureParam<std::string>(
+ &kContextualSearchRankerQuery, "contextual-search-ranker-model-url",
+ kContextualSearchDefaultModelUrl);
+ return kContextualSearchRankerUrl;
+}
+
+// NOTE: This list needs to be kept in sync with tools/metrics/ukm/ukm.xml!
+// Only features within this list will be logged to UKM.
+// TODO(chrome-ranker-team) Deprecate the whitelist once it is available through
+// the UKM generated API.
+const base::flat_set<std::string>* GetContextualSearchFeatureWhitelist() {
+ static auto* kContextualSearchFeatureWhitelist =
+ new base::flat_set<std::string>({"DidOptIn",
+ "DurationAfterScrollMs",
+ "IsEntity",
+ "IsEntityEligible",
+ "IsHttp",
+ "IsLanguageMismatch",
+ "IsLongWord",
+ "IsSecondTapOverride",
+ "IsShortWord",
+ "IsWordEdge",
+ "OutcomeRankerDidPredict",
+ "OutcomeRankerPrediction",
+ "OutcomeWasCardsDataShown",
+ "OutcomeWasPanelOpened",
+ "OutcomeWasQuickActionClicked",
+ "OutcomeWasQuickAnswerSeen",
+ "Previous28DayCtrPercent",
+ "Previous28DayImpressionsCount",
+ "PreviousWeekCtrPercent",
+ "PreviousWeekImpressionsCount",
+ "ScreenTopDps",
+ "TapDurationMs",
+ "WasScreenBottom"});
+ return kContextualSearchFeatureWhitelist;
+}
+
+} // namespace
+
+const PredictorConfig GetContextualSearchPredictorConfig() {
+ static auto kContextualSearchPredictorConfig = *(new PredictorConfig(
+ kContextualSearchModelName, kContextualSearchLoggingName,
+ kContextualSearchUmaPrefixName, LOG_UKM,
+ GetContextualSearchFeatureWhitelist(), &kContextualSearchRankerQuery,
+ GetContextualSearchRankerUrlFeatureParam()));
+ return kContextualSearchPredictorConfig;
+}
+#endif // OS_ANDROID
+
+} // namespace assist_ranker
diff --git a/chromium/components/assist_ranker/predictor_config_definitions.h b/chromium/components/assist_ranker/predictor_config_definitions.h
new file mode 100644
index 00000000000..431e0b4e754
--- /dev/null
+++ b/chromium/components/assist_ranker/predictor_config_definitions.h
@@ -0,0 +1,26 @@
+// Copyright 2017 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef COMPONENTS_ASSIST_RANKER_PREDICTOR_CONFIG_DEFINITIONS_H_
+#define COMPONENTS_ASSIST_RANKER_PREDICTOR_CONFIG_DEFINITIONS_H_
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+
+#include "base/feature_list.h"
+#include "base/metrics/field_trial_params.h"
+#include "build/build_config.h"
+#include "components/assist_ranker/predictor_config.h"
+
+namespace assist_ranker {
+
+#if defined(OS_ANDROID)
+extern const base::Feature kContextualSearchRankerQuery;
+const PredictorConfig GetContextualSearchPredictorConfig();
+#endif // OS_ANDROID
+
+} // namespace assist_ranker
+
+#endif // COMPONENTS_ASSIST_RANKER_PREDICTOR_CONFIG_DEFINITIONS_H_
diff --git a/chromium/components/assist_ranker/proto/ranker_example.proto b/chromium/components/assist_ranker/proto/ranker_example.proto
index 45ba2368c77..36d1a2c6e38 100644
--- a/chromium/components/assist_ranker/proto/ranker_example.proto
+++ b/chromium/components/assist_ranker/proto/ranker_example.proto
@@ -10,6 +10,10 @@ option optimize_for = LITE_RUNTIME;
package assist_ranker;
+message StringList {
+ repeated bytes string_value = 1;
+}
+
// Generic message that can contain a variety of data types.
message Feature {
oneof feature_type {
@@ -20,6 +24,8 @@ message Feature {
int32 int32_value = 3;
// String values are used for one-hot features.
bytes string_value = 4;
+ // String list are used for sparse features.
+ StringList string_list = 5;
}
}
@@ -32,5 +38,9 @@ message RankerExample {
// This field represents the ground truth that the ranker is
// expected to predict, and is typically derived from user feedback. It is
// used for training only and is not required for inference.
+ // NOTE: this field will not be logged. If you want to log an outcome, add it
+ // to the features field before calling LogExample.
+ // TODO(chrome-ranker-team) Add a metadata field to log metrics that are not
+ // used as model input.
optional Feature target = 2;
-} \ No newline at end of file
+}
diff --git a/chromium/components/assist_ranker/ranker_example_util.cc b/chromium/components/assist_ranker/ranker_example_util.cc
index 45c2dc3f163..54d4dbd58f7 100644
--- a/chromium/components/assist_ranker/ranker_example_util.cc
+++ b/chromium/components/assist_ranker/ranker_example_util.cc
@@ -3,9 +3,54 @@
// found in the LICENSE file.
#include "components/assist_ranker/ranker_example_util.h"
+#include "base/bit_cast.h"
+#include "base/format_macros.h"
#include "base/logging.h"
+#include "base/metrics/metrics_hashes.h"
+#include "base/strings/stringprintf.h"
namespace assist_ranker {
+namespace {
+const uint64_t MASK32Bits = (1LL << 32) - 1;
+constexpr int kFloatMainDigits = 23;
+// Returns lower 32 bits of the hash of the input.
+int32_t StringToIntBits(const std::string& str) {
+ return base::HashMetricName(str) & MASK32Bits;
+}
+
+// Converts float to int32
+int32_t FloatToIntBits(float f) {
+ if (std::numeric_limits<float>::is_iec559) {
+ // Directly bit_cast if float follows ieee754 standard.
+ return bit_cast<int32_t>(f);
+ } else {
+ // Otherwise, manually calculate sign, exp and mantissa.
+ // For sign.
+ const uint32_t sign = f < 0;
+
+ // For exponent.
+ int exp;
+ f = std::abs(std::frexp(f, &exp));
+ // Add 126 to get non-negative format of exp.
+ // This should not be 127 because the return of frexp is different from
+ // ieee754 with a multiple of 2.
+ const uint32_t exp_u = exp + 126;
+
+ // Get mantissa.
+ const uint32_t mantissa = std::ldexp(f * 2.0f - 1.0f, kFloatMainDigits);
+ // Set each bits and return.
+ return (sign << 31) | (exp_u << kFloatMainDigits) | mantissa;
+ }
+}
+
+// Pair type, value and index into one int64.
+int64_t PairInt(const uint64_t type,
+ const uint32_t value,
+ const uint64_t index) {
+ return (type << 56) | (index << 32) | static_cast<uint64_t>(value);
+}
+
+} // namespace
bool SafeGetFeature(const std::string& key,
const RankerExample& example,
@@ -42,6 +87,42 @@ bool GetFeatureValueAsFloat(const std::string& key,
return true;
}
+bool FeatureToInt64(const Feature& feature,
+ int64_t* const res,
+ const int index) {
+ int32_t value = -1;
+ int32_t type = feature.feature_type_case();
+ switch (type) {
+ case Feature::kBoolValue:
+ value = static_cast<int32_t>(feature.bool_value());
+ break;
+ case Feature::kFloatValue:
+ value = FloatToIntBits(feature.float_value());
+ break;
+ case Feature::kInt32Value:
+ value = feature.int32_value();
+ break;
+ case Feature::kStringValue:
+ value = StringToIntBits(feature.string_value());
+ break;
+ case Feature::kStringList:
+ if (index >= 0 && index < feature.string_list().string_value_size()) {
+ value = StringToIntBits(feature.string_list().string_value(index));
+ } else {
+ DVLOG(3) << "Invalid index for string list: " << index;
+ NOTREACHED();
+ return false;
+ }
+ break;
+ default:
+ DVLOG(3) << "Feature type is supported for logging: " << type;
+ NOTREACHED();
+ return false;
+ }
+ *res = PairInt(type, value, index);
+ return true;
+ }
+
bool GetOneHotValue(const std::string& key,
const RankerExample& example,
std::string* value) {
@@ -60,4 +141,20 @@ bool GetOneHotValue(const std::string& key,
return true;
}
+// Converts string to a hex hash string.
+std::string HashFeatureName(const std::string& feature_name) {
+ uint64_t feature_key = base::HashMetricName(feature_name);
+ return base::StringPrintf("%016" PRIx64, feature_key);
+}
+
+RankerExample HashExampleFeatureNames(const RankerExample& example) {
+ RankerExample hashed_example;
+ auto& output_features = *hashed_example.mutable_features();
+ for (const auto& feature : example.features()) {
+ output_features[HashFeatureName(feature.first)] = feature.second;
+ }
+ *hashed_example.mutable_target() = example.target();
+ return hashed_example;
+}
+
} // namespace assist_ranker
diff --git a/chromium/components/assist_ranker/ranker_example_util.h b/chromium/components/assist_ranker/ranker_example_util.h
index 4f49a251409..c75bb393522 100644
--- a/chromium/components/assist_ranker/ranker_example_util.h
+++ b/chromium/components/assist_ranker/ranker_example_util.h
@@ -25,6 +25,20 @@ bool GetFeatureValueAsFloat(const std::string& key,
const RankerExample& example,
float* value) WARN_UNUSED_RESULT;
+// Converts a Ranker Feature to an int64. For feature list, this converts the
+// index-th value of the list.
+// A feature is converted to an int64 by:
+// (a) use low32 bits represent the value of the feature.
+// a.1) bool_value, int32_value is directly converted to an int32.
+// a.2) string_value is hashed to an int32.
+// a.3) float_value is directly bit_cast into int32 if it follows ieee754
+// standard; otherwise manually calculate sign, exponent and mantissa.
+// (b) use high32 bits represent the type of the feature.
+// b.1) use high8 bits represent the feature_type_case.
+// b.2) use low24 bits represent the index if the feature is a list.
+// Returns true if the feature is converted successfully; false otherwise.
+bool FeatureToInt64(const Feature& feature, int64_t* res, int index = 0);
+
// Extract category from one-hot feature. Returns true and fills
// in |value| if the feature is found and is of type string_value. Returns false
// otherwise.
@@ -32,6 +46,15 @@ bool GetOneHotValue(const std::string& key,
const RankerExample& example,
std::string* value) WARN_UNUSED_RESULT;
+// Converts a string to a hex ahsh string.
+std::string HashFeatureName(const std::string& feature_name);
+
+// Hashes feature names to an hex string.
+// Features logged through UKM will apply this transformation when logging
+// features, so models trained on UKM data are expected to have hashed input
+// feature names.
+RankerExample HashExampleFeatureNames(const RankerExample& example);
+
} // namespace assist_ranker
#endif // COMPONENTS_ASSIST_RANKER_RANKER_EXAMPLE_UTIL_H_
diff --git a/chromium/components/assist_ranker/ranker_example_util_unittest.cc b/chromium/components/assist_ranker/ranker_example_util_unittest.cc
index ac9f34c5e1a..348dadc3666 100644
--- a/chromium/components/assist_ranker/ranker_example_util_unittest.cc
+++ b/chromium/components/assist_ranker/ranker_example_util_unittest.cc
@@ -103,4 +103,81 @@ TEST_F(RankerExampleUtilTest, GetOneHotValue) {
EXPECT_FALSE(GetOneHotValue("foo", example_, &value));
}
+TEST_F(RankerExampleUtilTest, ScalarFeatureInt64Conversion) {
+ Feature feature;
+ int64_t int64_value;
+
+ feature.set_bool_value(true);
+ EXPECT_TRUE(FeatureToInt64(feature, &int64_value));
+ EXPECT_EQ(int64_value, 72057594037927937LL);
+
+ feature.set_int32_value(std::numeric_limits<int32_t>::max());
+ EXPECT_TRUE(FeatureToInt64(feature, &int64_value));
+ EXPECT_EQ(int64_value, 216172784261267455LL);
+
+ feature.set_int32_value(std::numeric_limits<int32_t>::lowest());
+ EXPECT_TRUE(FeatureToInt64(feature, &int64_value));
+ EXPECT_EQ(int64_value, 216172784261267456LL);
+
+ feature.set_string_value("foo");
+ EXPECT_TRUE(FeatureToInt64(feature, &int64_value));
+ EXPECT_EQ(int64_value, 288230377439557724LL);
+}
+
+TEST_F(RankerExampleUtilTest, FloatFeatureInt64Conversion) {
+ Feature feature;
+ int64_t int64_value;
+
+ feature.set_float_value(std::numeric_limits<float>::epsilon());
+ EXPECT_TRUE(FeatureToInt64(feature, &int64_value));
+ EXPECT_EQ(int64_value, 144115188948271104LL);
+
+ feature.set_float_value(-std::numeric_limits<float>::epsilon());
+ EXPECT_TRUE(FeatureToInt64(feature, &int64_value));
+ EXPECT_EQ(int64_value, 144115191095754752LL);
+
+ feature.set_float_value(std::numeric_limits<float>::max());
+ EXPECT_TRUE(FeatureToInt64(feature, &int64_value));
+ EXPECT_EQ(int64_value, 144115190214950911LL);
+
+ feature.set_float_value(std::numeric_limits<float>::lowest());
+ EXPECT_TRUE(FeatureToInt64(feature, &int64_value));
+ EXPECT_EQ(int64_value, 144115192362434559LL);
+}
+
+TEST_F(RankerExampleUtilTest, StringListInt64Conversion) {
+ Feature feature;
+ int64_t int64_value;
+
+ feature.mutable_string_list()->add_string_value("");
+ feature.mutable_string_list()->add_string_value("TEST");
+ EXPECT_TRUE(FeatureToInt64(feature, &int64_value, 1));
+ EXPECT_EQ(int64_value, 360287974776690660LL);
+}
+
+TEST_F(RankerExampleUtilTest, HashExampleFeatureNames) {
+ auto hashed_example = HashExampleFeatureNames(example_);
+ // Hashed example has the same number of features.
+ EXPECT_EQ(example_.features().size(), hashed_example.features().size());
+
+ // But the feature names have changed.
+ EXPECT_FALSE(SafeGetFeature(bool_name_, hashed_example, nullptr));
+ EXPECT_FALSE(SafeGetFeature(int32_name_, hashed_example, nullptr));
+ EXPECT_FALSE(SafeGetFeature(float_name_, hashed_example, nullptr));
+ EXPECT_FALSE(SafeGetFeature(one_hot_name_, hashed_example, nullptr));
+
+ EXPECT_TRUE(
+ SafeGetFeature(HashFeatureName(bool_name_), hashed_example, nullptr));
+
+ // Values have not changed.
+ float float_value;
+ EXPECT_TRUE(GetFeatureValueAsFloat(HashFeatureName(float_name_),
+ hashed_example, &float_value));
+ EXPECT_EQ(float_value_, float_value);
+ std::string string_value;
+ EXPECT_TRUE(GetOneHotValue(HashFeatureName(one_hot_name_), hashed_example,
+ &string_value));
+ EXPECT_EQ(one_hot_value_, string_value);
+}
+
} // namespace assist_ranker