diff options
author | Allan Sandfeld Jensen <allan.jensen@qt.io> | 2018-01-31 16:33:43 +0100 |
---|---|---|
committer | Allan Sandfeld Jensen <allan.jensen@qt.io> | 2018-02-06 16:33:22 +0000 |
commit | da51f56cc21233c2d30f0fe0d171727c3102b2e0 (patch) | |
tree | 4e579ab70ce4b19bee7984237f3ce05a96d59d83 /chromium/components/assist_ranker | |
parent | c8c2d1901aec01e934adf561a9fdf0cc776cdef8 (diff) | |
download | qtwebengine-chromium-da51f56cc21233c2d30f0fe0d171727c3102b2e0.tar.gz |
BASELINE: Update Chromium to 65.0.3525.40
Also imports missing submodules
Change-Id: I36901b7c6a325cda3d2c10cedb2186c25af3b79b
Reviewed-by: Alexandru Croitor <alexandru.croitor@qt.io>
Diffstat (limited to 'chromium/components/assist_ranker')
20 files changed, 823 insertions, 76 deletions
diff --git a/chromium/components/assist_ranker/BUILD.gn b/chromium/components/assist_ranker/BUILD.gn index 992fe8b8117..1c07991b5dd 100644 --- a/chromium/components/assist_ranker/BUILD.gn +++ b/chromium/components/assist_ranker/BUILD.gn @@ -15,6 +15,10 @@ static_library("assist_ranker") { "fake_ranker_model_loader.h", "generic_logistic_regression_inference.cc", "generic_logistic_regression_inference.h", + "predictor_config.cc", + "predictor_config.h", + "predictor_config_definitions.cc", + "predictor_config_definitions.h", "ranker_example_util.cc", "ranker_example_util.h", "ranker_model.cc", @@ -32,6 +36,7 @@ static_library("assist_ranker") { "//components/data_use_measurement/core", "//components/keyed_service/core", "//net", + "//services/metrics/public/cpp:metrics_cpp", "//url", ] } @@ -40,6 +45,7 @@ source_set("unit_tests") { testonly = true sources = [ + "base_predictor_unittest.cc", "binary_classifier_predictor_unittest.cc", "generic_logistic_regression_inference_unittest.cc", "ranker_example_util_unittest.cc", @@ -51,6 +57,7 @@ source_set("unit_tests") { ":assist_ranker", "//base", "//components/assist_ranker/proto", + "//components/ukm:test_support", "//net:test_support", "//testing/gtest", ] diff --git a/chromium/components/assist_ranker/DEPS b/chromium/components/assist_ranker/DEPS index f66492a2bcd..d8355c686f2 100644 --- a/chromium/components/assist_ranker/DEPS +++ b/chromium/components/assist_ranker/DEPS @@ -2,5 +2,7 @@ include_rules = [ "+components/data_use_measurement/core", "+components/keyed_service/core", "+components/metrics", + "+components/ukm", "+net", + "+services/metrics/public", ]
\ No newline at end of file diff --git a/chromium/components/assist_ranker/assist_ranker_service.h b/chromium/components/assist_ranker/assist_ranker_service.h index d2015f07668..bb11a4789b2 100644 --- a/chromium/components/assist_ranker/assist_ranker_service.h +++ b/chromium/components/assist_ranker/assist_ranker_service.h @@ -9,32 +9,25 @@ #include <string> #include "base/macros.h" +#include "base/memory/weak_ptr.h" #include "components/keyed_service/core/keyed_service.h" -class GURL; - namespace assist_ranker { class BinaryClassifierPredictor; +struct PredictorConfig; -// TODO(crbug.com/778468) : Refactor this so that the service owns the predictor -// objects and enforce model uniqueness through internal registration in order -// to avoid cache collisions. -// // Service that provides Predictor objects. class AssistRankerService : public KeyedService { public: AssistRankerService() = default; - // Returns a binary classification model. |model_filename| is the filename of - // the cached model. It should be unique to this predictor to avoid cache - // collision. |model_url| represents a unique ID for the desired model (see - // ranker_model_loader.h for more details). |uma_prefix| is used to log - // histograms related to the loading of the model. - virtual std::unique_ptr<BinaryClassifierPredictor> - FetchBinaryClassifierPredictor(GURL model_url, - const std::string& model_filename, - const std::string& uma_prefix) = 0; + // Returns a binary classification model given a PredictorConfig. + // The predictor is instantiated the first time a predictor is fetched. The + // next calls to fetch will return a pointer to the already instantiated + // predictor. + virtual base::WeakPtr<BinaryClassifierPredictor> + FetchBinaryClassifierPredictor(const PredictorConfig& config) = 0; private: DISALLOW_COPY_AND_ASSIGN(AssistRankerService); diff --git a/chromium/components/assist_ranker/assist_ranker_service_impl.cc b/chromium/components/assist_ranker/assist_ranker_service_impl.cc index 255183a6e8c..5b870b8f578 100644 --- a/chromium/components/assist_ranker/assist_ranker_service_impl.cc +++ b/chromium/components/assist_ranker/assist_ranker_service_impl.cc @@ -3,7 +3,7 @@ // found in the LICENSE file. #include "components/assist_ranker/assist_ranker_service_impl.h" - +#include "base/memory/weak_ptr.h" #include "components/assist_ranker/binary_classifier_predictor.h" #include "components/assist_ranker/ranker_model_loader_impl.h" #include "net/url_request/url_request_context_getter.h" @@ -19,15 +19,27 @@ AssistRankerServiceImpl::AssistRankerServiceImpl( AssistRankerServiceImpl::~AssistRankerServiceImpl() {} -std::unique_ptr<BinaryClassifierPredictor> +base::WeakPtr<BinaryClassifierPredictor> AssistRankerServiceImpl::FetchBinaryClassifierPredictor( - GURL model_url, - const std::string& model_filename, - const std::string& uma_prefix) { + const PredictorConfig& config) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); - return BinaryClassifierPredictor::Create(url_request_context_getter_.get(), - GetModelPath(model_filename), - model_url, uma_prefix); + const std::string& model_name = config.model_name; + auto predictor_it = predictor_map_.find(model_name); + if (predictor_it != predictor_map_.end()) { + DVLOG(1) << "Predictor " << model_name << " already initialized."; + return base::AsWeakPtr( + static_cast<BinaryClassifierPredictor*>(predictor_it->second.get())); + } + + // The predictor does not exist yet, so we create one. + DVLOG(1) << "Initializing predictor: " << model_name; + std::unique_ptr<BinaryClassifierPredictor> predictor = + BinaryClassifierPredictor::Create(config, GetModelPath(model_name), + url_request_context_getter_.get()); + base::WeakPtr<BinaryClassifierPredictor> weak_ptr = + base::AsWeakPtr(predictor.get()); + predictor_map_[model_name] = std::move(predictor); + return weak_ptr; } base::FilePath AssistRankerServiceImpl::GetModelPath( diff --git a/chromium/components/assist_ranker/assist_ranker_service_impl.h b/chromium/components/assist_ranker/assist_ranker_service_impl.h index e4e967e5981..1b75ff4017f 100644 --- a/chromium/components/assist_ranker/assist_ranker_service_impl.h +++ b/chromium/components/assist_ranker/assist_ranker_service_impl.h @@ -7,13 +7,13 @@ #include <memory> #include <string> +#include <unordered_map> #include "base/files/file_path.h" #include "base/memory/ref_counted.h" #include "base/sequence_checker.h" #include "components/assist_ranker/assist_ranker_service.h" - -class GURL; +#include "components/assist_ranker/predictor_config.h" namespace net { class URLRequestContextGetter; @@ -21,6 +21,7 @@ class URLRequestContextGetter; namespace assist_ranker { +class BasePredictor; class BinaryClassifierPredictor; class AssistRankerServiceImpl : public AssistRankerService { @@ -31,10 +32,8 @@ class AssistRankerServiceImpl : public AssistRankerService { ~AssistRankerServiceImpl() override; // AssistRankerService... - std::unique_ptr<BinaryClassifierPredictor> FetchBinaryClassifierPredictor( - GURL model_url, - const std::string& model_filename, - const std::string& uma_prefix) override; + base::WeakPtr<BinaryClassifierPredictor> FetchBinaryClassifierPredictor( + const PredictorConfig& config) override; private: // Returns the full path to the model cache. @@ -46,6 +45,9 @@ class AssistRankerServiceImpl : public AssistRankerService { // Base path where models are stored. const base::FilePath base_path_; + std::unordered_map<std::string, std::unique_ptr<BasePredictor>> + predictor_map_; + SEQUENCE_CHECKER(sequence_checker_); DISALLOW_COPY_AND_ASSIGN(AssistRankerServiceImpl); diff --git a/chromium/components/assist_ranker/base_predictor.cc b/chromium/components/assist_ranker/base_predictor.cc index b890f998ec0..13bfe0c5e15 100644 --- a/chromium/components/assist_ranker/base_predictor.cc +++ b/chromium/components/assist_ranker/base_predictor.cc @@ -4,23 +4,40 @@ #include "components/assist_ranker/base_predictor.h" +#include "base/feature_list.h" #include "base/memory/ptr_util.h" +#include "components/assist_ranker/proto/ranker_example.pb.h" #include "components/assist_ranker/proto/ranker_model.pb.h" +#include "components/assist_ranker/ranker_example_util.h" #include "components/assist_ranker/ranker_model.h" +#include "services/metrics/public/cpp/ukm_entry_builder.h" +#include "services/metrics/public/cpp/ukm_recorder.h" +#include "url/gurl.h" namespace assist_ranker { -BasePredictor::BasePredictor() {} +BasePredictor::BasePredictor(const PredictorConfig& config) : config_(config) { + // TODO(chrome-ranker-team): validate config. + if (config_.field_trial) { + is_query_enabled_ = base::FeatureList::IsEnabled(*config_.field_trial); + } else { + DVLOG(0) << "No field trial specified"; + } +} + BasePredictor::~BasePredictor() {} void BasePredictor::LoadModel(std::unique_ptr<RankerModelLoader> model_loader) { + if (!is_query_enabled_) + return; + if (model_loader_) { - DLOG(ERROR) << "This predictor already has a model loader."; + DVLOG(0) << "This predictor already has a model loader."; return; } // Take ownership of the model loader. model_loader_ = std::move(model_loader); - // Kick off the initial load from cache. + // Kick off the initial model load. model_loader_->NotifyOfRankerActivity(); } @@ -31,10 +48,95 @@ void BasePredictor::OnModelAvailable( } bool BasePredictor::IsReady() { - if (!is_ready_) + if (!is_ready_ && model_loader_) model_loader_->NotifyOfRankerActivity(); return is_ready_; } +void BasePredictor::LogFeatureToUkm(const std::string& feature_name, + const Feature& feature, + ukm::UkmEntryBuilder* ukm_builder) { + if (!ukm_builder) + return; + + if (!base::ContainsKey(*config_.feature_whitelist, feature_name)) { + DVLOG(1) << "Feature not whitelisted: " << feature_name; + return; + } + + switch (feature.feature_type_case()) { + case Feature::kBoolValue: + case Feature::kFloatValue: + case Feature::kInt32Value: + case Feature::kStringValue: { + int64_t feature_int64_value = -1; + FeatureToInt64(feature, &feature_int64_value); + DVLOG(3) << "Logging: " << feature_name << ": " << feature_int64_value; + ukm_builder->AddMetric(feature_name.c_str(), feature_int64_value); + break; + } + case Feature::kStringList: { + for (int i = 0; i < feature.string_list().string_value_size(); ++i) { + int64_t feature_int64_value = -1; + FeatureToInt64(feature, &feature_int64_value, i); + DVLOG(3) << "Logging: " << feature_name << ": " << feature_int64_value; + ukm_builder->AddMetric(feature_name.c_str(), feature_int64_value); + } + break; + } + default: + DVLOG(0) << "Could not convert feature to int: " << feature_name; + } +} + +void BasePredictor::LogExampleToUkm(const RankerExample& example, + ukm::SourceId source_id) { + if (config_.log_type != LOG_UKM) { + DVLOG(0) << "Wrong log type in predictor config: " << config_.log_type; + return; + } + + if (!config_.feature_whitelist) { + DVLOG(0) << "No whitelist specified."; + return; + } + if (config_.feature_whitelist->empty()) { + DVLOG(0) << "Empty whitelist, examples will not be logged."; + return; + } + + // Releasing the builder will trigger logging. + std::unique_ptr<ukm::UkmEntryBuilder> builder = + ukm::UkmRecorder::Get()->GetEntryBuilder(source_id, config_.logging_name); + if (builder) { + for (const auto& feature_kv : example.features()) { + LogFeatureToUkm(feature_kv.first, feature_kv.second, builder.get()); + } + } else { + DVLOG(0) << "Could not get UKM Entry Builder."; + } +} + +std::string BasePredictor::GetModelName() const { + return config_.model_name; +} + +GURL BasePredictor::GetModelUrl() const { + if (!config_.field_trial_url_param) { + DVLOG(1) << "No URL specified."; + return GURL(); + } + + return GURL(config_.field_trial_url_param->Get()); +} + +RankerExample BasePredictor::PreprocessExample(const RankerExample& example) { + if (ranker_model_->proto().has_metadata() && + ranker_model_->proto().metadata().input_features_names_are_hex_hashes()) { + return HashExampleFeatureNames(example); + } + return example; +} + } // namespace assist_ranker diff --git a/chromium/components/assist_ranker/base_predictor.h b/chromium/components/assist_ranker/base_predictor.h index 5b69e8584e8..6904ad2e561 100644 --- a/chromium/components/assist_ranker/base_predictor.h +++ b/chromium/components/assist_ranker/base_predictor.h @@ -9,30 +9,53 @@ #include <string> #include "base/files/file_path.h" +#include "base/memory/weak_ptr.h" +#include "components/assist_ranker/predictor_config.h" #include "components/assist_ranker/ranker_model_loader.h" +#include "services/metrics/public/cpp/ukm_source_id.h" + +class GURL; + +namespace ukm { +class UkmEntryBuilder; +} namespace assist_ranker { +class Feature; +class RankerExample; class RankerModel; // Predictors are objects that provide an interface for prediction, as well as -// encapsulate the logic for loading the model. Sub-classes of BasePredictor -// implement an interface that depends on the nature of the suported model. -// Subclasses of BasePredictor will also need to implement an Initialize method -// that will be called once the model is available, and a static validation -// function with the following signature: +// encapsulate the logic for loading the model and logging. Sub-classes of +// BasePredictor implement an interface that depends on the nature of the +// suported model. Subclasses of BasePredictor will also need to implement an +// Initialize method that will be called once the model is available, and a +// static validation function with the following signature: // // static RankerModelStatus ValidateModel(const RankerModel& model); -class BasePredictor { +class BasePredictor : public base::SupportsWeakPtr<BasePredictor> { public: - BasePredictor(); + BasePredictor(const PredictorConfig& config); virtual ~BasePredictor(); + // Returns true if the predictor is ready to make predictions. bool IsReady(); + // Returns true if the base::Feature associated with this model is enabled. + bool is_query_enabled() const { return is_query_enabled_; } + + // Logs the features of |example| to UKM using the given source_id. + void LogExampleToUkm(const RankerExample& example, ukm::SourceId source_id); + + // Returns the model URL. + GURL GetModelUrl() const; + // Returns the model name. + std::string GetModelName() const; protected: - // The model used for prediction. - std::unique_ptr<RankerModel> ranker_model_; + // Preprocessing applied to an example before prediction. The original + // RankerExample is not modified, so it is safe to use it later for logging. + RankerExample PreprocessExample(const RankerExample& example); // Called when the RankerModelLoader has finished loading the model. Returns // true only if the model was succesfully loaded and is ready to predict. @@ -43,9 +66,17 @@ class BasePredictor { // Called once the model loader as succesfully loaded the model. void OnModelAvailable(std::unique_ptr<RankerModel> model); std::unique_ptr<RankerModelLoader> model_loader_; + // The model used for prediction. + std::unique_ptr<RankerModel> ranker_model_; private: + void LogFeatureToUkm(const std::string& feature_name, + const Feature& feature, + ukm::UkmEntryBuilder* ukm_builder); + bool is_ready_ = false; + bool is_query_enabled_ = false; + PredictorConfig config_; DISALLOW_COPY_AND_ASSIGN(BasePredictor); }; diff --git a/chromium/components/assist_ranker/base_predictor_unittest.cc b/chromium/components/assist_ranker/base_predictor_unittest.cc new file mode 100644 index 00000000000..263f27cdb75 --- /dev/null +++ b/chromium/components/assist_ranker/base_predictor_unittest.cc @@ -0,0 +1,183 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "components/assist_ranker/base_predictor.h" + +#include <memory> + +#include "base/bind.h" +#include "base/bind_helpers.h" +#include "base/memory/ptr_util.h" +#include "base/test/scoped_feature_list.h" +#include "base/test/scoped_task_environment.h" +#include "components/assist_ranker/fake_ranker_model_loader.h" +#include "components/assist_ranker/predictor_config.h" +#include "components/assist_ranker/proto/ranker_example.pb.h" +#include "components/assist_ranker/ranker_model.h" +#include "components/ukm/test_ukm_recorder.h" +#include "testing/gtest/include/gtest/gtest.h" +#include "url/gurl.h" + +namespace assist_ranker { + +using ::assist_ranker::testing::FakeRankerModelLoader; + +namespace { + +// Predictor config for testing. +const char kTestModelName[] = "test_model"; +const char kTestLoggingName[] = "TestLoggingName"; +const char kTestUmaPrefixName[] = "Test.Ranker"; +const char kTestUrlParamName[] = "ranker-model-url"; +const char kTestDefaultModelUrl[] = "https://foo.bar/model.bin"; + +const char kBoolFeature[] = "bool_feature"; +const char kIntFeature[] = "int_feature"; +const char kFloatFeature[] = "float_feature"; +const char kStringFeature[] = "string_feature"; +const char kStringListFeature[] = "string_list_feature"; +const char kFeatureNotWhitelisted[] = "not_whitelisted"; + +const char kTestNavigationUrl[] = "https://foo.com"; + +const base::flat_set<std::string> kFeatureWhitelist({kBoolFeature, kIntFeature, + kFloatFeature, + kStringFeature, + kStringListFeature}); + +const base::Feature kTestRankerQuery{"TestRankerQuery", + base::FEATURE_ENABLED_BY_DEFAULT}; + +const base::FeatureParam<std::string> kTestRankerUrl{ + &kTestRankerQuery, kTestUrlParamName, kTestDefaultModelUrl}; + +const PredictorConfig kTestPredictorConfig = PredictorConfig{ + kTestModelName, kTestLoggingName, kTestUmaPrefixName, LOG_UKM, + &kFeatureWhitelist, &kTestRankerQuery, &kTestRankerUrl}; + +// Class that implements virtual functions of the base class. +class FakePredictor : public BasePredictor { + public: + static std::unique_ptr<FakePredictor> Create(); + ~FakePredictor() override{}; + // Validation will always succeed. + static RankerModelStatus ValidateModel(const RankerModel& model); + + protected: + // Not implementing any inference logic. + bool Initialize() override { return true; }; + + private: + FakePredictor(const PredictorConfig& config); + DISALLOW_COPY_AND_ASSIGN(FakePredictor); +}; + +FakePredictor::FakePredictor(const PredictorConfig& config) + : BasePredictor(config) {} + +RankerModelStatus FakePredictor::ValidateModel(const RankerModel& model) { + return RankerModelStatus::OK; +} + +std::unique_ptr<FakePredictor> FakePredictor::Create() { + std::unique_ptr<FakePredictor> predictor( + new FakePredictor(kTestPredictorConfig)); + auto ranker_model = base::MakeUnique<RankerModel>(); + auto fake_model_loader = base::MakeUnique<FakeRankerModelLoader>( + base::BindRepeating(&FakePredictor::ValidateModel), + base::BindRepeating(&FakePredictor::OnModelAvailable, + base::Unretained(predictor.get())), + std::move(ranker_model)); + predictor->LoadModel(std::move(fake_model_loader)); + return predictor; +} + +} // namespace + +class BasePredictorTest : public ::testing::Test { + protected: + BasePredictorTest() = default; + // Disables Query for the test predictor. + void DisableQuery(); + + ukm::SourceId GetSourceId(); + + ukm::TestUkmRecorder* GetTestUkmRecorder() { return &test_ukm_recorder_; } + + private: + // Sets up the task scheduling/task-runner environment for each test. + base::test::ScopedTaskEnvironment scoped_task_environment_; + + // Sets itself as the global UkmRecorder on construction. + ukm::TestAutoSetUkmRecorder test_ukm_recorder_; + + // Manages the enabling/disabling of features within the scope of a test. + base::test::ScopedFeatureList scoped_feature_list_; + + DISALLOW_COPY_AND_ASSIGN(BasePredictorTest); +}; + +ukm::SourceId BasePredictorTest::GetSourceId() { + ukm::SourceId source_id = ukm::UkmRecorder::GetNewSourceID(); + test_ukm_recorder_.UpdateSourceURL(source_id, GURL(kTestNavigationUrl)); + return source_id; +} + +void BasePredictorTest::DisableQuery() { + scoped_feature_list_.InitWithFeatures({}, {kTestRankerQuery}); +} + +TEST_F(BasePredictorTest, BaseTest) { + auto predictor = FakePredictor::Create(); + EXPECT_EQ(kTestModelName, predictor->GetModelName()); + EXPECT_EQ(kTestDefaultModelUrl, predictor->GetModelUrl()); + EXPECT_TRUE(predictor->is_query_enabled()); + EXPECT_TRUE(predictor->IsReady()); +} + +TEST_F(BasePredictorTest, QueryDisabled) { + DisableQuery(); + auto predictor = FakePredictor::Create(); + EXPECT_EQ(kTestModelName, predictor->GetModelName()); + EXPECT_EQ(kTestDefaultModelUrl, predictor->GetModelUrl()); + EXPECT_FALSE(predictor->is_query_enabled()); + EXPECT_FALSE(predictor->IsReady()); +} + +TEST_F(BasePredictorTest, LogExampleToUkm) { + auto predictor = FakePredictor::Create(); + RankerExample example; + auto& features = *example.mutable_features(); + features[kBoolFeature].set_bool_value(true); + features[kIntFeature].set_int32_value(42); + features[kFloatFeature].set_float_value(42.0f); + features[kStringFeature].set_string_value("42"); + features[kStringListFeature].mutable_string_list()->add_string_value("42"); + + // This feature will not be logged. + features[kFeatureNotWhitelisted].set_bool_value(false); + + predictor->LogExampleToUkm(example, GetSourceId()); + + EXPECT_EQ(1U, GetTestUkmRecorder()->sources_count()); + EXPECT_EQ(1U, GetTestUkmRecorder()->entries_count()); + std::vector<const ukm::mojom::UkmEntry*> entries = + GetTestUkmRecorder()->GetEntriesByName(kTestLoggingName); + EXPECT_EQ(1U, entries.size()); + GetTestUkmRecorder()->ExpectEntryMetric(entries[0], kBoolFeature, + 72057594037927937); + GetTestUkmRecorder()->ExpectEntryMetric(entries[0], kIntFeature, + 216172782113783850); + GetTestUkmRecorder()->ExpectEntryMetric(entries[0], kFloatFeature, + 144115189185773568); + GetTestUkmRecorder()->ExpectEntryMetric(entries[0], kStringFeature, + 288230377208836903); + GetTestUkmRecorder()->ExpectEntryMetric(entries[0], kStringListFeature, + 360287971246764839); + + EXPECT_FALSE( + GetTestUkmRecorder()->EntryHasMetric(entries[0], kFeatureNotWhitelisted)); +} + +} // namespace assist_ranker diff --git a/chromium/components/assist_ranker/binary_classifier_predictor.cc b/chromium/components/assist_ranker/binary_classifier_predictor.cc index b37c616fa18..cc595bf0a62 100644 --- a/chromium/components/assist_ranker/binary_classifier_predictor.cc +++ b/chromium/components/assist_ranker/binary_classifier_predictor.cc @@ -15,26 +15,33 @@ #include "components/assist_ranker/ranker_model.h" #include "components/assist_ranker/ranker_model_loader_impl.h" #include "net/url_request/url_request_context_getter.h" -#include "url/gurl.h" namespace assist_ranker { -BinaryClassifierPredictor::BinaryClassifierPredictor(){}; +BinaryClassifierPredictor::BinaryClassifierPredictor( + const PredictorConfig& config) + : BasePredictor(config){}; BinaryClassifierPredictor::~BinaryClassifierPredictor(){}; // static std::unique_ptr<BinaryClassifierPredictor> BinaryClassifierPredictor::Create( - net::URLRequestContextGetter* request_context_getter, + const PredictorConfig& config, const base::FilePath& model_path, - GURL model_url, - const std::string& uma_prefix) { + net::URLRequestContextGetter* request_context_getter) { std::unique_ptr<BinaryClassifierPredictor> predictor( - new BinaryClassifierPredictor()); + new BinaryClassifierPredictor(config)); + if (!predictor->is_query_enabled()) { + DVLOG(1) << "Query disabled, bypassing model loading."; + return predictor; + } + const GURL& model_url = predictor->GetModelUrl(); + DVLOG(1) << "Creating predictor instance for " << predictor->GetModelName(); + DVLOG(1) << "Model URL: " << model_url; auto model_loader = base::MakeUnique<RankerModelLoaderImpl>( - base::Bind(&BinaryClassifierPredictor::ValidateModel), - base::Bind(&BinaryClassifierPredictor::OnModelAvailable, - base::Unretained(predictor.get())), - request_context_getter, model_path, model_url, uma_prefix); + base::BindRepeating(&BinaryClassifierPredictor::ValidateModel), + base::BindRepeating(&BinaryClassifierPredictor::OnModelAvailable, + base::Unretained(predictor.get())), + request_context_getter, model_path, model_url, config.uma_prefix); predictor->LoadModel(std::move(model_loader)); return predictor; } @@ -42,18 +49,23 @@ std::unique_ptr<BinaryClassifierPredictor> BinaryClassifierPredictor::Create( bool BinaryClassifierPredictor::Predict(const RankerExample& example, bool* prediction) { if (!IsReady()) { + DVLOG(1) << "Predictor " << GetModelName() << " not ready for prediction."; return false; } - *prediction = inference_module_->Predict(example); + + *prediction = inference_module_->Predict(PreprocessExample(example)); + DVLOG(1) << "Predictor " << GetModelName() << " predicted: " << *prediction; return true; } bool BinaryClassifierPredictor::PredictScore(const RankerExample& example, float* prediction) { if (!IsReady()) { + DVLOG(1) << "Predictor " << GetModelName() << " not ready for prediction."; return false; } - *prediction = inference_module_->PredictScore(example); + *prediction = inference_module_->PredictScore(PreprocessExample(example)); + DVLOG(1) << "Predictor " << GetModelName() << " predicted: " << prediction; return true; } @@ -61,17 +73,22 @@ bool BinaryClassifierPredictor::PredictScore(const RankerExample& example, RankerModelStatus BinaryClassifierPredictor::ValidateModel( const RankerModel& model) { if (model.proto().model_case() != RankerModelProto::kLogisticRegression) { + DVLOG(0) << "Model is incompatible."; return RankerModelStatus::INCOMPATIBLE; } return RankerModelStatus::OK; } bool BinaryClassifierPredictor::Initialize() { - // TODO(hamelphi): move the GLRM proto up one layer in the proto in order to - // be independent of the client feature. - inference_module_.reset(new GenericLogisticRegressionInference( - ranker_model_->proto().logistic_regression())); - return true; + if (ranker_model_->proto().model_case() == + RankerModelProto::kLogisticRegression) { + inference_module_ = std::make_unique<GenericLogisticRegressionInference>( + ranker_model_->proto().logistic_regression()); + return true; + } + + DVLOG(0) << "Could not initialize inference module."; + return false; } } // namespace assist_ranker diff --git a/chromium/components/assist_ranker/binary_classifier_predictor.h b/chromium/components/assist_ranker/binary_classifier_predictor.h index 1f960b01f4e..e932c91388f 100644 --- a/chromium/components/assist_ranker/binary_classifier_predictor.h +++ b/chromium/components/assist_ranker/binary_classifier_predictor.h @@ -9,8 +9,6 @@ #include "components/assist_ranker/base_predictor.h" #include "components/assist_ranker/proto/ranker_example.pb.h" -class GURL; - namespace base { class FilePath; } @@ -28,11 +26,13 @@ class BinaryClassifierPredictor : public BasePredictor { public: ~BinaryClassifierPredictor() override; + // Returns an new predictor instance with the given |config| and initialize + // its model loader. The |request_context getter| is passed to the + // predictor's model_loader which holds it as scoped_refptr. static std::unique_ptr<BinaryClassifierPredictor> Create( - net::URLRequestContextGetter* request_context_getter, + const PredictorConfig& config, const base::FilePath& model_path, - GURL model_url, - const std::string& uma_prefix); + net::URLRequestContextGetter* request_context_getter) WARN_UNUSED_RESULT; // Fills in a boolean decision given a RankerExample. Returns false if a // prediction could not be made (e.g. the model is not loaded yet). @@ -53,7 +53,7 @@ class BinaryClassifierPredictor : public BasePredictor { private: friend class BinaryClassifierPredictorTest; - BinaryClassifierPredictor(); + BinaryClassifierPredictor(const PredictorConfig& config); // TODO(hamelphi): Use an abstract BinaryClassifierInferenceModule in order to // generalize to other models. diff --git a/chromium/components/assist_ranker/binary_classifier_predictor_unittest.cc b/chromium/components/assist_ranker/binary_classifier_predictor_unittest.cc index fbb2c622d59..9249301d472 100644 --- a/chromium/components/assist_ranker/binary_classifier_predictor_unittest.cc +++ b/chromium/components/assist_ranker/binary_classifier_predictor_unittest.cc @@ -8,6 +8,7 @@ #include "base/bind.h" #include "base/bind_helpers.h" +#include "base/feature_list.h" #include "base/memory/ptr_util.h" #include "components/assist_ranker/fake_ranker_model_loader.h" #include "components/assist_ranker/proto/ranker_model.pb.h" @@ -21,11 +22,14 @@ using ::assist_ranker::testing::FakeRankerModelLoader; class BinaryClassifierPredictorTest : public ::testing::Test { public: std::unique_ptr<BinaryClassifierPredictor> InitPredictor( - std::unique_ptr<RankerModel> ranker_model); + std::unique_ptr<RankerModel> ranker_model, + const PredictorConfig& config); // This model will return the value of |feature| as a prediction. GenericLogisticRegressionModel GetSimpleLogisticRegressionModel(); + PredictorConfig GetConfig(); + protected: const std::string feature_ = "feature"; const float threshold_ = 0.5; @@ -33,10 +37,11 @@ class BinaryClassifierPredictorTest : public ::testing::Test { std::unique_ptr<BinaryClassifierPredictor> BinaryClassifierPredictorTest::InitPredictor( - std::unique_ptr<RankerModel> ranker_model) { + std::unique_ptr<RankerModel> ranker_model, + const PredictorConfig& config) { std::unique_ptr<BinaryClassifierPredictor> predictor( - new BinaryClassifierPredictor()); - auto fake_model_loader = base::MakeUnique<FakeRankerModelLoader>( + new BinaryClassifierPredictor(config)); + auto fake_model_loader = std::make_unique<FakeRankerModelLoader>( base::Bind(&BinaryClassifierPredictor::ValidateModel), base::Bind(&BinaryClassifierPredictor::OnModelAvailable, base::Unretained(predictor.get())), @@ -45,6 +50,20 @@ BinaryClassifierPredictorTest::InitPredictor( return predictor; } +const base::Feature kTestRankerQuery{"TestRankerQuery", + base::FEATURE_ENABLED_BY_DEFAULT}; + +const base::FeatureParam<std::string> kTestRankerUrl{ + &kTestRankerQuery, "url-param-name", "https://default.model.url"}; + +PredictorConfig BinaryClassifierPredictorTest::GetConfig() { + PredictorConfig config("model_name", "logging_name", "uma_prefix", LOG_NONE, + GetEmptyWhitelist(), &kTestRankerQuery, + &kTestRankerUrl); + + return config; +} + GenericLogisticRegressionModel BinaryClassifierPredictorTest::GetSimpleLogisticRegressionModel() { GenericLogisticRegressionModel lr_model; @@ -58,7 +77,7 @@ BinaryClassifierPredictorTest::GetSimpleLogisticRegressionModel() { TEST_F(BinaryClassifierPredictorTest, EmptyRankerModel) { auto ranker_model = base::MakeUnique<RankerModel>(); - auto predictor = InitPredictor(std::move(ranker_model)); + auto predictor = InitPredictor(std::move(ranker_model), GetConfig()); EXPECT_FALSE(predictor->IsReady()); RankerExample ranker_example; @@ -78,7 +97,7 @@ TEST_F(BinaryClassifierPredictorTest, NoInferenceModuleForModel) { ->mutable_translate() ->mutable_translate_logistic_regression_model() ->set_bias(1); - auto predictor = InitPredictor(std::move(ranker_model)); + auto predictor = InitPredictor(std::move(ranker_model), GetConfig()); EXPECT_FALSE(predictor->IsReady()); RankerExample ranker_example; @@ -94,7 +113,7 @@ TEST_F(BinaryClassifierPredictorTest, GenericLogisticRegressionModel) { auto ranker_model = base::MakeUnique<RankerModel>(); *ranker_model->mutable_proto()->mutable_logistic_regression() = GetSimpleLogisticRegressionModel(); - auto predictor = InitPredictor(std::move(ranker_model)); + auto predictor = InitPredictor(std::move(ranker_model), GetConfig()); EXPECT_TRUE(predictor->IsReady()); RankerExample ranker_example; diff --git a/chromium/components/assist_ranker/generic_logistic_regression_inference.cc b/chromium/components/assist_ranker/generic_logistic_regression_inference.cc index f2da478a433..1248b69772c 100644 --- a/chromium/components/assist_ranker/generic_logistic_regression_inference.cc +++ b/chromium/components/assist_ranker/generic_logistic_regression_inference.cc @@ -29,7 +29,7 @@ float GenericLogisticRegressionInference::PredictScore( const FeatureWeight& feature_weight = weight_it.second; switch (feature_weight.feature_type_case()) { case FeatureWeight::FEATURE_TYPE_NOT_SET: { - DVLOG(1) << "Feature type not set for " << feature_name; + DVLOG(0) << "Feature type not set for " << feature_name; break; } case FeatureWeight::kScalar: { @@ -37,6 +37,8 @@ float GenericLogisticRegressionInference::PredictScore( if (GetFeatureValueAsFloat(feature_name, example, &value)) { const float weight = feature_weight.scalar(); activation += value * weight; + } else { + DVLOG(1) << "Feature not in example: " << feature_name; } break; } @@ -50,19 +52,22 @@ float GenericLogisticRegressionInference::PredictScore( } else { // If the category is not found, use the default weight. activation += feature_weight.one_hot().default_weight(); + DVLOG(1) << "Unknown feature value for " << feature_name << ": " + << value; } } else { // If the feature is missing, use the default weight. activation += feature_weight.one_hot().default_weight(); + DVLOG(1) << "Feature not in example: " << feature_name; } break; } case FeatureWeight::kSparse: { - DVLOG(1) << "Sparse features not implemented yet."; + DVLOG(0) << "Sparse features not implemented yet."; break; } case FeatureWeight::kBucketized: { - DVLOG(1) << "Bucketized features not implemented yet."; + DVLOG(0) << "Bucketized features not implemented yet."; break; } } diff --git a/chromium/components/assist_ranker/predictor_config.cc b/chromium/components/assist_ranker/predictor_config.cc new file mode 100644 index 00000000000..57434e52390 --- /dev/null +++ b/chromium/components/assist_ranker/predictor_config.cc @@ -0,0 +1,14 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "components/assist_ranker/predictor_config.h" + +namespace assist_ranker { + +const base::flat_set<std::string>* GetEmptyWhitelist() { + static auto* whitelist = new base::flat_set<std::string>(); + return whitelist; +} + +} // namespace assist_ranker diff --git a/chromium/components/assist_ranker/predictor_config.h b/chromium/components/assist_ranker/predictor_config.h new file mode 100644 index 00000000000..6a545889f4f --- /dev/null +++ b/chromium/components/assist_ranker/predictor_config.h @@ -0,0 +1,52 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef COMPONENTS_ASSIST_RANKER_PREDICTOR_CONFIG_H_ +#define COMPONENTS_ASSIST_RANKER_PREDICTOR_CONFIG_H_ + +#include <string> + +#include "base/containers/flat_set.h" +#include "base/metrics/field_trial_params.h" + +namespace assist_ranker { + +// TODO(chrome-ranker-team): Implement other logging types. +enum LogType { + LOG_NONE = 0, + LOG_UKM = 1, +}; + +// Empty feature whitelist used for testing. +const base::flat_set<std::string>* GetEmptyWhitelist(); + +// This struct holds the config options for logging, loading and field trial +// for a predictor. +struct PredictorConfig { + PredictorConfig(const char* model_name, + const char* logging_name, + const char* uma_prefix, + const LogType log_type, + const base::flat_set<std::string>* feature_whitelist, + const base::Feature* field_trial, + const base::FeatureParam<std::string>* field_trial_url_param) + : model_name(model_name), + logging_name(logging_name), + uma_prefix(uma_prefix), + log_type(log_type), + feature_whitelist(feature_whitelist), + field_trial(field_trial), + field_trial_url_param(field_trial_url_param) {} + const char* model_name; + const char* logging_name; + const char* uma_prefix; + const LogType log_type; + const base::flat_set<std::string>* feature_whitelist; + const base::Feature* field_trial; + const base::FeatureParam<std::string>* field_trial_url_param; +}; + +} // namespace assist_ranker + +#endif // COMPONENTS_ASSIST_RANKER_PREDICTOR_CONFIG_H_ diff --git a/chromium/components/assist_ranker/predictor_config_definitions.cc b/chromium/components/assist_ranker/predictor_config_definitions.cc new file mode 100644 index 00000000000..771593c3d26 --- /dev/null +++ b/chromium/components/assist_ranker/predictor_config_definitions.cc @@ -0,0 +1,75 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "components/assist_ranker/predictor_config_definitions.h" + +namespace assist_ranker { + +#if defined(OS_ANDROID) +const base::Feature kContextualSearchRankerQuery{ + "ContextualSearchRankerQuery", base::FEATURE_DISABLED_BY_DEFAULT}; + +namespace { + +const char kContextualSearchModelName[] = "contextual_search_model"; +const char kContextualSearchLoggingName[] = "ContextualSearch"; +const char kContextualSearchUmaPrefixName[] = "Search.ContextualSearch.Ranker"; + +const char kContextualSearchDefaultModelUrl[] = + "https://www.gstatic.com/chrome/intelligence/assist/ranker/models/" + "contextual_search/test_ranker_model_20171109_short_words_v2.pb.bin"; + +const base::FeatureParam<std::string>* +GetContextualSearchRankerUrlFeatureParam() { + static auto* kContextualSearchRankerUrl = new base::FeatureParam<std::string>( + &kContextualSearchRankerQuery, "contextual-search-ranker-model-url", + kContextualSearchDefaultModelUrl); + return kContextualSearchRankerUrl; +} + +// NOTE: This list needs to be kept in sync with tools/metrics/ukm/ukm.xml! +// Only features within this list will be logged to UKM. +// TODO(chrome-ranker-team) Deprecate the whitelist once it is available through +// the UKM generated API. +const base::flat_set<std::string>* GetContextualSearchFeatureWhitelist() { + static auto* kContextualSearchFeatureWhitelist = + new base::flat_set<std::string>({"DidOptIn", + "DurationAfterScrollMs", + "IsEntity", + "IsEntityEligible", + "IsHttp", + "IsLanguageMismatch", + "IsLongWord", + "IsSecondTapOverride", + "IsShortWord", + "IsWordEdge", + "OutcomeRankerDidPredict", + "OutcomeRankerPrediction", + "OutcomeWasCardsDataShown", + "OutcomeWasPanelOpened", + "OutcomeWasQuickActionClicked", + "OutcomeWasQuickAnswerSeen", + "Previous28DayCtrPercent", + "Previous28DayImpressionsCount", + "PreviousWeekCtrPercent", + "PreviousWeekImpressionsCount", + "ScreenTopDps", + "TapDurationMs", + "WasScreenBottom"}); + return kContextualSearchFeatureWhitelist; +} + +} // namespace + +const PredictorConfig GetContextualSearchPredictorConfig() { + static auto kContextualSearchPredictorConfig = *(new PredictorConfig( + kContextualSearchModelName, kContextualSearchLoggingName, + kContextualSearchUmaPrefixName, LOG_UKM, + GetContextualSearchFeatureWhitelist(), &kContextualSearchRankerQuery, + GetContextualSearchRankerUrlFeatureParam())); + return kContextualSearchPredictorConfig; +} +#endif // OS_ANDROID + +} // namespace assist_ranker diff --git a/chromium/components/assist_ranker/predictor_config_definitions.h b/chromium/components/assist_ranker/predictor_config_definitions.h new file mode 100644 index 00000000000..431e0b4e754 --- /dev/null +++ b/chromium/components/assist_ranker/predictor_config_definitions.h @@ -0,0 +1,26 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef COMPONENTS_ASSIST_RANKER_PREDICTOR_CONFIG_DEFINITIONS_H_ +#define COMPONENTS_ASSIST_RANKER_PREDICTOR_CONFIG_DEFINITIONS_H_ + +#include <memory> +#include <string> +#include <unordered_map> + +#include "base/feature_list.h" +#include "base/metrics/field_trial_params.h" +#include "build/build_config.h" +#include "components/assist_ranker/predictor_config.h" + +namespace assist_ranker { + +#if defined(OS_ANDROID) +extern const base::Feature kContextualSearchRankerQuery; +const PredictorConfig GetContextualSearchPredictorConfig(); +#endif // OS_ANDROID + +} // namespace assist_ranker + +#endif // COMPONENTS_ASSIST_RANKER_PREDICTOR_CONFIG_DEFINITIONS_H_ diff --git a/chromium/components/assist_ranker/proto/ranker_example.proto b/chromium/components/assist_ranker/proto/ranker_example.proto index 45ba2368c77..36d1a2c6e38 100644 --- a/chromium/components/assist_ranker/proto/ranker_example.proto +++ b/chromium/components/assist_ranker/proto/ranker_example.proto @@ -10,6 +10,10 @@ option optimize_for = LITE_RUNTIME; package assist_ranker; +message StringList { + repeated bytes string_value = 1; +} + // Generic message that can contain a variety of data types. message Feature { oneof feature_type { @@ -20,6 +24,8 @@ message Feature { int32 int32_value = 3; // String values are used for one-hot features. bytes string_value = 4; + // String list are used for sparse features. + StringList string_list = 5; } } @@ -32,5 +38,9 @@ message RankerExample { // This field represents the ground truth that the ranker is // expected to predict, and is typically derived from user feedback. It is // used for training only and is not required for inference. + // NOTE: this field will not be logged. If you want to log an outcome, add it + // to the features field before calling LogExample. + // TODO(chrome-ranker-team) Add a metadata field to log metrics that are not + // used as model input. optional Feature target = 2; -}
\ No newline at end of file +} diff --git a/chromium/components/assist_ranker/ranker_example_util.cc b/chromium/components/assist_ranker/ranker_example_util.cc index 45c2dc3f163..54d4dbd58f7 100644 --- a/chromium/components/assist_ranker/ranker_example_util.cc +++ b/chromium/components/assist_ranker/ranker_example_util.cc @@ -3,9 +3,54 @@ // found in the LICENSE file. #include "components/assist_ranker/ranker_example_util.h" +#include "base/bit_cast.h" +#include "base/format_macros.h" #include "base/logging.h" +#include "base/metrics/metrics_hashes.h" +#include "base/strings/stringprintf.h" namespace assist_ranker { +namespace { +const uint64_t MASK32Bits = (1LL << 32) - 1; +constexpr int kFloatMainDigits = 23; +// Returns lower 32 bits of the hash of the input. +int32_t StringToIntBits(const std::string& str) { + return base::HashMetricName(str) & MASK32Bits; +} + +// Converts float to int32 +int32_t FloatToIntBits(float f) { + if (std::numeric_limits<float>::is_iec559) { + // Directly bit_cast if float follows ieee754 standard. + return bit_cast<int32_t>(f); + } else { + // Otherwise, manually calculate sign, exp and mantissa. + // For sign. + const uint32_t sign = f < 0; + + // For exponent. + int exp; + f = std::abs(std::frexp(f, &exp)); + // Add 126 to get non-negative format of exp. + // This should not be 127 because the return of frexp is different from + // ieee754 with a multiple of 2. + const uint32_t exp_u = exp + 126; + + // Get mantissa. + const uint32_t mantissa = std::ldexp(f * 2.0f - 1.0f, kFloatMainDigits); + // Set each bits and return. + return (sign << 31) | (exp_u << kFloatMainDigits) | mantissa; + } +} + +// Pair type, value and index into one int64. +int64_t PairInt(const uint64_t type, + const uint32_t value, + const uint64_t index) { + return (type << 56) | (index << 32) | static_cast<uint64_t>(value); +} + +} // namespace bool SafeGetFeature(const std::string& key, const RankerExample& example, @@ -42,6 +87,42 @@ bool GetFeatureValueAsFloat(const std::string& key, return true; } +bool FeatureToInt64(const Feature& feature, + int64_t* const res, + const int index) { + int32_t value = -1; + int32_t type = feature.feature_type_case(); + switch (type) { + case Feature::kBoolValue: + value = static_cast<int32_t>(feature.bool_value()); + break; + case Feature::kFloatValue: + value = FloatToIntBits(feature.float_value()); + break; + case Feature::kInt32Value: + value = feature.int32_value(); + break; + case Feature::kStringValue: + value = StringToIntBits(feature.string_value()); + break; + case Feature::kStringList: + if (index >= 0 && index < feature.string_list().string_value_size()) { + value = StringToIntBits(feature.string_list().string_value(index)); + } else { + DVLOG(3) << "Invalid index for string list: " << index; + NOTREACHED(); + return false; + } + break; + default: + DVLOG(3) << "Feature type is supported for logging: " << type; + NOTREACHED(); + return false; + } + *res = PairInt(type, value, index); + return true; + } + bool GetOneHotValue(const std::string& key, const RankerExample& example, std::string* value) { @@ -60,4 +141,20 @@ bool GetOneHotValue(const std::string& key, return true; } +// Converts string to a hex hash string. +std::string HashFeatureName(const std::string& feature_name) { + uint64_t feature_key = base::HashMetricName(feature_name); + return base::StringPrintf("%016" PRIx64, feature_key); +} + +RankerExample HashExampleFeatureNames(const RankerExample& example) { + RankerExample hashed_example; + auto& output_features = *hashed_example.mutable_features(); + for (const auto& feature : example.features()) { + output_features[HashFeatureName(feature.first)] = feature.second; + } + *hashed_example.mutable_target() = example.target(); + return hashed_example; +} + } // namespace assist_ranker diff --git a/chromium/components/assist_ranker/ranker_example_util.h b/chromium/components/assist_ranker/ranker_example_util.h index 4f49a251409..c75bb393522 100644 --- a/chromium/components/assist_ranker/ranker_example_util.h +++ b/chromium/components/assist_ranker/ranker_example_util.h @@ -25,6 +25,20 @@ bool GetFeatureValueAsFloat(const std::string& key, const RankerExample& example, float* value) WARN_UNUSED_RESULT; +// Converts a Ranker Feature to an int64. For feature list, this converts the +// index-th value of the list. +// A feature is converted to an int64 by: +// (a) use low32 bits represent the value of the feature. +// a.1) bool_value, int32_value is directly converted to an int32. +// a.2) string_value is hashed to an int32. +// a.3) float_value is directly bit_cast into int32 if it follows ieee754 +// standard; otherwise manually calculate sign, exponent and mantissa. +// (b) use high32 bits represent the type of the feature. +// b.1) use high8 bits represent the feature_type_case. +// b.2) use low24 bits represent the index if the feature is a list. +// Returns true if the feature is converted successfully; false otherwise. +bool FeatureToInt64(const Feature& feature, int64_t* res, int index = 0); + // Extract category from one-hot feature. Returns true and fills // in |value| if the feature is found and is of type string_value. Returns false // otherwise. @@ -32,6 +46,15 @@ bool GetOneHotValue(const std::string& key, const RankerExample& example, std::string* value) WARN_UNUSED_RESULT; +// Converts a string to a hex ahsh string. +std::string HashFeatureName(const std::string& feature_name); + +// Hashes feature names to an hex string. +// Features logged through UKM will apply this transformation when logging +// features, so models trained on UKM data are expected to have hashed input +// feature names. +RankerExample HashExampleFeatureNames(const RankerExample& example); + } // namespace assist_ranker #endif // COMPONENTS_ASSIST_RANKER_RANKER_EXAMPLE_UTIL_H_ diff --git a/chromium/components/assist_ranker/ranker_example_util_unittest.cc b/chromium/components/assist_ranker/ranker_example_util_unittest.cc index ac9f34c5e1a..348dadc3666 100644 --- a/chromium/components/assist_ranker/ranker_example_util_unittest.cc +++ b/chromium/components/assist_ranker/ranker_example_util_unittest.cc @@ -103,4 +103,81 @@ TEST_F(RankerExampleUtilTest, GetOneHotValue) { EXPECT_FALSE(GetOneHotValue("foo", example_, &value)); } +TEST_F(RankerExampleUtilTest, ScalarFeatureInt64Conversion) { + Feature feature; + int64_t int64_value; + + feature.set_bool_value(true); + EXPECT_TRUE(FeatureToInt64(feature, &int64_value)); + EXPECT_EQ(int64_value, 72057594037927937LL); + + feature.set_int32_value(std::numeric_limits<int32_t>::max()); + EXPECT_TRUE(FeatureToInt64(feature, &int64_value)); + EXPECT_EQ(int64_value, 216172784261267455LL); + + feature.set_int32_value(std::numeric_limits<int32_t>::lowest()); + EXPECT_TRUE(FeatureToInt64(feature, &int64_value)); + EXPECT_EQ(int64_value, 216172784261267456LL); + + feature.set_string_value("foo"); + EXPECT_TRUE(FeatureToInt64(feature, &int64_value)); + EXPECT_EQ(int64_value, 288230377439557724LL); +} + +TEST_F(RankerExampleUtilTest, FloatFeatureInt64Conversion) { + Feature feature; + int64_t int64_value; + + feature.set_float_value(std::numeric_limits<float>::epsilon()); + EXPECT_TRUE(FeatureToInt64(feature, &int64_value)); + EXPECT_EQ(int64_value, 144115188948271104LL); + + feature.set_float_value(-std::numeric_limits<float>::epsilon()); + EXPECT_TRUE(FeatureToInt64(feature, &int64_value)); + EXPECT_EQ(int64_value, 144115191095754752LL); + + feature.set_float_value(std::numeric_limits<float>::max()); + EXPECT_TRUE(FeatureToInt64(feature, &int64_value)); + EXPECT_EQ(int64_value, 144115190214950911LL); + + feature.set_float_value(std::numeric_limits<float>::lowest()); + EXPECT_TRUE(FeatureToInt64(feature, &int64_value)); + EXPECT_EQ(int64_value, 144115192362434559LL); +} + +TEST_F(RankerExampleUtilTest, StringListInt64Conversion) { + Feature feature; + int64_t int64_value; + + feature.mutable_string_list()->add_string_value(""); + feature.mutable_string_list()->add_string_value("TEST"); + EXPECT_TRUE(FeatureToInt64(feature, &int64_value, 1)); + EXPECT_EQ(int64_value, 360287974776690660LL); +} + +TEST_F(RankerExampleUtilTest, HashExampleFeatureNames) { + auto hashed_example = HashExampleFeatureNames(example_); + // Hashed example has the same number of features. + EXPECT_EQ(example_.features().size(), hashed_example.features().size()); + + // But the feature names have changed. + EXPECT_FALSE(SafeGetFeature(bool_name_, hashed_example, nullptr)); + EXPECT_FALSE(SafeGetFeature(int32_name_, hashed_example, nullptr)); + EXPECT_FALSE(SafeGetFeature(float_name_, hashed_example, nullptr)); + EXPECT_FALSE(SafeGetFeature(one_hot_name_, hashed_example, nullptr)); + + EXPECT_TRUE( + SafeGetFeature(HashFeatureName(bool_name_), hashed_example, nullptr)); + + // Values have not changed. + float float_value; + EXPECT_TRUE(GetFeatureValueAsFloat(HashFeatureName(float_name_), + hashed_example, &float_value)); + EXPECT_EQ(float_value_, float_value); + std::string string_value; + EXPECT_TRUE(GetOneHotValue(HashFeatureName(one_hot_name_), hashed_example, + &string_value)); + EXPECT_EQ(one_hot_value_, string_value); +} + } // namespace assist_ranker |