summaryrefslogtreecommitdiff
path: root/chromium/components/optimization_guide
diff options
context:
space:
mode:
Diffstat (limited to 'chromium/components/optimization_guide')
-rw-r--r--chromium/components/optimization_guide/DEPS2
-rw-r--r--chromium/components/optimization_guide/content/browser/BUILD.gn8
-rw-r--r--chromium/components/optimization_guide/content/browser/page_content_annotations_model_manager.cc102
-rw-r--r--chromium/components/optimization_guide/content/browser/page_content_annotations_model_manager.h23
-rw-r--r--chromium/components/optimization_guide/content/browser/page_content_annotations_model_manager_unittest.cc260
-rw-r--r--chromium/components/optimization_guide/content/browser/page_content_annotations_service.cc194
-rw-r--r--chromium/components/optimization_guide/content/browser/page_content_annotations_service.h103
-rw-r--r--chromium/components/optimization_guide/content/browser/page_content_annotations_web_contents_observer.cc101
-rw-r--r--chromium/components/optimization_guide/content/browser/page_content_annotations_web_contents_observer.h3
-rw-r--r--chromium/components/optimization_guide/content/browser/page_content_annotations_web_contents_observer_unittest.cc62
-rw-r--r--chromium/components/optimization_guide/content/browser/test_page_content_annotator.cc4
-rw-r--r--chromium/components/optimization_guide/content/browser/test_page_content_annotator.h4
-rw-r--r--chromium/components/optimization_guide/core/BUILD.gn127
-rw-r--r--chromium/components/optimization_guide/core/DEPS2
-rw-r--r--chromium/components/optimization_guide/core/base_model_executor.h6
-rw-r--r--chromium/components/optimization_guide/core/base_model_executor_helpers.h29
-rw-r--r--chromium/components/optimization_guide/core/bert_model_executor.cc17
-rw-r--r--chromium/components/optimization_guide/core/bloom_filter_unittest.cc2
-rw-r--r--chromium/components/optimization_guide/core/decision_tree_prediction_model.cc237
-rw-r--r--chromium/components/optimization_guide/core/decision_tree_prediction_model.h97
-rw-r--r--chromium/components/optimization_guide/core/decision_tree_prediction_model_unittest.cc434
-rw-r--r--chromium/components/optimization_guide/core/entity_annotator_native_library.cc445
-rw-r--r--chromium/components/optimization_guide/core/entity_annotator_native_library.h143
-rw-r--r--chromium/components/optimization_guide/core/entity_annotator_native_library_unittest.cc22
-rw-r--r--chromium/components/optimization_guide/core/entity_metadata.cc1
-rw-r--r--chromium/components/optimization_guide/core/entity_metadata_provider.h1
-rw-r--r--chromium/components/optimization_guide/core/hints_fetcher.cc106
-rw-r--r--chromium/components/optimization_guide/core/hints_fetcher.h20
-rw-r--r--chromium/components/optimization_guide/core/hints_fetcher_factory.cc12
-rw-r--r--chromium/components/optimization_guide/core/hints_fetcher_factory.h12
-rw-r--r--chromium/components/optimization_guide/core/hints_fetcher_unittest.cc67
-rw-r--r--chromium/components/optimization_guide/core/hints_manager.cc130
-rw-r--r--chromium/components/optimization_guide/core/hints_manager.h13
-rw-r--r--chromium/components/optimization_guide/core/hints_manager_unittest.cc132
-rw-r--r--chromium/components/optimization_guide/core/local_page_entities_metadata_provider.cc93
-rw-r--r--chromium/components/optimization_guide/core/local_page_entities_metadata_provider.h67
-rw-r--r--chromium/components/optimization_guide/core/local_page_entities_metadata_provider_unittest.cc134
-rw-r--r--chromium/components/optimization_guide/core/model_enums.h57
-rw-r--r--chromium/components/optimization_guide/core/model_executor.h28
-rw-r--r--chromium/components/optimization_guide/core/model_handler.h74
-rw-r--r--chromium/components/optimization_guide/core/model_info.cc4
-rw-r--r--chromium/components/optimization_guide/core/model_util.cc86
-rw-r--r--chromium/components/optimization_guide/core/model_util.h37
-rw-r--r--chromium/components/optimization_guide/core/model_validator.cc14
-rw-r--r--chromium/components/optimization_guide/core/model_validator.h6
-rw-r--r--chromium/components/optimization_guide/core/model_validator_unittest.cc18
-rw-r--r--chromium/components/optimization_guide/core/optimization_guide_constants.cc3
-rw-r--r--chromium/components/optimization_guide/core/optimization_guide_constants.h3
-rw-r--r--chromium/components/optimization_guide/core/optimization_guide_enums.h45
-rw-r--r--chromium/components/optimization_guide/core/optimization_guide_features.cc194
-rw-r--r--chromium/components/optimization_guide/core/optimization_guide_features.h50
-rw-r--r--chromium/components/optimization_guide/core/optimization_guide_features_unittest.cc103
-rw-r--r--chromium/components/optimization_guide/core/optimization_guide_logger.cc32
-rw-r--r--chromium/components/optimization_guide/core/optimization_guide_logger.h45
-rw-r--r--chromium/components/optimization_guide/core/optimization_guide_permissions_util.cc13
-rw-r--r--chromium/components/optimization_guide/core/optimization_guide_permissions_util_unittest.cc38
-rw-r--r--chromium/components/optimization_guide/core/optimization_guide_store.cc272
-rw-r--r--chromium/components/optimization_guide/core/optimization_guide_store.h99
-rw-r--r--chromium/components/optimization_guide/core/optimization_guide_store_unittest.cc553
-rw-r--r--chromium/components/optimization_guide/core/optimization_guide_switches.cc33
-rw-r--r--chromium/components/optimization_guide/core/optimization_guide_switches.h14
-rw-r--r--chromium/components/optimization_guide/core/optimization_guide_switches_unittest.cc2
-rw-r--r--chromium/components/optimization_guide/core/optimization_guide_test_util.cc2
-rw-r--r--chromium/components/optimization_guide/core/optimization_guide_util.cc69
-rw-r--r--chromium/components/optimization_guide/core/optimization_guide_util.h32
-rw-r--r--chromium/components/optimization_guide/core/optimization_guide_util_unittest.cc2
-rw-r--r--chromium/components/optimization_guide/core/optimization_hints_component_update_listener.cc1
-rw-r--r--chromium/components/optimization_guide/core/page_content_annotation_job.cc49
-rw-r--r--chromium/components/optimization_guide/core/page_content_annotation_job.h13
-rw-r--r--chromium/components/optimization_guide/core/page_content_annotation_job_executor.cc2
-rw-r--r--chromium/components/optimization_guide/core/page_content_annotation_job_executor_unittest.cc2
-rw-r--r--chromium/components/optimization_guide/core/page_content_annotations_common.cc30
-rw-r--r--chromium/components/optimization_guide/core/page_content_annotations_common.h30
-rw-r--r--chromium/components/optimization_guide/core/page_entities_model_executor.h2
-rw-r--r--chromium/components/optimization_guide/core/page_entities_model_executor_impl.cc230
-rw-r--r--chromium/components/optimization_guide/core/page_entities_model_executor_impl.h115
-rw-r--r--chromium/components/optimization_guide/core/page_entities_model_executor_impl_unittest.cc268
-rw-r--r--chromium/components/optimization_guide/core/page_topics_model_executor.cc33
-rw-r--r--chromium/components/optimization_guide/core/page_topics_model_executor.h3
-rw-r--r--chromium/components/optimization_guide/core/page_topics_model_executor_unittest.cc34
-rw-r--r--chromium/components/optimization_guide/core/prediction_model.cc82
-rw-r--r--chromium/components/optimization_guide/core/prediction_model.h70
-rw-r--r--chromium/components/optimization_guide/core/prediction_model_fetcher_impl.cc12
-rw-r--r--chromium/components/optimization_guide/core/prediction_model_fetcher_impl.h8
-rw-r--r--chromium/components/optimization_guide/core/prediction_model_fetcher_unittest.cc38
-rw-r--r--chromium/components/optimization_guide/core/prediction_model_unittest.cc134
-rw-r--r--chromium/components/optimization_guide/core/store_update_data.cc69
-rw-r--r--chromium/components/optimization_guide/core/store_update_data.h16
-rw-r--r--chromium/components/optimization_guide/core/store_update_data_unittest.cc39
-rw-r--r--chromium/components/optimization_guide/core/test_model_executor.cc4
-rw-r--r--chromium/components/optimization_guide/core/test_model_executor.h5
-rw-r--r--chromium/components/optimization_guide/core/test_model_info_builder.cc2
-rw-r--r--chromium/components/optimization_guide/core/test_tflite_model_executor.cc12
-rw-r--r--chromium/components/optimization_guide/core/test_tflite_model_executor.h6
-rw-r--r--chromium/components/optimization_guide/core/tflite_model_executor.h79
-rw-r--r--chromium/components/optimization_guide/core/tflite_model_executor_unittest.cc15
-rw-r--r--chromium/components/optimization_guide/core/tflite_op_resolver.cc4
-rw-r--r--chromium/components/optimization_guide/features.gni14
-rw-r--r--chromium/components/optimization_guide/optimization_guide_internals/resources/BUILD.gn71
-rw-r--r--chromium/components/optimization_guide/optimization_guide_internals/resources/OWNERS1
-rw-r--r--chromium/components/optimization_guide/optimization_guide_internals/resources/optimization_guide_internals.html29
-rw-r--r--chromium/components/optimization_guide/optimization_guide_internals/resources/optimization_guide_internals.ts80
-rw-r--r--chromium/components/optimization_guide/optimization_guide_internals/resources/optimization_guide_internals_browser_proxy.ts26
-rw-r--r--chromium/components/optimization_guide/optimization_guide_internals/resources/tsconfig_base.json6
-rw-r--r--chromium/components/optimization_guide/optimization_guide_internals/webui/BUILD.gn30
-rw-r--r--chromium/components/optimization_guide/optimization_guide_internals/webui/DEPS7
-rw-r--r--chromium/components/optimization_guide/optimization_guide_internals/webui/OWNERS4
-rw-r--r--chromium/components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals.mojom24
-rw-r--r--chromium/components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals_page_handler_impl.cc31
-rw-r--r--chromium/components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals_page_handler_impl.h44
-rw-r--r--chromium/components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals_ui.cc38
-rw-r--r--chromium/components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals_ui.h59
-rw-r--r--chromium/components/optimization_guide/optimization_guide_internals/webui/url_constants.cc12
-rw-r--r--chromium/components/optimization_guide/optimization_guide_internals/webui/url_constants.h15
-rw-r--r--chromium/components/optimization_guide/proto/BUILD.gn1
-rw-r--r--chromium/components/optimization_guide/proto/hint_cache.proto2
-rw-r--r--chromium/components/optimization_guide/proto/models.proto37
-rw-r--r--chromium/components/optimization_guide/proto/page_entities_metadata.proto11
-rw-r--r--chromium/components/optimization_guide/proto/page_entities_model_metadata.proto25
-rw-r--r--chromium/components/optimization_guide/proto/page_topics_model_metadata.proto19
120 files changed, 4186 insertions, 2950 deletions
diff --git a/chromium/components/optimization_guide/DEPS b/chromium/components/optimization_guide/DEPS
index c22edf8e983..ba9c08904b4 100644
--- a/chromium/components/optimization_guide/DEPS
+++ b/chromium/components/optimization_guide/DEPS
@@ -1,6 +1,4 @@
include_rules = [
- "+components/data_reduction_proxy/core/browser",
- "+components/data_reduction_proxy/core/common",
"+components/leveldb_proto",
"+components/prefs",
"+components/sync_preferences",
diff --git a/chromium/components/optimization_guide/content/browser/BUILD.gn b/chromium/components/optimization_guide/content/browser/BUILD.gn
index 0d26cc4b455..09a22e1ea29 100644
--- a/chromium/components/optimization_guide/content/browser/BUILD.gn
+++ b/chromium/components/optimization_guide/content/browser/BUILD.gn
@@ -54,10 +54,6 @@ static_library("browser") {
"//third_party/tflite_support",
"//third_party/tflite_support:tflite_support_proto",
]
-
- if (build_with_internal_optimization_guide) {
- deps = [ "//components/optimization_guide/internal" ]
- }
}
}
@@ -83,7 +79,9 @@ source_set("unit_tests") {
"page_text_dump_result_unittest.cc",
"page_text_observer_unittest.cc",
]
- if (build_with_tflite_lib) {
+
+ # crbug.com/1279884 Flaky on CrOS
+ if (!is_chromeos && build_with_tflite_lib) {
sources += [ "page_content_annotations_model_manager_unittest.cc" ]
}
deps = [
diff --git a/chromium/components/optimization_guide/content/browser/page_content_annotations_model_manager.cc b/chromium/components/optimization_guide/content/browser/page_content_annotations_model_manager.cc
index 0deb1b9d4aa..52339bd11a7 100644
--- a/chromium/components/optimization_guide/content/browser/page_content_annotations_model_manager.cc
+++ b/chromium/components/optimization_guide/content/browser/page_content_annotations_model_manager.cc
@@ -4,6 +4,7 @@
#include "components/optimization_guide/content/browser/page_content_annotations_model_manager.h"
+#include "base/metrics/histogram_functions.h"
#include "base/metrics/histogram_macros_local.h"
#include "base/strings/string_number_conversions.h"
#include "base/task/sequenced_task_runner.h"
@@ -17,7 +18,7 @@
#include "content/public/browser/browser_thread.h"
#if BUILDFLAG(BUILD_WITH_INTERNAL_OPTIMIZATION_GUIDE)
-#include "components/optimization_guide/internal/page_entities_model_executor_impl.h"
+#include "components/optimization_guide/core/page_entities_model_executor_impl.h"
#endif
namespace optimization_guide {
@@ -47,34 +48,23 @@ GetOrCreateCurrentContentModelAnnotations(
return std::make_unique<history::VisitContentModelAnnotations>();
}
-void PretendToExecuteJob(base::OnceClosure callback,
- std::unique_ptr<PageContentAnnotationJob> job) {
- while (absl::optional<std::string> input = job->GetNextInput()) {
- job->PostNewResult(
- BatchAnnotationResult::CreatePageTopicsResult(*input, absl::nullopt));
- }
- // Note to future self: The ordering of these callbacks being run will be
- // important once actually being run on an executor.
- job->OnComplete();
- std::move(callback).Run();
-}
-
} // namespace
PageContentAnnotationsModelManager::PageContentAnnotationsModelManager(
const std::string& application_locale,
- OptimizationGuideModelProvider* optimization_guide_model_provider) {
- for (auto opt_target :
- features::GetPageContentModelsToExecute(application_locale)) {
- if (opt_target == proto::OPTIMIZATION_TARGET_PAGE_TOPICS) {
- SetUpPageTopicsModel(optimization_guide_model_provider);
- ordered_models_to_execute_.push_back(opt_target);
- } else if (opt_target == proto::OPTIMIZATION_TARGET_PAGE_ENTITIES) {
- SetUpPageEntitiesModel(optimization_guide_model_provider);
- ordered_models_to_execute_.push_back(opt_target);
- } else {
- // TODO(crbug/1228790): Add histogram for if this happens.
- }
+ OptimizationGuideModelProvider* optimization_guide_model_provider)
+ : optimization_guide_model_provider_(optimization_guide_model_provider) {
+ if (features::ShouldExecutePageVisibilityModelOnPageContent(
+ application_locale)) {
+ SetUpPageTopicsModel(optimization_guide_model_provider);
+ ordered_models_to_execute_.push_back(
+ proto::OPTIMIZATION_TARGET_PAGE_TOPICS);
+ }
+ if (features::ShouldExecutePageEntitiesModelOnPageContent(
+ application_locale)) {
+ SetUpPageEntitiesModel(optimization_guide_model_provider);
+ ordered_models_to_execute_.push_back(
+ proto::OPTIMIZATION_TARGET_PAGE_ENTITIES);
}
}
@@ -234,6 +224,9 @@ void PageContentAnnotationsModelManager::SetUpPageTopicsV2Model(
if (!features::PageTopicsBatchAnnotationsEnabled())
return;
+ if (on_demand_page_topics_model_executor_)
+ return;
+
on_demand_page_topics_model_executor_ =
std::make_unique<PageTopicsModelExecutor>(
optimization_guide_model_provider,
@@ -247,6 +240,9 @@ void PageContentAnnotationsModelManager::SetUpPageVisibilityModel(
if (!features::PageVisibilityBatchAnnotationsEnabled())
return;
+ if (page_visibility_model_executor_)
+ return;
+
page_visibility_model_executor_ =
std::make_unique<PageVisibilityModelExecutor>(
optimization_guide_model_provider,
@@ -477,24 +473,32 @@ void PageContentAnnotationsModelManager::
out_content_annotations->categories = final_categories;
}
-void PageContentAnnotationsModelManager::NotifyWhenModelAvailable(
+void PageContentAnnotationsModelManager::RequestAndNotifyWhenModelAvailable(
AnnotationType type,
base::OnceCallback<void(bool)> callback) {
- if (type == AnnotationType::kPageTopics &&
- on_demand_page_topics_model_executor_) {
- on_demand_page_topics_model_executor_->AddOnModelUpdatedCallback(
- base::BindOnce(std::move(callback), true));
- return;
+ if (type == AnnotationType::kPageTopics) {
+ // No-op if the executor is already setup.
+ SetUpPageTopicsV2Model(optimization_guide_model_provider_);
+
+ if (on_demand_page_topics_model_executor_) {
+ on_demand_page_topics_model_executor_->AddOnModelUpdatedCallback(
+ base::BindOnce(std::move(callback), true));
+ return;
+ }
}
- if (type == AnnotationType::kContentVisibility &&
- page_visibility_model_executor_) {
- page_visibility_model_executor_->AddOnModelUpdatedCallback(
- base::BindOnce(std::move(callback), true));
- return;
+ if (type == AnnotationType::kContentVisibility) {
+ // No-op if the executor is already setup.
+ SetUpPageVisibilityModel(optimization_guide_model_provider_);
+
+ if (page_visibility_model_executor_) {
+ page_visibility_model_executor_->AddOnModelUpdatedCallback(
+ base::BindOnce(std::move(callback), true));
+ return;
+ }
}
- // TODO(crbug/1249632): Add support for page entities.
+ // TODO(crbug/1278828): Add support for page entities.
std::move(callback).Run(false);
}
@@ -506,6 +510,11 @@ PageContentAnnotationsModelManager::GetModelInfoForType(
on_demand_page_topics_model_executor_) {
return on_demand_page_topics_model_executor_->GetModelInfo();
}
+ if (type == AnnotationType::kContentVisibility &&
+ page_visibility_model_executor_) {
+ return page_visibility_model_executor_->GetModelInfo();
+ }
+ // TODO(crbug/1278828): Add support for page entities.
return absl::nullopt;
}
@@ -513,6 +522,11 @@ void PageContentAnnotationsModelManager::Annotate(
BatchAnnotationCallback callback,
const std::vector<std::string>& inputs,
AnnotationType annotation_type) {
+ base::UmaHistogramCounts100(
+ "OptimizationGuide.PageContentAnnotations.BatchRequestedSize." +
+ AnnotationTypeToString(annotation_type),
+ inputs.size());
+
std::unique_ptr<PageContentAnnotationJob> job =
std::make_unique<PageContentAnnotationJob>(std::move(callback), inputs,
annotation_type);
@@ -592,11 +606,15 @@ void PageContentAnnotationsModelManager::MaybeStartNextAnnotationJob() {
return;
}
- // TODO(crbug/1249632): Actually run the model instead.
- content::GetUIThreadTaskRunner({})->PostTask(
- FROM_HERE,
- base::BindOnce(&PretendToExecuteJob, std::move(on_job_complete_callback),
- std::move(job)));
+ // TODO(crbug/1278828): Add support for page entities.
+ if (job->type() == AnnotationType::kPageEntities) {
+ job->FillWithNullOutputs();
+ job->OnComplete();
+ job.reset();
+ std::move(on_job_complete_callback).Run();
+ return;
+ }
+ NOTREACHED();
}
} // namespace optimization_guide
diff --git a/chromium/components/optimization_guide/content/browser/page_content_annotations_model_manager.h b/chromium/components/optimization_guide/content/browser/page_content_annotations_model_manager.h
index 44ce32c15d3..dc43f852d23 100644
--- a/chromium/components/optimization_guide/content/browser/page_content_annotations_model_manager.h
+++ b/chromium/components/optimization_guide/content/browser/page_content_annotations_model_manager.h
@@ -5,6 +5,7 @@
#ifndef COMPONENTS_OPTIMIZATION_GUIDE_CONTENT_BROWSER_PAGE_CONTENT_ANNOTATIONS_MODEL_MANAGER_H_
#define COMPONENTS_OPTIMIZATION_GUIDE_CONTENT_BROWSER_PAGE_CONTENT_ANNOTATIONS_MODEL_MANAGER_H_
+#include "base/memory/raw_ptr.h"
#include "components/history/core/browser/url_row.h"
#include "components/optimization_guide/content/browser/page_content_annotator.h"
#include "components/optimization_guide/core/bert_model_handler.h"
@@ -49,6 +50,7 @@ class PageContentAnnotationsModelManager : public PageContentAnnotator {
// This will execute all supported models of the PageContentAnnotationsService
// feature and is only used by the History service code path. See the below
// |Annotate| for the publicly available Annotation code path.
+ // TODO(crbug/1278833): Remove this.
void Annotate(const std::string& text, PageContentAnnotatedCallback callback);
// PageContentAnnotator:
@@ -56,16 +58,16 @@ class PageContentAnnotationsModelManager : public PageContentAnnotator {
const std::vector<std::string>& inputs,
AnnotationType annotation_type) override;
- // Runs |callback| with true when the model that powers |BatchAnnotate| for
- // the given annotation type is ready to execute. If the model is ready now,
- // the callback is run immediately. If the model will never become ready, due
- // to feature flags for example, the callback run with false.
- void NotifyWhenModelAvailable(AnnotationType type,
- base::OnceCallback<void(bool)> callback);
+ // Requests that the given model for |type| be loaded in the background and
+ // then runs |callback| with true when the model is ready to execute. If the
+ // model is ready now, the callback is run immediately. If the model file will
+ // never be available, the callback is run with false.
+ void RequestAndNotifyWhenModelAvailable(
+ AnnotationType type,
+ base::OnceCallback<void(bool)> callback);
// Returns the model info associated with the given AnnotationType, if it is
// available and loaded.
- // TODO(crbug/1249632): Add support for more than just page topics.
absl::optional<ModelInfo> GetModelInfoForType(AnnotationType type) const;
// Returns the version of the page topics model that is currently being used
@@ -92,7 +94,7 @@ class PageContentAnnotationsModelManager : public PageContentAnnotator {
// All publicly posted jobs will have this priority level.
kNormal = 1,
- // TODO(crbug/1249632): Add a kHigh value for internal jobs.
+ // TODO(crbug/1278833): Add a kHigh value for internal jobs.
// Always keep this last and as the highest priority + 1. This value is
// passed to the priority queue ctor as "how many level of priorities are
@@ -215,7 +217,7 @@ class PageContentAnnotationsModelManager : public PageContentAnnotator {
// Runs the next job in |job_queue_| if there is any.
void MaybeStartNextAnnotationJob();
- // Called when a job finishes executing.
+ // Called when a |job| finishes executing, just before it is deleted.
void OnJobExecutionComplete();
// The model executor responsible for executing the page topics model.
@@ -254,6 +256,9 @@ class PageContentAnnotationsModelManager : public PageContentAnnotator {
// The current state of the running job, if any.
JobExecutionState job_state_ = JobExecutionState::kIdle;
+ // The model provider, not owned.
+ raw_ptr<OptimizationGuideModelProvider> optimization_guide_model_provider_;
+
base::WeakPtrFactory<PageContentAnnotationsModelManager> weak_ptr_factory_{
this};
};
diff --git a/chromium/components/optimization_guide/content/browser/page_content_annotations_model_manager_unittest.cc b/chromium/components/optimization_guide/content/browser/page_content_annotations_model_manager_unittest.cc
index caae58be379..10237a060e6 100644
--- a/chromium/components/optimization_guide/content/browser/page_content_annotations_model_manager_unittest.cc
+++ b/chromium/components/optimization_guide/content/browser/page_content_annotations_model_manager_unittest.cc
@@ -11,6 +11,7 @@
#include "base/test/metrics/histogram_tester.h"
#include "base/test/scoped_feature_list.h"
#include "base/test/scoped_run_loop_timeout.h"
+#include "build/build_config.h"
#include "components/optimization_guide/core/execution_status.h"
#include "components/optimization_guide/core/optimization_guide_features.h"
#include "components/optimization_guide/core/page_entities_model_executor.h"
@@ -117,9 +118,8 @@ class FakePageEntitiesModelExecutor : public PageEntitiesModelExecutor {
class PageContentAnnotationsModelManagerTest : public testing::Test {
public:
PageContentAnnotationsModelManagerTest() {
- scoped_feature_list_.InitAndEnableFeatureWithParameters(
- features::kPageContentAnnotations,
- {{"models_to_execute_v2", "OPTIMIZATION_TARGET_PAGE_TOPICS"}});
+ scoped_feature_list_.InitAndEnableFeature(
+ features::kPageVisibilityPageContentAnnotations);
}
~PageContentAnnotationsModelManagerTest() override = default;
@@ -172,7 +172,8 @@ class PageContentAnnotationsModelManagerTest : public testing::Test {
}
void SetupPageTopicsV2ModelExecutor() {
- model_manager()->SetUpPageTopicsV2Model(model_observer_tracker());
+ model_manager()->RequestAndNotifyWhenModelAvailable(
+ AnnotationType::kPageTopics, base::DoNothing());
// If the feature flag is disabled, the executor won't have been created so
// skip everything else.
if (!model_manager()->on_demand_page_topics_model_executor_)
@@ -204,7 +205,8 @@ class PageContentAnnotationsModelManagerTest : public testing::Test {
void SendPageVisibilityModelToExecutor(
const absl::optional<proto::Any>& model_metadata) {
- model_manager()->SetUpPageVisibilityModel(model_observer_tracker());
+ model_manager()->RequestAndNotifyWhenModelAvailable(
+ AnnotationType::kContentVisibility, base::DoNothing());
// If the feature flag is disabled, the executor won't have been created so
// skip everything else.
if (!model_manager()->page_visibility_model_executor_)
@@ -562,7 +564,13 @@ TEST_F(PageContentAnnotationsModelManagerTest,
EXPECT_FALSE(GetMetadataForEntityId("someid").has_value());
}
-TEST_F(PageContentAnnotationsModelManagerTest, BatchAnnotate_PageTopics) {
+// TODO(crbug.com/1286473): Flaky on Chrome OS.
+#if BUILDFLAG(IS_CHROMEOS)
+#define MAYBE_BatchAnnotate_PageTopics DISABLED_BatchAnnotate_PageTopics
+#else
+#define MAYBE_BatchAnnotate_PageTopics BatchAnnotate_PageTopics
+#endif
+TEST_F(PageContentAnnotationsModelManagerTest, MAYBE_BatchAnnotate_PageTopics) {
SetupPageTopicsV2ModelExecutor();
// Running the actual model can take a while.
@@ -591,6 +599,18 @@ TEST_F(PageContentAnnotationsModelManagerTest, BatchAnnotate_PageTopics) {
"OptimizationGuide.ModelExecutor.ExecutionStatus.PageTopicsV2",
ExecutionStatus::kSuccess, 1);
+ histogram_tester.ExpectUniqueSample(
+ "OptimizationGuide.PageContentAnnotations.BatchRequestedSize.PageTopics",
+ 1, 1);
+ histogram_tester.ExpectUniqueSample(
+ "OptimizationGuide.PageContentAnnotations.BatchSuccess.PageTopics", true,
+ 1);
+ histogram_tester.ExpectTotalCount(
+ "OptimizationGuide.PageContentAnnotations.JobExecutionTime.PageTopics",
+ 1);
+ histogram_tester.ExpectTotalCount(
+ "OptimizationGuide.PageContentAnnotations.JobScheduleTime.PageTopics", 1);
+
EXPECT_TRUE(model_observer_tracker()->DidRegisterForTarget(
proto::OptimizationTarget::OPTIMIZATION_TARGET_PAGE_TOPICS_V2, nullptr));
@@ -628,6 +648,17 @@ TEST_F(PageContentAnnotationsModelManagerTest,
base::RunLoop().RunUntilIdle();
histogram_tester.ExpectTotalCount(
"OptimizationGuide.ModelExecutor.ExecutionStatus.PageTopicsV2", 0);
+ histogram_tester.ExpectUniqueSample(
+ "OptimizationGuide.PageContentAnnotations.BatchRequestedSize.PageTopics",
+ 1, 1);
+ histogram_tester.ExpectUniqueSample(
+ "OptimizationGuide.PageContentAnnotations.BatchSuccess.PageTopics", false,
+ 1);
+ histogram_tester.ExpectTotalCount(
+ "OptimizationGuide.PageContentAnnotations.JobExecutionTime.PageTopics",
+ 1);
+ histogram_tester.ExpectTotalCount(
+ "OptimizationGuide.PageContentAnnotations.JobScheduleTime.PageTopics", 1);
EXPECT_FALSE(model_observer_tracker()->DidRegisterForTarget(
proto::OptimizationTarget::OPTIMIZATION_TARGET_PAGE_TOPICS_V2, nullptr));
@@ -641,6 +672,7 @@ TEST_F(PageContentAnnotationsModelManagerTest,
}
TEST_F(PageContentAnnotationsModelManagerTest, BatchAnnotate_PageEntities) {
+ base::HistogramTester histogram_tester;
base::RunLoop run_loop;
std::vector<BatchAnnotationResult> result;
BatchAnnotationCallback callback = base::BindOnce(
@@ -654,10 +686,22 @@ TEST_F(PageContentAnnotationsModelManagerTest, BatchAnnotate_PageEntities) {
model_manager()->Annotate(std::move(callback), {"input"},
AnnotationType::kPageEntities);
+ histogram_tester.ExpectUniqueSample(
+ "OptimizationGuide.PageContentAnnotations.BatchRequestedSize."
+ "PageEntities",
+ 1, 1);
+ histogram_tester.ExpectUniqueSample(
+ "OptimizationGuide.PageContentAnnotations.BatchSuccess.PageEntities",
+ false, 1);
+ histogram_tester.ExpectTotalCount(
+ "OptimizationGuide.PageContentAnnotations.JobExecutionTime.PageEntities",
+ 1);
+ histogram_tester.ExpectTotalCount(
+ "OptimizationGuide.PageContentAnnotations.JobScheduleTime.PageEntities",
+ 1);
+
run_loop.Run();
- // TODO(crbug/1249632): Check the corresponding output once the model is being
- // run.
ASSERT_EQ(result.size(), 1U);
EXPECT_EQ(result[0].input(), "input");
EXPECT_EQ(result[0].topics(), absl::nullopt);
@@ -665,7 +709,15 @@ TEST_F(PageContentAnnotationsModelManagerTest, BatchAnnotate_PageEntities) {
EXPECT_EQ(result[0].visibility_score(), absl::nullopt);
}
-TEST_F(PageContentAnnotationsModelManagerTest, BatchAnnotate_PageVisibility) {
+// TODO(crbug.com/1286473): Flaky on Chrome OS.
+#if BUILDFLAG(IS_CHROMEOS)
+#define MAYBE_BatchAnnotate_PageVisibility DISABLED_BatchAnnotate_PageVisibility
+#else
+#define MAYBE_BatchAnnotate_PageVisibility BatchAnnotate_PageVisibility
+#endif
+TEST_F(PageContentAnnotationsModelManagerTest,
+ MAYBE_BatchAnnotate_PageVisibility) {
+ base::HistogramTester histogram_tester;
proto::Any any_metadata;
any_metadata.set_type_url(
"type.googleapis.com/com.foo.PageTopicsModelMetadata");
@@ -697,6 +749,21 @@ TEST_F(PageContentAnnotationsModelManagerTest, BatchAnnotate_PageVisibility) {
EXPECT_TRUE(model_observer_tracker()->DidRegisterForTarget(
proto::OptimizationTarget::OPTIMIZATION_TARGET_PAGE_VISIBILITY, nullptr));
+ histogram_tester.ExpectUniqueSample(
+ "OptimizationGuide.PageContentAnnotations.BatchRequestedSize."
+ "ContentVisibility",
+ 1, 1);
+ histogram_tester.ExpectUniqueSample(
+ "OptimizationGuide.PageContentAnnotations.BatchSuccess.ContentVisibility",
+ true, 1);
+ histogram_tester.ExpectTotalCount(
+ "OptimizationGuide.PageContentAnnotations.JobExecutionTime."
+ "ContentVisibility",
+ 1);
+ histogram_tester.ExpectTotalCount(
+ "OptimizationGuide.PageContentAnnotations.JobScheduleTime."
+ "ContentVisibility",
+ 1);
ASSERT_EQ(result.size(), 1U);
EXPECT_EQ(result[0].input(), "input");
@@ -707,6 +774,7 @@ TEST_F(PageContentAnnotationsModelManagerTest, BatchAnnotate_PageVisibility) {
TEST_F(PageContentAnnotationsModelManagerTest,
BatchAnnotate_PageVisibilityDisabled) {
+ base::HistogramTester histogram_tester;
base::test::ScopedFeatureList scoped_feature_list;
scoped_feature_list.InitAndDisableFeature(
features::kPageVisibilityBatchAnnotations);
@@ -739,6 +807,21 @@ TEST_F(PageContentAnnotationsModelManagerTest,
EXPECT_FALSE(model_observer_tracker()->DidRegisterForTarget(
proto::OptimizationTarget::OPTIMIZATION_TARGET_PAGE_VISIBILITY, nullptr));
+ histogram_tester.ExpectUniqueSample(
+ "OptimizationGuide.PageContentAnnotations.BatchRequestedSize."
+ "ContentVisibility",
+ 1, 1);
+ histogram_tester.ExpectUniqueSample(
+ "OptimizationGuide.PageContentAnnotations.BatchSuccess.ContentVisibility",
+ false, 1);
+ histogram_tester.ExpectTotalCount(
+ "OptimizationGuide.PageContentAnnotations.JobExecutionTime."
+ "ContentVisibility",
+ 1);
+ histogram_tester.ExpectTotalCount(
+ "OptimizationGuide.PageContentAnnotations.JobScheduleTime."
+ "ContentVisibility",
+ 1);
ASSERT_EQ(result.size(), 1U);
EXPECT_EQ(result[0].input(), "input");
@@ -747,7 +830,14 @@ TEST_F(PageContentAnnotationsModelManagerTest,
EXPECT_EQ(result[0].visibility_score(), absl::nullopt);
}
-TEST_F(PageContentAnnotationsModelManagerTest, BatchAnnotate_CalledTwice) {
+// TODO(crbug.com/1286473): Flaky on Chrome OS.
+#if BUILDFLAG(IS_CHROMEOS)
+#define MAYBE_BatchAnnotate_CalledTwice DISABLED_BatchAnnotate_CalledTwice
+#else
+#define MAYBE_BatchAnnotate_CalledTwice BatchAnnotate_CalledTwice
+#endif
+TEST_F(PageContentAnnotationsModelManagerTest,
+ MAYBE_BatchAnnotate_CalledTwice) {
SetupPageTopicsV2ModelExecutor();
base::HistogramTester histogram_tester;
@@ -790,6 +880,18 @@ TEST_F(PageContentAnnotationsModelManagerTest, BatchAnnotate_CalledTwice) {
EXPECT_TRUE(model_observer_tracker()->DidRegisterForTarget(
proto::OptimizationTarget::OPTIMIZATION_TARGET_PAGE_TOPICS_V2, nullptr));
+ histogram_tester.ExpectUniqueSample(
+ "OptimizationGuide.PageContentAnnotations.BatchRequestedSize.PageTopics",
+ 1, 2);
+ histogram_tester.ExpectUniqueSample(
+ "OptimizationGuide.PageContentAnnotations.BatchSuccess.PageTopics", true,
+ 2);
+ histogram_tester.ExpectTotalCount(
+ "OptimizationGuide.PageContentAnnotations.JobExecutionTime.PageTopics",
+ 2);
+ histogram_tester.ExpectTotalCount(
+ "OptimizationGuide.PageContentAnnotations.JobScheduleTime.PageTopics", 2);
+
// The model should have only been loaded once and then used for both jobs.
histogram_tester.ExpectUniqueSample(
"OptimizationGuide.ModelExecutor.ModelAvailableToLoad.PageTopicsV2", true,
@@ -816,6 +918,8 @@ TEST_F(PageContentAnnotationsModelManagerTest, GetModelInfoForType) {
model_manager()->GetModelInfoForType(AnnotationType::kContentVisibility));
SetupPageTopicsV2ModelExecutor();
+ EXPECT_TRUE(
+ model_manager()->GetModelInfoForType(AnnotationType::kPageTopics));
proto::Any any_metadata;
any_metadata.set_type_url(
@@ -829,59 +933,61 @@ TEST_F(PageContentAnnotationsModelManagerTest, GetModelInfoForType) {
SendPageVisibilityModelToExecutor(any_metadata);
EXPECT_TRUE(
- model_manager()->GetModelInfoForType(AnnotationType::kPageTopics));
- EXPECT_FALSE(
model_manager()->GetModelInfoForType(AnnotationType::kContentVisibility));
}
TEST_F(PageContentAnnotationsModelManagerTest,
- NotifyWhenModelAvailable_NotAvailable) {
- absl::optional<bool> topics_callback_success;
- absl::optional<bool> visibility_callback_success;
+ NotifyWhenModelAvailable_TopicsOnly) {
+ SetupPageTopicsV2ModelExecutor();
- model_manager()->NotifyWhenModelAvailable(
+ base::RunLoop topics_run_loop;
+ bool topics_callback_success = false;
+
+ model_manager()->RequestAndNotifyWhenModelAvailable(
AnnotationType::kPageTopics,
- base::BindOnce([](absl::optional<bool>* out_success,
- bool success) { *out_success = success; },
- &topics_callback_success));
- model_manager()->NotifyWhenModelAvailable(
- AnnotationType::kContentVisibility,
- base::BindOnce([](absl::optional<bool>* out_success,
- bool success) { *out_success = success; },
- &visibility_callback_success));
-
- ASSERT_TRUE(topics_callback_success);
- ASSERT_TRUE(visibility_callback_success);
- EXPECT_FALSE(*topics_callback_success);
- EXPECT_FALSE(*visibility_callback_success);
+ base::BindOnce(
+ [](base::RunLoop* run_loop, bool* out_success, bool success) {
+ *out_success = success;
+ run_loop->Quit();
+ },
+ &topics_run_loop, &topics_callback_success));
+
+ topics_run_loop.Run();
+
+ EXPECT_TRUE(topics_callback_success);
}
TEST_F(PageContentAnnotationsModelManagerTest,
- NotifyWhenModelAvailable_TopicsOnly) {
- SetupPageTopicsV2ModelExecutor();
+ NotifyWhenModelAvailable_VisibilityOnly) {
+ proto::Any any_metadata;
+ any_metadata.set_type_url(
+ "type.googleapis.com/com.foo.PageTopicsModelMetadata");
+ proto::PageTopicsModelMetadata page_topics_model_metadata;
+ page_topics_model_metadata.set_version(123);
+ page_topics_model_metadata.mutable_output_postprocessing_params()
+ ->mutable_visibility_params()
+ ->set_category_name("DO NOT EVALUATE");
+ page_topics_model_metadata.SerializeToString(any_metadata.mutable_value());
+ SendPageVisibilityModelToExecutor(any_metadata);
- absl::optional<bool> topics_callback_success;
- absl::optional<bool> visibility_callback_success;
+ base::RunLoop visibility_run_loop;
+ bool visibility_callback_success = false;
- model_manager()->NotifyWhenModelAvailable(
- AnnotationType::kPageTopics,
- base::BindOnce([](absl::optional<bool>* out_success,
- bool success) { *out_success = success; },
- &topics_callback_success));
- model_manager()->NotifyWhenModelAvailable(
+ model_manager()->RequestAndNotifyWhenModelAvailable(
AnnotationType::kContentVisibility,
- base::BindOnce([](absl::optional<bool>* out_success,
- bool success) { *out_success = success; },
- &visibility_callback_success));
-
- ASSERT_TRUE(topics_callback_success);
- ASSERT_TRUE(visibility_callback_success);
- EXPECT_TRUE(*topics_callback_success);
- EXPECT_FALSE(*visibility_callback_success);
+ base::BindOnce(
+ [](base::RunLoop* run_loop, bool* out_success, bool success) {
+ *out_success = success;
+ run_loop->Quit();
+ },
+ &visibility_run_loop, &visibility_callback_success));
+
+ visibility_run_loop.Run();
+
+ EXPECT_TRUE(visibility_callback_success);
}
-TEST_F(PageContentAnnotationsModelManagerTest,
- NotifyWhenModelAvailable_VisibilityOnly) {
+TEST_F(PageContentAnnotationsModelManagerTest, NotifyWhenModelAvailable_Both) {
proto::Any any_metadata;
any_metadata.set_type_url(
"type.googleapis.com/com.foo.PageTopicsModelMetadata");
@@ -893,33 +999,44 @@ TEST_F(PageContentAnnotationsModelManagerTest,
page_topics_model_metadata.SerializeToString(any_metadata.mutable_value());
SendPageVisibilityModelToExecutor(any_metadata);
- absl::optional<bool> topics_callback_success;
- absl::optional<bool> visibility_callback_success;
+ SetupPageTopicsV2ModelExecutor();
+
+ base::RunLoop topics_run_loop;
+ base::RunLoop visibility_run_loop;
+ bool topics_callback_success = false;
+ bool visibility_callback_success = false;
- model_manager()->NotifyWhenModelAvailable(
+ model_manager()->RequestAndNotifyWhenModelAvailable(
AnnotationType::kPageTopics,
- base::BindOnce([](absl::optional<bool>* out_success,
- bool success) { *out_success = success; },
- &topics_callback_success));
- model_manager()->NotifyWhenModelAvailable(
+ base::BindOnce(
+ [](base::RunLoop* run_loop, bool* out_success, bool success) {
+ *out_success = success;
+ run_loop->Quit();
+ },
+ &topics_run_loop, &topics_callback_success));
+ model_manager()->RequestAndNotifyWhenModelAvailable(
AnnotationType::kContentVisibility,
- base::BindOnce([](absl::optional<bool>* out_success,
- bool success) { *out_success = success; },
- &visibility_callback_success));
-
- ASSERT_TRUE(topics_callback_success);
- ASSERT_TRUE(visibility_callback_success);
- EXPECT_FALSE(*topics_callback_success);
- EXPECT_TRUE(*visibility_callback_success);
+ base::BindOnce(
+ [](base::RunLoop* run_loop, bool* out_success, bool success) {
+ *out_success = success;
+ run_loop->Quit();
+ },
+ &visibility_run_loop, &visibility_callback_success));
+
+ topics_run_loop.Run();
+ visibility_run_loop.Run();
+
+ EXPECT_TRUE(topics_callback_success);
+ EXPECT_TRUE(visibility_callback_success);
}
class PageContentAnnotationsModelManagerEntitiesOnlyTest
: public PageContentAnnotationsModelManagerTest {
public:
PageContentAnnotationsModelManagerEntitiesOnlyTest() {
- scoped_feature_list_.InitAndEnableFeatureWithParameters(
- features::kPageContentAnnotations,
- {{"models_to_execute_v2", "OPTIMIZATION_TARGET_PAGE_ENTITIES"}});
+ scoped_feature_list_.InitWithFeatures(
+ {features::kPageEntitiesPageContentAnnotations},
+ {features::kPageVisibilityPageContentAnnotations});
}
private:
@@ -999,11 +1116,10 @@ class PageContentAnnotationsModelManagerMultipleModelsTest
: public PageContentAnnotationsModelManagerTest {
public:
PageContentAnnotationsModelManagerMultipleModelsTest() {
- scoped_feature_list_.InitAndEnableFeatureWithParameters(
- features::kPageContentAnnotations,
- {{"models_to_execute_v2",
- "OPTIMIZATION_TARGET_PAGE_ENTITIES,OPTIMIZATION_TARGET_PAGE_"
- "TOPICS"}});
+ scoped_feature_list_.InitWithFeatures(
+ {features::kPageEntitiesPageContentAnnotations,
+ features::kPageVisibilityPageContentAnnotations},
+ {});
}
private:
diff --git a/chromium/components/optimization_guide/content/browser/page_content_annotations_service.cc b/chromium/components/optimization_guide/content/browser/page_content_annotations_service.cc
index 8d1421253f2..a88b7181f85 100644
--- a/chromium/components/optimization_guide/content/browser/page_content_annotations_service.cc
+++ b/chromium/components/optimization_guide/content/browser/page_content_annotations_service.cc
@@ -4,14 +4,22 @@
#include "components/optimization_guide/content/browser/page_content_annotations_service.h"
+#include "base/callback_helpers.h"
#include "base/metrics/histogram_functions.h"
+#include "base/metrics/histogram_macros_local.h"
+#include "base/rand_util.h"
#include "base/strings/string_util.h"
#include "base/strings/utf_string_conversions.h"
+#include "base/time/default_tick_clock.h"
+#include "base/timer/timer.h"
#include "components/history/core/browser/history_service.h"
+#include "components/leveldb_proto/public/proto_database_provider.h"
+#include "components/optimization_guide/core/local_page_entities_metadata_provider.h"
#include "components/optimization_guide/core/noisy_metrics_recorder.h"
#include "components/optimization_guide/core/optimization_guide_enums.h"
#include "components/optimization_guide/core/optimization_guide_features.h"
#include "components/optimization_guide/core/optimization_guide_model_provider.h"
+#include "components/optimization_guide/core/optimization_guide_switches.h"
#include "content/public/browser/navigation_entry.h"
#include "content/public/browser/web_contents.h"
#include "services/metrics/public/cpp/metrics_utils.h"
@@ -71,14 +79,27 @@ void MaybeRecordVisibilityUKM(
}
#endif /* BUILDFLAG(BUILD_WITH_TFLITE_LIB) */
+const char kDummyTextBlob[] =
+ "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod "
+ "tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim "
+ "veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea "
+ "commodo consequat. Duis aute irure dolor in reprehenderit in voluptate "
+ "velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint "
+ "occaecat cupidatat non proident, sunt in culpa qui officia deserunt "
+ "mollit anim id est laborum";
+
} // namespace
PageContentAnnotationsService::PageContentAnnotationsService(
const std::string& application_locale,
OptimizationGuideModelProvider* optimization_guide_model_provider,
- history::HistoryService* history_service)
+ history::HistoryService* history_service,
+ leveldb_proto::ProtoDatabaseProvider* database_provider,
+ const base::FilePath& database_dir,
+ scoped_refptr<base::SequencedTaskRunner> background_task_runner)
: last_annotated_history_visits_(
- features::MaxContentAnnotationRequestsCached()) {
+ features::MaxContentAnnotationRequestsCached()),
+ annotated_text_cache_(features::MaxVisitAnnotationCacheSize()) {
DCHECK(optimization_guide_model_provider);
DCHECK(history_service);
history_service_ = history_service;
@@ -87,12 +108,28 @@ PageContentAnnotationsService::PageContentAnnotationsService(
application_locale, optimization_guide_model_provider);
annotator_ = model_manager_.get();
#endif
+
+ if (features::UseLocalPageEntitiesMetadataProvider()) {
+ local_page_entities_metadata_provider_ =
+ std::make_unique<LocalPageEntitiesMetadataProvider>();
+ local_page_entities_metadata_provider_->Initialize(
+ database_provider, database_dir, background_task_runner);
+ }
+
+ if (features::BatchAnnotationsValidationEnabled()) {
+ validation_timer_ = std::make_unique<base::OneShotTimer>(
+ base::DefaultTickClock::GetInstance());
+ validation_timer_->Start(
+ FROM_HERE, features::BatchAnnotationValidationStartupDelay(),
+ base::BindRepeating(
+ &PageContentAnnotationsService::RunBatchAnnotationValidation,
+ weak_ptr_factory_.GetWeakPtr()));
+ }
}
PageContentAnnotationsService::~PageContentAnnotationsService() = default;
-void PageContentAnnotationsService::Annotate(const HistoryVisit& visit,
- const std::string& text) {
+void PageContentAnnotationsService::Annotate(const HistoryVisit& visit) {
if (last_annotated_history_visits_.Peek(visit) !=
last_annotated_history_visits_.end()) {
// We have already been requested to annotate this visit, so don't submit
@@ -102,13 +139,88 @@ void PageContentAnnotationsService::Annotate(const HistoryVisit& visit,
last_annotated_history_visits_.Put(visit, true);
#if BUILDFLAG(BUILD_WITH_TFLITE_LIB)
- model_manager_->Annotate(
- text,
- base::BindOnce(&PageContentAnnotationsService::OnPageContentAnnotated,
- weak_ptr_factory_.GetWeakPtr(), visit));
+ if (!visit.text_to_annotate)
+ return;
+ // Used for testing.
+ LOCAL_HISTOGRAM_BOOLEAN(
+ "PageContentAnnotations.AnnotateVisit.AnnotationRequested", true);
+
+ auto it = annotated_text_cache_.Peek(*visit.text_to_annotate);
+ if (it != annotated_text_cache_.end()) {
+ // We have annotations the text for this visit, so return that immediately
+ // rather than re-executing the model.
+ //
+ // TODO(crbug.com/1291275): If the model was updated, the cached value could
+ // be stale so we should invalidate the cache on model updates.
+ OnPageContentAnnotated(visit, it->second);
+ base::UmaHistogramBoolean(
+ "OptimizationGuide.PageContentAnnotations.AnnotateVisitResultCached",
+ true);
+ return;
+ }
+ visits_to_annotate_.emplace_back(visit);
+ base::UmaHistogramBoolean(
+ "OptimizationGuide.PageContentAnnotations.AnnotateVisitResultCached",
+ false);
+ if (visits_to_annotate_.size() >= features::AnnotateVisitBatchSize()) {
+ if (current_visit_annotation_batch_.empty()) {
+ // Used for testing.
+ LOCAL_HISTOGRAM_BOOLEAN(
+ "PageContentAnnotations.AnnotateVisit.BatchAnnotationStarted", true);
+ current_visit_annotation_batch_ = std::move(visits_to_annotate_);
+ AnnotateVisitBatch();
+ return;
+ }
+ // The queue is full and an batch annotation is actively being done so
+ // we will remove the "oldest" visit.
+ visits_to_annotate_.erase(visits_to_annotate_.begin());
+ // Used for testing.
+ LOCAL_HISTOGRAM_BOOLEAN(
+ "PageContentAnnotations.AnnotateVisit.QueueFullVisitDropped", true);
+ }
+ // Used for testing.
+ LOCAL_HISTOGRAM_BOOLEAN(
+ "PageContentAnnotations.AnnotateVisit.AnnotationRequestQueued", true);
#endif
}
+#if BUILDFLAG(BUILD_WITH_TFLITE_LIB)
+void PageContentAnnotationsService::AnnotateVisitBatch() {
+ DCHECK(!current_visit_annotation_batch_.empty());
+
+ if (switches::StopHistoryVisitBatchAnnotateForTesting()) {
+ // Code beyond this is tested in multiple places. This just ensures the
+ // calls up to this point can be more easily configured.
+ return;
+ }
+
+ if (current_visit_annotation_batch_.empty()) {
+ return;
+ }
+ auto visit = current_visit_annotation_batch_.back();
+ DCHECK(visit.text_to_annotate);
+ if (visit.text_to_annotate) {
+ model_manager_->Annotate(
+ *(visit.text_to_annotate),
+ base::BindOnce(&PageContentAnnotationsService::OnBatchVisitAnnotated,
+ weak_ptr_factory_.GetWeakPtr(), visit));
+ }
+}
+
+void PageContentAnnotationsService::OnBatchVisitAnnotated(
+ const HistoryVisit& visit,
+ const absl::optional<history::VisitContentModelAnnotations>&
+ content_annotations) {
+ OnPageContentAnnotated(visit, content_annotations);
+ DCHECK_EQ(visit.navigation_id,
+ current_visit_annotation_batch_.back().navigation_id);
+ current_visit_annotation_batch_.pop_back();
+ if (!current_visit_annotation_batch_.empty()) {
+ AnnotateVisitBatch();
+ }
+}
+#endif
+
void PageContentAnnotationsService::OverridePageContentAnnotatorForTesting(
PageContentAnnotator* annotator) {
annotator_ = annotator;
@@ -135,17 +247,27 @@ absl::optional<ModelInfo> PageContentAnnotationsService::GetModelInfoForType(
#endif
}
-void PageContentAnnotationsService::NotifyWhenModelAvailable(
+void PageContentAnnotationsService::RequestAndNotifyWhenModelAvailable(
AnnotationType type,
base::OnceCallback<void(bool)> callback) {
#if BUILDFLAG(BUILD_WITH_TFLITE_LIB)
DCHECK(model_manager_);
- model_manager_->NotifyWhenModelAvailable(type, std::move(callback));
+ model_manager_->RequestAndNotifyWhenModelAvailable(type, std::move(callback));
#else
std::move(callback).Run(false);
#endif
}
+void PageContentAnnotationsService::PersistSearchMetadata(
+ const HistoryVisit& visit,
+ const SearchMetadata& search_metadata) {
+ QueryURL(visit,
+ base::BindOnce(&history::HistoryService::AddSearchMetadataForVisit,
+ history_service_->AsWeakPtr(),
+ search_metadata.normalized_url,
+ search_metadata.search_terms));
+}
+
void PageContentAnnotationsService::ExtractRelatedSearches(
const HistoryVisit& visit,
content::WebContents* web_contents) {
@@ -166,6 +288,10 @@ void PageContentAnnotationsService::OnPageContentAnnotated(
if (!content_annotations)
return;
+ if (annotated_text_cache_.Peek(*visit.text_to_annotate) ==
+ annotated_text_cache_.end()) {
+ annotated_text_cache_.Put(*visit.text_to_annotate, *content_annotations);
+ }
MaybeRecordVisibilityUKM(visit, content_annotations);
if (!features::ShouldWriteContentAnnotationsToHistoryService())
@@ -258,6 +384,13 @@ void PageContentAnnotationsService::OnURLQueried(
void PageContentAnnotationsService::GetMetadataForEntityId(
const std::string& entity_id,
EntityMetadataRetrievedCallback callback) {
+ if (features::UseLocalPageEntitiesMetadataProvider()) {
+ DCHECK(local_page_entities_metadata_provider_);
+ local_page_entities_metadata_provider_->GetMetadataForEntityId(
+ entity_id, std::move(callback));
+ return;
+ }
+
#if BUILDFLAG(BUILD_WITH_TFLITE_LIB)
model_manager_->GetMetadataForEntityId(entity_id, std::move(callback));
#else
@@ -277,14 +410,51 @@ void PageContentAnnotationsService::PersistRemotePageEntities(
history_service_->AsWeakPtr(), annotations));
}
+void PageContentAnnotationsService::RunBatchAnnotationValidation() {
+ DCHECK(features::BatchAnnotationsValidationEnabled());
+ DCHECK(validation_timer_);
+ validation_timer_.reset();
+
+ std::vector<std::string> dummy_inputs;
+ dummy_inputs.reserve(features::BatchAnnotationsValidationBatchSize());
+ for (size_t i = 0; i < features::BatchAnnotationsValidationBatchSize(); i++) {
+ // Pick a random substring of the dummy blob so that we can't do any caching
+ // or deduping.
+ size_t half_length = std::strlen(kDummyTextBlob) / 2;
+ size_t rand_start = base::RandInt(0, half_length - 1);
+ dummy_inputs.emplace_back(
+ std::string(kDummyTextBlob + rand_start, half_length));
+ }
+
+ LOCAL_HISTOGRAM_COUNTS_100(
+ "OptimizationGuide.PageContentAnnotationsService.ValidationRun",
+ dummy_inputs.size());
+
+ BatchAnnotate(base::DoNothing(), dummy_inputs,
+ AnnotationType::kContentVisibility);
+}
+
// static
HistoryVisit PageContentAnnotationsService::CreateHistoryVisitFromWebContents(
content::WebContents* web_contents,
int64_t navigation_id) {
- HistoryVisit visit = {
+ HistoryVisit visit(
web_contents->GetController().GetLastCommittedEntry()->GetTimestamp(),
- web_contents->GetLastCommittedURL(), navigation_id};
+ web_contents->GetLastCommittedURL(), navigation_id);
return visit;
}
+HistoryVisit::HistoryVisit() = default;
+
+HistoryVisit::HistoryVisit(base::Time nav_entry_timestamp,
+ GURL url,
+ int64_t navigation_id) {
+ this->nav_entry_timestamp = nav_entry_timestamp;
+ this->url = url;
+ this->navigation_id = navigation_id;
+}
+
+HistoryVisit::~HistoryVisit() = default;
+HistoryVisit::HistoryVisit(const HistoryVisit&) = default;
+
} // namespace optimization_guide
diff --git a/chromium/components/optimization_guide/content/browser/page_content_annotations_service.h b/chromium/components/optimization_guide/content/browser/page_content_annotations_service.h
index 4e361811f85..7931d37b30e 100644
--- a/chromium/components/optimization_guide/content/browser/page_content_annotations_service.h
+++ b/chromium/components/optimization_guide/content/browser/page_content_annotations_service.h
@@ -9,12 +9,14 @@
#include "base/callback_forward.h"
#include "base/containers/lru_cache.h"
+#include "base/files/file_path.h"
#include "base/hash/hash.h"
#include "base/memory/raw_ptr.h"
#include "base/memory/weak_ptr.h"
#include "base/strings/strcat.h"
#include "base/strings/string_number_conversions.h"
#include "base/task/cancelable_task_tracker.h"
+#include "base/task/sequenced_task_runner.h"
#include "components/continuous_search/browser/search_result_extractor_client.h"
#include "components/continuous_search/browser/search_result_extractor_client_status.h"
#include "components/continuous_search/common/public/mojom/continuous_search.mojom.h"
@@ -26,8 +28,13 @@
#include "components/optimization_guide/core/model_info.h"
#include "components/optimization_guide/core/page_content_annotations_common.h"
#include "components/optimization_guide/machine_learning_tflite_buildflags.h"
+#include "third_party/abseil-cpp/absl/types/optional.h"
#include "url/gurl.h"
+namespace base {
+class OneShotTimer;
+} // namespace base
+
namespace content {
class WebContents;
} // namespace content
@@ -36,8 +43,13 @@ namespace history {
class HistoryService;
} // namespace history
+namespace leveldb_proto {
+class ProtoDatabaseProvider;
+} // namespace leveldb_proto
+
namespace optimization_guide {
+class LocalPageEntitiesMetadataProvider;
class OptimizationGuideModelProvider;
class PageContentAnnotationsModelManager;
class PageContentAnnotationsServiceBrowserTest;
@@ -45,9 +57,15 @@ class PageContentAnnotationsWebContentsObserver;
// The information used by HistoryService to identify a visit to a URL.
struct HistoryVisit {
+ HistoryVisit();
+ HistoryVisit(base::Time nav_entry_timestamp, GURL url, int64_t navigation_id);
+ ~HistoryVisit();
+ HistoryVisit(const HistoryVisit&);
+
base::Time nav_entry_timestamp;
GURL url;
- int64_t navigation_id;
+ int64_t navigation_id = 0;
+ absl::optional<std::string> text_to_annotate;
struct Comp {
bool operator()(const HistoryVisit& lhs, const HistoryVisit& rhs) const {
@@ -58,6 +76,12 @@ struct HistoryVisit {
};
};
+// The information about a search visit to store in HistoryService.
+struct SearchMetadata {
+ GURL normalized_url;
+ std::u16string search_terms;
+};
+
// A KeyedService that annotates page content.
class PageContentAnnotationsService : public KeyedService,
public EntityMetadataProvider {
@@ -65,38 +89,44 @@ class PageContentAnnotationsService : public KeyedService,
PageContentAnnotationsService(
const std::string& application_locale,
OptimizationGuideModelProvider* optimization_guide_model_provider,
- history::HistoryService* history_service);
+ history::HistoryService* history_service,
+ leveldb_proto::ProtoDatabaseProvider* database_provider,
+ const base::FilePath& database_dir,
+ scoped_refptr<base::SequencedTaskRunner> background_task_runner);
~PageContentAnnotationsService() override;
PageContentAnnotationsService(const PageContentAnnotationsService&) = delete;
PageContentAnnotationsService& operator=(
const PageContentAnnotationsService&) = delete;
// This is the main entry point for page content annotations by external
- // callers.
+ // callers. Callers must call |RequestAndNotifyWhenModelAvailable| as close to
+ // session start as possible to allow time for the model file to be
+ // downloaded.
void BatchAnnotate(BatchAnnotationCallback callback,
const std::vector<std::string>& inputs,
AnnotationType annotation_type);
- // Overrides the PageContentAnnotator for testing. See
- // test_page_content_annotator.h for an implementation designed for testing.
- void OverridePageContentAnnotatorForTesting(PageContentAnnotator* annotator);
+ // Requests that the given model for |type| be loaded in the background and
+ // then runs |callback| with true when the model is ready to execute. If the
+ // model is ready now, the callback is run immediately. If the model file will
+ // never be available, the callback is run with false.
+ void RequestAndNotifyWhenModelAvailable(
+ AnnotationType type,
+ base::OnceCallback<void(bool)> callback);
// Returns the model info for the given annotation type, if the model file is
// available.
absl::optional<ModelInfo> GetModelInfoForType(AnnotationType type) const;
- // Runs |callback| with true when the model that powers |BatchAnnotate| for
- // the given annotation type is ready to execute. If the model is ready now,
- // the callback is run immediately. If the model file will never be available,
- // the callback is run with false.
- void NotifyWhenModelAvailable(AnnotationType type,
- base::OnceCallback<void(bool)> callback);
-
// EntityMetadataProvider:
void GetMetadataForEntityId(
const std::string& entity_id,
EntityMetadataRetrievedCallback callback) override;
+ // Overrides the PageContentAnnotator for testing. See
+ // test_page_content_annotator.h for an implementation designed for testing.
+ void OverridePageContentAnnotatorForTesting(PageContentAnnotator* annotator);
+
private:
#if BUILDFLAG(BUILD_WITH_TFLITE_LIB)
// Callback invoked when |visit| has been annotated.
@@ -105,7 +135,20 @@ class PageContentAnnotationsService : public KeyedService,
const absl::optional<history::VisitContentModelAnnotations>&
content_annotations);
+ // Runs the page annotation models available to |model_manager_| on all the
+ // visits within |current_visit_annotation_batch_|.
+ void AnnotateVisitBatch();
+
+ // Callback run after the annotations for a |visit| of a batch has been
+ // determined. |current_visit_annotation_batch_| is updated to remove
+ // the annotated visit and will trigger the next visit to be annotated.
+ void OnBatchVisitAnnotated(
+ const HistoryVisit& visit,
+ const absl::optional<history::VisitContentModelAnnotations>&
+ content_annotations);
+
std::unique_ptr<PageContentAnnotationsModelManager> model_manager_;
+
#endif
// The annotator to use for requests to |BatchAnnotate|. In prod, this is
@@ -123,13 +166,17 @@ class PageContentAnnotationsService : public KeyedService,
friend class PageContentAnnotationsWebContentsObserver;
friend class PageContentAnnotationsServiceBrowserTest;
// Virtualized for testing.
- virtual void Annotate(const HistoryVisit& visit, const std::string& text);
+ virtual void Annotate(const HistoryVisit& visit);
// Creates a HistoryVisit based on the current state of |web_contents|.
static HistoryVisit CreateHistoryVisitFromWebContents(
content::WebContents* web_contents,
int64_t navigation_id);
+ // Persist |search_metadata| for |visit| in |history_service_|.
+ virtual void PersistSearchMetadata(const HistoryVisit& visit,
+ const SearchMetadata& search_metadata);
+
// Requests |search_result_extractor_client_| to extract related searches from
// the Google SRP DOM associated with |web_contents|.
//
@@ -165,6 +212,17 @@ class PageContentAnnotationsService : public KeyedService,
PersistAnnotationsCallback callback,
history::QueryURLResult url_result);
+ // Runs a batch annotation validation, that is calls |BatchAnnotate| with
+ // dummy input and discards the output.
+ void RunBatchAnnotationValidation();
+
+ // A metadata-only provider for page entities (as opposed to |model_manager_|
+ // which does both entity model execution and metadata providing) that uses a
+ // local database to provide the metadata for a given entity id. This is only
+ // non-null and initialized when its feature flag is enabled.
+ std::unique_ptr<LocalPageEntitiesMetadataProvider>
+ local_page_entities_metadata_provider_;
+
// The history service to write content annotations to. Not owned. Guaranteed
// to outlive |this|.
raw_ptr<history::HistoryService> history_service_;
@@ -180,6 +238,23 @@ class PageContentAnnotationsService : public KeyedService,
base::LRUCache<HistoryVisit, bool, HistoryVisit::Comp>
last_annotated_history_visits_;
+ // A LRU cache of the annotation results for visits. If the text of the visit
+ // is in the cache, the cached model annotations will be used.
+ base::HashingLRUCache<std::string, history::VisitContentModelAnnotations>
+ annotated_text_cache_;
+
+ // The set of visits to be annotated, this is added to by Annotate requests
+ // from the web content observer. These will be annotated when the set is full
+ // and annotations can be scheduled with minimal impact to browsing.
+ std::vector<HistoryVisit> visits_to_annotate_;
+
+ // The batch of visits being annotated. If this is empty, it is assumed that
+ // no visits are actively be annotated and a new batch can be started.
+ std::vector<HistoryVisit> current_visit_annotation_batch_;
+
+ // Is only ever set when the feature is enabled.
+ std::unique_ptr<base::OneShotTimer> validation_timer_;
+
base::WeakPtrFactory<PageContentAnnotationsService> weak_ptr_factory_{this};
};
diff --git a/chromium/components/optimization_guide/content/browser/page_content_annotations_web_contents_observer.cc b/chromium/components/optimization_guide/content/browser/page_content_annotations_web_contents_observer.cc
index 8eea9f2744e..d91df2d5163 100644
--- a/chromium/components/optimization_guide/content/browser/page_content_annotations_web_contents_observer.cc
+++ b/chromium/components/optimization_guide/content/browser/page_content_annotations_web_contents_observer.cc
@@ -11,6 +11,7 @@
#include "components/optimization_guide/content/browser/optimization_guide_decider.h"
#include "components/optimization_guide/content/browser/page_content_annotations_service.h"
#include "components/optimization_guide/core/optimization_guide_features.h"
+#include "components/optimization_guide/core/optimization_guide_switches.h"
#include "components/optimization_guide/proto/page_entities_metadata.pb.h"
#include "components/search_engines/template_url_service.h"
#include "content/public/browser/navigation_entry.h"
@@ -21,24 +22,38 @@ namespace optimization_guide {
namespace {
-// Returns the search query if |url| is a valid Search URL according to
+// Returns search metadata if |url| is a valid Search URL according to
// |template_url_service|.
-absl::optional<std::u16string> ExtractSearchTerms(
+absl::optional<SearchMetadata> ExtractSearchMetadata(
const TemplateURLService* template_url_service,
const GURL& url) {
- const TemplateURL* default_search_provider =
- template_url_service->GetDefaultSearchProvider();
+ if (!template_url_service)
+ return absl::nullopt;
+
+ const TemplateURL* template_url =
+ template_url_service->GetTemplateURLForHost(url.host());
const SearchTermsData& search_terms_data =
template_url_service->search_terms_data();
std::u16string search_terms;
- if (default_search_provider &&
- default_search_provider->ExtractSearchTermsFromURL(url, search_terms_data,
- &search_terms) &&
- !search_terms.empty()) {
- return search_terms;
- }
- return absl::nullopt;
+ bool is_valid_search_url = template_url &&
+ template_url->ExtractSearchTermsFromURL(
+ url, search_terms_data, &search_terms) &&
+ !search_terms.empty();
+ if (!is_valid_search_url)
+ return absl::nullopt;
+
+ const std::u16string& normalized_search_query =
+ base::i18n::ToLower(base::CollapseWhitespace(search_terms, false));
+ TemplateURLRef::SearchTermsArgs search_terms_args(normalized_search_query);
+ const TemplateURLRef& search_url_ref = template_url->url_ref();
+ if (!search_url_ref.SupportsReplacement(search_terms_data))
+ return absl::nullopt;
+
+ return SearchMetadata{
+ GURL(search_url_ref.ReplaceSearchTerms(search_terms_args,
+ search_terms_data)),
+ base::i18n::ToLower(base::CollapseWhitespace(search_terms, false))};
}
// Data scoped to a single page. PageData has the same lifetime as the page's
@@ -147,16 +162,24 @@ void PageContentAnnotationsWebContentsObserver::DidFinishNavigation(
web_contents());
}
- absl::optional<std::u16string> search_terms =
- ExtractSearchTerms(template_url_service_, navigation_handle->GetURL());
- if (search_terms) {
+ absl::optional<SearchMetadata> search_metadata = ExtractSearchMetadata(
+ template_url_service_, navigation_handle->GetURL());
+ if (search_metadata) {
if (page_data) {
page_data->set_annotation_was_requested();
}
- const std::u16string& normalized_search_query =
- base::i18n::ToLower(base::CollapseWhitespace(*search_terms, false));
- page_content_annotations_service_->Annotate(
- history_visit, base::UTF16ToUTF8(normalized_search_query));
+ history_visit.text_to_annotate =
+ base::UTF16ToUTF8(search_metadata->search_terms);
+ page_content_annotations_service_->Annotate(history_visit);
+ page_content_annotations_service_->PersistSearchMetadata(
+ history_visit, *search_metadata);
+
+ if (switches::ShouldLogPageContentAnnotationsInput()) {
+ LOG(ERROR) << "Annotating search terms: \n"
+ << "URL: " << navigation_handle->GetURL() << "\n"
+ << "Text: " << *(history_visit.text_to_annotate);
+ }
+
return;
}
}
@@ -168,8 +191,15 @@ void PageContentAnnotationsWebContentsObserver::DidFinishNavigation(
page_data->set_annotation_was_requested();
}
// Annotate the title instead.
- page_content_annotations_service_->Annotate(
- history_visit, base::UTF16ToUTF8(web_contents()->GetTitle()));
+ history_visit.text_to_annotate =
+ base::UTF16ToUTF8(web_contents()->GetTitle());
+ page_content_annotations_service_->Annotate(history_visit);
+
+ if (switches::ShouldLogPageContentAnnotationsInput()) {
+ LOG(ERROR) << "Annotating same document navigation: \n"
+ << "URL: " << navigation_handle->GetURL() << "\n"
+ << "Text: " << *(history_visit.text_to_annotate);
+ }
}
}
@@ -191,8 +221,15 @@ void PageContentAnnotationsWebContentsObserver::TitleWasSet(
optimization_guide::HistoryVisit history_visit = optimization_guide::
PageContentAnnotationsService::CreateHistoryVisitFromWebContents(
web_contents(), page_data->navigation_id());
- page_content_annotations_service_->Annotate(
- history_visit, base::UTF16ToUTF8(entry->GetTitleForDisplay()));
+ history_visit.text_to_annotate =
+ base::UTF16ToUTF8(entry->GetTitleForDisplay());
+ page_content_annotations_service_->Annotate(history_visit);
+
+ if (switches::ShouldLogPageContentAnnotationsInput()) {
+ LOG(ERROR) << "Annotating main frame navigation: \n"
+ << "URL: " << entry->GetURL() << "\n"
+ << "Text: " << *(history_visit.text_to_annotate);
+ }
}
std::unique_ptr<PageTextObserver::ConsumerTextDumpRequest>
@@ -230,7 +267,7 @@ PageContentAnnotationsWebContentsObserver::MaybeRequestFrameTextDump(
}
void PageContentAnnotationsWebContentsObserver::OnTextDumpReceived(
- const HistoryVisit& visit,
+ HistoryVisit visit,
const PageTextDumpResult& result) {
DCHECK(!features::ShouldAnnotateTitleInsteadOfPageContent());
@@ -241,12 +278,22 @@ void PageContentAnnotationsWebContentsObserver::OnTextDumpReceived(
// If the page had AMP frames, then only use that content. Otherwise, use the
// mainframe.
if (result.GetAMPTextContent()) {
- page_content_annotations_service_->Annotate(visit,
- *result.GetAMPTextContent());
+ visit.text_to_annotate = *result.GetAMPTextContent();
+ page_content_annotations_service_->Annotate(visit);
+ if (switches::ShouldLogPageContentAnnotationsInput()) {
+ LOG(ERROR) << "Annotating AMP text content: \n"
+ << "URL: " << visit.url << "\n"
+ << "Text: " << *(visit.text_to_annotate);
+ }
return;
}
- page_content_annotations_service_->Annotate(
- visit, *result.GetMainFrameTextContent());
+ visit.text_to_annotate = *result.GetMainFrameTextContent();
+ page_content_annotations_service_->Annotate(visit);
+ if (switches::ShouldLogPageContentAnnotationsInput()) {
+ LOG(ERROR) << "Annotating main frame text content: \n"
+ << "URL: " << visit.url << "\n"
+ << "Text: " << *(visit.text_to_annotate);
+ }
}
void PageContentAnnotationsWebContentsObserver::OnRemotePageEntitiesReceived(
diff --git a/chromium/components/optimization_guide/content/browser/page_content_annotations_web_contents_observer.h b/chromium/components/optimization_guide/content/browser/page_content_annotations_web_contents_observer.h
index c0417821386..de084693c39 100644
--- a/chromium/components/optimization_guide/content/browser/page_content_annotations_web_contents_observer.h
+++ b/chromium/components/optimization_guide/content/browser/page_content_annotations_web_contents_observer.h
@@ -62,8 +62,7 @@ class PageContentAnnotationsWebContentsObserver
content::NavigationHandle* navigation_handle) override;
// Callback invoked when a text dump has been received for the |visit|.
- void OnTextDumpReceived(const HistoryVisit& visit,
- const PageTextDumpResult& result);
+ void OnTextDumpReceived(HistoryVisit visit, const PageTextDumpResult& result);
// Callback invoked when the page entities have been received from
// |optimization_guide_decider_| for |visit|.
diff --git a/chromium/components/optimization_guide/content/browser/page_content_annotations_web_contents_observer_unittest.cc b/chromium/components/optimization_guide/content/browser/page_content_annotations_web_contents_observer_unittest.cc
index c08c20f7276..49397fca3e7 100644
--- a/chromium/components/optimization_guide/content/browser/page_content_annotations_web_contents_observer_unittest.cc
+++ b/chromium/components/optimization_guide/content/browser/page_content_annotations_web_contents_observer_unittest.cc
@@ -62,11 +62,14 @@ class FakePageContentAnnotationsService : public PageContentAnnotationsService {
history::HistoryService* history_service)
: PageContentAnnotationsService("en-US",
optimization_guide_model_provider,
- history_service) {}
+ history_service,
+ nullptr,
+ base::FilePath(),
+ nullptr) {}
~FakePageContentAnnotationsService() override = default;
- void Annotate(const HistoryVisit& visit, const std::string& text) override {
- last_annotation_request_.emplace(std::make_pair(visit, text));
+ void Annotate(const HistoryVisit& visit) override {
+ last_annotation_request_.emplace(visit);
}
void ExtractRelatedSearches(const HistoryVisit& visit,
@@ -75,8 +78,7 @@ class FakePageContentAnnotationsService : public PageContentAnnotationsService {
std::make_pair(visit, web_contents));
}
- absl::optional<std::pair<HistoryVisit, std::string>> last_annotation_request()
- const {
+ absl::optional<HistoryVisit> last_annotation_request() const {
return last_annotation_request_;
}
@@ -103,14 +105,24 @@ class FakePageContentAnnotationsService : public PageContentAnnotationsService {
return last_entities_persistence_request_;
}
+ void PersistSearchMetadata(const HistoryVisit& visit,
+ const SearchMetadata& search_metadata) override {
+ last_search_metadata_ = search_metadata;
+ }
+
+ absl::optional<SearchMetadata> last_search_metadata_persisted() const {
+ return last_search_metadata_;
+ }
+
private:
- absl::optional<std::pair<HistoryVisit, std::string>> last_annotation_request_;
+ absl::optional<HistoryVisit> last_annotation_request_;
absl::optional<std::pair<HistoryVisit, content::WebContents*>>
last_related_searches_extraction_request_;
absl::optional<
std::pair<HistoryVisit,
std::vector<history::VisitContentModelAnnotations::Category>>>
last_entities_persistence_request_;
+ absl::optional<SearchMetadata> last_search_metadata_;
};
class FakeOptimizationGuideDecider : public TestOptimizationGuideDecider {
@@ -322,11 +334,11 @@ TEST_F(PageContentAnnotationsWebContentsObserverTest,
result.AddFrameTextDumpResult(frame_result);
std::move(request->callback).Run(std::move(result));
- absl::optional<std::pair<HistoryVisit, std::string>> last_annotation_request =
+ absl::optional<HistoryVisit> last_annotation_request =
service()->last_annotation_request();
EXPECT_TRUE(last_annotation_request.has_value());
- EXPECT_EQ(last_annotation_request->first.url, GURL("http://test.com"));
- EXPECT_EQ(last_annotation_request->second, "some text");
+ EXPECT_EQ(last_annotation_request->url, GURL("http://test.com"));
+ EXPECT_EQ(last_annotation_request->text_to_annotate, "some text");
service()->ClearLastAnnotationRequest();
@@ -355,11 +367,11 @@ TEST_F(PageContentAnnotationsWebContentsObserverTest,
navigation_simulator->CommitSameDocument();
// The title should be what is requested to be annotated.
- absl::optional<std::pair<HistoryVisit, std::string>> last_annotation_request =
+ absl::optional<HistoryVisit> last_annotation_request =
service()->last_annotation_request();
EXPECT_TRUE(last_annotation_request.has_value());
- EXPECT_EQ(last_annotation_request->first.url, url2);
- EXPECT_EQ(last_annotation_request->second, "Title");
+ EXPECT_EQ(last_annotation_request->url, url2);
+ EXPECT_EQ(last_annotation_request->text_to_annotate, "Title");
}
TEST_F(PageContentAnnotationsWebContentsObserverTest,
@@ -369,12 +381,19 @@ TEST_F(PageContentAnnotationsWebContentsObserverTest,
web_contents(), GURL("http://default-engine.com/search?q=a"));
// The search query should be what is requested to be annotated.
- absl::optional<std::pair<HistoryVisit, std::string>> last_annotation_request =
+ absl::optional<HistoryVisit> last_annotation_request =
service()->last_annotation_request();
ASSERT_TRUE(last_annotation_request.has_value());
- EXPECT_EQ(last_annotation_request->first.url,
+ EXPECT_EQ(last_annotation_request->url,
+ GURL("http://default-engine.com/search?q=a"));
+ EXPECT_EQ(last_annotation_request->text_to_annotate, "a");
+
+ absl::optional<SearchMetadata> last_search_metadata_persisted =
+ service()->last_search_metadata_persisted();
+ ASSERT_TRUE(last_search_metadata_persisted.has_value());
+ EXPECT_EQ(last_search_metadata_persisted->normalized_url,
GURL("http://default-engine.com/search?q=a"));
- EXPECT_EQ(last_annotation_request->second, "a");
+ EXPECT_EQ(last_search_metadata_persisted->search_terms, u"a");
}
TEST_F(PageContentAnnotationsWebContentsObserverTest,
@@ -460,11 +479,11 @@ TEST_F(PageContentAnnotationsWebContentsObserverAnnotateTitleTest,
navigation_simulator->CommitSameDocument();
// The title should be what is requested to be annotated.
- absl::optional<std::pair<HistoryVisit, std::string>> last_annotation_request =
+ absl::optional<HistoryVisit> last_annotation_request =
service()->last_annotation_request();
EXPECT_TRUE(last_annotation_request.has_value());
- EXPECT_EQ(last_annotation_request->first.url, url2);
- EXPECT_EQ(last_annotation_request->second, "Title");
+ EXPECT_EQ(last_annotation_request->url, url2);
+ EXPECT_EQ(last_annotation_request->text_to_annotate, "Title");
service()->ClearLastAnnotationRequest();
@@ -489,12 +508,11 @@ TEST_F(PageContentAnnotationsWebContentsObserverAnnotateTitleTest,
title);
// The title should be what is requested to be annotated.
- absl::optional<std::pair<HistoryVisit, std::string>> last_annotation_request =
+ absl::optional<HistoryVisit> last_annotation_request =
service()->last_annotation_request();
EXPECT_TRUE(last_annotation_request.has_value());
- EXPECT_EQ(last_annotation_request->first.url,
- GURL("http://www.foo.com/someurl"));
- EXPECT_EQ(last_annotation_request->second, "Title");
+ EXPECT_EQ(last_annotation_request->url, GURL("http://www.foo.com/someurl"));
+ EXPECT_EQ(last_annotation_request->text_to_annotate, "Title");
service()->ClearLastAnnotationRequest();
diff --git a/chromium/components/optimization_guide/content/browser/test_page_content_annotator.cc b/chromium/components/optimization_guide/content/browser/test_page_content_annotator.cc
index 5df206b2994..f1c12bed097 100644
--- a/chromium/components/optimization_guide/content/browser/test_page_content_annotator.cc
+++ b/chromium/components/optimization_guide/content/browser/test_page_content_annotator.cc
@@ -17,7 +17,7 @@ void TestPageContentAnnotator::Annotate(BatchAnnotationCallback callback,
if (annotation_type == AnnotationType::kPageTopics) {
for (const std::string& input : inputs) {
auto it = topics_by_input_.find(input);
- absl::optional<std::vector<WeightedString>> output;
+ absl::optional<std::vector<WeightedIdentifier>> output;
if (it != topics_by_input_.end()) {
output = it->second;
}
@@ -54,7 +54,7 @@ void TestPageContentAnnotator::Annotate(BatchAnnotationCallback callback,
}
void TestPageContentAnnotator::UsePageTopics(
- const base::flat_map<std::string, std::vector<WeightedString>>&
+ const base::flat_map<std::string, std::vector<WeightedIdentifier>>&
topics_by_input) {
topics_by_input_ = topics_by_input;
}
diff --git a/chromium/components/optimization_guide/content/browser/test_page_content_annotator.h b/chromium/components/optimization_guide/content/browser/test_page_content_annotator.h
index dc52ac14b03..a38e1ccd5c4 100644
--- a/chromium/components/optimization_guide/content/browser/test_page_content_annotator.h
+++ b/chromium/components/optimization_guide/content/browser/test_page_content_annotator.h
@@ -23,7 +23,7 @@ class TestPageContentAnnotator : public PageContentAnnotator {
// The given page topics are used for the matching BatchAnnotationResults by
// input string. If the input is not found, the output is left as nullopt.
void UsePageTopics(
- const base::flat_map<std::string, std::vector<WeightedString>>&
+ const base::flat_map<std::string, std::vector<WeightedIdentifier>>&
topics_by_input);
// The given page entities are used for the matching BatchAnnotationResults by
@@ -43,7 +43,7 @@ class TestPageContentAnnotator : public PageContentAnnotator {
AnnotationType annotation_type) override;
private:
- base::flat_map<std::string, std::vector<WeightedString>> topics_by_input_;
+ base::flat_map<std::string, std::vector<WeightedIdentifier>> topics_by_input_;
base::flat_map<std::string, std::vector<ScoredEntityMetadata>>
entities_by_input_;
base::flat_map<std::string, double> visibility_scores_for_input_;
diff --git a/chromium/components/optimization_guide/core/BUILD.gn b/chromium/components/optimization_guide/core/BUILD.gn
index e66d20fc9ba..90eaa54e934 100644
--- a/chromium/components/optimization_guide/core/BUILD.gn
+++ b/chromium/components/optimization_guide/core/BUILD.gn
@@ -1,5 +1,5 @@
-# Copyright 2017 The Chromium Authors. All rights reserved.
-# Use of this source code is governed by a BSD-style license that can be
+# Copyright 2017 The Chromium Authors.All rights reserved.
+# Use of this source code is governed by a BSD - style license that can be
# found in the LICENSE file.
if (is_android) {
@@ -32,14 +32,66 @@ static_library("entities") {
]
}
+static_library("model_executor") {
+ sources = [
+ "execution_status.cc",
+ "execution_status.h",
+ "model_enums.h",
+ "model_executor.h",
+ "model_info.cc",
+ "model_info.h",
+ "model_util.cc",
+ "model_util.h",
+ ]
+ if (build_with_tflite_lib) {
+ sources += [
+ "base_model_executor.h",
+ "base_model_executor_helpers.h",
+ "bert_model_executor.cc",
+ "bert_model_executor.h",
+ "tflite_model_executor.h",
+ ]
+ }
+
+ public_deps = [
+ "//components/optimization_guide:machine_learning_tflite_buildflags",
+ "//third_party/re2",
+ ]
+ if (build_with_tflite_lib) {
+ public_deps += [
+ "//components/optimization_guide/core:machine_learning",
+ "//third_party/abseil-cpp:absl",
+ "//third_party/tflite",
+ "//third_party/tflite:tflite_public_headers",
+ "//third_party/tflite_support",
+ "//third_party/tflite_support:tflite_support_proto",
+ ]
+ }
+ deps = [
+ "//base",
+ "//components/optimization_guide/proto:optimization_guide_proto",
+ "//net",
+ "//url",
+ ]
+}
+
+if (build_with_tflite_lib) {
+ static_library("machine_learning") {
+ sources = [
+ "tflite_op_resolver.cc",
+ "tflite_op_resolver.h",
+ ]
+ deps = [
+ "//third_party/tflite",
+ "//third_party/tflite:tflite_public_headers",
+ ]
+ }
+}
+
static_library("core") {
sources = [
"command_line_top_host_provider.cc",
"command_line_top_host_provider.h",
- "decision_tree_prediction_model.cc",
- "decision_tree_prediction_model.h",
- "execution_status.cc",
- "execution_status.h",
"hint_cache.cc",
"hint_cache.h",
"hints_component_info.h",
@@ -54,12 +106,11 @@ static_library("core") {
"hints_processing_util.cc",
"hints_processing_util.h",
"insertion_ordered_set.h",
+ "local_page_entities_metadata_provider.cc",
+ "local_page_entities_metadata_provider.h",
"memory_hint.cc",
"memory_hint.h",
- "model_executor.h",
"model_handler.h",
- "model_info.cc",
- "model_info.h",
"noisy_metrics_recorder.cc",
"noisy_metrics_recorder.h",
"optimization_filter.cc",
@@ -70,6 +121,8 @@ static_library("core") {
"optimization_guide_enums.h",
"optimization_guide_features.cc",
"optimization_guide_features.h",
+ "optimization_guide_logger.cc",
+ "optimization_guide_logger.h",
"optimization_guide_model_provider.h",
"optimization_guide_navigation_data.cc",
"optimization_guide_navigation_data.h",
@@ -93,8 +146,6 @@ static_library("core") {
"page_content_annotation_job.h",
"page_content_annotations_common.cc",
"page_content_annotations_common.h",
- "prediction_model.cc",
- "prediction_model.h",
"prediction_model_fetcher.h",
"prediction_model_fetcher_impl.cc",
"prediction_model_fetcher_impl.h",
@@ -108,10 +159,6 @@ static_library("core") {
]
if (build_with_tflite_lib) {
sources += [
- "base_model_executor.h",
- "base_model_executor_helpers.h",
- "bert_model_executor.cc",
- "bert_model_executor.h",
"bert_model_handler.cc",
"bert_model_handler.h",
"model_validator.cc",
@@ -123,12 +170,20 @@ static_library("core") {
"page_topics_model_executor.h",
"page_visibility_model_executor.cc",
"page_visibility_model_executor.h",
- "tflite_model_executor.h",
]
+ if (build_with_internal_optimization_guide) {
+ sources += [
+ "entity_annotator_native_library.cc",
+ "entity_annotator_native_library.h",
+ "page_entities_model_executor_impl.cc",
+ "page_entities_model_executor_impl.h",
+ ]
+ }
}
public_deps = [
":entities",
+ ":model_executor",
"//components/optimization_guide:machine_learning_tflite_buildflags",
"//third_party/re2",
]
@@ -146,7 +201,6 @@ static_library("core") {
deps = [
":bloomfilter",
"//base",
- "//components/data_reduction_proxy/core/browser",
"//components/leveldb_proto",
"//components/optimization_guide/proto:optimization_guide_proto",
"//components/prefs",
@@ -161,17 +215,9 @@ static_library("core") {
"//ui/base:base",
"//url:url",
]
-}
-
-if (build_with_tflite_lib) {
- static_library("machine_learning") {
- sources = [
- "tflite_op_resolver.cc",
- "tflite_op_resolver.h",
- ]
- deps = [
- "//third_party/tflite",
- "//third_party/tflite:tflite_public_headers",
+ if (build_with_tflite_lib && build_with_internal_optimization_guide) {
+ data_deps = [
+ "//components/optimization_guide/internal:optimization_guide_internal",
]
}
}
@@ -244,13 +290,13 @@ source_set("unit_tests") {
"batch_entity_metadata_task_unittest.cc",
"bloom_filter_unittest.cc",
"command_line_top_host_provider_unittest.cc",
- "decision_tree_prediction_model_unittest.cc",
"hint_cache_unittest.cc",
"hints_component_util_unittest.cc",
"hints_fetcher_unittest.cc",
"hints_manager_unittest.cc",
"hints_processing_util_unittest.cc",
"insertion_ordered_set_unittest.cc",
+ "local_page_entities_metadata_provider_unittest.cc",
"model_handler_unittest.cc",
"noisy_metrics_recorder_unittest.cc",
"optimization_filter_unittest.cc",
@@ -264,7 +310,6 @@ source_set("unit_tests") {
"optimization_metadata_unittest.cc",
"page_content_annotation_job_unittest.cc",
"prediction_model_fetcher_unittest.cc",
- "prediction_model_unittest.cc",
"store_update_data_unittest.cc",
"url_pattern_with_wildcards_unittest.cc",
]
@@ -277,6 +322,12 @@ source_set("unit_tests") {
"page_visibility_model_executor_unittest.cc",
"tflite_model_executor_unittest.cc",
]
+ if (build_with_internal_optimization_guide) {
+ sources += [
+ "entity_annotator_native_library_unittest.cc",
+ "page_entities_model_executor_impl_unittest.cc",
+ ]
+ }
}
deps = [
@@ -286,8 +337,6 @@ source_set("unit_tests") {
":test_support",
"//base",
"//base/test:test_support",
- "//components/data_reduction_proxy/core/browser",
- "//components/data_reduction_proxy/core/common",
"//components/leveldb_proto:test_support",
"//components/optimization_guide/proto:optimization_guide_proto",
"//components/prefs:test_support",
@@ -319,3 +368,17 @@ if (is_android) {
visibility = [ "//chrome/browser/optimization_guide/android:*" ]
}
}
+
+if (is_mac && build_with_internal_optimization_guide) {
+ # We need to copy the optimization guide shared library so that the
+ # bundle_data dependencies have a "copy" target type.Otherwise for
+ # "shared_library" target types it will try to link things into
+ # Chromium Framework when we want to keep it separate instead.
+ copy("optimization_guide_internal_library_copy") {
+ sources = [ "$root_out_dir/liboptimization_guide_internal.dylib" ]
+ outputs = [ "$root_out_dir/og_intermediates/{{source_file_part}}" ]
+ deps = [
+ "//components/optimization_guide/internal:optimization_guide_internal",
+ ]
+ }
+}
diff --git a/chromium/components/optimization_guide/core/DEPS b/chromium/components/optimization_guide/core/DEPS
index 3ab01bf2788..74dc5dfb565 100644
--- a/chromium/components/optimization_guide/core/DEPS
+++ b/chromium/components/optimization_guide/core/DEPS
@@ -1,7 +1,5 @@
include_rules = [
"+components/ukm/test_ukm_recorder.h",
"+services/metrics/public/cpp",
- "+third_party/abseil-cpp/absl/types/optional.h",
- "+third_party/abseil-cpp/absl/status/status.h",
"+ui/base/l10n",
]
diff --git a/chromium/components/optimization_guide/core/base_model_executor.h b/chromium/components/optimization_guide/core/base_model_executor.h
index 76bd8b96a46..78e064b4256 100644
--- a/chromium/components/optimization_guide/core/base_model_executor.h
+++ b/chromium/components/optimization_guide/core/base_model_executor.h
@@ -69,9 +69,9 @@ class BaseModelExecutor : public TFLiteModelExecutor<OutputType, InputTypes...>,
}
// InferenceDelegate:
- absl::Status Preprocess(const std::vector<TfLiteTensor*>& input_tensors,
- InputTypes... input) override = 0;
- OutputType Postprocess(
+ bool Preprocess(const std::vector<TfLiteTensor*>& input_tensors,
+ InputTypes... input) override = 0;
+ absl::optional<OutputType> Postprocess(
const std::vector<const TfLiteTensor*>& output_tensors) override = 0;
};
diff --git a/chromium/components/optimization_guide/core/base_model_executor_helpers.h b/chromium/components/optimization_guide/core/base_model_executor_helpers.h
index 0b9717eee2c..986ce04589d 100644
--- a/chromium/components/optimization_guide/core/base_model_executor_helpers.h
+++ b/chromium/components/optimization_guide/core/base_model_executor_helpers.h
@@ -11,6 +11,7 @@
#include "base/check.h"
#include "base/memory/raw_ptr.h"
#include "components/optimization_guide/core/execution_status.h"
+#include "third_party/abseil-cpp/absl/types/optional.h"
#include "third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/base_task_api.h"
namespace optimization_guide {
@@ -18,13 +19,13 @@ namespace optimization_guide {
template <class OutputType, class... InputTypes>
class InferenceDelegate {
public:
- // Preprocesses |args| into |input_tensors|.
- virtual absl::Status Preprocess(
- const std::vector<TfLiteTensor*>& input_tensors,
- InputTypes... args) = 0;
+ // Preprocesses |args| into |input_tensors|. Returns true on success.
+ virtual bool Preprocess(const std::vector<TfLiteTensor*>& input_tensors,
+ InputTypes... args) = 0;
- // Postprocesses |output_tensors| into the desired |OutputType|.
- virtual OutputType Postprocess(
+ // Postprocesses |output_tensors| into the desired |OutputType|, returning
+ // absl::nullopt on error.
+ virtual absl::optional<OutputType> Postprocess(
const std::vector<const TfLiteTensor*>& output_tensors) = 0;
};
@@ -59,12 +60,24 @@ class GenericModelExecutionTask
// BaseTaskApi:
absl::Status Preprocess(const std::vector<TfLiteTensor*>& input_tensors,
InputTypes... args) override {
- return delegate_->Preprocess(input_tensors, args...);
+ bool success = delegate_->Preprocess(input_tensors, args...);
+ if (success) {
+ return absl::OkStatus();
+ }
+ return absl::InternalError(
+ "error during preprocessing. See stderr for more information if "
+ "available");
}
tflite::support::StatusOr<OutputType> Postprocess(
const std::vector<const TfLiteTensor*>& output_tensors,
InputTypes... api_inputs) override {
- return delegate_->Postprocess(output_tensors);
+ absl::optional<OutputType> output = delegate_->Postprocess(output_tensors);
+ if (!output) {
+ return absl::InternalError(
+ "error during postprocessing. See stderr for more infomation if "
+ "available");
+ }
+ return *output;
}
private:
diff --git a/chromium/components/optimization_guide/core/bert_model_executor.cc b/chromium/components/optimization_guide/core/bert_model_executor.cc
index ab2b7306726..fcd4df0e96d 100644
--- a/chromium/components/optimization_guide/core/bert_model_executor.cc
+++ b/chromium/components/optimization_guide/core/bert_model_executor.cc
@@ -5,9 +5,9 @@
#include "components/optimization_guide/core/bert_model_executor.h"
#include "base/trace_event/trace_event.h"
-#include "components/optimization_guide/core/optimization_guide_util.h"
+#include "components/optimization_guide/core/model_util.h"
#include "components/optimization_guide/core/tflite_op_resolver.h"
-#include "third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h"
+#include "third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.h"
namespace optimization_guide {
@@ -28,18 +28,21 @@ BertModelExecutor::Execute(ModelExecutionTask* execution_task,
GetStringNameForOptimizationTarget(optimization_target_),
"input_length", input.size());
*out_status = ExecutionStatus::kSuccess;
- return static_cast<tflite::task::text::nlclassifier::BertNLClassifier*>(
- execution_task)
+ return static_cast<tflite::task::text::BertNLClassifier*>(execution_task)
->Classify(input);
}
std::unique_ptr<BertModelExecutor::ModelExecutionTask>
BertModelExecutor::BuildModelExecutionTask(base::MemoryMappedFile* model_file,
ExecutionStatus* out_status) {
+ tflite::task::text::BertNLClassifierOptions options;
+ *options.mutable_base_options()
+ ->mutable_model_file()
+ ->mutable_file_content() = std::string(
+ reinterpret_cast<const char*>(model_file->data()), model_file->length());
auto maybe_nl_classifier =
- tflite::task::text::nlclassifier::BertNLClassifier::CreateFromBuffer(
- reinterpret_cast<const char*>(model_file->data()),
- model_file->length(), std::make_unique<TFLiteOpResolver>());
+ tflite::task::text::BertNLClassifier::CreateFromOptions(
+ std::move(options), std::make_unique<TFLiteOpResolver>());
if (maybe_nl_classifier.ok())
return std::move(maybe_nl_classifier.value());
*out_status = ExecutionStatus::kErrorModelFileNotValid;
diff --git a/chromium/components/optimization_guide/core/bloom_filter_unittest.cc b/chromium/components/optimization_guide/core/bloom_filter_unittest.cc
index a2551a522fe..e0cf1884b67 100644
--- a/chromium/components/optimization_guide/core/bloom_filter_unittest.cc
+++ b/chromium/components/optimization_guide/core/bloom_filter_unittest.cc
@@ -103,7 +103,7 @@ TEST(BloomFilterTest, EverythingMatches) {
}
// Disable this test in configurations that don't print CHECK failures.
-#if !defined(OS_IOS) && !(defined(OFFICIAL_BUILD) && defined(NDEBUG))
+#if !BUILDFLAG(IS_IOS) && !(defined(OFFICIAL_BUILD) && defined(NDEBUG))
TEST(BloomFilterTest, ByteVectorTooSmall) {
std::string data(1023, 0xff);
EXPECT_DEATH(
diff --git a/chromium/components/optimization_guide/core/decision_tree_prediction_model.cc b/chromium/components/optimization_guide/core/decision_tree_prediction_model.cc
deleted file mode 100644
index f8f5906c072..00000000000
--- a/chromium/components/optimization_guide/core/decision_tree_prediction_model.cc
+++ /dev/null
@@ -1,237 +0,0 @@
-// Copyright 2020 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#include "components/optimization_guide/core/decision_tree_prediction_model.h"
-
-#include <utility>
-
-namespace optimization_guide {
-
-DecisionTreePredictionModel::DecisionTreePredictionModel(
- const proto::PredictionModel& prediction_model)
- : PredictionModel(prediction_model) {}
-
-DecisionTreePredictionModel::~DecisionTreePredictionModel() = default;
-
-bool DecisionTreePredictionModel::ValidatePredictionModel() const {
- // Only the top-level ensemble or decision tree must have a threshold. Any
- // submodels of an ensemble will have model weights but no threshold.
- if (!model_.has_threshold())
- return false;
- return ValidateModel(model_);
-}
-
-bool DecisionTreePredictionModel::ValidateModel(
- const proto::Model& model) const {
- if (model.has_ensemble()) {
- return ValidateEnsembleModel(model.ensemble());
- }
- if (model.has_decision_tree()) {
- return ValidateDecisionTree(model.decision_tree());
- }
- return false;
-}
-
-bool DecisionTreePredictionModel::ValidateEnsembleModel(
- const proto::Ensemble& ensemble) const {
- if (ensemble.members_size() == 0)
- return false;
-
- for (const auto& member : ensemble.members()) {
- if (!ValidateModel(member.submodel())) {
- return false;
- }
- }
- return true;
-}
-
-bool DecisionTreePredictionModel::ValidateDecisionTree(
- const proto::DecisionTree& tree) const {
- if (tree.nodes_size() == 0)
- return false;
- return ValidateTreeNode(tree, tree.nodes(0), 0);
-}
-
-bool DecisionTreePredictionModel::ValidateLeaf(const proto::Leaf& leaf) const {
- return leaf.has_vector() && leaf.vector().value_size() == 1 &&
- leaf.vector().value(0).has_double_value();
-}
-
-bool DecisionTreePredictionModel::ValidateInequalityTest(
- const proto::InequalityTest& inequality_test) const {
- if (!inequality_test.has_threshold())
- return false;
- if (!inequality_test.threshold().has_float_value())
- return false;
- if (!inequality_test.has_feature_id())
- return false;
- if (!inequality_test.feature_id().has_id())
- return false;
- if (!inequality_test.has_type())
- return false;
- return true;
-}
-
-bool DecisionTreePredictionModel::ValidateTreeNode(
- const proto::DecisionTree& tree,
- const proto::TreeNode& node,
- int node_index) const {
- if (node.has_leaf())
- return ValidateLeaf(node.leaf());
-
- if (!node.has_binary_node())
- return false;
-
- proto::BinaryNode binary_node = node.binary_node();
- if (!binary_node.has_inequality_left_child_test())
- return false;
-
- if (!ValidateInequalityTest(binary_node.inequality_left_child_test()))
- return false;
-
- if (!binary_node.left_child_id().has_value())
- return false;
- if (!binary_node.right_child_id().has_value())
- return false;
-
- if (binary_node.left_child_id().value() >= tree.nodes_size())
- return false;
- if (binary_node.right_child_id().value() >= tree.nodes_size())
- return false;
-
- // Assure that no parent has an child index less than itself in order to
- // prevent loops.
- if (node_index >= binary_node.left_child_id().value())
- return false;
- if (node_index >= binary_node.right_child_id().value())
- return false;
-
- if (!ValidateTreeNode(tree, tree.nodes(binary_node.left_child_id().value()),
- binary_node.left_child_id().value())) {
- return false;
- }
- if (!ValidateTreeNode(tree, tree.nodes(binary_node.right_child_id().value()),
- binary_node.right_child_id().value())) {
- return false;
- }
- return true;
-}
-
-OptimizationTargetDecision DecisionTreePredictionModel::Predict(
- const base::flat_map<std::string, float>& model_features,
- double* prediction_score) {
- *prediction_score = 0.0;
- // TODO(mcrouse): Add metrics to record if the model evaluation fails.
- if (!EvaluateModel(model_, model_features, prediction_score))
- return OptimizationTargetDecision::kUnknown;
- if (*prediction_score > model_.threshold().value())
- return OptimizationTargetDecision::kPageLoadMatches;
- return OptimizationTargetDecision::kPageLoadDoesNotMatch;
-}
-
-bool DecisionTreePredictionModel::TraverseTree(
- const proto::DecisionTree& tree,
- const proto::TreeNode& node,
- const base::flat_map<std::string, float>& model_features,
- double* result) {
- if (node.has_leaf()) {
- *result = node.leaf().vector().value(0).double_value();
- return true;
- }
-
- proto::BinaryNode binary_node = node.binary_node();
- float threshold =
- binary_node.inequality_left_child_test().threshold().float_value();
- std::string feature_name =
- binary_node.inequality_left_child_test().feature_id().id().value();
- auto it = model_features.find(feature_name);
- if (it == model_features.end())
- return false;
- switch (binary_node.inequality_left_child_test().type()) {
- case proto::InequalityTest::LESS_OR_EQUAL:
- if (it->second <= threshold)
- return TraverseTree(tree,
- tree.nodes(binary_node.left_child_id().value()),
- model_features, result);
- return TraverseTree(tree,
- tree.nodes(binary_node.right_child_id().value()),
- model_features, result);
- case proto::InequalityTest::LESS_THAN:
- if (it->second < threshold)
- return TraverseTree(tree,
- tree.nodes(binary_node.left_child_id().value()),
- model_features, result);
- return TraverseTree(tree,
- tree.nodes(binary_node.right_child_id().value()),
- model_features, result);
- case proto::InequalityTest::GREATER_OR_EQUAL:
- if (it->second >= threshold)
- return TraverseTree(tree,
- tree.nodes(binary_node.left_child_id().value()),
- model_features, result);
- return TraverseTree(tree,
- tree.nodes(binary_node.right_child_id().value()),
- model_features, result);
- case proto::InequalityTest::GREATER_THAN:
- if (it->second > threshold)
- return TraverseTree(tree,
- tree.nodes(binary_node.left_child_id().value()),
- model_features, result);
- return TraverseTree(tree,
- tree.nodes(binary_node.right_child_id().value()),
- model_features, result);
- default:
- return false;
- }
-}
-
-bool DecisionTreePredictionModel::EvaluateDecisionTree(
- const proto::DecisionTree& tree,
- const base::flat_map<std::string, float>& model_features,
- double* result) {
- if (TraverseTree(tree, tree.nodes(0), model_features, result)) {
- *result *= tree.weight();
- return true;
- }
- return false;
-}
-
-bool DecisionTreePredictionModel::EvaluateEnsembleModel(
- const proto::Ensemble& ensemble,
- const base::flat_map<std::string, float>& model_features,
- double* result) {
- if (ensemble.members_size() == 0)
- return false;
-
- double score = 0.0;
- for (const auto& member : ensemble.members()) {
- if (!EvaluateModel(member.submodel(), model_features, &score)) {
- *result = 0.0;
- return false;
- }
-
- *result += score;
- }
- *result = *result / ensemble.members_size();
- return true;
-}
-
-bool DecisionTreePredictionModel::EvaluateModel(
- const proto::Model& model,
- const base::flat_map<std::string, float>& model_features,
- double* result) {
- DCHECK(result);
- // Clear the result value.
- *result = 0.0;
-
- if (model.has_ensemble()) {
- return EvaluateEnsembleModel(model.ensemble(), model_features, result);
- }
- if (model.has_decision_tree()) {
- return EvaluateDecisionTree(model.decision_tree(), model_features, result);
- }
- return false;
-}
-
-} // namespace optimization_guide
diff --git a/chromium/components/optimization_guide/core/decision_tree_prediction_model.h b/chromium/components/optimization_guide/core/decision_tree_prediction_model.h
deleted file mode 100644
index a31ef93d1cd..00000000000
--- a/chromium/components/optimization_guide/core/decision_tree_prediction_model.h
+++ /dev/null
@@ -1,97 +0,0 @@
-// Copyright 2020 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#ifndef COMPONENTS_OPTIMIZATION_GUIDE_CORE_DECISION_TREE_PREDICTION_MODEL_H_
-#define COMPONENTS_OPTIMIZATION_GUIDE_CORE_DECISION_TREE_PREDICTION_MODEL_H_
-
-#include <memory>
-#include <string>
-
-#include "base/containers/flat_map.h"
-#include "components/optimization_guide/core/prediction_model.h"
-#include "components/optimization_guide/proto/models.pb.h"
-
-namespace optimization_guide {
-
-// A concrete PredictionModel capable of evaluating the decision tree model type
-// supported by the optimization guide.
-class DecisionTreePredictionModel : public PredictionModel {
- public:
- explicit DecisionTreePredictionModel(
- const proto::PredictionModel& prediction_model);
-
- DecisionTreePredictionModel(const DecisionTreePredictionModel&) = delete;
- DecisionTreePredictionModel& operator=(const DecisionTreePredictionModel&) =
- delete;
-
- ~DecisionTreePredictionModel() override;
-
- // PredictionModel implementation:
- OptimizationTargetDecision Predict(
- const base::flat_map<std::string, float>& model_features,
- double* prediction_score) override;
-
- private:
- // Evaluates the provided model, either an ensemble or decision tree model,
- // with the |model_features| and stores the output in |result|. Returns false
- // if evaluation fails.
- bool EvaluateModel(const proto::Model& model,
- const base::flat_map<std::string, float>& model_features,
- double* result);
-
- // Evaluates the decision tree model with the |model_features| and
- // stores the output in |result|. Returns false if the evaluation fails.
- bool EvaluateDecisionTree(
- const proto::DecisionTree& tree,
- const base::flat_map<std::string, float>& model_features,
- double* result);
-
- // Evaluates an ensemble model with the |model_features| and
- // stores the output in |result|. Returns false if the evaluation fails.
- bool EvaluateEnsembleModel(
- const proto::Ensemble& ensemble,
- const base::flat_map<std::string, float>& model_features,
- double* result);
-
- // Performs a depth first traversal the |tree| based on |model_features|
- // and stores the value of the leaf in |result|. Returns false if the
- // traversal or node evaluation fails.
- bool TraverseTree(const proto::DecisionTree& tree,
- const proto::TreeNode& node,
- const base::flat_map<std::string, float>& model_features,
- double* result);
-
- // PredictionModel implementation:
- bool ValidatePredictionModel() const override;
-
- // Validates a model or submodel of an ensemble. Returns
- // false if the model is invalid.
- bool ValidateModel(const proto::Model& model) const;
-
- // Validates an ensemble model. Returns false if the ensemble
- // if invalid.
- bool ValidateEnsembleModel(const proto::Ensemble& ensemble) const;
-
- // Validates a decision tree model. Returns false if the
- // decision tree model is invalid.
- bool ValidateDecisionTree(const proto::DecisionTree& tree) const;
-
- // Validates a leaf. Returns false if the leaf is invalid.
- bool ValidateLeaf(const proto::Leaf& leaf) const;
-
- // Validates an inequality test. Returns false if the
- // inequality test is invalid.
- bool ValidateInequalityTest(
- const proto::InequalityTest& inequality_test) const;
-
- // Validates each node of a decision tree by traversing every
- // node of the |tree|. Returns false if any part of the tree is invalid.
- bool ValidateTreeNode(const proto::DecisionTree& tree,
- const proto::TreeNode& node,
- int node_index) const;
-};
-
-} // namespace optimization_guide
-
-#endif // COMPONENTS_OPTIMIZATION_GUIDE_CORE_DECISION_TREE_PREDICTION_MODEL_H_
diff --git a/chromium/components/optimization_guide/core/decision_tree_prediction_model_unittest.cc b/chromium/components/optimization_guide/core/decision_tree_prediction_model_unittest.cc
deleted file mode 100644
index 4a479d11038..00000000000
--- a/chromium/components/optimization_guide/core/decision_tree_prediction_model_unittest.cc
+++ /dev/null
@@ -1,434 +0,0 @@
-// Copyright 2020 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#include "components/optimization_guide/core/decision_tree_prediction_model.h"
-
-#include <memory>
-#include <utility>
-
-#include "base/containers/flat_map.h"
-#include "base/containers/flat_set.h"
-#include "components/optimization_guide/core/prediction_model.h"
-#include "components/optimization_guide/proto/models.pb.h"
-#include "testing/gtest/include/gtest/gtest.h"
-
-namespace optimization_guide {
-
-proto::PredictionModel GetValidDecisionTreePredictionModel() {
- proto::PredictionModel prediction_model;
- prediction_model.mutable_model()->mutable_threshold()->set_value(5.0);
-
- proto::DecisionTree decision_tree_model = proto::DecisionTree();
- decision_tree_model.set_weight(2.0);
-
- proto::TreeNode* tree_node = decision_tree_model.add_nodes();
- tree_node->mutable_node_id()->set_value(0);
- tree_node->mutable_binary_node()->mutable_left_child_id()->set_value(1);
- tree_node->mutable_binary_node()->mutable_right_child_id()->set_value(2);
- tree_node->mutable_binary_node()
- ->mutable_inequality_left_child_test()
- ->mutable_feature_id()
- ->mutable_id()
- ->set_value("agg1");
- tree_node->mutable_binary_node()
- ->mutable_inequality_left_child_test()
- ->set_type(proto::InequalityTest::LESS_OR_EQUAL);
- tree_node->mutable_binary_node()
- ->mutable_inequality_left_child_test()
- ->mutable_threshold()
- ->set_float_value(1.0);
-
- tree_node = decision_tree_model.add_nodes();
- tree_node->mutable_node_id()->set_value(1);
- tree_node->mutable_leaf()->mutable_vector()->add_value()->set_double_value(
- 2.);
-
- tree_node = decision_tree_model.add_nodes();
- tree_node->mutable_node_id()->set_value(2);
- tree_node->mutable_leaf()->mutable_vector()->add_value()->set_double_value(
- 4.);
-
- *prediction_model.mutable_model()->mutable_decision_tree() =
- decision_tree_model;
- return prediction_model;
-}
-
-proto::PredictionModel GetValidEnsemblePredictionModel() {
- proto::PredictionModel prediction_model;
- prediction_model.mutable_model()->mutable_threshold()->set_value(5.0);
- proto::Ensemble ensemble = proto::Ensemble();
- *ensemble.add_members()->mutable_submodel() =
- *GetValidDecisionTreePredictionModel().mutable_model();
-
- *ensemble.add_members()->mutable_submodel() =
- *GetValidDecisionTreePredictionModel().mutable_model();
-
- *prediction_model.mutable_model()->mutable_ensemble() = ensemble;
- return prediction_model;
-}
-
-TEST(DecisionTreePredictionModel, ValidDecisionTreeModel) {
- proto::PredictionModel prediction_model =
- GetValidDecisionTreePredictionModel();
-
- proto::ModelInfo* model_info = prediction_model.mutable_model_info();
- model_info->set_version(1);
- model_info->add_supported_model_types(
- proto::ModelType::MODEL_TYPE_DECISION_TREE);
- model_info->add_supported_host_model_features("agg1");
-
- std::unique_ptr<PredictionModel> model =
- PredictionModel::Create(prediction_model);
- EXPECT_TRUE(model);
-
- double prediction_score;
- EXPECT_EQ(OptimizationTargetDecision::kPageLoadDoesNotMatch,
- model->Predict({{"agg1", 1.0}}, &prediction_score));
- EXPECT_EQ(4., prediction_score);
- EXPECT_EQ(OptimizationTargetDecision::kPageLoadMatches,
- model->Predict({{"agg1", 2.0}}, &prediction_score));
- EXPECT_EQ(8., prediction_score);
-}
-
-TEST(DecisionTreePredictionModel, InequalityLessThan) {
- proto::PredictionModel prediction_model =
- GetValidDecisionTreePredictionModel();
-
- prediction_model.mutable_model()
- ->mutable_decision_tree()
- ->mutable_nodes(0)
- ->mutable_binary_node()
- ->mutable_inequality_left_child_test()
- ->set_type(proto::InequalityTest::LESS_THAN);
-
- proto::ModelInfo* model_info = prediction_model.mutable_model_info();
- model_info->set_version(1);
- model_info->add_supported_model_types(
- proto::ModelType::MODEL_TYPE_DECISION_TREE);
- model_info->add_supported_host_model_features("agg1");
-
- std::unique_ptr<PredictionModel> model =
- PredictionModel::Create(std::move(prediction_model));
- EXPECT_TRUE(model);
-
- double prediction_score;
- EXPECT_EQ(OptimizationTargetDecision::kPageLoadDoesNotMatch,
- model->Predict({{"agg1", 0.5}}, &prediction_score));
- EXPECT_EQ(4., prediction_score);
- EXPECT_EQ(OptimizationTargetDecision::kPageLoadMatches,
- model->Predict({{"agg1", 2.0}}, &prediction_score));
- EXPECT_EQ(8., prediction_score);
-}
-
-TEST(DecisionTreePredictionModel, InequalityGreaterOrEqual) {
- proto::PredictionModel prediction_model =
- GetValidDecisionTreePredictionModel();
-
- prediction_model.mutable_model()
- ->mutable_decision_tree()
- ->mutable_nodes(0)
- ->mutable_binary_node()
- ->mutable_inequality_left_child_test()
- ->set_type(proto::InequalityTest::GREATER_OR_EQUAL);
-
- proto::ModelInfo* model_info = prediction_model.mutable_model_info();
- model_info->set_version(1);
- model_info->add_supported_model_types(
- proto::ModelType::MODEL_TYPE_DECISION_TREE);
- model_info->add_supported_host_model_features("agg1");
-
- std::unique_ptr<PredictionModel> model =
- PredictionModel::Create(prediction_model);
- EXPECT_TRUE(model);
-
- double prediction_score;
- EXPECT_EQ(OptimizationTargetDecision::kPageLoadMatches,
- model->Predict({{"agg1", 0.5}}, &prediction_score));
- EXPECT_EQ(8., prediction_score);
- EXPECT_EQ(OptimizationTargetDecision::kPageLoadDoesNotMatch,
- model->Predict({{"agg1", 1.0}}, &prediction_score));
- EXPECT_EQ(4., prediction_score);
-}
-
-TEST(DecisionTreePredictionModel, InequalityGreaterThan) {
- proto::PredictionModel prediction_model =
- GetValidDecisionTreePredictionModel();
-
- prediction_model.mutable_model()
- ->mutable_decision_tree()
- ->mutable_nodes(0)
- ->mutable_binary_node()
- ->mutable_inequality_left_child_test()
- ->set_type(proto::InequalityTest::GREATER_THAN);
-
- proto::ModelInfo* model_info = prediction_model.mutable_model_info();
- model_info->set_version(1);
- model_info->add_supported_model_types(
- proto::ModelType::MODEL_TYPE_DECISION_TREE);
- model_info->add_supported_host_model_features("agg1");
-
- std::unique_ptr<PredictionModel> model =
- PredictionModel::Create(std::move(prediction_model));
- EXPECT_TRUE(model);
-
- double prediction_score;
- EXPECT_EQ(OptimizationTargetDecision::kPageLoadMatches,
- model->Predict({{"agg1", 0.5}}, &prediction_score));
- EXPECT_EQ(8., prediction_score);
- EXPECT_EQ(OptimizationTargetDecision::kPageLoadDoesNotMatch,
- model->Predict({{"agg1", 2.0}}, &prediction_score));
- EXPECT_EQ(4., prediction_score);
-}
-
-TEST(DecisionTreePredictionModel, MissingInequalityTest) {
- proto::PredictionModel prediction_model =
- GetValidDecisionTreePredictionModel();
-
- prediction_model.mutable_model()
- ->mutable_decision_tree()
- ->mutable_nodes(0)
- ->mutable_binary_node()
- ->mutable_inequality_left_child_test()
- ->Clear();
-
- proto::ModelInfo* model_info = prediction_model.mutable_model_info();
- model_info->set_version(1);
- model_info->add_supported_model_types(
- proto::ModelType::MODEL_TYPE_DECISION_TREE);
- model_info->add_supported_host_model_features("agg1");
-
- std::unique_ptr<PredictionModel> model =
- PredictionModel::Create(std::move(prediction_model));
- EXPECT_FALSE(model);
-}
-
-TEST(DecisionTreePredictionModel, NoDecisionTreeThreshold) {
- proto::PredictionModel prediction_model =
- GetValidDecisionTreePredictionModel();
-
- prediction_model.mutable_model()->clear_threshold();
-
- proto::ModelInfo* model_info = prediction_model.mutable_model_info();
- model_info->set_version(1);
- model_info->add_supported_model_types(
- proto::ModelType::MODEL_TYPE_DECISION_TREE);
- model_info->add_supported_host_model_features("agg1");
-
- std::unique_ptr<PredictionModel> model =
- PredictionModel::Create(prediction_model);
- EXPECT_FALSE(model);
-}
-
-TEST(DecisionTreePredictionModel, EmptyTree) {
- proto::PredictionModel prediction_model =
- GetValidDecisionTreePredictionModel();
-
- prediction_model.mutable_model()->mutable_decision_tree()->clear_nodes();
-
- proto::ModelInfo* model_info = prediction_model.mutable_model_info();
- model_info->set_version(1);
- model_info->add_supported_model_types(
- proto::ModelType::MODEL_TYPE_DECISION_TREE);
- model_info->add_supported_host_model_features("agg1");
-
- std::unique_ptr<PredictionModel> model =
- PredictionModel::Create(std::move(prediction_model));
- EXPECT_FALSE(model);
-}
-
-TEST(DecisionTreePredictionModel, ModelFeatureNotInFeatureMap) {
- proto::PredictionModel prediction_model =
- GetValidDecisionTreePredictionModel();
-
- prediction_model.mutable_model()->mutable_decision_tree()->clear_nodes();
-
- proto::ModelInfo* model_info = prediction_model.mutable_model_info();
- model_info->set_version(1);
- model_info->add_supported_model_types(
- proto::ModelType::MODEL_TYPE_DECISION_TREE);
- model_info->add_supported_host_model_features("agg1");
-
- std::unique_ptr<PredictionModel> model =
- PredictionModel::Create(prediction_model);
- EXPECT_FALSE(model);
-}
-
-TEST(DecisionTreePredictionModel, DecisionTreeMissingLeaf) {
- proto::PredictionModel prediction_model =
- GetValidDecisionTreePredictionModel();
-
- prediction_model.mutable_model()
- ->mutable_decision_tree()
- ->mutable_nodes(1)
- ->mutable_leaf()
- ->Clear();
-
- proto::ModelInfo* model_info = prediction_model.mutable_model_info();
- model_info->set_version(1);
- model_info->add_supported_model_types(
- proto::ModelType::MODEL_TYPE_DECISION_TREE);
- model_info->add_supported_host_model_features("agg1");
-
- std::unique_ptr<PredictionModel> model =
- PredictionModel::Create(prediction_model);
- EXPECT_FALSE(model);
-}
-
-TEST(DecisionTreePredictionModel, DecisionTreeLeftChildIndexInvalid) {
- proto::PredictionModel prediction_model =
- GetValidDecisionTreePredictionModel();
-
- prediction_model.mutable_model()
- ->mutable_decision_tree()
- ->mutable_nodes(0)
- ->mutable_binary_node()
- ->mutable_left_child_id()
- ->set_value(3);
-
- proto::ModelInfo* model_info = prediction_model.mutable_model_info();
- model_info->set_version(1);
- model_info->add_supported_model_types(
- proto::ModelType::MODEL_TYPE_DECISION_TREE);
- model_info->add_supported_host_model_features("agg1");
-
- std::unique_ptr<PredictionModel> model =
- PredictionModel::Create(std::move(prediction_model));
- EXPECT_FALSE(model);
-}
-
-TEST(DecisionTreePredictionModel, DecisionTreeRightChildIndexInvalid) {
- proto::PredictionModel prediction_model =
- GetValidDecisionTreePredictionModel();
-
- prediction_model.mutable_model()
- ->mutable_decision_tree()
- ->mutable_nodes(0)
- ->mutable_binary_node()
- ->mutable_right_child_id()
- ->set_value(3);
-
- proto::ModelInfo* model_info = prediction_model.mutable_model_info();
- model_info->set_version(1);
- model_info->add_supported_model_types(
- proto::ModelType::MODEL_TYPE_DECISION_TREE);
- model_info->add_supported_host_model_features("agg1");
-
- std::unique_ptr<PredictionModel> model =
- PredictionModel::Create(prediction_model);
- EXPECT_FALSE(model);
-}
-
-TEST(DecisionTreePredictionModel, DecisionTreeWithLoopOnLeftChild) {
- proto::PredictionModel prediction_model =
- GetValidDecisionTreePredictionModel();
-
- proto::TreeNode* tree_node =
- prediction_model.mutable_model()->mutable_decision_tree()->mutable_nodes(
- 1);
-
- tree_node->mutable_node_id()->set_value(0);
- tree_node->mutable_binary_node()
- ->mutable_inequality_left_child_test()
- ->mutable_feature_id()
- ->mutable_id()
- ->set_value("agg1");
- tree_node->mutable_binary_node()
- ->mutable_inequality_left_child_test()
- ->set_type(proto::InequalityTest::LESS_OR_EQUAL);
- tree_node->mutable_binary_node()
- ->mutable_inequality_left_child_test()
- ->mutable_threshold()
- ->set_float_value(1.0);
-
- tree_node->mutable_binary_node()->mutable_left_child_id()->set_value(0);
- tree_node->mutable_binary_node()->mutable_right_child_id()->set_value(2);
-
- proto::ModelInfo* model_info = prediction_model.mutable_model_info();
- model_info->set_version(1);
- model_info->add_supported_model_types(
- proto::ModelType::MODEL_TYPE_DECISION_TREE);
- model_info->add_supported_host_model_features("agg1");
-
- std::unique_ptr<PredictionModel> model =
- PredictionModel::Create(prediction_model);
- EXPECT_FALSE(model);
-}
-
-TEST(DecisionTreePredictionModel, DecisionTreeWithLoopOnRightChild) {
- proto::PredictionModel prediction_model =
- GetValidDecisionTreePredictionModel();
-
- proto::TreeNode* tree_node =
- prediction_model.mutable_model()->mutable_decision_tree()->mutable_nodes(
- 1);
-
- tree_node->mutable_node_id()->set_value(0);
- tree_node->mutable_binary_node()
- ->mutable_inequality_left_child_test()
- ->mutable_feature_id()
- ->mutable_id()
- ->set_value("agg1");
- tree_node->mutable_binary_node()
- ->mutable_inequality_left_child_test()
- ->set_type(proto::InequalityTest::LESS_OR_EQUAL);
- tree_node->mutable_binary_node()
- ->mutable_inequality_left_child_test()
- ->mutable_threshold()
- ->set_float_value(1.0);
-
- tree_node->mutable_binary_node()->mutable_left_child_id()->set_value(2);
- tree_node->mutable_binary_node()->mutable_right_child_id()->set_value(0);
-
- proto::ModelInfo* model_info = prediction_model.mutable_model_info();
- model_info->set_version(1);
- model_info->add_supported_model_types(
- proto::ModelType::MODEL_TYPE_DECISION_TREE);
- model_info->add_supported_host_model_features("agg1");
-
- std::unique_ptr<PredictionModel> model =
- PredictionModel::Create(prediction_model);
- EXPECT_FALSE(model);
-}
-
-TEST(DecisionTreePredictionModel, ValidEnsembleModel) {
- proto::PredictionModel prediction_model = GetValidEnsemblePredictionModel();
-
- proto::ModelInfo* model_info = prediction_model.mutable_model_info();
- model_info->set_version(1);
- model_info->add_supported_model_types(
- proto::ModelType::MODEL_TYPE_DECISION_TREE);
- model_info->add_supported_host_model_features("agg1");
-
- std::unique_ptr<PredictionModel> model =
- PredictionModel::Create(prediction_model);
- EXPECT_TRUE(model);
-
- double prediction_score;
- EXPECT_EQ(OptimizationTargetDecision::kPageLoadDoesNotMatch,
- model->Predict({{"agg1", 1.0}}, &prediction_score));
- EXPECT_EQ(4., prediction_score);
- EXPECT_EQ(OptimizationTargetDecision::kPageLoadMatches,
- model->Predict({{"agg1", 2.0}}, &prediction_score));
- EXPECT_EQ(8., prediction_score);
-}
-
-TEST(DecisionTreePredictionModel, EnsembleWithNoMembers) {
- proto::PredictionModel prediction_model = GetValidEnsemblePredictionModel();
- prediction_model.mutable_model()
- ->mutable_ensemble()
- ->mutable_members()
- ->Clear();
-
- proto::ModelInfo* model_info = prediction_model.mutable_model_info();
- model_info->set_version(1);
- model_info->add_supported_model_types(
- proto::ModelType::MODEL_TYPE_DECISION_TREE);
- model_info->add_supported_host_model_features("agg1");
-
- std::unique_ptr<PredictionModel> model =
- PredictionModel::Create(prediction_model);
- EXPECT_FALSE(model);
-}
-
-} // namespace optimization_guide
diff --git a/chromium/components/optimization_guide/core/entity_annotator_native_library.cc b/chromium/components/optimization_guide/core/entity_annotator_native_library.cc
new file mode 100644
index 00000000000..1382ef228fc
--- /dev/null
+++ b/chromium/components/optimization_guide/core/entity_annotator_native_library.cc
@@ -0,0 +1,445 @@
+// Copyright 2022 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "components/optimization_guide/core/entity_annotator_native_library.h"
+
+#include "base/base_paths.h"
+#include "base/compiler_specific.h"
+#include "base/logging.h"
+#include "base/memory/ptr_util.h"
+#include "base/path_service.h"
+#include "build/build_config.h"
+#include "components/optimization_guide/core/model_util.h"
+#include "components/optimization_guide/core/optimization_guide_util.h"
+#include "components/optimization_guide/proto/page_entities_model_metadata.pb.h"
+
+#if BUILDFLAG(IS_MAC)
+#include "base/mac/bundle_locations.h"
+#include "base/mac/foundation_util.h"
+#endif
+
+// IMPORTANT: All functions in this file that call dlsym()'ed
+// functions should be annotated with DISABLE_CFI_ICALL.
+
+namespace optimization_guide {
+
+namespace {
+
+const char kModelMetadataBaseName[] = "model_metadata.pb";
+const char kWordEmbeddingsBaseName[] = "word_embeddings";
+const char kNameTableBaseName[] = "entities_names";
+const char kMetadataTableBaseName[] = "entities_metadata";
+const char kNameFilterBaseName[] = "entities_names_filter";
+const char kPrefixFilterBaseName[] = "entities_prefixes_filter";
+
+// Sets |field_to_set| with the full file path of |base_name|'s entry in
+// |base_to_full_file_path|. Returns whether |base_name| is in
+// |base_to_full_file_path|.
+absl::optional<std::string> GetFilePathFromMap(
+ const std::string& base_name,
+ const base::flat_map<std::string, base::FilePath>& base_to_full_file_path) {
+ auto it = base_to_full_file_path.find(base_name);
+ return it == base_to_full_file_path.end()
+ ? absl::nullopt
+ : absl::make_optional(FilePathToString(it->second));
+}
+
+// Returns the expected base name for |slice|. Will be of the form
+// |slice|-|base_name|.
+std::string GetSliceBaseName(const std::string& slice,
+ const std::string& base_name) {
+ return slice + "-" + base_name;
+}
+
+} // namespace
+
+EntityAnnotatorNativeLibrary::EntityAnnotatorNativeLibrary(
+ base::NativeLibrary native_library)
+ : native_library_(std::move(native_library)) {
+ LoadFunctions();
+}
+EntityAnnotatorNativeLibrary::~EntityAnnotatorNativeLibrary() = default;
+
+// static
+std::unique_ptr<EntityAnnotatorNativeLibrary>
+EntityAnnotatorNativeLibrary::Create() {
+ base::FilePath base_dir;
+#if BUILDFLAG(IS_MAC)
+ if (base::mac::AmIBundled()) {
+ base_dir = base::mac::FrameworkBundlePath().Append("Libraries");
+ } else {
+#endif
+ if (!base::PathService::Get(base::DIR_MODULE, &base_dir)) {
+ LOG(ERROR) << "Error getting app dir";
+ return nullptr;
+ }
+#if BUILDFLAG(IS_MAC)
+ }
+#endif
+
+ base::NativeLibraryLoadError error;
+ base::NativeLibrary native_library = base::LoadNativeLibrary(
+ base_dir.AppendASCII(
+ base::GetNativeLibraryName("optimization_guide_internal")),
+ &error);
+ if (!native_library) {
+ LOG(ERROR) << "Failed to initialize optimization guide internal: "
+ << error.ToString();
+ return nullptr;
+ }
+
+ std::unique_ptr<EntityAnnotatorNativeLibrary>
+ entity_annotator_native_library =
+ base::WrapUnique<EntityAnnotatorNativeLibrary>(
+ new EntityAnnotatorNativeLibrary(std::move(native_library)));
+ if (entity_annotator_native_library->IsValid()) {
+ return entity_annotator_native_library;
+ }
+ LOG(ERROR) << "Could not find all required functions for optimization guide "
+ "internal library";
+ return nullptr;
+}
+
+DISABLE_CFI_ICALL
+void EntityAnnotatorNativeLibrary::LoadFunctions() {
+ get_max_supported_feature_flag_func_ =
+ reinterpret_cast<GetMaxSupportedFeatureFlagFunc>(
+ base::GetFunctionPointerFromNativeLibrary(
+ native_library_,
+ "OptimizationGuideEntityAnnotatorGetMaxSupportedFeatureFlag"));
+
+ create_from_options_func_ = reinterpret_cast<CreateFromOptionsFunc>(
+ base::GetFunctionPointerFromNativeLibrary(
+ native_library_,
+ "OptimizationGuideEntityAnnotatorCreateFromOptions"));
+ get_creation_error_func_ = reinterpret_cast<GetCreationErrorFunc>(
+ base::GetFunctionPointerFromNativeLibrary(
+ native_library_, "OptimizationGuideEntityAnnotatorGetCreationError"));
+ delete_func_ =
+ reinterpret_cast<DeleteFunc>(base::GetFunctionPointerFromNativeLibrary(
+ native_library_, "OptimizationGuideEntityAnnotatorDelete"));
+
+ annotate_job_create_func_ = reinterpret_cast<AnnotateJobCreateFunc>(
+ base::GetFunctionPointerFromNativeLibrary(
+ native_library_,
+ "OptimizationGuideEntityAnnotatorAnnotateJobCreate"));
+ annotate_job_delete_func_ = reinterpret_cast<AnnotateJobDeleteFunc>(
+ base::GetFunctionPointerFromNativeLibrary(
+ native_library_,
+ "OptimizationGuideEntityAnnotatorAnnotateJobDelete"));
+ run_annotate_job_func_ = reinterpret_cast<RunAnnotateJobFunc>(
+ base::GetFunctionPointerFromNativeLibrary(
+ native_library_, "OptimizationGuideEntityAnnotatorRunAnnotateJob"));
+ annotate_get_output_metadata_at_index_func_ = reinterpret_cast<
+ AnnotateGetOutputMetadataAtIndexFunc>(
+ base::GetFunctionPointerFromNativeLibrary(
+ native_library_,
+ "OptimizationGuideEntityAnnotatorAnnotateGetOutputMetadataAtIndex"));
+ annotate_get_output_metadata_score_at_index_func_ =
+ reinterpret_cast<AnnotateGetOutputMetadataScoreAtIndexFunc>(
+ base::GetFunctionPointerFromNativeLibrary(
+ native_library_,
+ "OptimizationGuideEntityAnnotatorAnnotateGetOutputMetadataScoreAt"
+ "Index"));
+
+ entity_metadata_job_create_func_ =
+ reinterpret_cast<EntityMetadataJobCreateFunc>(
+ base::GetFunctionPointerFromNativeLibrary(
+ native_library_,
+ "OptimizationGuideEntityAnnotatorEntityMetadataJobCreate"));
+ entity_metadata_job_delete_func_ =
+ reinterpret_cast<EntityMetadataJobDeleteFunc>(
+ base::GetFunctionPointerFromNativeLibrary(
+ native_library_,
+ "OptimizationGuideEntityAnnotatorEntityMetadataJobDelete"));
+ run_entity_metadata_job_func_ = reinterpret_cast<RunEntityMetadataJobFunc>(
+ base::GetFunctionPointerFromNativeLibrary(
+ native_library_,
+ "OptimizationGuideEntityAnnotatorRunEntityMetadataJob"));
+
+ options_create_func_ = reinterpret_cast<OptionsCreateFunc>(
+ base::GetFunctionPointerFromNativeLibrary(
+ native_library_, "OptimizationGuideEntityAnnotatorOptionsCreate"));
+ options_set_model_file_path_func_ =
+ reinterpret_cast<OptionsSetModelFilePathFunc>(
+ base::GetFunctionPointerFromNativeLibrary(
+ native_library_,
+ "OptimizationGuideEntityAnnotatorOptionsSetModelFilePath"));
+ options_set_model_metadata_file_path_func_ = reinterpret_cast<
+ OptionsSetModelMetadataFilePathFunc>(
+ base::GetFunctionPointerFromNativeLibrary(
+ native_library_,
+ "OptimizationGuideEntityAnnotatorOptionsSetModelMetadataFilePath"));
+ options_set_word_embeddings_file_path_func_ = reinterpret_cast<
+ OptionsSetWordEmbeddingsFilePathFunc>(
+ base::GetFunctionPointerFromNativeLibrary(
+ native_library_,
+ "OptimizationGuideEntityAnnotatorOptionsSetWordEmbeddingsFilePath"));
+ options_add_model_slice_func_ = reinterpret_cast<OptionsAddModelSliceFunc>(
+ base::GetFunctionPointerFromNativeLibrary(
+ native_library_,
+ "OptimizationGuideEntityAnnotatorOptionsAddModelSlice"));
+ options_delete_func_ = reinterpret_cast<OptionsDeleteFunc>(
+ base::GetFunctionPointerFromNativeLibrary(
+ native_library_, "OptimizationGuideEntityAnnotatorOptionsDelete"));
+
+ entity_metadata_get_entity_id_func_ =
+ reinterpret_cast<EntityMetadataGetEntityIdFunc>(
+ base::GetFunctionPointerFromNativeLibrary(
+ native_library_, "OptimizationGuideEntityMetadataGetEntityID"));
+ entity_metadata_get_human_readable_name_func_ =
+ reinterpret_cast<EntityMetadataGetHumanReadableNameFunc>(
+ base::GetFunctionPointerFromNativeLibrary(
+ native_library_,
+ "OptimizationGuideEntityMetadataGetHumanReadableName"));
+ entity_metadata_get_human_readable_categories_count_func_ = reinterpret_cast<
+ EntityMetadataGetHumanReadableCategoriesCountFunc>(
+ base::GetFunctionPointerFromNativeLibrary(
+ native_library_,
+ "OptimizationGuideEntityMetadataGetHumanReadableCategoriesCount"));
+ entity_metadata_get_human_readable_category_name_at_index_func_ =
+ reinterpret_cast<EntityMetadataGetHumanReadableCategoryNameAtIndexFunc>(
+ base::GetFunctionPointerFromNativeLibrary(
+ native_library_,
+ "OptimizationGuideEntityMetadataGetHumanReadableCategoryNameAtInd"
+ "ex"));
+ entity_metadata_get_human_readable_category_score_at_index_func_ =
+ reinterpret_cast<EntityMetadataGetHumanReadableCategoryScoreAtIndexFunc>(
+ base::GetFunctionPointerFromNativeLibrary(
+ native_library_,
+ "OptimizationGuideEntityMetadataGetHumanReadableCategoryScoreAtIn"
+ "dex"));
+}
+
+DISABLE_CFI_ICALL
+bool EntityAnnotatorNativeLibrary::IsValid() const {
+ return get_max_supported_feature_flag_func_ && create_from_options_func_ &&
+ get_creation_error_func_ && delete_func_ &&
+ annotate_job_create_func_ && annotate_job_delete_func_ &&
+ run_annotate_job_func_ &&
+ annotate_get_output_metadata_at_index_func_ &&
+ annotate_get_output_metadata_score_at_index_func_ &&
+ entity_metadata_job_create_func_ && entity_metadata_job_delete_func_ &&
+ run_entity_metadata_job_func_ && options_create_func_ &&
+ options_set_model_file_path_func_ &&
+ options_set_model_metadata_file_path_func_ &&
+ options_set_word_embeddings_file_path_func_ &&
+ options_add_model_slice_func_ && options_delete_func_ &&
+ entity_metadata_get_entity_id_func_ &&
+ entity_metadata_get_human_readable_name_func_ &&
+ entity_metadata_get_human_readable_categories_count_func_ &&
+ entity_metadata_get_human_readable_category_name_at_index_func_ &&
+ entity_metadata_get_human_readable_category_score_at_index_func_;
+}
+
+DISABLE_CFI_ICALL
+int32_t EntityAnnotatorNativeLibrary::GetMaxSupportedFeatureFlag() {
+ DCHECK(IsValid());
+ if (!IsValid()) {
+ return -1;
+ }
+
+ return get_max_supported_feature_flag_func_();
+}
+
+DISABLE_CFI_ICALL
+void* EntityAnnotatorNativeLibrary::CreateEntityAnnotator(
+ const ModelInfo& model_info) {
+ DCHECK(IsValid());
+ if (!IsValid()) {
+ return nullptr;
+ }
+
+ void* options = options_create_func_();
+ if (!PopulateEntityAnnotatorOptionsFromModelInfo(options, model_info)) {
+ options_delete_func_(options);
+ return nullptr;
+ }
+
+ void* entity_annotator = create_from_options_func_(options);
+ const char* creation_error = get_creation_error_func_(entity_annotator);
+ if (creation_error) {
+ LOG(ERROR) << "Failed to create entity annotator: " << creation_error;
+ DeleteEntityAnnotator(entity_annotator);
+ entity_annotator = nullptr;
+ }
+ options_delete_func_(options);
+ return entity_annotator;
+}
+
+DISABLE_CFI_ICALL
+bool EntityAnnotatorNativeLibrary::PopulateEntityAnnotatorOptionsFromModelInfo(
+ void* options,
+ const ModelInfo& model_info) {
+ // We don't know which files are intended for use if we don't have model
+ // metadata, so return early.
+ if (!model_info.GetModelMetadata()) {
+ return false;
+ }
+
+ // // Validate the model metadata.
+ absl::optional<proto::PageEntitiesModelMetadata> entities_model_metadata =
+ ParsedAnyMetadata<proto::PageEntitiesModelMetadata>(
+ model_info.GetModelMetadata().value());
+ if (!entities_model_metadata) {
+ return false;
+ }
+ if (entities_model_metadata->slice_size() == 0) {
+ return false;
+ }
+
+ // Build the entity annotator options.
+ options_set_model_file_path_func_(
+ options, FilePathToString(model_info.GetModelFilePath()).c_str());
+
+ // Attach the additional files required by the model.
+ base::flat_map<std::string, base::FilePath> base_to_full_file_path;
+ for (const auto& model_file : model_info.GetAdditionalFiles()) {
+ base_to_full_file_path.insert(
+ {FilePathToString(model_file.BaseName()), model_file});
+ }
+ absl::optional<std::string> model_metadata_file_path =
+ GetFilePathFromMap(kModelMetadataBaseName, base_to_full_file_path);
+ if (!model_metadata_file_path) {
+ return false;
+ }
+ options_set_model_metadata_file_path_func_(options,
+ model_metadata_file_path->c_str());
+ absl::optional<std::string> word_embeddings_file_path =
+ GetFilePathFromMap(kWordEmbeddingsBaseName, base_to_full_file_path);
+ if (!word_embeddings_file_path) {
+ return false;
+ }
+ options_set_word_embeddings_file_path_func_(
+ options, word_embeddings_file_path->c_str());
+
+ base::flat_set<std::string> slices(entities_model_metadata->slice().begin(),
+ entities_model_metadata->slice().end());
+ for (const auto& slice_id : slices) {
+ absl::optional<std::string> name_filter_path =
+ GetFilePathFromMap(GetSliceBaseName(slice_id, kNameFilterBaseName),
+ base_to_full_file_path);
+ if (!name_filter_path) {
+ return false;
+ }
+ absl::optional<std::string> name_table_path = GetFilePathFromMap(
+ GetSliceBaseName(slice_id, kNameTableBaseName), base_to_full_file_path);
+ if (!name_table_path) {
+ return false;
+ }
+ absl::optional<std::string> prefix_filter_path =
+ GetFilePathFromMap(GetSliceBaseName(slice_id, kPrefixFilterBaseName),
+ base_to_full_file_path);
+ if (!prefix_filter_path) {
+ return false;
+ }
+ absl::optional<std::string> metadata_table_path =
+ GetFilePathFromMap(GetSliceBaseName(slice_id, kMetadataTableBaseName),
+ base_to_full_file_path);
+ if (!metadata_table_path) {
+ return false;
+ }
+ options_add_model_slice_func_(
+ options, slice_id.c_str(), name_filter_path->c_str(),
+ name_table_path->c_str(), prefix_filter_path->c_str(),
+ metadata_table_path->c_str());
+ }
+
+ return true;
+}
+
+DISABLE_CFI_ICALL
+void EntityAnnotatorNativeLibrary::DeleteEntityAnnotator(
+ void* entity_annotator) {
+ DCHECK(IsValid());
+ if (!IsValid()) {
+ return;
+ }
+
+ delete_func_(reinterpret_cast<void*>(entity_annotator));
+}
+
+DISABLE_CFI_ICALL
+absl::optional<std::vector<ScoredEntityMetadata>>
+EntityAnnotatorNativeLibrary::AnnotateText(void* annotator,
+ const std::string& text) {
+ DCHECK(IsValid());
+ if (!IsValid()) {
+ return absl::nullopt;
+ }
+
+ if (!annotator) {
+ return absl::nullopt;
+ }
+
+ void* job = annotate_job_create_func_(reinterpret_cast<void*>(annotator));
+ int32_t output_metadata_count = run_annotate_job_func_(job, text.c_str());
+ if (output_metadata_count <= 0) {
+ return absl::nullopt;
+ }
+ std::vector<ScoredEntityMetadata> scored_md;
+ scored_md.reserve(output_metadata_count);
+ for (int32_t i = 0; i < output_metadata_count; i++) {
+ ScoredEntityMetadata md;
+ md.score = annotate_get_output_metadata_score_at_index_func_(job, i);
+ md.metadata = GetEntityMetadataFromOptimizationGuideEntityMetadata(
+ annotate_get_output_metadata_at_index_func_(job, i));
+ scored_md.emplace_back(md);
+ }
+ annotate_job_delete_func_(job);
+ return scored_md;
+}
+
+DISABLE_CFI_ICALL
+absl::optional<EntityMetadata>
+EntityAnnotatorNativeLibrary::GetEntityMetadataForEntityId(
+ void* annotator,
+ const std::string& entity_id) {
+ DCHECK(IsValid());
+ if (!IsValid()) {
+ return absl::nullopt;
+ }
+ if (!annotator) {
+ return absl::nullopt;
+ }
+
+ void* job =
+ entity_metadata_job_create_func_(reinterpret_cast<void*>(annotator));
+ const void* entity_metadata =
+ run_entity_metadata_job_func_(job, entity_id.c_str());
+ if (!entity_metadata) {
+ return absl::nullopt;
+ }
+ EntityMetadata md =
+ GetEntityMetadataFromOptimizationGuideEntityMetadata(entity_metadata);
+ entity_metadata_job_delete_func_(job);
+ return md;
+}
+
+DISABLE_CFI_ICALL
+EntityMetadata EntityAnnotatorNativeLibrary::
+ GetEntityMetadataFromOptimizationGuideEntityMetadata(
+ const void* og_entity_metadata) {
+ EntityMetadata entity_metadata;
+ entity_metadata.entity_id =
+ entity_metadata_get_entity_id_func_(og_entity_metadata);
+ entity_metadata.human_readable_name =
+ entity_metadata_get_human_readable_name_func_(og_entity_metadata);
+
+ int32_t human_readable_categories_count =
+ entity_metadata_get_human_readable_categories_count_func_(
+ og_entity_metadata);
+ for (int32_t i = 0; i < human_readable_categories_count; i++) {
+ std::string category_name =
+ entity_metadata_get_human_readable_category_name_at_index_func_(
+ og_entity_metadata, i);
+ float category_score =
+ entity_metadata_get_human_readable_category_score_at_index_func_(
+ og_entity_metadata, i);
+ entity_metadata.human_readable_categories[category_name] = category_score;
+ }
+ return entity_metadata;
+}
+
+} // namespace optimization_guide \ No newline at end of file
diff --git a/chromium/components/optimization_guide/core/entity_annotator_native_library.h b/chromium/components/optimization_guide/core/entity_annotator_native_library.h
new file mode 100644
index 00000000000..a1fc4973269
--- /dev/null
+++ b/chromium/components/optimization_guide/core/entity_annotator_native_library.h
@@ -0,0 +1,143 @@
+// Copyright 2022 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef COMPONENTS_OPTIMIZATION_GUIDE_CORE_ENTITY_ANNOTATOR_NATIVE_LIBRARY_H_
+#define COMPONENTS_OPTIMIZATION_GUIDE_CORE_ENTITY_ANNOTATOR_NATIVE_LIBRARY_H_
+
+#include <memory>
+#include <vector>
+
+#include "base/native_library.h"
+#include "components/optimization_guide/core/entity_metadata.h"
+#include "components/optimization_guide/core/model_info.h"
+
+namespace optimization_guide {
+
+// Handles interactions with the native library that contains logic for the
+// entity annotator.
+class EntityAnnotatorNativeLibrary {
+ public:
+ // Creates an EntityAnnotatorNativeLibrary, which loads a native library and
+ // relevant functions required. Will return nullptr if fails.
+ static std::unique_ptr<EntityAnnotatorNativeLibrary> Create();
+
+ EntityAnnotatorNativeLibrary(const EntityAnnotatorNativeLibrary&) = delete;
+ EntityAnnotatorNativeLibrary& operator=(const EntityAnnotatorNativeLibrary&) =
+ delete;
+ ~EntityAnnotatorNativeLibrary();
+
+ // Returns whether this instance is valid (i.e. all necessary functions have
+ // been loaded.)
+ bool IsValid() const;
+
+ // Gets the max supported feature from this native library.
+ int32_t GetMaxSupportedFeatureFlag();
+
+ // Creates an entity annotator from |model_info|.
+ void* CreateEntityAnnotator(const ModelInfo& model_info);
+
+ // Deletes |entity_annotator|.
+ void DeleteEntityAnnotator(void* entity_annotator);
+
+ // Uses |annotator| to annotate entities present in |text|.
+ absl::optional<std::vector<ScoredEntityMetadata>> AnnotateText(
+ void* annotator,
+ const std::string& text);
+
+ // Returns entity metadata from |annotator| for |entity_id|.
+ absl::optional<EntityMetadata> GetEntityMetadataForEntityId(
+ void* annotator,
+ const std::string& entity_id);
+
+ private:
+ EntityAnnotatorNativeLibrary(base::NativeLibrary native_library);
+
+ // Loads the functions exposed by the native library.
+ void LoadFunctions();
+
+ // Populates |options| based on |model_info|. Returns false if |model_info|
+ // cannot construct a valid options object.
+ bool PopulateEntityAnnotatorOptionsFromModelInfo(void* options,
+ const ModelInfo& model_info);
+
+ // Returns an entity metadata from the C-API representation.
+ EntityMetadata GetEntityMetadataFromOptimizationGuideEntityMetadata(
+ const void* og_entity_metadata);
+
+ base::NativeLibrary native_library_;
+
+ // Functions exposed by native library.
+ using GetMaxSupportedFeatureFlagFunc = int32_t (*)();
+ GetMaxSupportedFeatureFlagFunc get_max_supported_feature_flag_func_ = nullptr;
+
+ using CreateFromOptionsFunc = void* (*)(const void*);
+ CreateFromOptionsFunc create_from_options_func_ = nullptr;
+ using GetCreationErrorFunc = const char* (*)(const void*);
+ GetCreationErrorFunc get_creation_error_func_ = nullptr;
+ using DeleteFunc = void (*)(void*);
+ DeleteFunc delete_func_ = nullptr;
+
+ using AnnotateJobCreateFunc = void* (*)(void*);
+ AnnotateJobCreateFunc annotate_job_create_func_ = nullptr;
+ using AnnotateJobDeleteFunc = void (*)(void*);
+ AnnotateJobDeleteFunc annotate_job_delete_func_ = nullptr;
+ using RunAnnotateJobFunc = int32_t (*)(void*, const char*);
+ RunAnnotateJobFunc run_annotate_job_func_ = nullptr;
+ using AnnotateGetOutputMetadataAtIndexFunc = const void* (*)(void*, int32_t);
+ AnnotateGetOutputMetadataAtIndexFunc
+ annotate_get_output_metadata_at_index_func_ = nullptr;
+ using AnnotateGetOutputMetadataScoreAtIndexFunc = float (*)(void*, int32_t);
+ AnnotateGetOutputMetadataScoreAtIndexFunc
+ annotate_get_output_metadata_score_at_index_func_ = nullptr;
+
+ using EntityMetadataJobCreateFunc = void* (*)(void*);
+ EntityMetadataJobCreateFunc entity_metadata_job_create_func_ = nullptr;
+ using EntityMetadataJobDeleteFunc = void (*)(void*);
+ EntityMetadataJobDeleteFunc entity_metadata_job_delete_func_ = nullptr;
+ using RunEntityMetadataJobFunc = const void* (*)(void*, const char*);
+ RunEntityMetadataJobFunc run_entity_metadata_job_func_ = nullptr;
+
+ using OptionsCreateFunc = void* (*)();
+ OptionsCreateFunc options_create_func_ = nullptr;
+ using OptionsSetModelFilePathFunc = void (*)(void*, const char*);
+ OptionsSetModelFilePathFunc options_set_model_file_path_func_ = nullptr;
+ using OptionsSetModelMetadataFilePathFunc = void (*)(void*, const char*);
+ OptionsSetModelMetadataFilePathFunc
+ options_set_model_metadata_file_path_func_ = nullptr;
+ using OptionsSetWordEmbeddingsFilePathFunc = void (*)(void*, const char*);
+ OptionsSetWordEmbeddingsFilePathFunc
+ options_set_word_embeddings_file_path_func_ = nullptr;
+ using OptionsAddModelSliceFunc = void (*)(void*,
+ const char*,
+ const char*,
+ const char*,
+ const char*,
+ const char*);
+ OptionsAddModelSliceFunc options_add_model_slice_func_ = nullptr;
+ using OptionsDeleteFunc = void (*)(void*);
+ OptionsDeleteFunc options_delete_func_ = nullptr;
+
+ using EntityMetadataGetEntityIdFunc = const char* (*)(const void*);
+ EntityMetadataGetEntityIdFunc entity_metadata_get_entity_id_func_ = nullptr;
+ using EntityMetadataGetHumanReadableNameFunc = const char* (*)(const void*);
+ EntityMetadataGetHumanReadableNameFunc
+ entity_metadata_get_human_readable_name_func_ = nullptr;
+ using EntityMetadataGetHumanReadableCategoriesCountFunc =
+ int32_t (*)(const void*);
+ EntityMetadataGetHumanReadableCategoriesCountFunc
+ entity_metadata_get_human_readable_categories_count_func_ = nullptr;
+ using EntityMetadataGetHumanReadableCategoryNameAtIndexFunc =
+ const char* (*)(const void*, int32_t);
+ EntityMetadataGetHumanReadableCategoryNameAtIndexFunc
+ entity_metadata_get_human_readable_category_name_at_index_func_ = nullptr;
+ using EntityMetadataGetHumanReadableCategoryScoreAtIndexFunc =
+ float (*)(const void*, int32_t);
+ EntityMetadataGetHumanReadableCategoryScoreAtIndexFunc
+ entity_metadata_get_human_readable_category_score_at_index_func_ =
+ nullptr;
+};
+
+} // namespace optimization_guide
+
+#endif // COMPONENTS_OPTIMIZATION_GUIDE_CORE_ENTITY_ANNOTATOR_NATIVE_LIBRARY_H_ \ No newline at end of file
diff --git a/chromium/components/optimization_guide/core/entity_annotator_native_library_unittest.cc b/chromium/components/optimization_guide/core/entity_annotator_native_library_unittest.cc
new file mode 100644
index 00000000000..3155e9bf10c
--- /dev/null
+++ b/chromium/components/optimization_guide/core/entity_annotator_native_library_unittest.cc
@@ -0,0 +1,22 @@
+// Copyright 2022 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "components/optimization_guide/core/entity_annotator_native_library.h"
+
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace optimization_guide {
+namespace {
+
+using EntityAnnotatorNativeLibraryTest = ::testing::Test;
+
+TEST_F(EntityAnnotatorNativeLibraryTest, CanCreateValidLibrary) {
+ std::unique_ptr<EntityAnnotatorNativeLibrary> lib =
+ EntityAnnotatorNativeLibrary::Create();
+ ASSERT_TRUE(lib);
+ EXPECT_TRUE(lib->IsValid());
+}
+
+} // namespace
+} // namespace optimization_guide
diff --git a/chromium/components/optimization_guide/core/entity_metadata.cc b/chromium/components/optimization_guide/core/entity_metadata.cc
index 6ca1930d1fd..2f8301579f3 100644
--- a/chromium/components/optimization_guide/core/entity_metadata.cc
+++ b/chromium/components/optimization_guide/core/entity_metadata.cc
@@ -4,6 +4,7 @@
#include "components/optimization_guide/core/entity_metadata.h"
+#include <ostream>
#include <string>
#include <vector>
diff --git a/chromium/components/optimization_guide/core/entity_metadata_provider.h b/chromium/components/optimization_guide/core/entity_metadata_provider.h
index 54ec81d7e01..14f7def09a8 100644
--- a/chromium/components/optimization_guide/core/entity_metadata_provider.h
+++ b/chromium/components/optimization_guide/core/entity_metadata_provider.h
@@ -6,6 +6,7 @@
#define COMPONENTS_OPTIMIZATION_GUIDE_CORE_ENTITY_METADATA_PROVIDER_H_
#include "components/optimization_guide/core/entity_metadata.h"
+#include "third_party/abseil-cpp/absl/types/optional.h"
namespace optimization_guide {
diff --git a/chromium/components/optimization_guide/core/hints_fetcher.cc b/chromium/components/optimization_guide/core/hints_fetcher.cc
index 3909ae5c3f8..7d58512dc42 100644
--- a/chromium/components/optimization_guide/core/hints_fetcher.cc
+++ b/chromium/components/optimization_guide/core/hints_fetcher.cc
@@ -11,9 +11,11 @@
#include "base/feature_list.h"
#include "base/metrics/histogram_functions.h"
#include "base/metrics/histogram_macros.h"
+#include "base/strings/strcat.h"
#include "base/time/default_clock.h"
#include "components/optimization_guide/core/hints_processing_util.h"
#include "components/optimization_guide/core/optimization_guide_features.h"
+#include "components/optimization_guide/core/optimization_guide_logger.h"
#include "components/optimization_guide/core/optimization_guide_prefs.h"
#include "components/optimization_guide/core/optimization_guide_switches.h"
#include "components/optimization_guide/core/optimization_guide_util.h"
@@ -27,7 +29,6 @@
#include "net/http/http_response_headers.h"
#include "net/http/http_status_code.h"
#include "net/traffic_annotation/network_traffic_annotation.h"
-#include "services/network/public/cpp/network_connection_tracker.h"
#include "services/network/public/cpp/resource_request.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"
#include "services/network/public/cpp/simple_url_loader.h"
@@ -62,23 +63,9 @@ std::string GetStringNameForRequestContext(
return std::string();
}
-// Returns the subset of URLs from |urls| for which the URL is considered
-// valid and can be included in a hints fetch.
-std::vector<GURL> GetValidURLsForFetching(const std::vector<GURL>& urls) {
- std::vector<GURL> valid_urls;
- for (const auto& url : urls) {
- if (valid_urls.size() >=
- features::MaxUrlsForOptimizationGuideServiceHintsFetch()) {
- break;
- }
- if (IsValidURLForURLKeyedHint(url))
- valid_urls.push_back(url);
- }
- return valid_urls;
-}
-
void RecordRequestStatusHistogram(proto::RequestContext request_context,
HintsFetcherRequestStatus status) {
+ DCHECK_NE(status, HintsFetcherRequestStatus::kDeprecatedNetworkOffline);
base::UmaHistogramEnumeration(
"OptimizationGuide.HintsFetcher.RequestStatus." +
GetStringNameForRequestContext(request_context),
@@ -91,14 +78,14 @@ HintsFetcher::HintsFetcher(
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
const GURL& optimization_guide_service_url,
PrefService* pref_service,
- network::NetworkConnectionTracker* network_connection_tracker)
+ OptimizationGuideLogger* optimization_guide_logger)
: optimization_guide_service_url_(net::AppendOrReplaceQueryParameter(
optimization_guide_service_url,
"key",
features::GetOptimizationGuideServiceAPIKey())),
pref_service_(pref_service),
- network_connection_tracker_(network_connection_tracker),
- time_clock_(base::DefaultClock::GetInstance()) {
+ time_clock_(base::DefaultClock::GetInstance()),
+ optimization_guide_logger_(optimization_guide_logger) {
url_loader_factory_ = std::move(url_loader_factory);
// Allow non-https scheme only when it is overridden in command line. This is
// needed for iOS EG2 tests which don't support HTTPS embedded test servers
@@ -190,16 +177,13 @@ bool HintsFetcher::FetchOptimizationGuideServiceHints(
HintsFetchedCallback hints_fetched_callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK_GT(optimization_types.size(), 0u);
-
- if (network_connection_tracker_->IsOffline()) {
- RecordRequestStatusHistogram(request_context,
- HintsFetcherRequestStatus::kNetworkOffline);
- std::move(hints_fetched_callback).Run(absl::nullopt);
- return false;
- }
+ request_context_ = request_context;
if (active_url_loader_) {
- RecordRequestStatusHistogram(request_context,
+ OPTIMIZATION_GUIDE_LOG(
+ optimization_guide_logger_,
+ "No hints fetched: HintsFetcher busy in another fetch");
+ RecordRequestStatusHistogram(request_context_,
HintsFetcherRequestStatus::kFetcherBusy);
std::move(hints_fetched_callback).Run(absl::nullopt);
return false;
@@ -207,10 +191,12 @@ bool HintsFetcher::FetchOptimizationGuideServiceHints(
std::vector<std::string> filtered_hosts =
GetSizeLimitedHostsDueForHintsRefresh(hosts);
- std::vector<GURL> valid_urls = GetValidURLsForFetching(urls);
+ std::vector<GURL> valid_urls = GetSizeLimitedURLsForFetching(urls);
if (filtered_hosts.empty() && valid_urls.empty()) {
+ OPTIMIZATION_GUIDE_LOG(optimization_guide_logger_,
+ "No hints fetched: No hosts/URLs");
RecordRequestStatusHistogram(
- request_context, HintsFetcherRequestStatus::kNoHostsOrURLsToFetch);
+ request_context_, HintsFetcherRequestStatus::kNoHostsOrURLsToFetch);
std::move(hints_fetched_callback).Run(absl::nullopt);
return false;
}
@@ -221,15 +207,16 @@ bool HintsFetcher::FetchOptimizationGuideServiceHints(
valid_urls.size());
if (optimization_types.empty()) {
+ OPTIMIZATION_GUIDE_LOG(optimization_guide_logger_,
+ "No hints fetched: No supported optimization types");
RecordRequestStatusHistogram(
- request_context,
+ request_context_,
HintsFetcherRequestStatus::kNoSupportedOptimizationTypes);
std::move(hints_fetched_callback).Run(absl::nullopt);
return false;
}
hints_fetch_start_time_ = base::TimeTicks::Now();
- request_context_ = request_context;
proto::GetHintsRequest get_hints_request;
get_hints_request.add_supported_key_representations(proto::HOST);
@@ -445,6 +432,38 @@ void HintsFetcher::OnURLLoadComplete(
HandleResponse(response_body ? *response_body : "", net_error, response_code);
}
+// Returns the subset of URLs from |urls| for which the URL is considered
+// valid and can be included in a hints fetch.
+std::vector<GURL> HintsFetcher::GetSizeLimitedURLsForFetching(
+ const std::vector<GURL>& urls) const {
+ std::vector<GURL> valid_urls;
+ for (size_t i = 0; i < urls.size(); i++) {
+ if (valid_urls.size() >=
+ features::MaxUrlsForOptimizationGuideServiceHintsFetch()) {
+ base::UmaHistogramCounts100(
+ "OptimizationGuide.HintsFetcher.GetHintsRequest.DroppedUrls." +
+ GetStringNameForRequestContext(request_context_),
+ urls.size() - i);
+ OPTIMIZATION_GUIDE_LOG(
+ optimization_guide_logger_,
+ base::StrCat({"Skipped adding URL due to limit, context:",
+ GetStringNameForRequestContext(request_context_),
+ " URL:", urls[i].possibly_invalid_spec()}));
+ break;
+ }
+ if (IsValidURLForURLKeyedHint(urls[i])) {
+ valid_urls.push_back(urls[i]);
+ } else {
+ OPTIMIZATION_GUIDE_LOG(
+ optimization_guide_logger_,
+ base::StrCat({"Skipped adding invalid URL, context:",
+ GetStringNameForRequestContext(request_context_),
+ " URL:", urls[i].possibly_invalid_spec()}));
+ }
+ }
+ return valid_urls;
+}
+
std::vector<std::string> HintsFetcher::GetSizeLimitedHostsDueForHintsRefresh(
const std::vector<std::string>& hosts) const {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
@@ -455,7 +474,17 @@ std::vector<std::string> HintsFetcher::GetSizeLimitedHostsDueForHintsRefresh(
std::vector<std::string> target_hosts;
target_hosts.reserve(hosts.size());
- for (const auto& host : hosts) {
+ for (size_t i = 0; i < hosts.size(); i++) {
+ if (target_hosts.size() >=
+ features::MaxHostsForOptimizationGuideServiceHintsFetch()) {
+ base::UmaHistogramCounts100(
+ "OptimizationGuide.HintsFetcher.GetHintsRequest.DroppedHosts." +
+ GetStringNameForRequestContext(request_context_),
+ hosts.size() - i);
+ break;
+ }
+
+ std::string host = hosts[i];
// Skip over localhosts, IP addresses, and invalid hosts.
if (net::HostStringIsLocalhost(host))
continue;
@@ -463,6 +492,9 @@ std::vector<std::string> HintsFetcher::GetSizeLimitedHostsDueForHintsRefresh(
std::string canonicalized_host(net::CanonicalizeHost(host, &host_info));
if (host_info.IsIPAddress() ||
!net::IsCanonicalizedHostCompliant(canonicalized_host)) {
+ OPTIMIZATION_GUIDE_LOG(
+ optimization_guide_logger_,
+ base::StrCat({"Skipped adding invalid host:", host}));
continue;
}
@@ -477,12 +509,12 @@ std::vector<std::string> HintsFetcher::GetSizeLimitedHostsDueForHintsRefresh(
(host_valid_time - features::GetHostHintsFetchRefreshDuration() <=
time_clock_->Now());
}
- if (host_hints_due_for_refresh)
+ if (host_hints_due_for_refresh) {
target_hosts.push_back(host);
-
- if (target_hosts.size() >=
- features::MaxHostsForOptimizationGuideServiceHintsFetch()) {
- break;
+ } else {
+ OPTIMIZATION_GUIDE_LOG(
+ optimization_guide_logger_,
+ base::StrCat({"Skipped refreshing hints for host:", host}));
}
}
DCHECK_GE(features::MaxHostsForOptimizationGuideServiceHintsFetch(),
diff --git a/chromium/components/optimization_guide/core/hints_fetcher.h b/chromium/components/optimization_guide/core/hints_fetcher.h
index 95c2000513c..0ef0a13dea7 100644
--- a/chromium/components/optimization_guide/core/hints_fetcher.h
+++ b/chromium/components/optimization_guide/core/hints_fetcher.h
@@ -20,10 +20,10 @@
#include "third_party/abseil-cpp/absl/types/optional.h"
#include "url/gurl.h"
+class OptimizationGuideLogger;
class PrefService;
namespace network {
-class NetworkConnectionTracker;
class SharedURLLoaderFactory;
class SimpleURLLoader;
} // namespace network
@@ -41,8 +41,8 @@ enum class HintsFetcherRequestStatus {
kSuccess,
// Fetch request was sent but no response received.
kResponseError,
- // Fetch request not sent because of offline network status.
- kNetworkOffline,
+ // DEPRECATED: Fetch request not sent because of offline network status.
+ kDeprecatedNetworkOffline,
// Fetch request not sent because fetcher was busy with another request.
kFetcherBusy,
// Fetch request not sent because the host and URL lists were empty.
@@ -71,7 +71,7 @@ class HintsFetcher {
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
const GURL& optimization_guide_service_url,
PrefService* pref_service,
- network::NetworkConnectionTracker* network_connection_tracker);
+ OptimizationGuideLogger* optimization_guide_logger);
HintsFetcher(const HintsFetcher&) = delete;
HintsFetcher& operator=(const HintsFetcher&) = delete;
@@ -142,6 +142,11 @@ class HintsFetcher {
// in the pref.
void UpdateHostsSuccessfullyFetched(base::TimeDelta valid_duration);
+ // Returns the subset of URLs from |urls| for which the URL is considered
+ // valid and can be included in a hints fetch.
+ std::vector<GURL> GetSizeLimitedURLsForFetching(
+ const std::vector<GURL>& urls) const;
+
// Returns the subset of hosts from |hosts| for which the hints should be
// refreshed. The count of returned hosts is limited to
// features::MaxHostsForOptimizationGuideServiceHintsFetch().
@@ -165,10 +170,6 @@ class HintsFetcher {
// A reference to the PrefService for this profile. Not owned.
raw_ptr<PrefService> pref_service_ = nullptr;
- // Listens to changes around the network connection. Not owned. Guaranteed to
- // outlive |this|.
- raw_ptr<network::NetworkConnectionTracker> network_connection_tracker_;
-
// Holds the hosts being requested by the hints fetcher.
std::vector<std::string> hosts_fetched_;
@@ -182,6 +183,9 @@ class HintsFetcher {
// retrieving hints from the remote Optimization Guide Service.
base::TimeTicks hints_fetch_start_time_;
+ // Owned by OptimizationGuideKeyedService and outlives |this|.
+ raw_ptr<OptimizationGuideLogger> optimization_guide_logger_;
+
SEQUENCE_CHECKER(sequence_checker_);
};
diff --git a/chromium/components/optimization_guide/core/hints_fetcher_factory.cc b/chromium/components/optimization_guide/core/hints_fetcher_factory.cc
index cd63684f63a..b6e454fa336 100644
--- a/chromium/components/optimization_guide/core/hints_fetcher_factory.cc
+++ b/chromium/components/optimization_guide/core/hints_fetcher_factory.cc
@@ -5,7 +5,6 @@
#include "components/optimization_guide/core/hints_fetcher.h"
#include "components/prefs/pref_service.h"
-#include "services/network/public/cpp/network_connection_tracker.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"
namespace optimization_guide {
@@ -13,19 +12,18 @@ namespace optimization_guide {
HintsFetcherFactory::HintsFetcherFactory(
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
const GURL& optimization_guide_service_url,
- PrefService* pref_service,
- network::NetworkConnectionTracker* network_connection_tracker)
+ PrefService* pref_service)
: url_loader_factory_(url_loader_factory),
optimization_guide_service_url_(optimization_guide_service_url),
- pref_service_(pref_service),
- network_connection_tracker_(network_connection_tracker) {}
+ pref_service_(pref_service) {}
HintsFetcherFactory::~HintsFetcherFactory() = default;
-std::unique_ptr<HintsFetcher> HintsFetcherFactory::BuildInstance() {
+std::unique_ptr<HintsFetcher> HintsFetcherFactory::BuildInstance(
+ OptimizationGuideLogger* optimization_guide_logger) {
return std::make_unique<HintsFetcher>(
url_loader_factory_, optimization_guide_service_url_, pref_service_,
- network_connection_tracker_);
+ optimization_guide_logger);
}
void HintsFetcherFactory::OverrideOptimizationGuideServiceUrlForTesting(
diff --git a/chromium/components/optimization_guide/core/hints_fetcher_factory.h b/chromium/components/optimization_guide/core/hints_fetcher_factory.h
index 67eb35c4c97..f8c8a7ffd3c 100644
--- a/chromium/components/optimization_guide/core/hints_fetcher_factory.h
+++ b/chromium/components/optimization_guide/core/hints_fetcher_factory.h
@@ -11,10 +11,10 @@
#include "base/memory/scoped_refptr.h"
#include "url/gurl.h"
+class OptimizationGuideLogger;
class PrefService;
namespace network {
-class NetworkConnectionTracker;
class SharedURLLoaderFactory;
} // namespace network
@@ -29,15 +29,15 @@ class HintsFetcherFactory {
HintsFetcherFactory(
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
const GURL& optimization_guide_service_url,
- PrefService* pref_service,
- network::NetworkConnectionTracker* network_connection_tracker);
+ PrefService* pref_service);
HintsFetcherFactory(const HintsFetcherFactory&) = delete;
HintsFetcherFactory& operator=(const HintsFetcherFactory&) = delete;
virtual ~HintsFetcherFactory();
// Creates a new instance of HintsFetcher. Virtualized for testing so that the
// testing code can override this to provide a mocked instance.
- virtual std::unique_ptr<HintsFetcher> BuildInstance();
+ virtual std::unique_ptr<HintsFetcher> BuildInstance(
+ OptimizationGuideLogger* optimization_guide_logger);
// Override the optimization guide hints server URL. Used for testing.
void OverrideOptimizationGuideServiceUrlForTesting(
@@ -53,10 +53,6 @@ class HintsFetcherFactory {
// A reference to the PrefService for this profile. Not owned.
raw_ptr<PrefService> pref_service_ = nullptr;
-
- // A reference to the object that listens for changes in network connection.
- // Not owned. Guaranteed to outlive |this|.
- raw_ptr<network::NetworkConnectionTracker> network_connection_tracker_;
};
} // namespace optimization_guide
diff --git a/chromium/components/optimization_guide/core/hints_fetcher_unittest.cc b/chromium/components/optimization_guide/core/hints_fetcher_unittest.cc
index 58e9c1470e2..b7c2c5e647b 100644
--- a/chromium/components/optimization_guide/core/hints_fetcher_unittest.cc
+++ b/chromium/components/optimization_guide/core/hints_fetcher_unittest.cc
@@ -25,7 +25,6 @@
#include "net/base/url_util.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"
#include "services/network/public/cpp/weak_wrapper_shared_url_loader_factory.h"
-#include "services/network/test/test_network_connection_tracker.h"
#include "services/network/test/test_url_loader_factory.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/abseil-cpp/absl/types/optional.h"
@@ -56,8 +55,7 @@ class HintsFetcherTest : public testing::Test,
hints_fetcher_ = std::make_unique<HintsFetcher>(
shared_url_loader_factory_, GURL(optimization_guide_service_url),
- pref_service_.get(),
- network::TestNetworkConnectionTracker::GetInstance());
+ pref_service_.get(), /*optimization_guide_logger=*/nullptr);
hints_fetcher_->SetTimeClockForTesting(task_environment_.GetMockClock());
}
@@ -72,19 +70,7 @@ class HintsFetcherTest : public testing::Test,
hints_fetched_ = true;
}
- bool hints_fetched() { return hints_fetched_; }
-
- void SetConnectionOffline() {
- network_tracker_ = network::TestNetworkConnectionTracker::GetInstance();
- network_tracker_->SetConnectionType(
- network::mojom::ConnectionType::CONNECTION_NONE);
- }
-
- void SetConnectionOnline() {
- network_tracker_ = network::TestNetworkConnectionTracker::GetInstance();
- network_tracker_->SetConnectionType(
- network::mojom::ConnectionType::CONNECTION_4G);
- }
+ bool hints_fetched() const { return hints_fetched_; }
// Updates the pref so that hints for each of the host in |hosts| are set to
// expire at |host_invalid_time|.
@@ -177,7 +163,6 @@ class HintsFetcherTest : public testing::Test,
std::unique_ptr<TestingPrefServiceSimple> pref_service_;
scoped_refptr<network::SharedURLLoaderFactory> shared_url_loader_factory_;
network::TestURLLoaderFactory test_url_loader_factory_;
- raw_ptr<network::TestNetworkConnectionTracker> network_tracker_;
std::string last_request_body_;
};
@@ -355,35 +340,6 @@ TEST_P(HintsFetcherTest, FetchReturnBadResponse) {
HintsFetcherRequestStatus::kResponseError, 1);
}
-TEST_P(HintsFetcherTest, FetchAttemptWhenNetworkOffline) {
- base::HistogramTester histogram_tester;
-
- SetConnectionOffline();
- std::string response_content;
- EXPECT_FALSE(FetchHints({"foo.com"}, {} /* urls */));
- EXPECT_FALSE(hints_fetched());
-
- // Make sure histograms are recorded correctly on bad response.
- histogram_tester.ExpectTotalCount(
- "OptimizationGuide.HintsFetcher.GetHintsRequest.FetchLatency", 0);
- histogram_tester.ExpectUniqueSample(
- "OptimizationGuide.HintsFetcher.RequestStatus.BatchUpdateActiveTabs",
- HintsFetcherRequestStatus::kNetworkOffline, 1);
-
- SetConnectionOnline();
- EXPECT_TRUE(FetchHints({"foo.com"}, {} /* urls */));
- VerifyHasPendingFetchRequests();
- EXPECT_TRUE(SimulateResponse(response_content, net::HTTP_OK));
- EXPECT_TRUE(hints_fetched());
-
- histogram_tester.ExpectTotalCount(
- "OptimizationGuide.HintsFetcher.GetHintsRequest.FetchLatency", 1);
- histogram_tester.ExpectTotalCount(
- "OptimizationGuide.HintsFetcher.GetHintsRequest.FetchLatency."
- "BatchUpdateActiveTabs",
- 1);
-}
-
TEST_P(HintsFetcherTest, HintsFetchSuccessfulHostsRecorded) {
std::vector<std::string> hosts{"host1.com", "host2.com"};
std::string response_content;
@@ -606,6 +562,8 @@ TEST_P(HintsFetcherTest, HintsFetcherSuccessfullyFetchedHostsFull) {
}
TEST_P(HintsFetcherTest, MaxHostsForOptimizationGuideServiceHintsFetch) {
+ base::HistogramTester histogram_tester;
+
std::string response_content;
std::vector<std::string> all_hosts;
@@ -640,6 +598,12 @@ TEST_P(HintsFetcherTest, MaxHostsForOptimizationGuideServiceHintsFetch) {
EXPECT_TRUE(
WasHostCoveredByFetch("host" + base::NumberToString(i) + ".com"));
}
+
+ // extra1.com and extra2.com should have been considered "dropped".
+ histogram_tester.ExpectUniqueSample(
+ "OptimizationGuide.HintsFetcher.GetHintsRequest.DroppedHosts."
+ "BatchUpdateActiveTabs",
+ 2, 1);
}
TEST_P(HintsFetcherTest, MaxUrlsForOptimizationGuideServiceHintsFetch) {
@@ -677,6 +641,12 @@ TEST_P(HintsFetcherTest, MaxUrlsForOptimizationGuideServiceHintsFetch) {
EXPECT_EQ(last_request.urls(i).url(),
"https://url" + base::NumberToString(i) + ".com/");
}
+
+ // notfetched.com and notfetched-2.com should have been considered "dropped".
+ histogram_tester.ExpectUniqueSample(
+ "OptimizationGuide.HintsFetcher.GetHintsRequest.DroppedUrls."
+ "BatchUpdateActiveTabs",
+ 2, 1);
}
TEST_P(HintsFetcherTest, OnlyURLsToFetch) {
@@ -697,6 +667,11 @@ TEST_P(HintsFetcherTest, OnlyURLsToFetch) {
histogram_tester.ExpectUniqueSample(
"OptimizationGuide.HintsFetcher.RequestStatus.BatchUpdateActiveTabs",
static_cast<int>(HintsFetcherRequestStatus::kSuccess), 1);
+ // Nothing was dropped so this shouldn't be recorded.
+ histogram_tester.ExpectTotalCount(
+ "OptimizationGuide.HintsFetcher.GetHintsRequest.DroppedHosts", 0);
+ histogram_tester.ExpectTotalCount(
+ "OptimizationGuide.HintsFetcher.GetHintsRequest.DroppedUrls", 0);
}
TEST_P(HintsFetcherTest, NoHostsOrURLsToFetch) {
diff --git a/chromium/components/optimization_guide/core/hints_manager.cc b/chromium/components/optimization_guide/core/hints_manager.cc
index 7ae4f3a4fe0..3270962e174 100644
--- a/chromium/components/optimization_guide/core/hints_manager.cc
+++ b/chromium/components/optimization_guide/core/hints_manager.cc
@@ -17,6 +17,8 @@
#include "base/metrics/histogram_macros_local.h"
#include "base/notreached.h"
#include "base/rand_util.h"
+#include "base/strings/strcat.h"
+#include "base/strings/string_number_conversions.h"
#include "base/task/post_task.h"
#include "base/task/sequenced_task_runner.h"
#include "base/task/task_runner_util.h"
@@ -33,6 +35,7 @@
#include "components/optimization_guide/core/optimization_guide_constants.h"
#include "components/optimization_guide/core/optimization_guide_enums.h"
#include "components/optimization_guide/core/optimization_guide_features.h"
+#include "components/optimization_guide/core/optimization_guide_logger.h"
#include "components/optimization_guide/core/optimization_guide_navigation_data.h"
#include "components/optimization_guide/core/optimization_guide_permissions_util.h"
#include "components/optimization_guide/core/optimization_guide_prefs.h"
@@ -50,7 +53,6 @@
#include "services/metrics/public/cpp/ukm_recorder.h"
#include "services/metrics/public/cpp/ukm_source.h"
#include "services/metrics/public/cpp/ukm_source_id.h"
-#include "services/network/public/cpp/network_connection_tracker.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"
namespace optimization_guide {
@@ -204,23 +206,31 @@ bool ShouldIgnoreNewlyRegisteredOptimizationType(
class ScopedCanApplyOptimizationLogger {
public:
- ScopedCanApplyOptimizationLogger(proto::OptimizationType opt_type, GURL url)
+ ScopedCanApplyOptimizationLogger(
+ proto::OptimizationType opt_type,
+ GURL url,
+ OptimizationGuideLogger* optimization_guide_logger)
: decision_(OptimizationGuideDecision::kUnknown),
type_decision_(OptimizationTypeDecision::kUnknown),
opt_type_(opt_type),
has_metadata_(false),
- url_(url) {}
+ url_(url),
+ optimization_guide_logger_(optimization_guide_logger) {}
~ScopedCanApplyOptimizationLogger() {
if (!switches::IsDebugLogsEnabled())
return;
DCHECK_NE(type_decision_, OptimizationTypeDecision::kUnknown);
- DVLOG(0) << "OptimizationGuide: CanApplyOptimization: "
- << GetStringNameForOptimizationType(opt_type_)
- << "\nqueried on: " << url_ << "\nDecision: "
- << GetStringForOptimizationGuideDecision(decision_)
- << "\nTypeDecision: " << static_cast<int>(type_decision_)
- << "\nHas Metadata: " << has_metadata_;
+ OPTIMIZATION_GUIDE_LOG(
+ optimization_guide_logger_,
+ base::StrCat(
+ {"OptimizationGuide: CanApplyOptimization: ",
+ GetStringNameForOptimizationType(opt_type_),
+ "\nqueried on: ", url_.possibly_invalid_spec(),
+ "\nDecision: ", GetStringForOptimizationGuideDecision(decision_),
+ "\nTypeDecision: ",
+ base::NumberToString(static_cast<int>(type_decision_)),
+ "\nHas Metadata: ", (has_metadata_ ? "True" : "False")}));
}
void set_has_metadata() { has_metadata_ = true; }
@@ -238,6 +248,9 @@ class ScopedCanApplyOptimizationLogger {
proto::OptimizationType opt_type_;
bool has_metadata_;
GURL url_;
+
+ // Not owned. Guaranteed to outlive |this| scoped object.
+ raw_ptr<OptimizationGuideLogger> optimization_guide_logger_;
};
// Reads component file and parses it into a Configuration proto. Should not be
@@ -264,23 +277,33 @@ void MaybeLogGetHintRequestInfo(
const base::flat_set<proto::OptimizationType>&
registered_optimization_types,
const std::vector<GURL>& urls_to_fetch,
- const std::vector<std::string>& hosts_to_fetch) {
+ const std::vector<std::string>& hosts_to_fetch,
+ OptimizationGuideLogger* optimization_guide_logger) {
if (!switches::IsDebugLogsEnabled())
return;
- DVLOG(0) << "OptimizationGuide: Starting fetch for request context "
- << proto::RequestContext_Name(request_context);
- DVLOG(0) << "OptimizationGuide: Registered Optimization Types: ";
+ OPTIMIZATION_GUIDE_LOG(
+ optimization_guide_logger,
+ base::StrCat({"OptimizationGuide: Starting fetch for request context ",
+ proto::RequestContext_Name(request_context)}));
+ OPTIMIZATION_GUIDE_LOG(optimization_guide_logger,
+ "OptimizationGuide: Registered Optimization Types: ");
for (const auto& optimization_type : registered_optimization_types) {
- DVLOG(0) << "OptimizationGuide: Optimization Type: "
- << proto::OptimizationType_Name(optimization_type);
+ OPTIMIZATION_GUIDE_LOG(
+ optimization_guide_logger,
+ base::StrCat({"OptimizationGuide: Optimization Type: ",
+ proto::OptimizationType_Name(optimization_type)}));
}
- DVLOG(0) << "OptimizationGuide: URLs and Hosts: ";
+ OPTIMIZATION_GUIDE_LOG(optimization_guide_logger,
+ "OptimizationGuide: URLs and Hosts: ");
for (const auto& url : urls_to_fetch) {
- DVLOG(0) << "OptimizationGuide: URL: " << url;
+ OPTIMIZATION_GUIDE_LOG(optimization_guide_logger,
+ base::StrCat({"OptimizationGuide: URL: ",
+ url.possibly_invalid_spec()}));
}
for (const auto& host : hosts_to_fetch) {
- DVLOG(0) << "OptimizationGuide: Host: " << host;
+ OPTIMIZATION_GUIDE_LOG(optimization_guide_logger,
+ base::StrCat({"OptimizationGuide: Host: ", host}));
}
}
@@ -294,8 +317,8 @@ HintsManager::HintsManager(
TopHostProvider* top_host_provider,
TabUrlProvider* tab_url_provider,
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
- network::NetworkConnectionTracker* network_connection_tracker,
- std::unique_ptr<PushNotificationManager> push_notification_manager)
+ std::unique_ptr<PushNotificationManager> push_notification_manager,
+ OptimizationGuideLogger* optimization_guide_logger)
: is_off_the_record_(is_off_the_record),
application_locale_(application_locale),
pref_service_(pref_service),
@@ -308,11 +331,11 @@ HintsManager::HintsManager(
hints_fetcher_factory_(std::make_unique<HintsFetcherFactory>(
url_loader_factory,
features::GetOptimizationGuideServiceGetHintsURL(),
- pref_service,
- network_connection_tracker)),
+ pref_service)),
top_host_provider_(top_host_provider),
tab_url_provider_(tab_url_provider),
push_notification_manager_(std::move(push_notification_manager)),
+ optimization_guide_logger_(optimization_guide_logger),
clock_(base::DefaultClock::GetInstance()),
background_task_runner_(base::ThreadPool::CreateSequencedTaskRunner(
{base::MayBlock(), base::TaskPriority::BEST_EFFORT})) {
@@ -687,14 +710,14 @@ void HintsManager::FetchHintsForActiveTabs() {
top_hosts.insert(top_hosts.begin(), url.host());
}
}
- MaybeLogGetHintRequestInfo(proto::CONTEXT_BATCH_UPDATE_ACTIVE_TABS,
- registered_optimization_types_,
- active_tab_urls_to_refresh, top_hosts);
+ MaybeLogGetHintRequestInfo(
+ proto::CONTEXT_BATCH_UPDATE_ACTIVE_TABS, registered_optimization_types_,
+ active_tab_urls_to_refresh, top_hosts, optimization_guide_logger_);
if (!active_tabs_batch_update_hints_fetcher_) {
DCHECK(hints_fetcher_factory_);
active_tabs_batch_update_hints_fetcher_ =
- hints_fetcher_factory_->BuildInstance();
+ hints_fetcher_factory_->BuildInstance(optimization_guide_logger_);
}
active_tabs_batch_update_hints_fetcher_->FetchOptimizationGuideServiceHints(
top_hosts, active_tab_urls_to_refresh, registered_optimization_types_,
@@ -710,8 +733,12 @@ void HintsManager::OnHintsForActiveTabsFetched(
const base::flat_set<GURL>& urls_fetched,
absl::optional<std::unique_ptr<proto::GetHintsResponse>>
get_hints_response) {
- if (!get_hints_response)
+ if (!get_hints_response) {
+ if (switches::IsDebugLogsEnabled()) {
+ DVLOG(0) << "OptimizationGuide: OnHintsForActiveTabsFetched failed";
+ }
return;
+ }
hint_cache_->UpdateFetchedHints(
std::move(*get_hints_response),
@@ -719,8 +746,11 @@ void HintsManager::OnHintsForActiveTabsFetched(
hosts_fetched, urls_fetched,
base::BindOnce(&HintsManager::OnFetchedActiveTabsHintsStored,
weak_ptr_factory_.GetWeakPtr()));
- if (switches::IsDebugLogsEnabled())
- DVLOG(0) << "OptimizationGuide: OnHintsForActiveTabsFetched complete";
+ if (switches::IsDebugLogsEnabled()) {
+ OPTIMIZATION_GUIDE_LOG(
+ optimization_guide_logger_,
+ "OptimizationGuide: OnHintsForActiveTabsFetched complete");
+ }
}
void HintsManager::OnPageNavigationHintsFetched(
@@ -735,6 +765,10 @@ void HintsManager::OnPageNavigationHintsFetched(
}
if (!get_hints_response.has_value() || !get_hints_response.value()) {
+ if (switches::IsDebugLogsEnabled()) {
+ DVLOG(0) << "OptimizationGuide: OnPageNavigationHintsFetched failed";
+ }
+
if (navigation_url) {
PrepareToInvokeRegisteredCallbacks(*navigation_url);
}
@@ -748,6 +782,10 @@ void HintsManager::OnPageNavigationHintsFetched(
base::BindOnce(&HintsManager::OnFetchedPageNavigationHintsStored,
weak_ptr_factory_.GetWeakPtr(), navigation_data_weak_ptr,
navigation_url, page_navigation_hosts_requested));
+
+ if (switches::IsDebugLogsEnabled()) {
+ DVLOG(0) << "OptimizationGuide: OnPageNavigationHintsFetched complete";
+ }
}
void HintsManager::OnFetchedActiveTabsHintsStored() {
@@ -844,7 +882,8 @@ void HintsManager::FetchHintsForURLs(const std::vector<GURL>& urls,
return;
MaybeLogGetHintRequestInfo(request_context, registered_optimization_types_,
- target_urls.vector(), target_hosts.vector());
+ target_urls.vector(), target_hosts.vector(),
+ optimization_guide_logger_);
std::pair<int32_t, HintsFetcher*> request_id_and_fetcher =
CreateAndTrackBatchUpdateHintsFetcher();
@@ -898,8 +937,10 @@ void HintsManager::RegisterOptimizationTypes(
registered_optimization_types_.insert(optimization_type);
if (switches::IsDebugLogsEnabled()) {
- DVLOG(0) << "OptimizationGuide: Registered new OptimizationType: "
- << proto::OptimizationType_Name(optimization_type);
+ OPTIMIZATION_GUIDE_LOG(
+ optimization_guide_logger_,
+ base::StrCat({"OptimizationGuide: Registered new OptimizationType: ",
+ proto::OptimizationType_Name(optimization_type)}));
}
absl::optional<double> value = previously_registered_opt_types->FindBoolKey(
@@ -1030,7 +1071,8 @@ void HintsManager::CanApplyOptimizationOnDemand(
}
MaybeLogGetHintRequestInfo(request_context, registered_optimization_types_,
- urls_to_fetch.vector(), hosts_to_fetch.vector());
+ urls_to_fetch.vector(), hosts_to_fetch.vector(),
+ optimization_guide_logger_);
// Fetch the data for the entries we don't have all information for.
std::pair<int32_t, HintsFetcher*> request_id_and_fetcher =
@@ -1058,6 +1100,10 @@ void HintsManager::OnBatchUpdateHintsFetched(
CleanUpBatchUpdateHintsFetcher(request_id);
if (!get_hints_response.has_value() || !get_hints_response.value()) {
+ if (switches::IsDebugLogsEnabled()) {
+ DVLOG(0) << "OptimizationGuide: OnBatchUpdateHintsFetched for "
+ << proto::RequestContext_Name(request_context) << " failed";
+ }
OnBatchUpdateHintsStored(urls_with_pending_callback, optimization_types,
callback);
return;
@@ -1074,8 +1120,11 @@ void HintsManager::OnBatchUpdateHintsFetched(
optimization_types, callback));
if (switches::IsDebugLogsEnabled()) {
- DVLOG(0) << "OptimizationGuide: OnBatchUpdateHintsFetched for "
- << proto::RequestContext_Name(request_context) << " complete";
+ OPTIMIZATION_GUIDE_LOG(
+ optimization_guide_logger_,
+ base::StrCat({"OptimizationGuide: OnBatchUpdateHintsFetched for ",
+ proto::RequestContext_Name(request_context),
+ " complete"}));
}
}
@@ -1100,7 +1149,7 @@ std::pair<int32_t, HintsFetcher*>
HintsManager::CreateAndTrackBatchUpdateHintsFetcher() {
DCHECK(hints_fetcher_factory_);
std::unique_ptr<HintsFetcher> hints_fetcher =
- hints_fetcher_factory_->BuildInstance();
+ hints_fetcher_factory_->BuildInstance(optimization_guide_logger_);
HintsFetcher* hints_fetcher_ptr = hints_fetcher.get();
batch_update_hints_fetchers_.Put(batch_update_hints_fetcher_request_id_++,
std::move(hints_fetcher));
@@ -1161,8 +1210,8 @@ OptimizationTypeDecision HintsManager::CanApplyOptimization(
OptimizationMetadata* optimization_metadata) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
- ScopedCanApplyOptimizationLogger scoped_logger(optimization_type,
- navigation_url);
+ ScopedCanApplyOptimizationLogger scoped_logger(
+ optimization_type, navigation_url, optimization_guide_logger_);
// Clear out optimization metadata if provided.
if (optimization_metadata)
*optimization_metadata = {};
@@ -1432,14 +1481,15 @@ void HintsManager::MaybeFetchHintsForNavigation(
DCHECK(hints_fetcher_factory_);
auto it = page_navigation_hints_fetchers_.Put(
- url, hints_fetcher_factory_->BuildInstance());
+ url, hints_fetcher_factory_->BuildInstance(optimization_guide_logger_));
UMA_HISTOGRAM_COUNTS_100(
"OptimizationGuide.HintsManager.ConcurrentPageNavigationFetches",
page_navigation_hints_fetchers_.size());
MaybeLogGetHintRequestInfo(proto::CONTEXT_PAGE_NAVIGATION,
- registered_optimization_types_, urls, hosts);
+ registered_optimization_types_, urls, hosts,
+ optimization_guide_logger_);
bool fetch_attempted = it->second->FetchOptimizationGuideServiceHints(
hosts, urls, registered_optimization_types_,
proto::CONTEXT_PAGE_NAVIGATION, application_locale_,
diff --git a/chromium/components/optimization_guide/core/hints_manager.h b/chromium/components/optimization_guide/core/hints_manager.h
index 59b0782d191..f6e39c1cbb1 100644
--- a/chromium/components/optimization_guide/core/hints_manager.h
+++ b/chromium/components/optimization_guide/core/hints_manager.h
@@ -27,21 +27,21 @@
#include "components/optimization_guide/proto/hints.pb.h"
#include "third_party/abseil-cpp/absl/types/optional.h"
+class OptimizationGuideLogger;
class OptimizationGuideNavigationData;
class OptimizationGuideTestAppInterfaceWrapper;
class PrefService;
namespace network {
class SharedURLLoaderFactory;
-class NetworkConnectionTracker;
} // namespace network
namespace optimization_guide {
class HintCache;
class HintsFetcherFactory;
class OptimizationFilter;
-class OptimizationMetadata;
class OptimizationGuideStore;
+class OptimizationMetadata;
enum class OptimizationTypeDecision;
class StoreUpdateData;
class TabUrlProvider;
@@ -58,8 +58,8 @@ class HintsManager : public OptimizationHintsComponentObserver,
TopHostProvider* top_host_provider,
TabUrlProvider* tab_url_provider,
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
- network::NetworkConnectionTracker* network_connection_tracker,
- std::unique_ptr<PushNotificationManager> push_notification_manager);
+ std::unique_ptr<PushNotificationManager> push_notification_manager,
+ OptimizationGuideLogger* optimization_guide_logger);
~HintsManager() override;
@@ -473,6 +473,11 @@ class HintsManager : public OptimizationHintsComponentObserver,
// what to do through the implemented Delegate above.
std::unique_ptr<PushNotificationManager> push_notification_manager_;
+ // The logger that plumbs the debug logs to the optimization guide
+ // internals page. Not owned. Guaranteed to outlive |this|, since the logger
+ // and |this| are owned by the optimization guide keyed service.
+ raw_ptr<OptimizationGuideLogger> optimization_guide_logger_;
+
// The clock used to schedule fetching from the remote Optimization Guide
// Service.
raw_ptr<const base::Clock> clock_;
diff --git a/chromium/components/optimization_guide/core/hints_manager_unittest.cc b/chromium/components/optimization_guide/core/hints_manager_unittest.cc
index 5f4fd3b8383..2b2ddb9d398 100644
--- a/chromium/components/optimization_guide/core/hints_manager_unittest.cc
+++ b/chromium/components/optimization_guide/core/hints_manager_unittest.cc
@@ -15,7 +15,6 @@
#include "base/test/metrics/histogram_tester.h"
#include "base/test/scoped_feature_list.h"
#include "base/test/task_environment.h"
-#include "components/data_reduction_proxy/core/common/data_reduction_proxy_pref_names.h"
#include "components/optimization_guide/core/bloom_filter.h"
#include "components/optimization_guide/core/hint_cache.h"
#include "components/optimization_guide/core/hints_component_util.h"
@@ -38,10 +37,8 @@
#include "components/variations/scoped_variations_ids_provider.h"
#include "services/metrics/public/cpp/ukm_builders.h"
#include "services/metrics/public/cpp/ukm_source.h"
-#include "services/network/public/cpp/network_connection_tracker.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"
#include "services/network/public/cpp/weak_wrapper_shared_url_loader_factory.h"
-#include "services/network/test/test_network_connection_tracker.h"
#include "services/network/test/test_url_loader_factory.h"
#include "testing/gtest/include/gtest/gtest.h"
@@ -190,12 +187,12 @@ class TestHintsFetcher : public HintsFetcher {
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
GURL optimization_guide_service_url,
PrefService* pref_service,
- network::NetworkConnectionTracker* network_connection_tracker,
- const std::vector<HintsFetcherEndState>& fetch_states)
+ const std::vector<HintsFetcherEndState>& fetch_states,
+ OptimizationGuideLogger* optimization_guide_logger)
: HintsFetcher(url_loader_factory,
optimization_guide_service_url,
pref_service,
- network_connection_tracker),
+ optimization_guide_logger),
fetch_states_(fetch_states) {
DCHECK(!fetch_states_.empty());
}
@@ -266,18 +263,17 @@ class TestHintsFetcherFactory : public HintsFetcherFactory {
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
GURL optimization_guide_service_url,
PrefService* pref_service,
- const std::vector<HintsFetcherEndState>& fetch_states,
- network::NetworkConnectionTracker* network_connection_tracker)
+ const std::vector<HintsFetcherEndState>& fetch_states)
: HintsFetcherFactory(url_loader_factory,
optimization_guide_service_url,
- pref_service,
- network_connection_tracker),
+ pref_service),
fetch_states_(fetch_states) {}
- std::unique_ptr<HintsFetcher> BuildInstance() override {
+ std::unique_ptr<HintsFetcher> BuildInstance(
+ OptimizationGuideLogger* optimization_guide_logger) override {
return std::make_unique<TestHintsFetcher>(
url_loader_factory_, optimization_guide_service_url_, pref_service_,
- network_connection_tracker_, fetch_states_);
+ fetch_states_, optimization_guide_logger);
}
private:
@@ -313,8 +309,6 @@ class HintsManagerTest : public ProtoDatabaseProviderTestBase {
pref_service_ =
std::make_unique<sync_preferences::TestingPrefServiceSyncable>();
prefs::RegisterProfilePrefs(pref_service_->registry());
- pref_service_->registry()->RegisterBooleanPref(
- data_reduction_proxy::prefs::kDataSaverEnabled, false);
unified_consent::UnifiedConsentService::RegisterPrefs(
pref_service_->registry());
@@ -332,8 +326,8 @@ class HintsManagerTest : public ProtoDatabaseProviderTestBase {
/*is_off_the_record=*/false, /*application_locale=*/"en-US",
pref_service(), hint_store_->AsWeakPtr(), top_host_provider,
tab_url_provider_.get(), url_loader_factory_,
- network::TestNetworkConnectionTracker::GetInstance(),
- /*push_notification_manager=*/nullptr);
+ /*push_notification_manager=*/nullptr,
+ /*optimization_guide_logger=*/nullptr);
hints_manager_->SetClockForTesting(task_environment_.GetMockClock());
// Run until hint cache is initialized and the HintsManager is ready to
@@ -414,7 +408,7 @@ class HintsManagerTest : public ProtoDatabaseProviderTestBase {
const std::vector<HintsFetcherEndState>& fetch_states) {
return std::make_unique<TestHintsFetcherFactory>(
url_loader_factory_, GURL("https://hintsserver.com"), pref_service(),
- fetch_states, network::TestNetworkConnectionTracker::GetInstance());
+ fetch_states);
}
void MoveClockForwardBy(base::TimeDelta time_delta) {
@@ -441,16 +435,6 @@ class HintsManagerTest : public ProtoDatabaseProviderTestBase {
std::move(callback));
}
- void SetConnectionOffline() {
- network::TestNetworkConnectionTracker::GetInstance()->SetConnectionType(
- network::mojom::ConnectionType::CONNECTION_NONE);
- }
-
- void SetConnectionOnline() {
- network::TestNetworkConnectionTracker::GetInstance()->SetConnectionType(
- network::mojom::ConnectionType::CONNECTION_4G);
- }
-
HintsManager* hints_manager() const { return hints_manager_.get(); }
int32_t num_batch_update_hints_fetches_initiated() const {
@@ -1402,8 +1386,6 @@ TEST_F(HintsManagerTest, CanApplyOptimizationAndPopulatesAnyMetadata) {
TEST_F(HintsManagerTest, CanApplyOptimizationNoMatchingPageHint) {
InitializeWithDefaultConfig("1.0.0.0");
- // Set to online so fetch is activated.
- SetConnectionOnline();
auto navigation_data =
CreateTestNavigationData(GURL("https://somedomain.org/nomatch"), {});
base::RunLoop run_loop;
@@ -2074,8 +2056,6 @@ TEST_F(HintsManagerFetchingTest,
/*is_allowlist=*/true, &config);
ProcessHints(config, "1.0.0.0");
- // Set to online so fetch is activated.
- SetConnectionOnline();
auto navigation_data = CreateTestNavigationData(url_without_hints(),
{proto::LITE_PAGE_REDIRECT});
base::HistogramTester histogram_tester;
@@ -2096,8 +2076,6 @@ TEST_F(HintsManagerFetchingTest, HintsFetchedAtNavigationTime) {
hints_manager()->RegisterOptimizationTypes({proto::DEFER_ALL_SCRIPT});
InitializeWithDefaultConfig("1.0.0.0");
- // Set to online so fetch is activated.
- SetConnectionOnline();
auto navigation_data =
CreateTestNavigationData(url_without_hints(), {proto::DEFER_ALL_SCRIPT});
base::HistogramTester histogram_tester;
@@ -2120,8 +2098,6 @@ TEST_F(HintsManagerFetchingTest,
hints_manager()->RegisterOptimizationTypes({proto::DEFER_ALL_SCRIPT});
InitializeWithDefaultConfig("1.0.0.0");
- // Set to online so fetch is activated.
- SetConnectionOnline();
auto navigation_data =
CreateTestNavigationData(url_without_hints(), {proto::DEFER_ALL_SCRIPT});
hints_manager()->SetHintsFetcherFactoryForTesting(
@@ -2149,8 +2125,6 @@ TEST_F(HintsManagerFetchingTest,
BuildTestHintsFetcherFactory(
{HintsFetcherEndState::kFetchSuccessWithURLHints}));
- // Set to online so fetch is activated.
- SetConnectionOnline();
auto navigation_data =
CreateTestNavigationData(url_with_hints(), {proto::DEFER_ALL_SCRIPT});
base::HistogramTester histogram_tester;
@@ -2188,8 +2162,6 @@ TEST_F(HintsManagerFetchingTest,
switches::kDisableCheckingUserPermissionsForTesting);
hints_manager()->RegisterOptimizationTypes({proto::DEFER_ALL_SCRIPT});
- // Set to online so fetch is activated.
- SetConnectionOnline();
auto navigation_data =
CreateTestNavigationData(example_url, {proto::DEFER_ALL_SCRIPT});
base::HistogramTester histogram_tester;
@@ -2221,9 +2193,6 @@ TEST_F(HintsManagerFetchingTest, URLHintsNotFetchedAtNavigationTime) {
BuildTestHintsFetcherFactory(
{HintsFetcherEndState::kFetchSuccessWithURLHints}));
- // Set to online so fetch is activated.
- SetConnectionOnline();
-
{
base::HistogramTester histogram_tester;
auto navigation_data =
@@ -2282,9 +2251,6 @@ TEST_F(HintsManagerFetchingTest, URLWithNoHintsNotRefetchedAtNavigationTime) {
BuildTestHintsFetcherFactory(
{HintsFetcherEndState::kFetchSuccessWithHostHints}));
- // Set to online so fetch is activated.
- SetConnectionOnline();
-
base::HistogramTester histogram_tester;
{
auto navigation_data = CreateTestNavigationData(url_without_hints(),
@@ -2330,8 +2296,6 @@ TEST_F(HintsManagerFetchingTest, CanApplyOptimizationCalledMidFetch) {
hints_manager()->RegisterOptimizationTypes({proto::DEFER_ALL_SCRIPT});
InitializeWithDefaultConfig("1.0.0.0");
- // Set to online so fetch is activated.
- SetConnectionOnline();
auto navigation_data =
CreateTestNavigationData(url_without_hints(), {proto::DEFER_ALL_SCRIPT});
CallOnNavigationStartOrRedirect(navigation_data.get(), base::DoNothing());
@@ -2355,8 +2319,6 @@ TEST_F(HintsManagerFetchingTest,
BuildTestHintsFetcherFactory(
{HintsFetcherEndState::kFetchSuccessWithNoHints}));
- // Set to online so fetch is activated.
- SetConnectionOnline();
auto navigation_data =
CreateTestNavigationData(url_without_hints(), {proto::DEFER_ALL_SCRIPT});
CallOnNavigationStartOrRedirect(navigation_data.get(), base::DoNothing());
@@ -2381,8 +2343,6 @@ TEST_F(HintsManagerFetchingTest,
hints_manager()->SetHintsFetcherFactoryForTesting(
BuildTestHintsFetcherFactory({HintsFetcherEndState::kFetchFailed}));
- // Set to online so fetch is activated.
- SetConnectionOnline();
auto navigation_data =
CreateTestNavigationData(url_without_hints(), {proto::DEFER_ALL_SCRIPT});
CallOnNavigationStartOrRedirect(navigation_data.get(), base::DoNothing());
@@ -2408,8 +2368,6 @@ TEST_F(HintsManagerFetchingTest,
BuildTestHintsFetcherFactory(
{HintsFetcherEndState::kFetchSuccessWithURLHints}));
- // Set to online so fetch is activated.
- SetConnectionOnline();
auto navigation_data = CreateTestNavigationData(url_with_url_keyed_hint(),
{proto::DEFER_ALL_SCRIPT});
// Make sure URL-keyed hint is fetched and processed.
@@ -2437,9 +2395,6 @@ TEST_F(HintsManagerFetchingTest,
InitializeWithDefaultConfig("1.0.0.0");
- // Set to online so fetch is activated.
- SetConnectionOnline();
-
// Make sure both URL-Keyed and host-keyed hints are processed and cached.
hints_manager()->SetHintsFetcherFactoryForTesting(
BuildTestHintsFetcherFactory(
@@ -2467,9 +2422,6 @@ TEST_F(HintsManagerFetchingTest,
InitializeWithDefaultConfig("1.0.0.0");
- // Set to online so fetch is activated.
- SetConnectionOnline();
-
// Make sure both URL-Keyed and host-keyed hints are processed and cached.
hints_manager()->SetHintsFetcherFactoryForTesting(
BuildTestHintsFetcherFactory(
@@ -2497,9 +2449,6 @@ TEST_F(HintsManagerFetchingTest,
InitializeWithDefaultConfig("1.0.0.0");
- // Set to online so fetch is activated.
- SetConnectionOnline();
-
hints_manager()->SetHintsFetcherFactoryForTesting(
BuildTestHintsFetcherFactory(
{HintsFetcherEndState::kFetchSuccessWithNoHints}));
@@ -2528,9 +2477,6 @@ TEST_F(HintsManagerFetchingTest,
InitializeWithDefaultConfig("1.0.0.0");
- // Set to online so fetch is activated.
- SetConnectionOnline();
-
// Attempt to fetch a hint but call CanApplyOptimization right away to
// simulate being mid-fetch.
auto navigation_data = CreateTestNavigationData(
@@ -2557,9 +2503,6 @@ TEST_F(HintsManagerFetchingTest,
InitializeWithDefaultConfig("1.0.0.0");
- // Set to online so fetch is activated.
- SetConnectionOnline();
-
// Attempt to fetch a hint but initiate the next navigation right away to
// simulate being mid-fetch.
auto navigation_data =
@@ -2608,8 +2551,6 @@ TEST_F(HintsManagerFetchingTest,
InitializeWithDefaultConfig("1.0.0.0");
- // Set to online so fetch is activated.
- SetConnectionOnline();
hints_manager()->SetHintsFetcherFactoryForTesting(
BuildTestHintsFetcherFactory(
{HintsFetcherEndState::kFetchSuccessWithNoHints}));
@@ -2650,8 +2591,6 @@ TEST_F(HintsManagerFetchingTest,
InitializeWithDefaultConfig("1.0.0.0");
- // Set to online so fetch is activated.
- SetConnectionOnline();
hints_manager()->SetHintsFetcherFactoryForTesting(
BuildTestHintsFetcherFactory(
{HintsFetcherEndState::kFetchSuccessWithURLHints}));
@@ -2691,9 +2630,6 @@ TEST_F(HintsManagerFetchingTest,
hints_manager()->RegisterOptimizationTypes({proto::COMPRESS_PUBLIC_IMAGES});
InitializeWithDefaultConfig("1.0.0.0");
- // Set to online so fetch is activated.
- SetConnectionOnline();
-
hints_manager()->SetHintsFetcherFactoryForTesting(
BuildTestHintsFetcherFactory(
{HintsFetcherEndState::kFetchSuccessWithURLHints}));
@@ -2723,9 +2659,6 @@ TEST_F(HintsManagerFetchingTest,
hints_manager()->RegisterOptimizationTypes({proto::COMPRESS_PUBLIC_IMAGES});
InitializeWithDefaultConfig("1.0.0.0");
- // Set to online so fetch is activated.
- SetConnectionOnline();
-
hints_manager()->SetHintsFetcherFactoryForTesting(
BuildTestHintsFetcherFactory(
{HintsFetcherEndState::kFetchSuccessWithURLHints}));
@@ -2763,9 +2696,6 @@ TEST_F(
hints_manager()->RegisterOptimizationTypes({proto::RESOURCE_LOADING});
InitializeWithDefaultConfig("1.0.0.0");
- // Set to online so fetch is activated.
- SetConnectionOnline();
-
hints_manager()->SetHintsFetcherFactoryForTesting(
BuildTestHintsFetcherFactory(
{HintsFetcherEndState::kFetchSuccessWithURLHints}));
@@ -2795,9 +2725,6 @@ TEST_F(HintsManagerFetchingTest,
InitializeWithDefaultConfig("1.0.0.0");
- // Set to online so fetch is activated.
- SetConnectionOnline();
-
hints_manager()->SetHintsFetcherFactoryForTesting(
BuildTestHintsFetcherFactory({HintsFetcherEndState::kFetchFailed}));
auto navigation_data = CreateTestNavigationData(
@@ -2826,9 +2753,6 @@ TEST_F(HintsManagerFetchingTest,
InitializeWithDefaultConfig("1.0.0.0");
- // Set to online so fetch is activated.
- SetConnectionOnline();
-
hints_manager()->SetHintsFetcherFactoryForTesting(
BuildTestHintsFetcherFactory(
{HintsFetcherEndState::kFetchSuccessWithURLHints}));
@@ -2861,9 +2785,6 @@ TEST_F(HintsManagerFetchingTest,
InitializeWithDefaultConfig("1.0.0.0");
- // Set to online so fetch is activated.
- SetConnectionOnline();
-
hints_manager()->SetHintsFetcherFactoryForTesting(
BuildTestHintsFetcherFactory(
{HintsFetcherEndState::kFetchSuccessWithURLHints}));
@@ -2895,9 +2816,6 @@ TEST_F(HintsManagerFetchingTest,
InitializeWithDefaultConfig("1.0.0.0");
- // Set to online so fetch is activated.
- SetConnectionOnline();
-
hints_manager()->SetHintsFetcherFactoryForTesting(
BuildTestHintsFetcherFactory(
{HintsFetcherEndState::kFetchSuccessWithNoHints}));
@@ -2927,9 +2845,6 @@ TEST_F(HintsManagerFetchingTest,
InitializeWithDefaultConfig("1.0.0.0");
- // Set to offline so fetch is NOT activated.
- SetConnectionOffline();
-
GURL url_that_redirected("https://urlthatredirected.com");
auto navigation_data_redirect = CreateTestNavigationData(
url_that_redirected, {proto::COMPRESS_PUBLIC_IMAGES});
@@ -2958,9 +2873,6 @@ TEST_F(HintsManagerFetchingTest,
InitializeWithDefaultConfig("1.0.0.0");
- // Set to offline so fetch is NOT activated.
- SetConnectionOffline();
-
hints_manager()->SetHintsFetcherFactoryForTesting(
BuildTestHintsFetcherFactory(
{HintsFetcherEndState::kFetchSuccessWithNoHints}));
@@ -3053,9 +2965,6 @@ TEST_F(HintsManagerFetchingTest,
InitializeWithDefaultConfig("1.0.0.0");
- // Set to online so fetch is activated.
- SetConnectionOnline();
-
hints_manager()->SetHintsFetcherFactoryForTesting(
BuildTestHintsFetcherFactory(
{HintsFetcherEndState::kFetchSuccessWithURLHints}));
@@ -3099,9 +3008,6 @@ TEST_F(HintsManagerFetchingTest, NewOptTypeRegisteredClearsHintCache) {
GURL url("https://host.com/fetched_hint_host");
- // Set to online so fetch is activated.
- SetConnectionOnline();
-
hints_manager()->SetHintsFetcherFactoryForTesting(
BuildTestHintsFetcherFactory(
{HintsFetcherEndState::kFetchSuccessWithHostHints}));
@@ -3129,10 +3035,6 @@ TEST_F(HintsManagerFetchingTest, NewOptTypeRegisteredClearsHintCache) {
base::RunLoop run_loop;
- // Set to offline so fetch is NOT activated, so the cache state is known and
- // empty.
- SetConnectionOffline();
-
base::HistogramTester histogram_tester;
navigation_data = CreateTestNavigationData(url, {proto::DEFER_ALL_SCRIPT});
@@ -3159,9 +3061,6 @@ TEST_F(HintsManagerFetchingTest,
hints_manager()->RegisterOptimizationTypes({proto::COMPRESS_PUBLIC_IMAGES});
InitializeWithDefaultConfig("1.0.0.0");
- // Set to online so fetch is activated.
- SetConnectionOnline();
-
hints_manager()->SetHintsFetcherFactoryForTesting(
BuildTestHintsFetcherFactory(
{HintsFetcherEndState::kFetchSuccessWithURLHints}));
@@ -3196,9 +3095,6 @@ TEST_F(HintsManagerFetchingTest, BatchUpdateCalledMoreThanMaxConcurrent) {
hints_manager()->RegisterOptimizationTypes({proto::COMPRESS_PUBLIC_IMAGES});
InitializeWithDefaultConfig("1.0.0.0");
- // Set to online so fetch is activated.
- SetConnectionOnline();
-
hints_manager()->SetHintsFetcherFactoryForTesting(
BuildTestHintsFetcherFactory(
{HintsFetcherEndState::kFetchSuccessWithURLHints}));
@@ -3244,9 +3140,6 @@ TEST_F(
{proto::NOSCRIPT, proto::COMPRESS_PUBLIC_IMAGES});
InitializeWithDefaultConfig("1.0.0.0");
- // Set to online so fetch is activated.
- SetConnectionOnline();
-
hints_manager()->SetHintsFetcherFactoryForTesting(
BuildTestHintsFetcherFactory(
{HintsFetcherEndState::kFetchSuccessWithURLHints}));
@@ -3285,9 +3178,6 @@ TEST_F(HintsManagerFetchingTest,
{proto::NOSCRIPT, proto::COMPRESS_PUBLIC_IMAGES});
InitializeWithDefaultConfig("1.0.0.0");
- // Set to online so fetch is activated.
- SetConnectionOnline();
-
hints_manager()->SetHintsFetcherFactoryForTesting(
BuildTestHintsFetcherFactory({HintsFetcherEndState::kFetchFailed}));
std::unique_ptr<base::RunLoop> run_loop = std::make_unique<base::RunLoop>();
diff --git a/chromium/components/optimization_guide/core/local_page_entities_metadata_provider.cc b/chromium/components/optimization_guide/core/local_page_entities_metadata_provider.cc
new file mode 100644
index 00000000000..18a32f9a573
--- /dev/null
+++ b/chromium/components/optimization_guide/core/local_page_entities_metadata_provider.cc
@@ -0,0 +1,93 @@
+// Copyright 2021 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "components/optimization_guide/core/local_page_entities_metadata_provider.h"
+
+#include "components/optimization_guide/core/entity_metadata.h"
+
+namespace optimization_guide {
+
+namespace {
+
+// The amount of data to build up in memory before converting to a sorted on-
+// disk file.
+constexpr size_t kDatabaseWriteBufferSizeBytes = 128 * 1024;
+
+} // namespace
+
+LocalPageEntitiesMetadataProvider::LocalPageEntitiesMetadataProvider() =
+ default;
+LocalPageEntitiesMetadataProvider::~LocalPageEntitiesMetadataProvider() =
+ default;
+
+void LocalPageEntitiesMetadataProvider::Initialize(
+ leveldb_proto::ProtoDatabaseProvider* database_provider,
+ const base::FilePath& database_dir,
+ scoped_refptr<base::SequencedTaskRunner> background_task_runner) {
+ DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
+
+ background_task_runner_ = std::move(background_task_runner);
+ database_ = database_provider->GetDB<proto::EntityMetadataStorage>(
+ leveldb_proto::ProtoDbType::PAGE_ENTITY_METADATA_STORE, database_dir,
+ background_task_runner_);
+
+ leveldb_env::Options options = leveldb_proto::CreateSimpleOptions();
+ options.write_buffer_size = kDatabaseWriteBufferSizeBytes;
+ database_->Init(
+ options,
+ base::BindOnce(&LocalPageEntitiesMetadataProvider::OnDatabaseInitialized,
+ weak_ptr_factory_.GetWeakPtr()));
+}
+
+void LocalPageEntitiesMetadataProvider::InitializeForTesting(
+ std::unique_ptr<leveldb_proto::ProtoDatabase<proto::EntityMetadataStorage>>
+ database,
+ scoped_refptr<base::SequencedTaskRunner> background_task_runner) {
+ database_ = std::move(database);
+ background_task_runner_ = std::move(background_task_runner);
+}
+
+void LocalPageEntitiesMetadataProvider::OnDatabaseInitialized(
+ leveldb_proto::Enums::InitStatus status) {
+ DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
+ if (status != leveldb_proto::Enums::InitStatus::kOK) {
+ database_.reset();
+ return;
+ }
+}
+
+void LocalPageEntitiesMetadataProvider::GetMetadataForEntityId(
+ const std::string& entity_id,
+ EntityMetadataRetrievedCallback callback) {
+ DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
+
+ if (!database_) {
+ std::move(callback).Run(absl::nullopt);
+ return;
+ }
+
+ database_->GetEntry(
+ entity_id, base::BindOnce(&LocalPageEntitiesMetadataProvider::OnGotEntry,
+ weak_ptr_factory_.GetWeakPtr(), entity_id,
+ std::move(callback)));
+}
+
+void LocalPageEntitiesMetadataProvider::OnGotEntry(
+ const std::string& entity_id,
+ EntityMetadataRetrievedCallback callback,
+ bool success,
+ std::unique_ptr<proto::EntityMetadataStorage> entry) {
+ if (!success || !entry) {
+ std::move(callback).Run(absl::nullopt);
+ return;
+ }
+
+ EntityMetadata md;
+ md.entity_id = entity_id;
+ md.human_readable_name = entry->entity_name();
+
+ std::move(callback).Run(md);
+}
+
+} // namespace optimization_guide \ No newline at end of file
diff --git a/chromium/components/optimization_guide/core/local_page_entities_metadata_provider.h b/chromium/components/optimization_guide/core/local_page_entities_metadata_provider.h
new file mode 100644
index 00000000000..11070c8961a
--- /dev/null
+++ b/chromium/components/optimization_guide/core/local_page_entities_metadata_provider.h
@@ -0,0 +1,67 @@
+// Copyright 2021 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef COMPONENTS_OPTIMIZATION_GUIDE_CORE_LOCAL_PAGE_ENTITIES_METADATA_PROVIDER_H_
+#define COMPONENTS_OPTIMIZATION_GUIDE_CORE_LOCAL_PAGE_ENTITIES_METADATA_PROVIDER_H_
+
+#include "base/callback.h"
+#include "base/memory/weak_ptr.h"
+#include "base/sequence_checker.h"
+#include "components/leveldb_proto/public/proto_database.h"
+#include "components/leveldb_proto/public/proto_database_provider.h"
+#include "components/optimization_guide/core/entity_metadata_provider.h"
+#include "components/optimization_guide/proto/page_entities_metadata.pb.h"
+#include "third_party/abseil-cpp/absl/types/optional.h"
+
+namespace optimization_guide {
+
+// Provides EntityMetadata given an entity id by looking up entries in a local
+// database on-disk.
+class LocalPageEntitiesMetadataProvider : public EntityMetadataProvider {
+ public:
+ LocalPageEntitiesMetadataProvider();
+ ~LocalPageEntitiesMetadataProvider() override;
+ LocalPageEntitiesMetadataProvider(const LocalPageEntitiesMetadataProvider&) =
+ delete;
+ LocalPageEntitiesMetadataProvider& operator=(
+ const LocalPageEntitiesMetadataProvider&) = delete;
+
+ // Initializes this class, setting |database_| and |background_task_runner_|.
+ void Initialize(
+ leveldb_proto::ProtoDatabaseProvider* database_provider,
+ const base::FilePath& database_dir,
+ scoped_refptr<base::SequencedTaskRunner> background_task_runner);
+
+ // Directly sets |database_| and |background_task_runner_| for tests.
+ void InitializeForTesting(
+ std::unique_ptr<
+ leveldb_proto::ProtoDatabase<proto::EntityMetadataStorage>> database,
+ scoped_refptr<base::SequencedTaskRunner> background_task_runner);
+
+ // EntityMetadataProvider:
+ void GetMetadataForEntityId(
+ const std::string& entity_id,
+ EntityMetadataRetrievedCallback callback) override;
+
+ private:
+ void OnDatabaseInitialized(leveldb_proto::Enums::InitStatus status);
+ void OnGotEntry(const std::string& entity_id,
+ EntityMetadataRetrievedCallback callback,
+ bool success,
+ std::unique_ptr<proto::EntityMetadataStorage> entry);
+
+ std::unique_ptr<leveldb_proto::ProtoDatabase<proto::EntityMetadataStorage>>
+ database_;
+
+ scoped_refptr<base::SequencedTaskRunner> background_task_runner_;
+
+ SEQUENCE_CHECKER(sequence_checker_);
+
+ base::WeakPtrFactory<LocalPageEntitiesMetadataProvider> weak_ptr_factory_{
+ this};
+};
+
+} // namespace optimization_guide
+
+#endif // COMPONENTS_OPTIMIZATION_GUIDE_CORE_LOCAL_PAGE_ENTITIES_METADATA_PROVIDER_H_ \ No newline at end of file
diff --git a/chromium/components/optimization_guide/core/local_page_entities_metadata_provider_unittest.cc b/chromium/components/optimization_guide/core/local_page_entities_metadata_provider_unittest.cc
new file mode 100644
index 00000000000..7a24cbe2430
--- /dev/null
+++ b/chromium/components/optimization_guide/core/local_page_entities_metadata_provider_unittest.cc
@@ -0,0 +1,134 @@
+// Copyright 2021 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "components/optimization_guide/core/local_page_entities_metadata_provider.h"
+
+#include "base/test/task_environment.h"
+#include "components/leveldb_proto/testing/fake_db.h"
+#include "components/optimization_guide/core/optimization_guide_features.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace optimization_guide {
+
+class LocalPageEntitiesMetadataProviderTest : public testing::Test {
+ public:
+ LocalPageEntitiesMetadataProviderTest() = default;
+ ~LocalPageEntitiesMetadataProviderTest() override = default;
+
+ void SetUp() override {
+ auto db = std::make_unique<
+ leveldb_proto::test::FakeDB<proto::EntityMetadataStorage>>(&db_store_);
+ db_ = db.get();
+
+ provider_ = std::make_unique<LocalPageEntitiesMetadataProvider>();
+ provider_->InitializeForTesting(
+ std::move(db), task_environment_.GetMainThreadTaskRunner());
+ }
+
+ LocalPageEntitiesMetadataProvider* provider() { return provider_.get(); }
+
+ leveldb_proto::test::FakeDB<proto::EntityMetadataStorage>* db() {
+ return db_;
+ }
+
+ std::map<std::string, proto::EntityMetadataStorage>* store() {
+ return &db_store_;
+ }
+
+ private:
+ base::test::TaskEnvironment task_environment_;
+ std::unique_ptr<LocalPageEntitiesMetadataProvider> provider_;
+ leveldb_proto::test::FakeDB<proto::EntityMetadataStorage>* db_;
+ std::map<std::string, proto::EntityMetadataStorage> db_store_;
+};
+
+TEST_F(LocalPageEntitiesMetadataProviderTest, NonInitReturnsNullOpt) {
+ LocalPageEntitiesMetadataProvider provider;
+
+ absl::optional<EntityMetadata> md;
+ bool callback_ran = false;
+ provider.GetMetadataForEntityId(
+ "entity_id",
+ base::BindOnce(
+ [](bool* callback_ran_flag, absl::optional<EntityMetadata>* md_out,
+ const absl::optional<EntityMetadata>& md_in) {
+ *callback_ran_flag = true;
+ *md_out = md_in;
+ },
+ &callback_ran, &md));
+
+ ASSERT_TRUE(callback_ran);
+ EXPECT_EQ(absl::nullopt, md);
+}
+
+TEST_F(LocalPageEntitiesMetadataProviderTest, EmptyStoreReturnsNullOpt) {
+ absl::optional<EntityMetadata> md;
+ bool callback_ran = false;
+ provider()->GetMetadataForEntityId(
+ "entity_id",
+ base::BindOnce(
+ [](bool* callback_ran_flag, absl::optional<EntityMetadata>* md_out,
+ const absl::optional<EntityMetadata>& md_in) {
+ *callback_ran_flag = true;
+ *md_out = md_in;
+ },
+ &callback_ran, &md));
+
+ db()->GetCallback(/*success=*/true);
+
+ ASSERT_TRUE(callback_ran);
+ EXPECT_EQ(absl::nullopt, md);
+}
+
+TEST_F(LocalPageEntitiesMetadataProviderTest, PopulatedSuccess) {
+ proto::EntityMetadataStorage stored_proto;
+ stored_proto.set_entity_name("chip");
+ store()->emplace("chocolate", stored_proto);
+
+ EntityMetadata want_md;
+ want_md.entity_id = "chocolate";
+ want_md.human_readable_name = "chip";
+
+ absl::optional<EntityMetadata> md;
+ bool callback_ran = false;
+ provider()->GetMetadataForEntityId(
+ "chocolate",
+ base::BindOnce(
+ [](bool* callback_ran_flag, absl::optional<EntityMetadata>* md_out,
+ const absl::optional<EntityMetadata>& md_in) {
+ *callback_ran_flag = true;
+ *md_out = md_in;
+ },
+ &callback_ran, &md));
+
+ db()->GetCallback(/*success=*/true);
+
+ ASSERT_TRUE(callback_ran);
+ EXPECT_EQ(absl::make_optional(want_md), md);
+}
+
+TEST_F(LocalPageEntitiesMetadataProviderTest, PopulatedFailure) {
+ proto::EntityMetadataStorage stored_proto;
+ stored_proto.set_entity_name("chip");
+ store()->emplace("chocolate", stored_proto);
+
+ absl::optional<EntityMetadata> md;
+ bool callback_ran = false;
+ provider()->GetMetadataForEntityId(
+ "chocolate",
+ base::BindOnce(
+ [](bool* callback_ran_flag, absl::optional<EntityMetadata>* md_out,
+ const absl::optional<EntityMetadata>& md_in) {
+ *callback_ran_flag = true;
+ *md_out = md_in;
+ },
+ &callback_ran, &md));
+
+ db()->GetCallback(/*success=*/false);
+
+ ASSERT_TRUE(callback_ran);
+ EXPECT_EQ(absl::nullopt, md);
+}
+
+} // namespace optimization_guide \ No newline at end of file
diff --git a/chromium/components/optimization_guide/core/model_enums.h b/chromium/components/optimization_guide/core/model_enums.h
new file mode 100644
index 00000000000..1e6c495c753
--- /dev/null
+++ b/chromium/components/optimization_guide/core/model_enums.h
@@ -0,0 +1,57 @@
+// Copyright 2021 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef COMPONENTS_OPTIMIZATION_GUIDE_CORE_MODEL_ENUMS_H_
+#define COMPONENTS_OPTIMIZATION_GUIDE_CORE_MODEL_ENUMS_H_
+
+namespace optimization_guide {
+
+// The types of decisions that can be made for an optimization target.
+//
+// Keep in sync with OptimizationGuideOptimizationTargetDecision in enums.xml.
+enum class OptimizationTargetDecision {
+ kUnknown = 0,
+ // The page load does not match the optimization target.
+ kPageLoadDoesNotMatch = 1,
+ // The page load matches the optimization target.
+ kPageLoadMatches = 2,
+ // The model needed to make the target decision was not available on the
+ // client.
+ kModelNotAvailableOnClient = 3,
+ // The page load is part of a model prediction holdback where all decisions
+ // will return |OptimizationGuideDecision::kFalse| in an attempt to not taint
+ // the data for understanding the production recall of the model.
+ kModelPredictionHoldback = 4,
+ // The OptimizationGuideDecider was not initialized yet.
+ kDeciderNotInitialized = 5,
+
+ // Add new values above this line.
+ kMaxValue = kDeciderNotInitialized,
+};
+
+// The statuses for a prediction model in the prediction manager when requested
+// to be evaluated.
+//
+// Keep in sync with OptimizationGuidePredictionManagerModelStatus in enums.xml.
+enum class PredictionManagerModelStatus {
+ kUnknown = 0,
+ // The model is loaded and available for use.
+ kModelAvailable = 1,
+ // The store is initialized but does not contain a model for the optimization
+ // target.
+ kStoreAvailableNoModelForTarget = 2,
+ // The store is initialized and contains a model for the optimization target
+ // but it is not loaded in memory.
+ kStoreAvailableModelNotLoaded = 3,
+ // The store is not initialized and it is unknown if it contains a model for
+ // the optimization target.
+ kStoreUnavailableModelUnknown = 4,
+
+ // Add new values above this line.
+ kMaxValue = kStoreUnavailableModelUnknown,
+};
+
+} // namespace optimization_guide
+
+#endif // COMPONENTS_OPTIMIZATION_GUIDE_CORE_OPTIMIZATION_GUIDE_ENUMS_H_
diff --git a/chromium/components/optimization_guide/core/model_executor.h b/chromium/components/optimization_guide/core/model_executor.h
index 542c604c9db..305b6a9e24b 100644
--- a/chromium/components/optimization_guide/core/model_executor.h
+++ b/chromium/components/optimization_guide/core/model_executor.h
@@ -17,7 +17,7 @@
namespace optimization_guide {
// This class handles the execution, loading, unloading, and associated metrics
-// of machine learning models in Optimization Guide on a background thread. This
+// of machine learning models in Optimization Guide on a specified thread. This
// class is meant to be used and owned by an instance of |ModelHandler|. A
// ModelExecutor must be passed to a ModelHandler's constructor, this design
// allows the implementer of a ModelExecutor to define how the model is built
@@ -25,21 +25,21 @@ namespace optimization_guide {
// base_model_executor_helpers.h in this directory for helpful derived classes.
//
// Lifetime: This class can be constructed on any thread but cannot do anything
-// useful until |InitializeAndMoveToBackgroundThread| is called. After that
+// useful until |InitializeAndMoveToExecutionThread| is called. After that
// method is called, all subsequent calls to this class must be made through the
-// |background_task_runner| that was passed to initialize. Furthermore, all
-// WeakPointers of this class must only be dereferenced on the background thread
-// as well. This in turn means that this class must be destroyed on the
-// background thread as well.
+// |execution_task_runner| that was passed to initialize. Furthermore, all
+// WeakPointers of this class must only be dereferenced on the
+// |execution_task_runner| thread as well. This in turn means that this class
+// must be destroyed on the |execution_task_runner| thread as well.
template <class OutputType, class... InputTypes>
class ModelExecutor {
public:
ModelExecutor() = default;
virtual ~ModelExecutor() = default;
- virtual void InitializeAndMoveToBackgroundThread(
+ virtual void InitializeAndMoveToExecutionThread(
proto::OptimizationTarget optimization_target,
- scoped_refptr<base::SequencedTaskRunner> background_task_runner,
+ scoped_refptr<base::SequencedTaskRunner> execution_task_runner,
scoped_refptr<base::SequencedTaskRunner> reply_task_runner) = 0;
virtual void UpdateModelFile(const base::FilePath& file_path) = 0;
@@ -53,21 +53,21 @@ class ModelExecutor {
using ExecutionCallback =
base::OnceCallback<void(const absl::optional<OutputType>&)>;
- virtual void SendForExecution(ExecutionCallback ui_callback_on_complete,
+ virtual void SendForExecution(ExecutionCallback callback_on_complete,
base::TimeTicks start_time,
InputTypes... args) = 0;
- // IMPORTANT: These WeakPointers must only be dereferenced on the background
- // thread.
- base::WeakPtr<ModelExecutor> GetBackgroundWeakPtr() {
- return background_weak_ptr_factory_.GetWeakPtr();
+ // IMPORTANT: These WeakPointers must only be dereferenced on the
+ // |execution_task_runner| thread.
+ base::WeakPtr<ModelExecutor> GetWeakPtrForExecutionThread() {
+ return weak_ptr_factory_.GetWeakPtr();
}
ModelExecutor(const ModelExecutor&) = delete;
ModelExecutor& operator=(const ModelExecutor&) = delete;
private:
- base::WeakPtrFactory<ModelExecutor> background_weak_ptr_factory_{this};
+ base::WeakPtrFactory<ModelExecutor> weak_ptr_factory_{this};
};
} // namespace optimization_guide
diff --git a/chromium/components/optimization_guide/core/model_handler.h b/chromium/components/optimization_guide/core/model_handler.h
index 010734bab35..ff1c332a4a7 100644
--- a/chromium/components/optimization_guide/core/model_handler.h
+++ b/chromium/components/optimization_guide/core/model_handler.h
@@ -16,6 +16,7 @@
#include "base/threading/sequenced_task_runner_handle.h"
#include "base/time/time.h"
#include "components/optimization_guide/core/model_executor.h"
+#include "components/optimization_guide/core/model_util.h"
#include "components/optimization_guide/core/optimization_guide_model_provider.h"
#include "components/optimization_guide/core/optimization_guide_util.h"
#include "components/optimization_guide/core/optimization_target_model_observer.h"
@@ -24,32 +25,33 @@
namespace optimization_guide {
-// This class owns and handles the execution of models on the UI thread. Derived
-// classes must provide an implementation of |ModelExecutor|
-// (see above) which is then owned by |this|. The passed executor will be called
-// and destroyed on a background thread, which is all handled by this class.
+// This class owns and handles the execution of models on the UI thread.
+// Derived classes must provide an implementation of |ModelExecutor|
+// which is then owned by |this|. The passed executor will be called
+// and destroyed on the thread specified by |model_executor_task_runner|,
+// which is all handled by this class.
template <class OutputType, class... InputTypes>
class ModelHandler : public OptimizationTargetModelObserver {
public:
- ModelHandler(OptimizationGuideModelProvider* model_provider,
- scoped_refptr<base::SequencedTaskRunner> background_task_runner,
- std::unique_ptr<ModelExecutor<OutputType, InputTypes...>>
- background_executor,
- proto::OptimizationTarget optimization_target,
- const absl::optional<proto::Any>& model_metadata)
+ ModelHandler(
+ OptimizationGuideModelProvider* model_provider,
+ scoped_refptr<base::SequencedTaskRunner> model_executor_task_runner,
+ std::unique_ptr<ModelExecutor<OutputType, InputTypes...>> model_executor,
+ proto::OptimizationTarget optimization_target,
+ const absl::optional<proto::Any>& model_metadata)
: model_provider_(model_provider),
optimization_target_(optimization_target),
- background_executor_(std::move(background_executor)),
- background_task_runner_(background_task_runner) {
+ model_executor_(std::move(model_executor)),
+ model_executor_task_runner_(model_executor_task_runner) {
DCHECK(model_provider_);
- DCHECK(background_executor_);
+ DCHECK(model_executor_);
DCHECK_NE(optimization_target_,
proto::OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN);
model_provider_->AddObserverForOptimizationTargetModel(
optimization_target_, model_metadata, this);
- background_executor_->InitializeAndMoveToBackgroundThread(
- optimization_target_, background_task_runner_,
+ model_executor_->InitializeAndMoveToExecutionThread(
+ optimization_target_, model_executor_task_runner_,
base::SequencedTaskRunnerHandle::Get());
}
~ModelHandler() override {
@@ -58,10 +60,10 @@ class ModelHandler : public OptimizationTargetModelObserver {
model_provider_->RemoveObserverForOptimizationTargetModel(
optimization_target_, this);
- // |background_executor_|'s WeakPtrs are used on the background thread, so
+ // |model_executor_|'s WeakPtrs are used on the model thread, so
// that is also where the class must be destroyed.
- background_task_runner_->DeleteSoon(FROM_HERE,
- std::move(background_executor_));
+ model_executor_task_runner_->DeleteSoon(FROM_HERE,
+ std::move(model_executor_));
}
ModelHandler(const ModelHandler&) = delete;
ModelHandler& operator=(const ModelHandler&) = delete;
@@ -79,32 +81,33 @@ class ModelHandler : public OptimizationTargetModelObserver {
ExecutionCallback on_complete_callback =
base::BindOnce(&ModelHandler::OnExecutionCompleted, std::move(callback),
optimization_target_, now);
- background_task_runner_->PostTask(
+ model_executor_task_runner_->PostTask(
FROM_HERE,
base::BindOnce(
&ModelExecutor<OutputType, InputTypes...>::SendForExecution,
- background_executor_->GetBackgroundWeakPtr(),
+ model_executor_->GetWeakPtrForExecutionThread(),
std::move(on_complete_callback), now, input...));
}
void SetShouldUnloadModelOnComplete(bool should_auto_unload) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
- background_task_runner_->PostTask(
+ model_executor_task_runner_->PostTask(
FROM_HERE,
base::BindOnce(
&ModelExecutor<OutputType,
InputTypes...>::SetShouldUnloadModelOnComplete,
- background_executor_->GetBackgroundWeakPtr(), should_auto_unload));
+ model_executor_->GetWeakPtrForExecutionThread(),
+ should_auto_unload));
}
// Requests that the model executor unload the model from memory, if it is
// currently loaded.
void UnloadModel() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
- background_task_runner_->PostTask(
+ model_executor_task_runner_->PostTask(
FROM_HERE,
base::BindOnce(&ModelExecutor<OutputType, InputTypes...>::UnloadModel,
- background_executor_->GetBackgroundWeakPtr()));
+ model_executor_->GetWeakPtrForExecutionThread()));
}
// OptimizationTargetModelObserver:
@@ -118,16 +121,16 @@ class ModelHandler : public OptimizationTargetModelObserver {
model_info_ = model_info;
model_available_ = true;
- background_task_runner_->PostTask(
+ model_executor_task_runner_->PostTask(
FROM_HERE,
base::BindOnce(
&ModelExecutor<OutputType, InputTypes...>::UpdateModelFile,
- background_executor_->GetBackgroundWeakPtr(),
+ model_executor_->GetWeakPtrForExecutionThread(),
model_info.GetModelFilePath()));
// Run any observing callbacks after the model file is posted to the
- // background thread so that any model execution requests are posted to the
- // background thread after the model update.
+ // model executor thread so that any model execution requests are posted to
+ // the model executor thread after the model update.
on_model_updated_callbacks_.Notify();
}
@@ -168,7 +171,7 @@ class ModelHandler : public OptimizationTargetModelObserver {
}
private:
- // This is called by |background_executor_|. This method does not have to be
+ // This is called by |model_executor_|. This method does not have to be
// static, but because it is stateless we've made it static so that we don't
// have to have this class support WeakPointers.
static void OnExecutionCompleted(
@@ -198,15 +201,14 @@ class ModelHandler : public OptimizationTargetModelObserver {
const proto::OptimizationTarget optimization_target_;
- // The owned background executor.
- std::unique_ptr<ModelExecutor<OutputType, InputTypes...>>
- background_executor_;
+ // The owned model executor.
+ std::unique_ptr<ModelExecutor<OutputType, InputTypes...>> model_executor_;
- // The background task runner. Note that whenever a task is posted here, the
- // task takes a reference to the TaskRunner (in a cyclic dependency) so
+ // The model executor task runner. Note that whenever a task is posted here,
+ // the task takes a reference to the TaskRunner (in a cyclic dependency) so
// |base::Unretained| is not safe anywhere in this class or the
- // |background_executor_|.
- scoped_refptr<base::SequencedTaskRunner> background_task_runner_;
+ // |model_executor_|.
+ scoped_refptr<base::SequencedTaskRunner> model_executor_task_runner_;
// Set in |OnModelUpdated|.
absl::optional<ModelInfo> model_info_ GUARDED_BY_CONTEXT(sequence_checker_);
diff --git a/chromium/components/optimization_guide/core/model_info.cc b/chromium/components/optimization_guide/core/model_info.cc
index 077b95dc9b2..db16275f297 100644
--- a/chromium/components/optimization_guide/core/model_info.cc
+++ b/chromium/components/optimization_guide/core/model_info.cc
@@ -6,7 +6,9 @@
#include "base/memory/ptr_util.h"
#include "base/notreached.h"
-#include "components/optimization_guide/core/optimization_guide_util.h"
+#include "base/strings/utf_string_conversions.h"
+#include "build/build_config.h"
+#include "components/optimization_guide/core/model_util.h"
namespace optimization_guide {
diff --git a/chromium/components/optimization_guide/core/model_util.cc b/chromium/components/optimization_guide/core/model_util.cc
new file mode 100644
index 00000000000..936d0e4c4cb
--- /dev/null
+++ b/chromium/components/optimization_guide/core/model_util.cc
@@ -0,0 +1,86 @@
+// Copyright 2021 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "components/optimization_guide/core/model_util.h"
+
+#include "base/base64.h"
+#include "base/containers/flat_set.h"
+#include "base/notreached.h"
+#include "base/strings/utf_string_conversions.h"
+#include "build/build_config.h"
+#include "net/base/url_util.h"
+#include "url/url_canon.h"
+
+namespace optimization_guide {
+
+// These names are persisted to histograms, so don't change them.
+std::string GetStringNameForOptimizationTarget(
+ optimization_guide::proto::OptimizationTarget optimization_target) {
+ switch (optimization_target) {
+ case proto::OPTIMIZATION_TARGET_UNKNOWN:
+ return "Unknown";
+ case proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD:
+ return "PainfulPageLoad";
+ case proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION:
+ return "LanguageDetection";
+ case proto::OPTIMIZATION_TARGET_PAGE_TOPICS:
+ return "PageTopics";
+ case proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB:
+ return "SegmentationNewTab";
+ case proto::OPTIMIZATION_TARGET_SEGMENTATION_SHARE:
+ return "SegmentationShare";
+ case proto::OPTIMIZATION_TARGET_SEGMENTATION_VOICE:
+ return "SegmentationVoice";
+ case proto::OPTIMIZATION_TARGET_MODEL_VALIDATION:
+ return "ModelValidation";
+ case proto::OPTIMIZATION_TARGET_PAGE_ENTITIES:
+ return "PageEntities";
+ case proto::OPTIMIZATION_TARGET_NOTIFICATION_PERMISSION_PREDICTIONS:
+ return "NotificationPermissions";
+ case proto::OPTIMIZATION_TARGET_SEGMENTATION_DUMMY:
+ return "SegmentationDummyFeature";
+ case proto::OPTIMIZATION_TARGET_SEGMENTATION_CHROME_START_ANDROID:
+ return "SegmentationChromeStartAndroid";
+ case proto::OPTIMIZATION_TARGET_SEGMENTATION_QUERY_TILES:
+ return "SegmentationQueryTiles";
+ case proto::OPTIMIZATION_TARGET_PAGE_VISIBILITY:
+ return "PageVisibility";
+ case proto::OPTIMIZATION_TARGET_AUTOFILL_ASSISTANT:
+ return "AutofillAssistant";
+ case proto::OPTIMIZATION_TARGET_PAGE_TOPICS_V2:
+ return "PageTopicsV2";
+ case proto::OPTIMIZATION_TARGET_SEGMENTATION_CHROME_LOW_USER_ENGAGEMENT:
+ return "SegmentationChromeLowUserEngagement";
+ // Whenever a new value is added, make sure to add it to the OptTarget
+ // variant list in
+ // //tools/metrics/histograms/metadata/optimization/histograms.xml.
+ }
+ NOTREACHED();
+ return std::string();
+}
+
+absl::optional<base::FilePath> StringToFilePath(const std::string& str_path) {
+ if (str_path.empty())
+ return absl::nullopt;
+
+#if BUILDFLAG(IS_WIN)
+ return base::FilePath(base::UTF8ToWide(str_path));
+#else
+ return base::FilePath(str_path);
+#endif
+}
+
+std::string FilePathToString(const base::FilePath& file_path) {
+#if BUILDFLAG(IS_WIN)
+ return base::WideToUTF8(file_path.value());
+#else
+ return file_path.value();
+#endif
+}
+
+base::FilePath GetBaseFileNameForModels() {
+ return base::FilePath(FILE_PATH_LITERAL("model.tflite"));
+}
+
+} // namespace optimization_guide
diff --git a/chromium/components/optimization_guide/core/model_util.h b/chromium/components/optimization_guide/core/model_util.h
new file mode 100644
index 00000000000..dbcf1d476db
--- /dev/null
+++ b/chromium/components/optimization_guide/core/model_util.h
@@ -0,0 +1,37 @@
+// Copyright 2021 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef COMPONENTS_OPTIMIZATION_GUIDE_CORE_MODEL_UTIL_H_
+#define COMPONENTS_OPTIMIZATION_GUIDE_CORE_MODEL_UTIL_H_
+
+#include <string>
+
+#include "base/files/file_path.h"
+#include "components/optimization_guide/proto/models.pb.h"
+#include "third_party/abseil-cpp/absl/types/optional.h"
+
+namespace optimization_guide {
+
+// Returns the string than can be used to record histograms for the optimization
+// target. If adding a histogram to use the string or adding an optimization
+// target, update the OptimizationGuide.OptimizationTargets histogram suffixes
+// in histograms.xml.
+std::string GetStringNameForOptimizationTarget(
+ proto::OptimizationTarget optimization_target);
+
+// Returns the file path represented by the given string, handling platform
+// differences in the conversion. nullopt is only returned iff the passed string
+// is empty.
+absl::optional<base::FilePath> StringToFilePath(const std::string& str_path);
+
+// Returns a string representation of the given |file_path|, handling platform
+// differences in the conversion.
+std::string FilePathToString(const base::FilePath& file_path);
+
+// Returns the base file name to use for storing all prediction models.
+base::FilePath GetBaseFileNameForModels();
+
+} // namespace optimization_guide
+
+#endif // COMPONENTS_OPTIMIZATION_GUIDE_CORE_MODEL_UTIL_H_
diff --git a/chromium/components/optimization_guide/core/model_validator.cc b/chromium/components/optimization_guide/core/model_validator.cc
index b936f7470f5..bb09c0c8e80 100644
--- a/chromium/components/optimization_guide/core/model_validator.cc
+++ b/chromium/components/optimization_guide/core/model_validator.cc
@@ -53,18 +53,22 @@ ModelValidatorExecutor::ModelValidatorExecutor() = default;
ModelValidatorExecutor::~ModelValidatorExecutor() = default;
-absl::Status ModelValidatorExecutor::Preprocess(
+bool ModelValidatorExecutor::Preprocess(
const std::vector<TfLiteTensor*>& input_tensors,
const std::vector<float>& input) {
// Return error so that actual model execution does not happen.
- return absl::Status(absl::StatusCode::kUnimplemented,
- "Model execution not supported");
+ return false;
}
-float ModelValidatorExecutor::Postprocess(
+absl::optional<float> ModelValidatorExecutor::Postprocess(
const std::vector<const TfLiteTensor*>& output_tensors) {
std::vector<float> data;
- tflite::task::core::PopulateVector<float>(output_tensors[0], &data);
+ absl::Status status =
+ tflite::task::core::PopulateVector<float>(output_tensors[0], &data);
+ if (!status.ok()) {
+ NOTREACHED();
+ return absl::nullopt;
+ }
return data[0];
}
diff --git a/chromium/components/optimization_guide/core/model_validator.h b/chromium/components/optimization_guide/core/model_validator.h
index 4e48f4f4a2f..fc25cf6bd94 100644
--- a/chromium/components/optimization_guide/core/model_validator.h
+++ b/chromium/components/optimization_guide/core/model_validator.h
@@ -53,9 +53,9 @@ class ModelValidatorExecutor
protected:
// BaseModelExecutor:
- absl::Status Preprocess(const std::vector<TfLiteTensor*>& input_tensors,
- const std::vector<float>& input) override;
- float Postprocess(
+ bool Preprocess(const std::vector<TfLiteTensor*>& input_tensors,
+ const std::vector<float>& input) override;
+ absl::optional<float> Postprocess(
const std::vector<const TfLiteTensor*>& output_tensors) override;
};
diff --git a/chromium/components/optimization_guide/core/model_validator_unittest.cc b/chromium/components/optimization_guide/core/model_validator_unittest.cc
index 48a66a623eb..8792f0a9798 100644
--- a/chromium/components/optimization_guide/core/model_validator_unittest.cc
+++ b/chromium/components/optimization_guide/core/model_validator_unittest.cc
@@ -15,6 +15,7 @@
#include "base/test/metrics/histogram_tester.h"
#include "base/test/task_environment.h"
#include "build/build_config.h"
+#include "components/optimization_guide/core/model_util.h"
#include "components/optimization_guide/core/model_validator.h"
#include "components/optimization_guide/core/optimization_guide_switches.h"
#include "components/optimization_guide/core/optimization_guide_util.h"
@@ -125,8 +126,14 @@ TEST_F(ModelValidatorExecutorTest, ValidModel) {
proto::OptimizationTarget::OPTIMIZATION_TARGET_MODEL_VALIDATION),
ExecutionStatus::kErrorUnknown, 1);
+ histogram_tester().ExpectUniqueSample(
+ "OptimizationGuide.ModelExecutor.ModelLoadedSuccessfully." +
+ GetStringNameForOptimizationTarget(
+ proto::OptimizationTarget::OPTIMIZATION_TARGET_MODEL_VALIDATION),
+ true, 1);
+
histogram_tester().ExpectTotalCount(
- "OptimizationGuide.ModelExecutor.ModelLoadingDuration." +
+ "OptimizationGuide.ModelExecutor.ModelLoadingDuration2." +
GetStringNameForOptimizationTarget(
proto::OptimizationTarget::OPTIMIZATION_TARGET_MODEL_VALIDATION),
1);
@@ -147,8 +154,15 @@ TEST_F(ModelValidatorExecutorTest, DISABLED_InvalidModel) {
GetStringNameForOptimizationTarget(
proto::OptimizationTarget::OPTIMIZATION_TARGET_MODEL_VALIDATION),
ExecutionStatus::kErrorModelFileNotValid, 1);
+
+ histogram_tester().ExpectUniqueSample(
+ "OptimizationGuide.ModelExecutor.ModelLoadedSuccessfully." +
+ GetStringNameForOptimizationTarget(
+ proto::OptimizationTarget::OPTIMIZATION_TARGET_MODEL_VALIDATION),
+ false, 1);
+
histogram_tester().ExpectTotalCount(
- "OptimizationGuide.ModelExecutor.ModelLoadingDuration." +
+ "OptimizationGuide.ModelExecutor.ModelLoadingDuration2." +
GetStringNameForOptimizationTarget(
proto::OptimizationTarget::OPTIMIZATION_TARGET_MODEL_VALIDATION),
1);
diff --git a/chromium/components/optimization_guide/core/optimization_guide_constants.cc b/chromium/components/optimization_guide/core/optimization_guide_constants.cc
index a8102e2dcac..c6b1318da25 100644
--- a/chromium/components/optimization_guide/core/optimization_guide_constants.cc
+++ b/chromium/components/optimization_guide/core/optimization_guide_constants.cc
@@ -27,4 +27,7 @@ const base::FilePath::CharType
kOptimizationGuidePredictionModelAndFeaturesStore[] =
FILE_PATH_LITERAL("optimization_guide_model_and_features_store");
+const base::FilePath::CharType kPageEntitiesMetadataStore[] =
+ FILE_PATH_LITERAL("page_content_annotations_page_entities_metadata_store");
+
} // namespace optimization_guide
diff --git a/chromium/components/optimization_guide/core/optimization_guide_constants.h b/chromium/components/optimization_guide/core/optimization_guide_constants.h
index 8b94c027625..d095761b879 100644
--- a/chromium/components/optimization_guide/core/optimization_guide_constants.h
+++ b/chromium/components/optimization_guide/core/optimization_guide_constants.h
@@ -33,6 +33,9 @@ extern const base::FilePath::CharType kOptimizationGuideHintStore[];
extern const base::FilePath::CharType
kOptimizationGuidePredictionModelAndFeaturesStore[];
+// The folder where the page entities metadata store will be stored on disk.
+extern const base::FilePath::CharType kPageEntitiesMetadataStore[];
+
} // namespace optimization_guide
#endif // COMPONENTS_OPTIMIZATION_GUIDE_CORE_OPTIMIZATION_GUIDE_CONSTANTS_H_
diff --git a/chromium/components/optimization_guide/core/optimization_guide_enums.h b/chromium/components/optimization_guide/core/optimization_guide_enums.h
index 88d24bcf537..a7e817c2b5e 100644
--- a/chromium/components/optimization_guide/core/optimization_guide_enums.h
+++ b/chromium/components/optimization_guide/core/optimization_guide_enums.h
@@ -49,29 +49,6 @@ enum class OptimizationTypeDecision {
kMaxValue = kHintFetchStartedButNotAvailableInTime,
};
-// The types of decisions that can be made for an optimization target.
-//
-// Keep in sync with OptimizationGuideOptimizationTargetDecision in enums.xml.
-enum class OptimizationTargetDecision {
- kUnknown,
- // The page load does not match the optimization target.
- kPageLoadDoesNotMatch,
- // The page load matches the optimization target.
- kPageLoadMatches,
- // The model needed to make the target decision was not available on the
- // client.
- kModelNotAvailableOnClient,
- // The page load is part of a model prediction holdback where all decisions
- // will return |OptimizationGuideDecision::kFalse| in an attempt to not taint
- // the data for understanding the production recall of the model.
- kModelPredictionHoldback,
- // The OptimizationGuideDecider was not initialized yet.
- kDeciderNotInitialized,
-
- // Add new values above this line.
- kMaxValue = kDeciderNotInitialized,
-};
-
// The statuses for racing a hints fetch with the current navigation based
// on the availability of hints for both the current host and URL.
//
@@ -101,28 +78,6 @@ enum class RaceNavigationFetchAttemptStatus {
kDeprecatedRaceNavigationFetchNotAttemptedTooManyConcurrentFetches,
};
-// The statuses for a prediction model in the prediction manager when requested
-// to be evaluated.
-//
-// Keep in sync with OptimizationGuidePredictionManagerModelStatus in enums.xml.
-enum class PredictionManagerModelStatus {
- kUnknown,
- // The model is loaded and available for use.
- kModelAvailable,
- // The store is initialized but does not contain a model for the optimization
- // target.
- kStoreAvailableNoModelForTarget,
- // The store is initialized and contains a model for the optimization target
- // but it is not loaded in memory.
- kStoreAvailableModelNotLoaded,
- // The store is not initialized and it is unknown if it contains a model for
- // the optimization target.
- kStoreUnavailableModelUnknown,
-
- // Add new values above this line.
- kMaxValue = kStoreUnavailableModelUnknown,
-};
-
// The statuses for a download file containing a prediction model when verified
// and processed.
//
diff --git a/chromium/components/optimization_guide/core/optimization_guide_features.cc b/chromium/components/optimization_guide/core/optimization_guide_features.cc
index 8de77e545fd..6c3a200263f 100644
--- a/chromium/components/optimization_guide/core/optimization_guide_features.cc
+++ b/chromium/components/optimization_guide/core/optimization_guide_features.cc
@@ -25,15 +25,49 @@
namespace optimization_guide {
namespace features {
+namespace {
+
+// Returns whether |locale| is a supported locale for |feature|.
+//
+// This matches |locale| with the "supported_locales" feature param value in
+// |feature|, which is expected to be a comma-separated list of locales. A
+// feature param containing "en,es-ES,zh-TW" restricts the feature to English
+// language users from any locale and Spanish language users from the Spain
+// es-ES locale. A feature param containing "" is unrestricted by locale and any
+// user may load it.
+bool IsSupportedLocaleForFeature(const std::string locale,
+ const base::Feature& feature) {
+ if (!base::FeatureList::IsEnabled(feature)) {
+ return false;
+ }
+
+ std::string value =
+ base::GetFieldTrialParamValueByFeature(feature, "supported_locales");
+ std::vector<std::string> supported_locales = base::SplitString(
+ value, ",", base::TRIM_WHITESPACE, base::SPLIT_WANT_NONEMPTY);
+ // An empty allowlist admits any locale.
+ if (supported_locales.empty()) {
+ return true;
+ }
+
+ // Otherwise, the locale or the
+ // primary language subtag must match an element of the allowlist.
+ std::string locale_language = l10n_util::GetLanguage(locale);
+ return base::Contains(supported_locales, locale) ||
+ base::Contains(supported_locales, locale_language);
+}
+
+} // namespace
+
// Enables the syncing of the Optimization Hints component, which provides
// hints for what optimizations can be applied on a page load.
const base::Feature kOptimizationHints {
"OptimizationHints",
-#if defined(OS_IOS)
+#if BUILDFLAG(IS_IOS)
base::FEATURE_DISABLED_BY_DEFAULT
-#else // !defined(OS_IOS)
+#else // !BUILDFLAG(IS_IOS)
base::FEATURE_ENABLED_BY_DEFAULT
-#endif // defined(OS_IOS)
+#endif // BUILDFLAG(IS_IOS)
};
// Feature flag that contains a feature param that specifies the field trials
@@ -47,11 +81,11 @@ const base::Feature kRemoteOptimizationGuideFetching{
const base::Feature kRemoteOptimizationGuideFetchingAnonymousDataConsent {
"OptimizationHintsFetchingAnonymousDataConsent",
-#if defined(OS_ANDROID)
+#if BUILDFLAG(IS_ANDROID)
base::FEATURE_ENABLED_BY_DEFAULT
-#else // !defined(OS_ANDROID)
+#else // !BUILDFLAG(IS_ANDROID)
base::FEATURE_DISABLED_BY_DEFAULT
-#endif // defined(OS_ANDROID)
+#endif // BUILDFLAG(IS_ANDROID)
};
// Enables performance info in the context menu and fetching from a remote
@@ -78,6 +112,17 @@ const base::Feature kOptimizationGuideModelDownloading {
const base::Feature kPageContentAnnotations{"PageContentAnnotations",
base::FEATURE_DISABLED_BY_DEFAULT};
+// Enables the page entities model to be annotated on every page load.
+const base::Feature kPageEntitiesPageContentAnnotations{
+ "PageEntitiesPageContentAnnotations", base::FEATURE_DISABLED_BY_DEFAULT};
+// Enables the page visibility model to be annotated on every page load.
+const base::Feature kPageVisibilityPageContentAnnotations{
+ "PageVisibilityPageContentAnnotations", base::FEATURE_DISABLED_BY_DEFAULT};
+
+// This feature flag enables resetting the entities model on shutdown.
+const base::Feature kPageEntitiesModelResetOnShutdown{
+ "PageEntitiesModelResetOnShutdown", base::FEATURE_DISABLED_BY_DEFAULT};
+
// Enables push notification of hints.
const base::Feature kPushNotifications{"OptimizationGuidePushNotifications",
base::FEATURE_DISABLED_BY_DEFAULT};
@@ -96,6 +141,12 @@ const base::Feature kPageTopicsBatchAnnotations{
const base::Feature kPageVisibilityBatchAnnotations{
"PageVisibilityBatchAnnotations", base::FEATURE_ENABLED_BY_DEFAULT};
+const base::Feature kUseLocalPageEntitiesMetadataProvider{
+ "UseLocalPageEntitiesMetadataProvider", base::FEATURE_DISABLED_BY_DEFAULT};
+
+const base::Feature kBatchAnnotationsValidation{
+ "BatchAnnotationsValidation", base::FEATURE_DISABLED_BY_DEFAULT};
+
// The default value here is a bit of a guess.
// TODO(crbug/1163244): This should be tuned once metrics are available.
base::TimeDelta PageTextExtractionOutstandingRequestsGracePeriod() {
@@ -264,14 +315,15 @@ base::TimeDelta StoredHostModelFeaturesFreshnessDuration() {
"max_store_duration_for_host_model_features_in_days", 7));
}
-base::TimeDelta StoredModelsInactiveDuration() {
+base::TimeDelta StoredModelsValidDuration() {
// TODO(crbug.com/1234054) This field should not be changed without VERY
- // careful consideration. Any model that is on device and expires will be
- // removed and triggered to refetch so any feature relying on the model could
- // have a period of time without a valid model.
+ // careful consideration. This is the default duration for models that do not
+ // specify retention, so changing this can cause models to be removed and
+ // refetch would only apply to newer models. Any feature relying on the model
+ // would have a period of time without a valid model, and would need to push a
+ // new version.
return base::Days(GetFieldTrialParamByFeatureAsInt(
- kOptimizationTargetPrediction, "inactive_duration_for_models_in_days",
- 30));
+ kOptimizationTargetPrediction, "valid_duration_for_models_in_days", 30));
}
base::TimeDelta URLKeyedHintValidCacheDuration() {
@@ -334,6 +386,11 @@ base::TimeDelta PredictionModelFetchRetryDelay() {
kOptimizationTargetPrediction, "fetch_retry_minutes", 2));
}
+base::TimeDelta PredictionModelFetchStartupDelay() {
+ return base::Milliseconds(GetFieldTrialParamByFeatureAsInt(
+ kOptimizationTargetPrediction, "fetch_startup_delay_ms", 2000));
+}
+
base::TimeDelta PredictionModelFetchInterval() {
return base::Hours(GetFieldTrialParamByFeatureAsInt(
kOptimizationTargetPrediction, "fetch_interval_hours", 24));
@@ -396,75 +453,20 @@ bool ShouldExtractRelatedSearches() {
return kContentAnnotationsExtractRelatedSearchesParam.Get();
}
-std::vector<optimization_guide::proto::OptimizationTarget>
-GetPageContentModelsToExecute(const std::string& locale) {
- if (!IsPageContentAnnotationEnabled())
- return {};
-
- // Use an updated parameter name that supports locale filtering. That way,
- // older clients that don't know how to interpret locale filtering ignore the
- // new parameter name and keep looking for the old one.
- std::string value = base::GetFieldTrialParamValueByFeature(
- kPageContentAnnotations, "models_to_execute_v2");
- if (value.empty()) {
- // If the updated parameter is empty, try getting the older parameter name
- // that doesn't support locale-specific models. That way, older parameter
- // configurations still work. We don't do a union because that's confusing.
- value = base::GetFieldTrialParamValueByFeature(kPageContentAnnotations,
- "models_to_execute");
- }
- if (value.empty()) {
- // If neither the newer or older parameter is set, run the page topics model
- // by default.
- return {optimization_guide::proto::OPTIMIZATION_TARGET_PAGE_TOPICS};
- }
-
- // The parameter value delimits models by commas, and per-model locale
- // restrictions by colon. For example:
- // FOO_MODEL:en:es-ES,BAR_MODEL,BAZ_MODEL:zh-TW
- // - FOO_MODEL is restricted to English language users from any locale, and
- // Spanish language users from the Spain es-ES locale.
- // - BAR_MODEL is unrestricted by locale, and any user may load it.
- // - BAZ_MODEL is restricted to zh-TW only, so zh-CN users won't load it.
- //
- // First split by comma to handle one model at a time.
- std::vector<std::string> model_target_strings = base::SplitString(
- value, ",", base::TRIM_WHITESPACE, base::SPLIT_WANT_NONEMPTY);
-
- std::string locale_language = l10n_util::GetLanguage(locale);
+bool ShouldExecutePageEntitiesModelOnPageContent(const std::string& locale) {
+ return base::FeatureList::IsEnabled(kPageEntitiesPageContentAnnotations) &&
+ IsSupportedLocaleForFeature(locale,
+ kPageEntitiesPageContentAnnotations);
+}
- optimization_guide::InsertionOrderedSet<
- optimization_guide::proto::OptimizationTarget>
- model_targets;
- for (const auto& model_target_string : model_target_strings) {
- // Split by colon to extract the model name and allowlist, early continuing
- // for invalid values.
- std::vector<std::string> model_name_and_allowed_locales =
- base::SplitString(model_target_string, ":", base::TRIM_WHITESPACE,
- base::SPLIT_WANT_NONEMPTY);
- if (model_name_and_allowed_locales.empty())
- continue;
- std::string model_name = model_name_and_allowed_locales[0];
- std::vector<std::string> allowlist;
- for (size_t i = 1; i < model_name_and_allowed_locales.size(); ++i) {
- allowlist.push_back(model_name_and_allowed_locales[i]);
- }
-
- optimization_guide::proto::OptimizationTarget model_target;
- if (!optimization_guide::proto::OptimizationTarget_Parse(model_name,
- &model_target)) {
- continue;
- }
-
- // An empty allowlist admits any locale. Otherwise, the locale or the
- // primary language subtag must match an element of the allowlist.
- if (allowlist.empty() || base::Contains(allowlist, locale) ||
- base::Contains(allowlist, locale_language)) {
- model_targets.insert(model_target);
- }
- }
+bool ShouldResetPageEntitiesModelOnShutdown() {
+ return base::FeatureList::IsEnabled(kPageEntitiesModelResetOnShutdown);
+}
- return model_targets.vector();
+bool ShouldExecutePageVisibilityModelOnPageContent(const std::string& locale) {
+ return base::FeatureList::IsEnabled(kPageVisibilityPageContentAnnotations) &&
+ IsSupportedLocaleForFeature(locale,
+ kPageVisibilityPageContentAnnotations);
}
bool RemotePageEntitiesEnabled() {
@@ -501,7 +503,7 @@ bool ShouldMetadataValidationFetchHostKeyed() {
bool ShouldDeferStartupActiveTabsHintsFetch() {
return GetFieldTrialParamByFeatureAsBool(
kOptimizationHints, "defer_startup_active_tabs_hints_fetch",
-#if defined(OS_ANDROID)
+#if BUILDFLAG(IS_ANDROID)
true
#else
false
@@ -517,5 +519,37 @@ bool PageVisibilityBatchAnnotationsEnabled() {
return base::FeatureList::IsEnabled(kPageVisibilityBatchAnnotations);
}
+bool UseLocalPageEntitiesMetadataProvider() {
+ return base::FeatureList::IsEnabled(kUseLocalPageEntitiesMetadataProvider);
+}
+
+size_t AnnotateVisitBatchSize() {
+ return std::max(
+ 1, GetFieldTrialParamByFeatureAsInt(kPageContentAnnotations,
+ "annotate_visit_batch_size", 1));
+}
+
+bool BatchAnnotationsValidationEnabled() {
+ return base::FeatureList::IsEnabled(kBatchAnnotationsValidation);
+}
+
+base::TimeDelta BatchAnnotationValidationStartupDelay() {
+ return base::Seconds(
+ std::max(1, GetFieldTrialParamByFeatureAsInt(kBatchAnnotationsValidation,
+ "startup_delay", 30)));
+}
+
+size_t BatchAnnotationsValidationBatchSize() {
+ int batch_size = GetFieldTrialParamByFeatureAsInt(kBatchAnnotationsValidation,
+ "batch_size", 25);
+ return std::max(1, batch_size);
+}
+
+size_t MaxVisitAnnotationCacheSize() {
+ int batch_size = GetFieldTrialParamByFeatureAsInt(
+ kPageContentAnnotations, "max_visit_annotation_cache_size", 50);
+ return std::max(1, batch_size);
+}
+
} // namespace features
} // namespace optimization_guide
diff --git a/chromium/components/optimization_guide/core/optimization_guide_features.h b/chromium/components/optimization_guide/core/optimization_guide_features.h
index 9789dbc0b42..e69bd80e439 100644
--- a/chromium/components/optimization_guide/core/optimization_guide_features.h
+++ b/chromium/components/optimization_guide/core/optimization_guide_features.h
@@ -29,11 +29,15 @@ extern const base::Feature kContextMenuPerformanceInfoAndRemoteHintFetching;
extern const base::Feature kOptimizationTargetPrediction;
extern const base::Feature kOptimizationGuideModelDownloading;
extern const base::Feature kPageContentAnnotations;
+extern const base::Feature kPageEntitiesPageContentAnnotations;
+extern const base::Feature kPageVisibilityPageContentAnnotations;
extern const base::Feature kPageTextExtraction;
extern const base::Feature kPushNotifications;
extern const base::Feature kOptimizationGuideMetadataValidation;
extern const base::Feature kPageTopicsBatchAnnotations;
extern const base::Feature kPageVisibilityBatchAnnotations;
+extern const base::Feature kUseLocalPageEntitiesMetadataProvider;
+extern const base::Feature kBatchAnnotationsValidation;
// The grace period duration for how long to give outstanding page text dump
// requests to respond after DidFinishLoad.
@@ -133,7 +137,7 @@ base::TimeDelta StoredHostModelFeaturesFreshnessDuration();
// The maximum duration for which models can remain in the
// OptimizationGuideStore without being loaded.
-base::TimeDelta StoredModelsInactiveDuration();
+base::TimeDelta StoredModelsValidDuration();
// The amount of time URL-keyed hints within the hint cache will be
// allowed to be used and not be purged.
@@ -177,6 +181,10 @@ int PredictionModelFetchRandomMaxDelaySecs();
// models.
base::TimeDelta PredictionModelFetchRetryDelay();
+// Returns the time to wait after browser start before fetching prediciton
+// models.
+base::TimeDelta PredictionModelFetchStartupDelay();
+
// Returns the time to wait after a successful fetch of prediction models to
// refresh models.
base::TimeDelta PredictionModelFetchInterval();
@@ -214,15 +222,16 @@ size_t MaxContentAnnotationRequestsCached();
// as part of page content annotations.
bool ShouldExtractRelatedSearches();
-// Returns an ordered vector of models to execute on the page content for each
-// page load. It is guaranteed that an optimization target will only be present
-// at most once in the returned vector. However, it is not guaranteed that it
-// will only contain models that the current PageContentAnnotationsService
-// supports, so it is up to the caller to ensure that it can execute the
-// specified models. `locale` is used for implement client-side locale filtering
-// for models that only work for some locales.
-std::vector<optimization_guide::proto::OptimizationTarget>
-GetPageContentModelsToExecute(const std::string& locale);
+// Returns whether the page entities model should be executed on page content
+// for a user using |locale| as their browser language.
+bool ShouldExecutePageEntitiesModelOnPageContent(const std::string& locale);
+
+// Returns whether the page entities model should be reset on shutdown.
+bool ShouldResetPageEntitiesModelOnShutdown();
+
+// Returns whether the page visibility model should be executed on page content
+// for a user using |locale| as their browser language.
+bool ShouldExecutePageVisibilityModelOnPageContent(const std::string& locale);
// Returns whether page entities should be retrieved from the remote
// Optimization Guide service.
@@ -249,6 +258,27 @@ bool PageTopicsBatchAnnotationsEnabled();
// Returns if Page Visibility Batch Annotations are enabled.
bool PageVisibilityBatchAnnotationsEnabled();
+// Whether to use the leveldb-based page entities metadata provider.
+bool UseLocalPageEntitiesMetadataProvider();
+
+// The number of visits batch before running the page content annotation
+// models. A size of 1 is equivalent to annotating one page load at time
+// immediately after requested.
+size_t AnnotateVisitBatchSize();
+
+// Whether the batch annotation validation feature is enabled.
+bool BatchAnnotationsValidationEnabled();
+
+// The time period between browser start and running a running batch annotation
+// validation.
+base::TimeDelta BatchAnnotationValidationStartupDelay();
+
+// The size of batches to run for validation.
+size_t BatchAnnotationsValidationBatchSize();
+
+// The maximum size of the visit annotation cache.
+size_t MaxVisitAnnotationCacheSize();
+
} // namespace features
} // namespace optimization_guide
diff --git a/chromium/components/optimization_guide/core/optimization_guide_features_unittest.cc b/chromium/components/optimization_guide/core/optimization_guide_features_unittest.cc
index 152c621bfe3..48433bff52b 100644
--- a/chromium/components/optimization_guide/core/optimization_guide_features_unittest.cc
+++ b/chromium/components/optimization_guide/core/optimization_guide_features_unittest.cc
@@ -68,73 +68,72 @@ TEST(OptimizationGuideFeaturesTest, ValidPageContentRAPPORMetrics) {
EXPECT_EQ(.2, features::NoiseProbabilityForRAPPORMetrics());
}
-TEST(OptimizationGuideFeaturesTest, GetPageContentModelsToExecute) {
+TEST(OptimizationGuideFeaturesTest,
+ ShouldExecutePageEntitiesModelOnPageContentDisabled) {
base::test::ScopedFeatureList scoped_feature_list;
- scoped_feature_list.InitAndEnableFeatureWithParameters(
- features::kPageContentAnnotations,
- {{"models_to_execute_v2",
- "OPTIMIZATION_TARGET_PAGE_TOPICS,OPTIMIZATION_TARGET_PAGE_ENTITIES"}});
+ scoped_feature_list.InitAndDisableFeature(
+ features::kPageEntitiesPageContentAnnotations);
- auto models = features::GetPageContentModelsToExecute("en-US");
- ASSERT_EQ(2U, models.size());
- ASSERT_EQ(proto::OPTIMIZATION_TARGET_PAGE_TOPICS, models[0]);
- ASSERT_EQ(proto::OPTIMIZATION_TARGET_PAGE_ENTITIES, models[1]);
+ EXPECT_FALSE(features::ShouldExecutePageEntitiesModelOnPageContent("en-US"));
}
TEST(OptimizationGuideFeaturesTest,
- GetPageContentModelsToExecuteOldParameterName) {
+ ShouldExecutePageEntitiesModelOnPageContentEmptyAllowlist) {
+ base::test::ScopedFeatureList scoped_feature_list;
+
+ scoped_feature_list.InitAndEnableFeature(
+ features::kPageEntitiesPageContentAnnotations);
+
+ EXPECT_TRUE(features::ShouldExecutePageEntitiesModelOnPageContent("en-US"));
+}
+
+TEST(OptimizationGuideFeaturesTest,
+ ShouldExecutePageEntitiesModelOnPageContentWithAllowlist) {
base::test::ScopedFeatureList scoped_feature_list;
scoped_feature_list.InitAndEnableFeatureWithParameters(
- features::kPageContentAnnotations,
- {{"models_to_execute",
- "OPTIMIZATION_TARGET_PAGE_TOPICS,OPTIMIZATION_TARGET_PAGE_ENTITIES"}});
+ features::kPageEntitiesPageContentAnnotations,
+ {{"supported_locales", "en,zh-TW"}});
- auto models = features::GetPageContentModelsToExecute("en-US");
- ASSERT_EQ(2U, models.size());
- ASSERT_EQ(proto::OPTIMIZATION_TARGET_PAGE_TOPICS, models[0]);
- ASSERT_EQ(proto::OPTIMIZATION_TARGET_PAGE_ENTITIES, models[1]);
+ EXPECT_TRUE(features::ShouldExecutePageEntitiesModelOnPageContent("en-US"));
+ EXPECT_FALSE(features::ShouldExecutePageEntitiesModelOnPageContent(""));
+ EXPECT_FALSE(features::ShouldExecutePageEntitiesModelOnPageContent("zh-CN"));
}
-TEST(OptimizationGuideFeaturesTest, GetPageContentModelsToExecuteLocales) {
+TEST(OptimizationGuideFeaturesTest,
+ ShouldExecutePageVisibilityModelOnPageContentDisabled) {
+ base::test::ScopedFeatureList scoped_feature_list;
+
+ scoped_feature_list.InitAndDisableFeature(
+ features::kPageVisibilityPageContentAnnotations);
+
+ EXPECT_FALSE(
+ features::ShouldExecutePageVisibilityModelOnPageContent("en-US"));
+}
+
+TEST(OptimizationGuideFeaturesTest,
+ ShouldExecutePageVisibilityModelOnPageContentEmptyAllowlist) {
+ base::test::ScopedFeatureList scoped_feature_list;
+
+ scoped_feature_list.InitAndEnableFeature(
+ features::kPageVisibilityPageContentAnnotations);
+
+ EXPECT_TRUE(features::ShouldExecutePageVisibilityModelOnPageContent("en-US"));
+}
+
+TEST(OptimizationGuideFeaturesTest,
+ ShouldExecutePageVisibilityModelOnPageContentWithAllowlist) {
base::test::ScopedFeatureList scoped_feature_list;
scoped_feature_list.InitAndEnableFeatureWithParameters(
- features::kPageContentAnnotations,
- {{
- "models_to_execute_v2",
- // This string is meant to test language filtering, locale filtering,
- // and tolerance of whitespaces, as well as extra delimiters.
- "OPTIMIZATION_TARGET_PAGE_TOPICS:en:es-ES , OPTIMIZATION_TARGET_PAGE_"
- "ENTITIES,,OPTIMIZATION_TARGET_PAGE_VISIBILITY:zh-TW:",
- }});
-
- {
- auto models = features::GetPageContentModelsToExecute("en-US");
- ASSERT_EQ(2U, models.size());
- ASSERT_EQ(proto::OPTIMIZATION_TARGET_PAGE_TOPICS, models[0]);
- ASSERT_EQ(proto::OPTIMIZATION_TARGET_PAGE_ENTITIES, models[1]);
- }
-
- {
- auto models = features::GetPageContentModelsToExecute("");
- ASSERT_EQ(1U, models.size());
- ASSERT_EQ(proto::OPTIMIZATION_TARGET_PAGE_ENTITIES, models[0]);
- }
-
- {
- auto models = features::GetPageContentModelsToExecute("zh-CN");
- ASSERT_EQ(1U, models.size());
- ASSERT_EQ(proto::OPTIMIZATION_TARGET_PAGE_ENTITIES, models[0]);
- }
-
- {
- auto models = features::GetPageContentModelsToExecute("zh-TW");
- ASSERT_EQ(2U, models.size());
- ASSERT_EQ(proto::OPTIMIZATION_TARGET_PAGE_ENTITIES, models[0]);
- ASSERT_EQ(proto::OPTIMIZATION_TARGET_PAGE_VISIBILITY, models[1]);
- }
+ features::kPageVisibilityPageContentAnnotations,
+ {{"supported_locales", "en,zh-TW"}});
+
+ EXPECT_TRUE(features::ShouldExecutePageVisibilityModelOnPageContent("en-US"));
+ EXPECT_FALSE(features::ShouldExecutePageVisibilityModelOnPageContent(""));
+ EXPECT_FALSE(
+ features::ShouldExecutePageVisibilityModelOnPageContent("zh-CN"));
}
} // namespace
diff --git a/chromium/components/optimization_guide/core/optimization_guide_logger.cc b/chromium/components/optimization_guide/core/optimization_guide_logger.cc
new file mode 100644
index 00000000000..20cdbaf4d1a
--- /dev/null
+++ b/chromium/components/optimization_guide/core/optimization_guide_logger.cc
@@ -0,0 +1,32 @@
+// Copyright 2022 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "components/optimization_guide/core/optimization_guide_logger.h"
+
+OptimizationGuideLogger::OptimizationGuideLogger() = default;
+
+OptimizationGuideLogger::~OptimizationGuideLogger() = default;
+
+void OptimizationGuideLogger::AddObserver(
+ OptimizationGuideLogger::Observer* observer) {
+ observers_.AddObserver(observer);
+}
+
+void OptimizationGuideLogger::RemoveObserver(
+ OptimizationGuideLogger::Observer* observer) {
+ observers_.RemoveObserver(observer);
+}
+
+void OptimizationGuideLogger::OnLogMessageAdded(base::Time event_time,
+ const std::string& source_file,
+ int source_line,
+ const std::string& message) {
+ DCHECK(!observers_.empty());
+ for (Observer& obs : observers_)
+ obs.OnLogMessageAdded(event_time, source_file, source_line, message);
+}
+
+bool OptimizationGuideLogger::ShouldEnableDebugLogs() const {
+ return !observers_.empty();
+}
diff --git a/chromium/components/optimization_guide/core/optimization_guide_logger.h b/chromium/components/optimization_guide/core/optimization_guide_logger.h
new file mode 100644
index 00000000000..d9a2f169ba1
--- /dev/null
+++ b/chromium/components/optimization_guide/core/optimization_guide_logger.h
@@ -0,0 +1,45 @@
+// Copyright 2022 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef COMPONENTS_OPTIMIZATION_GUIDE_CORE_OPTIMIZATION_GUIDE_LOGGER_H_
+#define COMPONENTS_OPTIMIZATION_GUIDE_CORE_OPTIMIZATION_GUIDE_LOGGER_H_
+
+#include <string>
+
+#include "base/observer_list.h"
+#include "base/observer_list_types.h"
+#include "base/time/time.h"
+
+// Interface to record the debug logs and send it to be shown in the
+// optimization guide internals page.
+class OptimizationGuideLogger {
+ public:
+ class Observer : public base::CheckedObserver {
+ public:
+ virtual void OnLogMessageAdded(base::Time event_time,
+ const std::string& source_file,
+ int source_line,
+ const std::string& message) = 0;
+ };
+ OptimizationGuideLogger();
+ ~OptimizationGuideLogger();
+
+ OptimizationGuideLogger(const OptimizationGuideLogger&) = delete;
+ OptimizationGuideLogger& operator=(const OptimizationGuideLogger&) = delete;
+
+ void AddObserver(OptimizationGuideLogger::Observer* observer);
+ void RemoveObserver(OptimizationGuideLogger::Observer* observer);
+ void OnLogMessageAdded(base::Time event_time,
+ const std::string& source_file,
+ int source_line,
+ const std::string& message);
+
+ // Whether debug logs should allowed to be recorded.
+ bool ShouldEnableDebugLogs() const;
+
+ private:
+ base::ObserverList<OptimizationGuideLogger::Observer> observers_;
+};
+
+#endif // COMPONENTS_OPTIMIZATION_GUIDE_CORE_OPTIMIZATION_GUIDE_LOGGER_H_
diff --git a/chromium/components/optimization_guide/core/optimization_guide_permissions_util.cc b/chromium/components/optimization_guide/core/optimization_guide_permissions_util.cc
index d7a71e2785e..fec08802020 100644
--- a/chromium/components/optimization_guide/core/optimization_guide_permissions_util.cc
+++ b/chromium/components/optimization_guide/core/optimization_guide_permissions_util.cc
@@ -7,21 +7,12 @@
#include <memory>
#include "base/feature_list.h"
-#include "components/data_reduction_proxy/core/browser/data_reduction_proxy_settings.h"
#include "components/optimization_guide/core/optimization_guide_features.h"
#include "components/optimization_guide/core/optimization_guide_switches.h"
#include "components/unified_consent/url_keyed_data_collection_consent_helper.h"
namespace {
-bool IsUserDataSaverEnabledAndAllowedToFetchFromRemoteService(
- bool is_off_the_record,
- PrefService* pref_service) {
- // Check if they are a data saver user.
- return data_reduction_proxy::DataReductionProxySettings::
- IsDataSaverEnabledByUser(is_off_the_record, pref_service);
-}
-
bool IsUserConsentedToAnonymousDataCollectionAndAllowedToFetchFromRemoteService(
PrefService* pref_service) {
if (!optimization_guide::features::
@@ -55,10 +46,6 @@ bool IsUserPermittedToFetchFromRemoteOptimizationGuide(
if (features::IsRemoteFetchingExplicitlyAllowedForPerformanceInfo())
return true;
- if (IsUserDataSaverEnabledAndAllowedToFetchFromRemoteService(
- is_off_the_record, pref_service))
- return true;
-
return IsUserConsentedToAnonymousDataCollectionAndAllowedToFetchFromRemoteService(
pref_service);
}
diff --git a/chromium/components/optimization_guide/core/optimization_guide_permissions_util_unittest.cc b/chromium/components/optimization_guide/core/optimization_guide_permissions_util_unittest.cc
index 9e879748e84..015955ccc26 100644
--- a/chromium/components/optimization_guide/core/optimization_guide_permissions_util_unittest.cc
+++ b/chromium/components/optimization_guide/core/optimization_guide_permissions_util_unittest.cc
@@ -7,8 +7,6 @@
#include "base/command_line.h"
#include "base/test/scoped_feature_list.h"
#include "base/test/task_environment.h"
-#include "components/data_reduction_proxy/core/browser/data_reduction_proxy_settings.h"
-#include "components/data_reduction_proxy/core/common/data_reduction_proxy_pref_names.h"
#include "components/optimization_guide/core/optimization_guide_features.h"
#include "components/sync_preferences/testing_pref_service_syncable.h"
#include "components/unified_consent/pref_names.h"
@@ -22,13 +20,6 @@ class OptimizationGuidePermissionsUtilTest : public testing::Test {
void SetUp() override {
unified_consent::UnifiedConsentService::RegisterPrefs(
pref_service_.registry());
- pref_service_.registry()->RegisterBooleanPref(
- data_reduction_proxy::prefs::kDataSaverEnabled, false);
- }
-
- void SetDataSaverEnabled(bool enabled) {
- data_reduction_proxy::DataReductionProxySettings::
- SetDataSaverEnabledForTesting(&pref_service_, enabled);
}
void SetUrlKeyedAnonymizedDataCollectionEnabled(bool enabled) {
@@ -45,53 +36,38 @@ class OptimizationGuidePermissionsUtilTest : public testing::Test {
};
TEST_F(OptimizationGuidePermissionsUtilTest,
- IsUserPermittedToFetchHintsNonDataSaverUser) {
+ IsUserPermittedToFetchHintsDefaultUser) {
base::test::ScopedFeatureList scoped_feature_list;
scoped_feature_list.InitAndEnableFeature(
{optimization_guide::features::kRemoteOptimizationGuideFetching});
- SetDataSaverEnabled(false);
EXPECT_FALSE(IsUserPermittedToFetchFromRemoteOptimizationGuide(
/*is_off_the_record=*/false, pref_service()));
}
-TEST_F(OptimizationGuidePermissionsUtilTest,
- IsUserPermittedToFetchHintsDataSaverUser) {
- base::test::ScopedFeatureList scoped_feature_list;
- scoped_feature_list.InitAndEnableFeature(
- {optimization_guide::features::kRemoteOptimizationGuideFetching});
- SetDataSaverEnabled(true);
-
- EXPECT_TRUE(IsUserPermittedToFetchFromRemoteOptimizationGuide(
- /*is_off_the_record=*/false, pref_service()));
-}
-
TEST_F(
OptimizationGuidePermissionsUtilTest,
- IsUserPermittedToFetchHintsNonDataSaverUserAnonymousDataCollectionEnabledFeatureEnabled) {
+ IsUserPermittedToFetchHintsDefaultUserAnonymousDataCollectionEnabledFeatureEnabled) {
base::test::ScopedFeatureList scoped_feature_list;
scoped_feature_list.InitWithFeatures(
{optimization_guide::features::kRemoteOptimizationGuideFetching,
optimization_guide::features::
kRemoteOptimizationGuideFetchingAnonymousDataConsent},
{});
- SetDataSaverEnabled(false);
SetUrlKeyedAnonymizedDataCollectionEnabled(true);
EXPECT_TRUE(IsUserPermittedToFetchFromRemoteOptimizationGuide(
/*is_off_the_record=*/false, pref_service()));
}
-TEST_F(
- OptimizationGuidePermissionsUtilTest,
- IsUserPermittedToFetchHintsNonDataSaverUserAnonymousDataCollectionDisabled) {
+TEST_F(OptimizationGuidePermissionsUtilTest,
+ IsUserPermittedToFetchHintsDefaultUserAnonymousDataCollectionDisabled) {
base::test::ScopedFeatureList scoped_feature_list;
scoped_feature_list.InitWithFeatures(
{optimization_guide::features::kRemoteOptimizationGuideFetching,
optimization_guide::features::
kRemoteOptimizationGuideFetchingAnonymousDataConsent},
{});
- SetDataSaverEnabled(false);
SetUrlKeyedAnonymizedDataCollectionEnabled(false);
EXPECT_FALSE(IsUserPermittedToFetchFromRemoteOptimizationGuide(
@@ -100,13 +76,12 @@ TEST_F(
TEST_F(
OptimizationGuidePermissionsUtilTest,
- IsUserPermittedToFetchHintsNonDataSaverUserAnonymousDataCollectionEnabledFeatureNotEnabled) {
+ IsUserPermittedToFetchHintsDefaultUserAnonymousDataCollectionEnabledFeatureNotEnabled) {
base::test::ScopedFeatureList scoped_feature_list;
scoped_feature_list.InitWithFeatures(
{optimization_guide::features::kRemoteOptimizationGuideFetching},
{optimization_guide::features::
kRemoteOptimizationGuideFetchingAnonymousDataConsent});
- SetDataSaverEnabled(false);
SetUrlKeyedAnonymizedDataCollectionEnabled(true);
EXPECT_FALSE(IsUserPermittedToFetchFromRemoteOptimizationGuide(
@@ -118,7 +93,6 @@ TEST_F(OptimizationGuidePermissionsUtilTest,
base::test::ScopedFeatureList scoped_feature_list;
scoped_feature_list.InitWithFeatures(
{}, {optimization_guide::features::kRemoteOptimizationGuideFetching});
- SetDataSaverEnabled(true);
SetUrlKeyedAnonymizedDataCollectionEnabled(true);
EXPECT_FALSE(IsUserPermittedToFetchFromRemoteOptimizationGuide(
@@ -133,7 +107,6 @@ TEST_F(OptimizationGuidePermissionsUtilTest,
optimization_guide::features::
kContextMenuPerformanceInfoAndRemoteHintFetching},
{});
- SetDataSaverEnabled(false);
SetUrlKeyedAnonymizedDataCollectionEnabled(false);
EXPECT_TRUE(IsUserPermittedToFetchFromRemoteOptimizationGuide(
@@ -150,7 +123,6 @@ TEST_F(OptimizationGuidePermissionsUtilTest,
optimization_guide::features::
kContextMenuPerformanceInfoAndRemoteHintFetching},
{});
- SetDataSaverEnabled(true);
SetUrlKeyedAnonymizedDataCollectionEnabled(true);
EXPECT_FALSE(IsUserPermittedToFetchFromRemoteOptimizationGuide(
diff --git a/chromium/components/optimization_guide/core/optimization_guide_store.cc b/chromium/components/optimization_guide/core/optimization_guide_store.cc
index 97a50f7812a..7b99acaa5ad 100644
--- a/chromium/components/optimization_guide/core/optimization_guide_store.cc
+++ b/chromium/components/optimization_guide/core/optimization_guide_store.cc
@@ -3,10 +3,13 @@
// found in the LICENSE file.
#include "components/optimization_guide/core/optimization_guide_store.h"
+#include <memory>
+#include <string>
#include "base/bind.h"
#include "base/callback_helpers.h"
#include "base/files/file_util.h"
+#include "base/logging.h"
#include "base/metrics/histogram_functions.h"
#include "base/metrics/histogram_macros.h"
#include "base/sequence_checker.h"
@@ -17,9 +20,11 @@
#include "components/leveldb_proto/public/proto_database_provider.h"
#include "components/leveldb_proto/public/shared_proto_database_client_list.h"
#include "components/optimization_guide/core/memory_hint.h"
+#include "components/optimization_guide/core/model_util.h"
#include "components/optimization_guide/core/optimization_guide_features.h"
#include "components/optimization_guide/core/optimization_guide_util.h"
#include "components/optimization_guide/proto/hint_cache.pb.h"
+#include "third_party/abseil-cpp/absl/types/optional.h"
namespace optimization_guide {
@@ -62,9 +67,7 @@ enum class OptimizationGuideHintCacheLevelDBStoreLoadMetadataResult {
// recorded when it goes out of scope and its destructor is called.
class ScopedLoadMetadataResultRecorder {
public:
- ScopedLoadMetadataResultRecorder()
- : result_(OptimizationGuideHintCacheLevelDBStoreLoadMetadataResult::
- kSuccess) {}
+ ScopedLoadMetadataResultRecorder() = default;
~ScopedLoadMetadataResultRecorder() {
UMA_HISTOGRAM_ENUMERATION(
"OptimizationGuide.HintCacheLevelDBStore.LoadMetadataResult", result_);
@@ -76,7 +79,8 @@ class ScopedLoadMetadataResultRecorder {
}
private:
- OptimizationGuideHintCacheLevelDBStoreLoadMetadataResult result_;
+ OptimizationGuideHintCacheLevelDBStoreLoadMetadataResult result_ =
+ OptimizationGuideHintCacheLevelDBStoreLoadMetadataResult::kSuccess;
};
void RecordStatusChange(OptimizationGuideStore::Status status) {
@@ -289,20 +293,6 @@ void OptimizationGuideStore::PurgeExpiredFetchedHints() {
weak_ptr_factory_.GetWeakPtr()));
}
-void OptimizationGuideStore::PurgeExpiredHostModelFeatures() {
- DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
-
- if (!IsAvailable())
- return;
-
- // Load all the host model features to check their expiry times.
- database_->LoadKeysAndEntriesWithFilter(
- base::BindRepeating(&DatabasePrefixFilter,
- GetHostModelFeaturesEntryKeyPrefix()),
- base::BindOnce(&OptimizationGuideStore::OnLoadEntriesToPurgeExpired,
- weak_ptr_factory_.GetWeakPtr()));
-}
-
void OptimizationGuideStore::PurgeInactiveModels() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
@@ -453,14 +443,6 @@ base::Time OptimizationGuideStore::GetFetchedHintsUpdateTime() const {
return fetched_update_time_;
}
-base::Time OptimizationGuideStore::GetHostModelFeaturesUpdateTime() const {
- // If the store is not available, the metadata entries have not been loaded
- // so there are no host model features.
- if (!IsAvailable())
- return base::Time();
- return host_model_features_update_time_;
-}
-
// static
const char OptimizationGuideStore::kStoreSchemaVersion[] = "1";
@@ -530,14 +512,6 @@ OptimizationGuideStore::GetOptimizationTargetFromPredictionModelEntryKey(
return static_cast<proto::OptimizationTarget>(optimization_target_number);
}
-// static
-OptimizationGuideStore::EntryKeyPrefix
-OptimizationGuideStore::GetHostModelFeaturesEntryKeyPrefix() {
- return base::NumberToString(static_cast<int>(
- OptimizationGuideStore::StoreEntryType::kHostModelFeatures)) +
- kKeySectionDelimiter;
-}
-
void OptimizationGuideStore::UpdateStatus(Status new_status) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
@@ -780,24 +754,6 @@ void OptimizationGuideStore::OnLoadMetadata(
fetched_update_time_ = base::Time();
}
- auto host_model_features_entry = metadata_entries->find(
- GetMetadataTypeEntryKey(MetadataType::kHostModelFeatures));
- bool host_model_features_metadata_loaded = false;
- host_model_features_update_time_ = base::Time();
- if (host_model_features_entry != metadata_entries->end()) {
- DCHECK(host_model_features_entry->second.has_update_time_secs());
- host_model_features_update_time_ = base::Time::FromDeltaSinceWindowsEpoch(
- base::Seconds(host_model_features_entry->second.update_time_secs()));
- host_model_features_metadata_loaded = true;
- }
- // TODO(crbug/1001194): Metrics should be separated so that stores maintaining
- // different information types only record metrics for the types of entries
- // they store.
- UMA_HISTOGRAM_BOOLEAN(
- "OptimizationGuide.PredictionModelStore."
- "HostModelFeaturesLoadMetadataResult",
- host_model_features_metadata_loaded);
-
UpdateStatus(Status::kAvailable);
MaybeLoadEntryKeys(std::move(callback));
}
@@ -963,14 +919,30 @@ void OptimizationGuideStore::OnLoadModelsToBeUpdated(
bool had_entries_to_update_or_remove =
!update_vector->empty() || !remove_vector->empty();
for (const auto& entry : *entries) {
- bool should_delete_download_file = had_entries_to_update_or_remove;
+ absl::optional<std::string> delete_download_file;
+ if (had_entries_to_update_or_remove &&
+ entry.second.has_prediction_model() &&
+ !entry.second.prediction_model().model().download_url().empty()) {
+ delete_download_file =
+ entry.second.prediction_model().model().download_url();
+ }
+
// Only check expiry if we weren't explicitly passed in entries to update or
// remove.
if (!had_entries_to_update_or_remove) {
+ if (entry.second.keep_beyond_valid_duration()) {
+ continue;
+ }
if (entry.second.has_expiry_time_secs()) {
if (entry.second.expiry_time_secs() <= now_since_epoch) {
+ // Update the entry to remove the model.
+ if (entry.second.has_prediction_model() &&
+ !entry.second.prediction_model().model().download_url().empty()) {
+ delete_download_file =
+ entry.second.prediction_model().model().download_url();
+ }
+
remove_vector->push_back(entry.first);
- should_delete_download_file = true;
proto::OptimizationTarget optimization_target =
GetOptimizationTargetFromPredictionModelEntryKey(entry.first);
base::UmaHistogramBoolean(
@@ -984,20 +956,17 @@ void OptimizationGuideStore::OnLoadModelsToBeUpdated(
update_vector->push_back(entry);
update_vector->back().second.set_expiry_time_secs(
now_since_epoch +
- features::StoredModelsInactiveDuration().InSeconds());
+ features::StoredModelsValidDuration().InSeconds());
}
}
// Delete files (the model itself and any additional files) that are
// provided by the model in its directory.
- if (should_delete_download_file && entry.second.has_prediction_model() &&
- !entry.second.prediction_model().model().download_url().empty()) {
- // |GetFilePathFromPredictionModel| only returns nullopt when
- // |model().download_url()| is empty.
+ if (delete_download_file) {
+ // |StringToFilePath| only returns nullopt when
+ // |delete_download_file| is empty.
base::FilePath model_file_path =
- StringToFilePath(
- entry.second.prediction_model().model().download_url())
- .value();
+ StringToFilePath(*delete_download_file).value();
base::FilePath path_to_delete;
// Backwards compatibility: Once upon a time (<M93), model files were
@@ -1043,9 +1012,8 @@ bool OptimizationGuideStore::FindPredictionModelEntryKey(
*out_prediction_model_entry_key =
GetPredictionModelEntryKeyPrefix() +
base::NumberToString(static_cast<int>(optimization_target));
- if (entry_keys_->find(*out_prediction_model_entry_key) != entry_keys_->end())
- return true;
- return false;
+ return entry_keys_->find(*out_prediction_model_entry_key) !=
+ entry_keys_->end();
}
bool OptimizationGuideStore::RemovePredictionModelFromEntryKey(
@@ -1107,6 +1075,17 @@ void OptimizationGuideStore::OnLoadPredictionModel(
std::move(callback).Run(std::move(loaded_prediction_model));
return;
}
+ // Also ensure that nothing is returned if the model is expired.
+ int64_t now_since_epoch =
+ base::Time::Now().ToDeltaSinceWindowsEpoch().InSeconds();
+ if (!entry->keep_beyond_valid_duration() &&
+ entry->expiry_time_secs() <= now_since_epoch) {
+ // Leave the actual model deletions to |PurgeInactiveModels| and return
+ // early.
+ std::unique_ptr<proto::PredictionModel> loaded_prediction_model(nullptr);
+ std::move(callback).Run(std::move(loaded_prediction_model));
+ return;
+ }
std::unique_ptr<proto::PredictionModel> loaded_prediction_model(
entry->release_prediction_model());
@@ -1166,167 +1145,4 @@ void OptimizationGuideStore::OnModelFilePathVerified(
std::move(callback).Run(nullptr);
}
-std::unique_ptr<StoreUpdateData>
-OptimizationGuideStore::CreateUpdateDataForHostModelFeatures(
- base::Time host_model_features_update_time,
- base::Time expiry_time) const {
- // Create and returns a StoreUpdateData object. This object has host model
- // features from the GetModelsResponse moved into and organizes them in a
- // format usable by the store. The object will be stored with
- // UpdateHostModelFeatures().
- return StoreUpdateData::CreateHostModelFeaturesStoreUpdateData(
- host_model_features_update_time, expiry_time);
-}
-
-void OptimizationGuideStore::UpdateHostModelFeatures(
- std::unique_ptr<StoreUpdateData> host_model_features_update_data,
- base::OnceClosure callback) {
- DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
- DCHECK(host_model_features_update_data->update_time());
-
- if (!IsAvailable()) {
- std::move(callback).Run();
- return;
- }
-
- host_model_features_update_time_ =
- *host_model_features_update_data->update_time();
-
- entry_keys_.reset();
-
- // This will remove the host model features metadata entry and insert all the
- // entries currently in |host_model_features_update_data|.
- database_->UpdateEntriesWithRemoveFilter(
- host_model_features_update_data->TakeUpdateEntries(),
- base::BindRepeating(
- &DatabasePrefixFilter,
- GetMetadataTypeEntryKey(MetadataType::kHostModelFeatures)),
- base::BindOnce(&OptimizationGuideStore::OnUpdateStore,
- weak_ptr_factory_.GetWeakPtr(), std::move(callback)));
-}
-
-bool OptimizationGuideStore::FindHostModelFeaturesEntryKey(
- const std::string& host,
- OptimizationGuideStore::EntryKey* out_host_model_features_entry_key) const {
- DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
-
- if (!entry_keys_)
- return false;
-
- return FindEntryKeyForHostWithPrefix(host, out_host_model_features_entry_key,
- GetHostModelFeaturesEntryKeyPrefix());
-}
-
-void OptimizationGuideStore::LoadAllHostModelFeatures(
- AllHostModelFeaturesLoadedCallback callback) {
- DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
-
- if (!IsAvailable()) {
- std::move(callback).Run({});
- return;
- }
-
- // Load all the host model features within the store.
- database_->LoadEntriesWithFilter(
- base::BindRepeating(&DatabasePrefixFilter,
- GetHostModelFeaturesEntryKeyPrefix()),
- base::BindOnce(&OptimizationGuideStore::OnLoadAllHostModelFeatures,
- weak_ptr_factory_.GetWeakPtr(), std::move(callback)));
-}
-
-void OptimizationGuideStore::LoadHostModelFeatures(
- const EntryKey& host_model_features_entry_key,
- HostModelFeaturesLoadedCallback callback) {
- DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
-
- if (!IsAvailable()) {
- std::move(callback).Run({});
- return;
- }
-
- // Load all the host model features within the store.
- database_->GetEntry(
- host_model_features_entry_key,
- base::BindOnce(&OptimizationGuideStore::OnLoadHostModelFeatures,
- weak_ptr_factory_.GetWeakPtr(), std::move(callback)));
-}
-
-void OptimizationGuideStore::OnLoadHostModelFeatures(
- HostModelFeaturesLoadedCallback callback,
- bool success,
- std::unique_ptr<proto::StoreEntry> entry) {
- DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
-
- // If either the request failed or the store was set to unavailable after the
- // request was started, then the loaded host model features should not
- // be considered valid. Reset the entry so that nothing is returned to the
- // requester.
- if (!success || !IsAvailable()) {
- entry.reset();
- }
- if (!entry || !entry->has_host_model_features()) {
- std::move(callback).Run(nullptr);
- return;
- }
-
- std::unique_ptr<proto::HostModelFeatures> loaded_host_model_features(
- entry->release_host_model_features());
- std::move(callback).Run(std::move(loaded_host_model_features));
-}
-
-void OptimizationGuideStore::OnLoadAllHostModelFeatures(
- AllHostModelFeaturesLoadedCallback callback,
- bool success,
- std::unique_ptr<std::vector<proto::StoreEntry>> entries) {
- DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
-
- // If either the request failed or the store was set to unavailable after the
- // request was started, then the loaded host model features should not
- // be considered valid. Reset the entry so that nothing is returned to the
- // requester.
- if (!success || !IsAvailable()) {
- entries.reset();
- }
-
- if (!entries || entries->size() == 0) {
- std::unique_ptr<std::vector<proto::HostModelFeatures>>
- loaded_host_model_features(nullptr);
- std::move(callback).Run(std::move(loaded_host_model_features));
- return;
- }
-
- std::unique_ptr<std::vector<proto::HostModelFeatures>>
- loaded_host_model_features =
- std::make_unique<std::vector<proto::HostModelFeatures>>();
- for (auto& entry : *entries.get()) {
- if (!entry.has_host_model_features())
- continue;
- loaded_host_model_features->emplace_back(entry.host_model_features());
- }
- std::move(callback).Run(std::move(loaded_host_model_features));
-}
-
-void OptimizationGuideStore::ClearHostModelFeaturesFromDatabase() {
- DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
-
- base::UmaHistogramBoolean(
- "OptimizationGuide.ClearHostModelFeatures.StoreAvailable", IsAvailable());
- if (!IsAvailable())
- return;
-
- auto entries_to_save = std::make_unique<EntryVector>();
-
- entry_keys_.reset();
-
- // Removes all |kHostModelFeatures| store entries. OnUpdateStore will handle
- // updating status and re-filling entry_keys with the entries still in the
- // store.
- database_->UpdateEntriesWithRemoveFilter(
- std::move(entries_to_save), // this should be empty.
- base::BindRepeating(&DatabasePrefixFilter,
- GetHostModelFeaturesEntryKeyPrefix()),
- base::BindOnce(&OptimizationGuideStore::OnUpdateStore,
- weak_ptr_factory_.GetWeakPtr(), base::DoNothing()));
-}
-
} // namespace optimization_guide
diff --git a/chromium/components/optimization_guide/core/optimization_guide_store.h b/chromium/components/optimization_guide/core/optimization_guide_store.h
index 2b7a51f85af..38b70f16814 100644
--- a/chromium/components/optimization_guide/core/optimization_guide_store.h
+++ b/chromium/components/optimization_guide/core/optimization_guide_store.h
@@ -40,10 +40,6 @@ class OptimizationGuideStore {
base::OnceCallback<void(const std::string&, std::unique_ptr<MemoryHint>)>;
using PredictionModelLoadedCallback =
base::OnceCallback<void(std::unique_ptr<proto::PredictionModel>)>;
- using HostModelFeaturesLoadedCallback =
- base::OnceCallback<void(std::unique_ptr<proto::HostModelFeatures>)>;
- using AllHostModelFeaturesLoadedCallback = base::OnceCallback<void(
- std::unique_ptr<std::vector<proto::HostModelFeatures>>)>;
using EntryKey = std::string;
using StoreEntryProtoDatabase =
leveldb_proto::ProtoDatabase<proto::StoreEntry>;
@@ -80,16 +76,14 @@ class OptimizationGuideStore {
// cannot be changed, but new types can be added to the end.
// StoreEntryType should remain synchronized with the
// HintCacheStoreEntryType in enums.xml.
- // Also ensure to add to the OptimizationGuide.StoreEntryTypes histogram
- // suffixes if adding a new one.
enum class StoreEntryType {
kEmpty = 0,
kMetadata = 1,
kComponentHint = 2,
kFetchedHint = 3,
kPredictionModel = 4,
- kHostModelFeatures = 5,
- kMaxValue = kHostModelFeatures,
+ kDeprecatedHostModelFeatures = 5, // deprecated.
+ kMaxValue = kDeprecatedHostModelFeatures,
};
OptimizationGuideStore(
@@ -172,13 +166,9 @@ class OptimizationGuideStore {
// removed.
void PurgeExpiredFetchedHints();
- // Removes all host model features that have expired from the store.
- // |entry_keys_| is updated after the expired host model features are
- // removed.
- void PurgeExpiredHostModelFeatures();
-
// Removes all models that have not been loaded in the max inactive duration
// configured. |entry_keys| is updated after the inactive models are removed.
+ // Respects models' |keep_beyond_valid_duration| setting.
void PurgeInactiveModels();
// Creates and returns a StoreUpdateData object for Prediction Models. This
@@ -196,9 +186,9 @@ class OptimizationGuideStore {
std::unique_ptr<StoreUpdateData> prediction_models_update_data,
base::OnceClosure callback);
- // Finds the entry key for the prediction model if it is known to the store.
- // Returns true if an entry key is found and |out_prediction_model_entry_key|
- // is populated with the matching key.
+ // Finds the entry key for the prediction model if it is still valid in the
+ // store. Returns true if an entry key is valid and
+ // |out_prediction_model_entry_key| is populated with any matching key.
// Virtualized for testing.
virtual bool FindPredictionModelEntryKey(
proto::OptimizationTarget optimization_target,
@@ -218,60 +208,11 @@ class OptimizationGuideStore {
// false otherwise.
bool RemovePredictionModelFromEntryKey(const EntryKey& entry_key);
- // Creates and returns a StoreUpdateData object for host model features. This
- // object is used to collect a batch of host model features in a format that
- // is usable to update the store on a background thread. This is always
- // created when host model features have been successfully fetched from the
- // remote Optimization Guide Service so the store can update old host model
- // features.
- std::unique_ptr<StoreUpdateData> CreateUpdateDataForHostModelFeatures(
- base::Time host_model_features_update_time,
- base::Time expiry_time) const;
-
- // Updates the host model features contained in the store. The callback is run
- // asynchronously after the database stores the host model features.
- // Virtualized for testing.
- virtual void UpdateHostModelFeatures(
- std::unique_ptr<StoreUpdateData> host_model_features_update_data,
- base::OnceClosure callback);
-
- // Finds the entry key for the host model features for |host| if it is known
- // to the store. Returns true if an entry key is found and
- // |out_host_model_features_entry_key| is populated with the matching key.
- bool FindHostModelFeaturesEntryKey(
- const std::string& host,
- OptimizationGuideStore::EntryKey* out_host_model_features_entry_key)
- const;
-
- // Loads the host model features specified by |host_model_features_entry_key|.
- // After the load finishes, the host model features data is passed to
- // |callback|. In the case where the host model features cannot be loaded, the
- // callback is run with a nullptr. Depending on the load result, the callback
- // may be synchronous or asynchronous.
- void LoadHostModelFeatures(const EntryKey& host_model_features_entry_key,
- HostModelFeaturesLoadedCallback callback);
-
- // Loads all the host model features known to the store. After the load
- // finishes, the host model features data is passed back to |callback|. In the
- // case where the host model features cannot be loaded, the callback is run
- // with a nullptr. Depending on the load result, the callback may be
- // synchronous or asynchronous.
- // Virtualized for testing.
- virtual void LoadAllHostModelFeatures(
- AllHostModelFeaturesLoadedCallback callback);
-
- // Returns the time that the host model features in the store can be updated.
- // If |this| is not available, base::Time() is returned.
- base::Time GetHostModelFeaturesUpdateTime() const;
-
// Removes fetched hints whose keys are in |hint_keys| and runs |on_success|
// if successful, otherwise the callback is not run.
void RemoveFetchedHintsByKey(base::OnceClosure on_success,
const base::flat_set<std::string>& hint_keys);
- // Clears all host model features from the database and resets the entry keys.
- void ClearHostModelFeaturesFromDatabase();
-
// Returns true if the current status is Status::kAvailable.
bool IsAvailable() const;
@@ -304,7 +245,7 @@ class OptimizationGuideStore {
kSchema = 1,
kComponent = 2,
kFetched = 3,
- kHostModelFeatures = 4,
+ kDeprecatedHostModelFeatures = 4, // deprecated.
};
// Current schema version of the hint cache store. When this is changed,
@@ -331,9 +272,6 @@ class OptimizationGuideStore {
// Returns prefix of the key of every prediction model entry: "4_".
static EntryKeyPrefix GetPredictionModelEntryKeyPrefix();
- // Returns prefix of the key of every host model features entry: "5_".
- static EntryKeyPrefix GetHostModelFeaturesEntryKeyPrefix();
-
// Returns the OptimizationTarget from |prediction_model_entry_key|.
static proto::OptimizationTarget
GetOptimizationTargetFromPredictionModelEntryKey(
@@ -466,29 +404,6 @@ class OptimizationGuideStore {
PredictionModelLoadedCallback callback,
bool success);
- // Callback that runs after a host model features entry is loaded from the
- // database. If there's currently an in-flight update, then the data could be
- // invalidated, so loaded host model features data is discarded. Otherwise,
- // the host model features are released into the callback, allowing the caller
- // to own the host model features without copying it. Regardless of the
- // success or failure of retrieving the key, the callback always runs (it
- // simply runs with a nullptr on failure).
- void OnLoadHostModelFeatures(HostModelFeaturesLoadedCallback callback,
- bool success,
- std::unique_ptr<proto::StoreEntry> entry);
-
- // Callback that runs after all the host model features entries are loaded
- // from the database. If there's currently an in-flight update, then the data
- // could be invalidated, so loaded host model features data is discarded.
- // Otherwise, the host model features are released into the callback, allowing
- // the caller to own the host model features without copying it. Regardless of
- // the success or failure of retrieving the key, the callback always runs (it
- // simply runs with a nullptr on failure).
- void OnLoadAllHostModelFeatures(
- AllHostModelFeaturesLoadedCallback callback,
- bool success,
- std::unique_ptr<std::vector<proto::StoreEntry>> entry);
-
// Proto database used by the store.
std::unique_ptr<StoreEntryProtoDatabase> database_;
diff --git a/chromium/components/optimization_guide/core/optimization_guide_store_unittest.cc b/chromium/components/optimization_guide/core/optimization_guide_store_unittest.cc
index bf60d8a865d..6baa370ec39 100644
--- a/chromium/components/optimization_guide/core/optimization_guide_store_unittest.cc
+++ b/chromium/components/optimization_guide/core/optimization_guide_store_unittest.cc
@@ -15,6 +15,7 @@
#include "base/test/metrics/histogram_tester.h"
#include "base/test/task_environment.h"
#include "components/leveldb_proto/testing/fake_db.h"
+#include "components/optimization_guide/core/model_util.h"
#include "components/optimization_guide/core/optimization_guide_features.h"
#include "components/optimization_guide/core/optimization_guide_util.h"
#include "components/optimization_guide/core/store_update_data.h"
@@ -60,8 +61,8 @@ std::unique_ptr<proto::PredictionModel> CreatePredictionModel() {
model_info->set_version(1);
model_info->set_optimization_target(
proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD);
- model_info->add_supported_model_types(
- proto::ModelType::MODEL_TYPE_DECISION_TREE);
+ model_info->add_supported_model_engine_versions(
+ proto::ModelEngineVersion::MODEL_ENGINE_VERSION_DECISION_TREE);
return prediction_model;
}
@@ -87,8 +88,6 @@ class OptimizationGuideStoreTest : public testing::Test {
MetadataSchemaState state,
absl::optional<size_t> component_hint_count = absl::optional<size_t>(),
absl::optional<base::Time> fetched_hints_update =
- absl::optional<base::Time>(),
- absl::optional<base::Time> host_model_features_update =
absl::optional<base::Time>()) {
db_store_.clear();
@@ -135,13 +134,6 @@ class OptimizationGuideStoreTest : public testing::Test {
.set_update_time_secs(
fetched_hints_update->ToDeltaSinceWindowsEpoch().InSeconds());
}
- if (host_model_features_update) {
- db_store_[OptimizationGuideStore::GetMetadataTypeEntryKey(
- OptimizationGuideStore::MetadataType::kHostModelFeatures)]
- .set_update_time_secs(
- host_model_features_update->ToDeltaSinceWindowsEpoch()
- .InSeconds());
- }
}
// Moves the specified number of component hints into the update data.
@@ -176,38 +168,27 @@ class OptimizationGuideStoreTest : public testing::Test {
StoreUpdateData* update_data,
optimization_guide::proto::OptimizationTarget optimization_target,
absl::optional<base::FilePath> model_file_path = absl::nullopt,
- base::flat_set<base::FilePath> additional_file_paths = {}) {
+ absl::optional<proto::ModelInfo> info = absl::nullopt,
+ absl::optional<base::Time> expiry_time = absl::nullopt) {
std::unique_ptr<optimization_guide::proto::PredictionModel>
prediction_model = CreatePredictionModel();
+ if (info)
+ prediction_model->mutable_model_info()->MergeFrom(*info);
prediction_model->mutable_model_info()->set_optimization_target(
optimization_target);
+ if (expiry_time) {
+ auto diff = expiry_time.value() - base::Time::Now();
+ prediction_model->mutable_model_info()
+ ->mutable_valid_duration()
+ ->set_seconds(diff.InSeconds());
+ }
if (model_file_path) {
prediction_model->mutable_model()->set_download_url(
FilePathToString(*model_file_path));
}
- for (const base::FilePath& additional_file : additional_file_paths) {
- prediction_model->mutable_model_info()
- ->add_additional_files()
- ->set_file_path(FilePathToString(additional_file));
- }
update_data->CopyPredictionModelIntoUpdateData(*prediction_model);
}
- // Moves |host_model_features_count| into |update_data|.
- void SeedHostModelFeaturesUpdateData(StoreUpdateData* update_data,
- size_t host_model_features_count) {
- for (size_t i = 0; i < host_model_features_count; i++) {
- std::string host = GetHost(i);
- proto::HostModelFeatures host_model_features;
- proto::ModelFeature* model_feature =
- host_model_features.add_model_features();
- model_feature->set_feature_name("host_feat1");
- model_feature->set_double_value(2.0);
- host_model_features.set_host(host);
- update_data->CopyHostModelFeaturesIntoUpdateData(host_model_features);
- }
- }
-
void CreateDatabase() {
// Reset everything.
db_ = nullptr;
@@ -304,35 +285,12 @@ class OptimizationGuideStoreTest : public testing::Test {
}
}
- void UpdateHostModelFeatures(
- std::unique_ptr<StoreUpdateData> host_model_features_data,
- bool update_success = true,
- bool load_host_model_features_entry_keys_success = true) {
- EXPECT_CALL(*this, OnUpdateStore());
- guide_store()->UpdateHostModelFeatures(
- std::move(host_model_features_data),
- base::BindOnce(&OptimizationGuideStoreTest::OnUpdateStore,
- base::Unretained(this)));
- // OnUpdateStore callback
- db()->UpdateCallback(update_success);
- if (update_success) {
- // OnLoadEntryKeys callback
- db()->LoadCallback(load_host_model_features_entry_keys_success);
- }
- }
-
void ClearFetchedHintsFromDatabase() {
guide_store()->ClearFetchedHintsFromDatabase();
db()->UpdateCallback(true);
db()->LoadCallback(true);
}
- void ClearHostModelFeaturesFromDatabase() {
- guide_store()->ClearHostModelFeaturesFromDatabase();
- db()->UpdateCallback(true);
- db()->LoadCallback(true);
- }
-
void PurgeExpiredFetchedHints() {
guide_store()->PurgeExpiredFetchedHints();
@@ -344,17 +302,6 @@ class OptimizationGuideStoreTest : public testing::Test {
db()->LoadCallback(true);
}
- void PurgeExpiredHostModelFeatures() {
- guide_store()->PurgeExpiredHostModelFeatures();
-
- // OnLoadExpiredEntriesToPurge
- db()->LoadCallback(true);
- // OnUpdateStore
- db()->UpdateCallback(true);
- // OnLoadEntryKeys callback
- db()->LoadCallback(true);
- }
-
void PurgeInactiveModels() {
guide_store()->PurgeInactiveModels();
@@ -388,23 +335,6 @@ class OptimizationGuideStoreTest : public testing::Test {
}
}
- // Verifies that the host model features metadata has the expected next update
- // time.
- void ExpectHostModelFeaturesMetadata(base::Time update_time) const {
- const auto& metadata_entry =
- db_store_.find(OptimizationGuideStore::GetMetadataTypeEntryKey(
- OptimizationGuideStore::MetadataType::kHostModelFeatures));
- if (metadata_entry != db_store_.end()) {
- // The next update time should have same time up to the second as the
- // metadata entry is stored in seconds.
- EXPECT_TRUE(base::Time::FromDeltaSinceWindowsEpoch(base::Seconds(
- metadata_entry->second.update_time_secs())) -
- update_time <
- base::Seconds(1));
- } else {
- FAIL() << "No host model features metadata found";
- }
- }
// Verifies that the component metadata has the expected version and all
// expected component hints are present.
void ExpectComponentHintsPresent(const std::string& version,
@@ -458,14 +388,6 @@ class OptimizationGuideStoreTest : public testing::Test {
MemoryHint* last_loaded_hint() { return last_loaded_hint_.get(); }
- proto::HostModelFeatures* last_loaded_host_model_features() {
- return last_loaded_host_model_features_.get();
- }
-
- std::vector<proto::HostModelFeatures>* last_loaded_all_host_model_features() {
- return last_loaded_all_host_model_features_.get();
- }
-
proto::PredictionModel* last_loaded_prediction_model() {
return last_loaded_prediction_model_.get();
}
@@ -476,18 +398,6 @@ class OptimizationGuideStoreTest : public testing::Test {
last_loaded_hint_ = std::move(loaded_hint);
}
- void OnHostModelFeaturesLoaded(
- std::unique_ptr<proto::HostModelFeatures> loaded_host_model_features) {
- last_loaded_host_model_features_ = std::move(loaded_host_model_features);
- }
-
- void OnAllHostModelFeaturesLoaded(
- std::unique_ptr<std::vector<proto::HostModelFeatures>>
- loaded_all_host_model_features) {
- last_loaded_all_host_model_features_ =
- std::move(loaded_all_host_model_features);
- }
-
void OnPredictionModelLoaded(
std::unique_ptr<proto::PredictionModel> loaded_prediction_model) {
last_loaded_prediction_model_ = std::move(loaded_prediction_model);
@@ -507,9 +417,6 @@ class OptimizationGuideStoreTest : public testing::Test {
OptimizationGuideStore::EntryKey last_loaded_entry_key_;
std::unique_ptr<MemoryHint> last_loaded_hint_;
- std::unique_ptr<proto::HostModelFeatures> last_loaded_host_model_features_;
- std::unique_ptr<std::vector<proto::HostModelFeatures>>
- last_loaded_all_host_model_features_;
std::unique_ptr<proto::PredictionModel> last_loaded_prediction_model_;
};
@@ -737,10 +644,6 @@ TEST_F(OptimizationGuideStoreTest,
histogram_tester.ExpectBucketCount(
"OptimizationGuide.HintCacheLevelDBStore.LoadMetadataResult",
0 /* kSuccess */, 1);
- histogram_tester.ExpectBucketCount(
- "OptimizationGuide.PredictionModelStore."
- "HostModelFeaturesLoadMetadataResult",
- false, 1);
histogram_tester.ExpectBucketCount(
"OptimizationGuide.HintCacheLevelDBStore.Status", 0 /* kUninitialized */,
@@ -839,11 +742,6 @@ TEST_F(OptimizationGuideStoreTest, InitializeSucceededWithValidSchemaEntry) {
6 /* kComponentAndFetchedMetadataMissing*/, 1);
histogram_tester.ExpectBucketCount(
- "OptimizationGuide.PredictionModelStore."
- "HostModelFeaturesLoadMetadataResult",
- false, 1);
-
- histogram_tester.ExpectBucketCount(
"OptimizationGuide.HintCacheLevelDBStore.Status", 0 /* kUninitialized */,
1);
histogram_tester.ExpectBucketCount(
@@ -901,9 +799,6 @@ TEST_F(OptimizationGuideStoreTest, InitializeSucceededWithPurgeExistingData) {
EXPECT_TRUE(IsMetadataSchemaEntryKeyPresent());
- histogram_tester.ExpectTotalCount(
- "OptimizationGuide.HintCacheLevelDBStore.LoadMetadataResult", 0);
-
histogram_tester.ExpectBucketCount(
"OptimizationGuide.HintCacheLevelDBStore.Status", 0 /* kUninitialized */,
1);
@@ -923,15 +818,14 @@ TEST_F(OptimizationGuideStoreTest,
MetadataSchemaState schema_state = MetadataSchemaState::kValid;
size_t component_hint_count = 10;
SeedInitialData(schema_state, component_hint_count,
- base::Time().Now(), /* fetch_update_time */
- base::Time().Now() /* host_model_features_update_time */);
+ base::Time().Now() /* fetch_update_time */);
CreateDatabase();
InitializeStore(schema_state);
// The store should contain the schema metadata entry, the component metadata
// entry, and all of the initial component hints.
EXPECT_EQ(GetDBStoreEntryCount(),
- static_cast<size_t>(component_hint_count + 4));
+ static_cast<size_t>(component_hint_count + 3));
EXPECT_EQ(GetStoreEntryKeyCount(), component_hint_count);
EXPECT_TRUE(IsMetadataSchemaEntryKeyPresent());
@@ -942,11 +836,6 @@ TEST_F(OptimizationGuideStoreTest,
0 /* kSuccess */, 1);
histogram_tester.ExpectBucketCount(
- "OptimizationGuide.PredictionModelStore."
- "HostModelFeaturesLoadMetadataResult",
- true, 1);
-
- histogram_tester.ExpectBucketCount(
"OptimizationGuide.HintCacheLevelDBStore.Status", 0 /* kUninitialized */,
1);
histogram_tester.ExpectBucketCount(
@@ -988,11 +877,6 @@ TEST_F(OptimizationGuideStoreTest,
6 /* kComponentAndFetchedMetadataMissing*/, 0);
histogram_tester.ExpectBucketCount(
- "OptimizationGuide.PredictionModelStore."
- "HostModelFeaturesLoadMetadataResult",
- false, 1);
-
- histogram_tester.ExpectBucketCount(
"OptimizationGuide.HintCacheLevelDBStore.Status", 0 /* kUninitialized */,
1);
histogram_tester.ExpectBucketCount(
@@ -1027,11 +911,6 @@ TEST_F(OptimizationGuideStoreTest,
4 /* kComponentMetadataMissing*/, 1);
histogram_tester.ExpectBucketCount(
- "OptimizationGuide.PredictionModelStore."
- "HostModelFeaturesLoadMetadataResult",
- false, 1);
-
- histogram_tester.ExpectBucketCount(
"OptimizationGuide.HintCacheLevelDBStore.Status", 0 /* kUninitialized */,
1);
histogram_tester.ExpectBucketCount(
@@ -1945,7 +1824,7 @@ TEST_F(OptimizationGuideStoreTest, FindPredictionModelEntryKey) {
std::unique_ptr<StoreUpdateData> update_data =
guide_store()->CreateUpdateDataForPredictionModels(
update_time +
- optimization_guide::features::StoredModelsInactiveDuration());
+ optimization_guide::features::StoredModelsValidDuration());
ASSERT_TRUE(update_data);
SeedPredictionModelUpdateData(update_data.get(),
proto::OPTIMIZATION_TARGET_UNKNOWN);
@@ -1971,7 +1850,7 @@ TEST_F(OptimizationGuideStoreTest,
std::unique_ptr<StoreUpdateData> update_data =
guide_store()->CreateUpdateDataForPredictionModels(
update_time +
- optimization_guide::features::StoredModelsInactiveDuration());
+ optimization_guide::features::StoredModelsValidDuration());
ASSERT_TRUE(update_data);
SeedPredictionModelUpdateData(update_data.get(),
proto::OPTIMIZATION_TARGET_UNKNOWN);
@@ -1994,7 +1873,7 @@ TEST_F(OptimizationGuideStoreTest, FindAndRemovePredictionModelEntryKey) {
std::unique_ptr<StoreUpdateData> update_data =
guide_store()->CreateUpdateDataForPredictionModels(
update_time +
- optimization_guide::features::StoredModelsInactiveDuration());
+ optimization_guide::features::StoredModelsValidDuration());
ASSERT_TRUE(update_data);
SeedPredictionModelUpdateData(update_data.get(),
proto::OPTIMIZATION_TARGET_UNKNOWN);
@@ -2034,7 +1913,7 @@ TEST_F(OptimizationGuideStoreTest, LoadPredictionModel) {
std::unique_ptr<StoreUpdateData> update_data =
guide_store()->CreateUpdateDataForPredictionModels(
update_time +
- optimization_guide::features::StoredModelsInactiveDuration());
+ optimization_guide::features::StoredModelsValidDuration());
ASSERT_TRUE(update_data);
SeedPredictionModelUpdateData(update_data.get(),
proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD);
@@ -2076,6 +1955,46 @@ TEST_F(OptimizationGuideStoreTest, LoadPredictionModelOnUnavailableStore) {
EXPECT_FALSE(last_loaded_prediction_model());
}
+TEST_F(OptimizationGuideStoreTest, LoadPredictionModelOnExpiredModel) {
+ base::HistogramTester histogram_tester;
+ size_t initial_hint_count = 10;
+ MetadataSchemaState schema_state = MetadataSchemaState::kValid;
+ SeedInitialData(schema_state, initial_hint_count);
+ CreateDatabase();
+ InitializeStore(schema_state);
+
+ // Add an update with models that are "inactive".
+ base::Time update_time = base::Time().Now();
+ std::unique_ptr<StoreUpdateData> update_data =
+ guide_store()->CreateUpdateDataForPredictionModels(
+ update_time -
+ optimization_guide::features::StoredModelsValidDuration());
+ ASSERT_TRUE(update_data);
+ SeedPredictionModelUpdateData(
+ update_data.get(), proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD,
+ /*model_file_path=*/absl::nullopt,
+ /*info=*/{},
+ update_time - optimization_guide::features::StoredModelsValidDuration());
+ UpdatePredictionModels(std::move(update_data));
+ // Since models haven't been purged yet it will initially show up as valid.
+ OptimizationGuideStore::EntryKey entry_key;
+ EXPECT_TRUE(guide_store()->FindPredictionModelEntryKey(
+ proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, &entry_key));
+ guide_store()->LoadPredictionModel(
+ entry_key,
+ base::BindOnce(&OptimizationGuideStoreTest::OnPredictionModelLoaded,
+ base::Unretained(this)));
+ // OnPredictionModelLoaded callback
+ db()->GetCallback(true);
+ // After load completes, the key will still exist until after purge.
+ EXPECT_TRUE(guide_store()->FindPredictionModelEntryKey(
+ proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, &entry_key));
+
+ // Verify that the OnPredictionModelLoaded callback runs when the store is
+ // unavailable and that the prediction model was correctly not available.
+ EXPECT_FALSE(last_loaded_prediction_model());
+}
+
TEST_F(OptimizationGuideStoreTest, LoadPredictionModelWithUpdateInFlight) {
base::HistogramTester histogram_tester;
MetadataSchemaState schema_state = MetadataSchemaState::kValid;
@@ -2087,7 +2006,7 @@ TEST_F(OptimizationGuideStoreTest, LoadPredictionModelWithUpdateInFlight) {
std::unique_ptr<StoreUpdateData> update_data =
guide_store()->CreateUpdateDataForPredictionModels(
update_time +
- optimization_guide::features::StoredModelsInactiveDuration());
+ optimization_guide::features::StoredModelsValidDuration());
ASSERT_TRUE(update_data);
SeedPredictionModelUpdateData(update_data.get(),
proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD);
@@ -2117,7 +2036,7 @@ TEST_F(OptimizationGuideStoreTest,
std::unique_ptr<StoreUpdateData> update_data =
guide_store()->CreateUpdateDataForPredictionModels(
update_time +
- optimization_guide::features::StoredModelsInactiveDuration());
+ optimization_guide::features::StoredModelsValidDuration());
ASSERT_TRUE(update_data);
SeedPredictionModelUpdateData(update_data.get(),
proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD,
@@ -2164,7 +2083,7 @@ TEST_F(OptimizationGuideStoreTest,
std::unique_ptr<StoreUpdateData> update_data =
guide_store()->CreateUpdateDataForPredictionModels(
update_time +
- optimization_guide::features::StoredModelsInactiveDuration());
+ optimization_guide::features::StoredModelsValidDuration());
ASSERT_TRUE(update_data);
base::FilePath model_file_path = temp_dir().AppendASCII("model.tflite");
@@ -2173,9 +2092,12 @@ TEST_F(OptimizationGuideStoreTest,
base::FilePath additional_file_path = temp_dir().AppendASCII("doesntexist");
- SeedPredictionModelUpdateData(
- update_data.get(), proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD,
- temp_dir().AppendASCII("doesntexist"), {additional_file_path});
+ proto::ModelInfo info;
+ info.add_additional_files()->set_file_path(
+ FilePathToString(additional_file_path));
+ SeedPredictionModelUpdateData(update_data.get(),
+ proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD,
+ temp_dir().AppendASCII("doesntexist"), info);
UpdatePredictionModels(std::move(update_data));
OptimizationGuideStore::EntryKey entry_key;
@@ -2221,7 +2143,7 @@ TEST_F(OptimizationGuideStoreTest,
std::unique_ptr<StoreUpdateData> update_data =
guide_store()->CreateUpdateDataForPredictionModels(
update_time +
- optimization_guide::features::StoredModelsInactiveDuration());
+ optimization_guide::features::StoredModelsValidDuration());
ASSERT_TRUE(update_data);
base::FilePath model_file_path = temp_dir().AppendASCII("model.tflite");
@@ -2232,10 +2154,12 @@ TEST_F(OptimizationGuideStoreTest,
temp_dir().AppendASCII("additional_file.txt");
ASSERT_EQ(static_cast<int32_t>(3),
base::WriteFile(additional_file_path, "ah!", 3));
-
+ proto::ModelInfo info;
+ info.add_additional_files()->set_file_path(
+ FilePathToString(additional_file_path));
SeedPredictionModelUpdateData(update_data.get(),
proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD,
- model_file_path, {additional_file_path});
+ model_file_path, info);
UpdatePredictionModels(std::move(update_data));
OptimizationGuideStore::EntryKey entry_key;
@@ -2277,7 +2201,7 @@ TEST_F(OptimizationGuideStoreTest,
std::unique_ptr<StoreUpdateData> update_data =
guide_store()->CreateUpdateDataForPredictionModels(
update_time +
- optimization_guide::features::StoredModelsInactiveDuration());
+ optimization_guide::features::StoredModelsValidDuration());
ASSERT_TRUE(update_data);
base::FilePath file_path = temp_dir().AppendASCII("file");
ASSERT_EQ(static_cast<int32_t>(3), base::WriteFile(file_path, "boo", 3));
@@ -2324,7 +2248,7 @@ TEST_F(OptimizationGuideStoreTest, UpdatePredictionModelsDeletesOldFile) {
std::unique_ptr<StoreUpdateData> update_data =
guide_store()->CreateUpdateDataForPredictionModels(
update_time +
- optimization_guide::features::StoredModelsInactiveDuration());
+ optimization_guide::features::StoredModelsValidDuration());
ASSERT_TRUE(update_data);
base::FilePath old_file_path = old_dir.AppendASCII("model.tflite");
ASSERT_EQ(static_cast<int32_t>(3), base::WriteFile(old_file_path, "boo", 3));
@@ -2341,7 +2265,7 @@ TEST_F(OptimizationGuideStoreTest, UpdatePredictionModelsDeletesOldFile) {
std::unique_ptr<StoreUpdateData> update_data2 =
guide_store()->CreateUpdateDataForPredictionModels(
update_time +
- optimization_guide::features::StoredModelsInactiveDuration());
+ optimization_guide::features::StoredModelsValidDuration());
ASSERT_TRUE(update_data2);
base::FilePath new_file_path = new_dir.AppendASCII("model.tflite");
ASSERT_EQ(static_cast<int32_t>(3), base::WriteFile(new_file_path, "boo", 3));
@@ -2380,7 +2304,7 @@ TEST_F(OptimizationGuideStoreTest,
std::unique_ptr<StoreUpdateData> update_data =
guide_store()->CreateUpdateDataForPredictionModels(
update_time +
- optimization_guide::features::StoredModelsInactiveDuration());
+ optimization_guide::features::StoredModelsValidDuration());
ASSERT_TRUE(update_data);
base::FilePath old_file_path = old_dir.AppendASCII("model_v1.tflite");
ASSERT_EQ(static_cast<int32_t>(3), base::WriteFile(old_file_path, "boo", 3));
@@ -2397,7 +2321,7 @@ TEST_F(OptimizationGuideStoreTest,
std::unique_ptr<StoreUpdateData> update_data2 =
guide_store()->CreateUpdateDataForPredictionModels(
update_time +
- optimization_guide::features::StoredModelsInactiveDuration());
+ optimization_guide::features::StoredModelsValidDuration());
ASSERT_TRUE(update_data2);
base::FilePath new_file_path = new_dir.Append(GetBaseFileNameForModels());
ASSERT_EQ(static_cast<int32_t>(3), base::WriteFile(new_file_path, "boo", 3));
@@ -2429,7 +2353,7 @@ TEST_F(OptimizationGuideStoreTest, RemovePredictionModelEntryKeyDeletesFile) {
std::unique_ptr<StoreUpdateData> update_data =
guide_store()->CreateUpdateDataForPredictionModels(
update_time +
- optimization_guide::features::StoredModelsInactiveDuration());
+ optimization_guide::features::StoredModelsValidDuration());
ASSERT_TRUE(update_data);
base::FilePath file_path = temp_dir().AppendASCII("file");
ASSERT_EQ(static_cast<int32_t>(3), base::WriteFile(file_path, "boo", 3));
@@ -2460,203 +2384,79 @@ TEST_F(OptimizationGuideStoreTest, RemovePredictionModelEntryKeyDeletesFile) {
EXPECT_FALSE(base::PathExists(file_path));
}
-TEST_F(OptimizationGuideStoreTest, HostModelFeaturesMetadataStored) {
- MetadataSchemaState schema_state = MetadataSchemaState::kValid;
- base::Time update_time = base::Time().Now();
- SeedInitialData(schema_state, 10, update_time,
- base::Time().Now() /* host_model_features_update_time */);
- CreateDatabase();
- InitializeStore(schema_state);
-
- ExpectHostModelFeaturesMetadata(update_time);
-}
-
-TEST_F(OptimizationGuideStoreTest, FindEntryKeyForHostModelFeatures) {
- MetadataSchemaState schema_state = MetadataSchemaState::kValid;
- size_t update_host_model_features_count = 5;
- base::Time update_time = base::Time().Now();
- SeedInitialData(schema_state, 0,
- base::Time().Now() /* host_model_features_update_time */);
- CreateDatabase();
- InitializeStore(schema_state);
-
- std::unique_ptr<StoreUpdateData> update_data =
- guide_store()->CreateUpdateDataForHostModelFeatures(
- update_time, update_time +
- optimization_guide::features::
- StoredHostModelFeaturesFreshnessDuration());
- ASSERT_TRUE(update_data);
- SeedHostModelFeaturesUpdateData(update_data.get(),
- update_host_model_features_count);
- UpdateHostModelFeatures(std::move(update_data));
-
- for (size_t i = 0; i < update_host_model_features_count; ++i) {
- std::string host = GetHost(i);
- OptimizationGuideStore::EntryKey entry_key;
- bool success =
- guide_store()->FindHostModelFeaturesEntryKey(host, &entry_key);
- EXPECT_EQ(success, i < update_host_model_features_count);
- }
-}
-
-TEST_F(OptimizationGuideStoreTest, LoadHostModelFeaturesForHost) {
+TEST_F(OptimizationGuideStoreTest, PurgeInactiveModels) {
base::HistogramTester histogram_tester;
- size_t update_host_model_features_count = 5;
- MetadataSchemaState schema_state = MetadataSchemaState::kValid;
- base::Time update_time = base::Time().Now();
- SeedInitialData(schema_state, 0, base::Time().Now());
- CreateDatabase();
- InitializeStore(schema_state);
-
- std::unique_ptr<StoreUpdateData> update_data =
- guide_store()->CreateUpdateDataForHostModelFeatures(
- update_time, update_time +
- optimization_guide::features::
- StoredHostModelFeaturesFreshnessDuration());
- ASSERT_TRUE(update_data);
- SeedHostModelFeaturesUpdateData(update_data.get(),
- update_host_model_features_count);
- UpdateHostModelFeatures(std::move(update_data));
-
- for (size_t i = 0; i < update_host_model_features_count; ++i) {
- std::string host = GetHost(i);
- OptimizationGuideStore::EntryKey entry_key;
- bool success =
- guide_store()->FindHostModelFeaturesEntryKey(host, &entry_key);
- EXPECT_TRUE(success);
-
- guide_store()->LoadHostModelFeatures(
- entry_key,
- base::BindOnce(&OptimizationGuideStoreTest::OnHostModelFeaturesLoaded,
- base::Unretained(this)));
- // OnPredictionModelLoaded callback
- db()->GetCallback(true);
-
- if (!last_loaded_host_model_features()) {
- FAIL() << "Loaded host model features was null for entry key: "
- << entry_key;
- }
-
- EXPECT_EQ(last_loaded_host_model_features()->host(), host);
- }
-}
-
-TEST_F(OptimizationGuideStoreTest, LoadAllHostModelFeatures) {
- base::HistogramTester histogram_tester;
- size_t update_host_model_features_count = 5;
MetadataSchemaState schema_state = MetadataSchemaState::kValid;
- base::Time update_time = base::Time().Now();
- SeedInitialData(schema_state, 0, base::Time().Now());
+ SeedInitialData(schema_state, 0);
CreateDatabase();
InitializeStore(schema_state);
- std::unique_ptr<StoreUpdateData> update_data =
- guide_store()->CreateUpdateDataForHostModelFeatures(
- update_time, update_time +
- optimization_guide::features::
- StoredHostModelFeaturesFreshnessDuration());
- ASSERT_TRUE(update_data);
- SeedHostModelFeaturesUpdateData(update_data.get(),
- update_host_model_features_count);
- UpdateHostModelFeatures(std::move(update_data));
- guide_store()->LoadAllHostModelFeatures(
- base::BindOnce(&OptimizationGuideStoreTest::OnAllHostModelFeaturesLoaded,
- base::Unretained(this)));
-
- // OnAllHostModelFeaturesLoaded callback
- db()->LoadCallback(true);
-
- std::vector<proto::HostModelFeatures>* all_host_model_features =
- last_loaded_all_host_model_features();
- EXPECT_TRUE(all_host_model_features);
- EXPECT_EQ(update_host_model_features_count, all_host_model_features->size());
-
- // Build a list of the hosts that are stored in the store.
- base::flat_set<std::string> hosts = {};
- for (size_t i = 0; i < update_host_model_features_count; i++)
- hosts.insert(GetHost(i));
-
- // Make sure all of the hosts of the host model features are returned.
- for (const auto& host_model_features : *all_host_model_features)
- EXPECT_NE(hosts.find(host_model_features.host()), hosts.end());
-}
-
-TEST_F(OptimizationGuideStoreTest, ClearHostModelFeatures) {
- base::HistogramTester histogram_tester;
- size_t update_host_model_features_count = 5;
- MetadataSchemaState schema_state = MetadataSchemaState::kValid;
+ // Add an update with models that are "inactive".
base::Time update_time = base::Time().Now();
- SeedInitialData(schema_state, 0, base::Time().Now());
- CreateDatabase();
- InitializeStore(schema_state);
-
std::unique_ptr<StoreUpdateData> update_data =
- guide_store()->CreateUpdateDataForHostModelFeatures(
- update_time, update_time +
- optimization_guide::features::
- StoredHostModelFeaturesFreshnessDuration());
+ guide_store()->CreateUpdateDataForPredictionModels(
+ update_time -
+ optimization_guide::features::StoredModelsValidDuration());
ASSERT_TRUE(update_data);
- SeedHostModelFeaturesUpdateData(update_data.get(),
- update_host_model_features_count);
- UpdateHostModelFeatures(std::move(update_data));
-
- for (size_t i = 0; i < update_host_model_features_count; ++i) {
- std::string host = GetHost(i);
- OptimizationGuideStore::EntryKey entry_key;
- EXPECT_TRUE(guide_store()->FindHostModelFeaturesEntryKey(host, &entry_key));
- }
+ base::FilePath old_file_path = temp_dir().AppendASCII("model_v1.tflite");
+ ASSERT_EQ(static_cast<int32_t>(3), base::WriteFile(old_file_path, "boo", 3));
+ SeedPredictionModelUpdateData(
+ update_data.get(), proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD,
+ old_file_path,
+ /*info=*/{},
+ update_time - optimization_guide::features::StoredModelsValidDuration());
+ UpdatePredictionModels(std::move(update_data));
- // Remove the host model features from the OptimizationGuideStore.
- ClearHostModelFeaturesFromDatabase();
- histogram_tester.ExpectBucketCount(
- "OptimizationGuide.ClearHostModelFeatures.StoreAvailable", true, 1);
+ // Add an update with models that are "active".
+ std::unique_ptr<StoreUpdateData> update_data2 =
+ guide_store()->CreateUpdateDataForPredictionModels(
+ update_time +
+ optimization_guide::features::StoredModelsValidDuration());
+ ASSERT_TRUE(update_data2);
+ SeedPredictionModelUpdateData(update_data2.get(),
+ proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION);
+ UpdatePredictionModels(std::move(update_data2));
- for (size_t i = 0; i < update_host_model_features_count; ++i) {
- std::string host = GetHost(i);
- OptimizationGuideStore::EntryKey entry_key;
- EXPECT_FALSE(
- guide_store()->FindHostModelFeaturesEntryKey(host, &entry_key));
- }
-}
+ // Make sure both models are in the store.
+ OptimizationGuideStore::EntryKey entry_key;
+ bool success = guide_store()->FindPredictionModelEntryKey(
+ proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, &entry_key);
+ ASSERT_TRUE(success);
+ success = guide_store()->FindPredictionModelEntryKey(
+ proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION, &entry_key);
+ ASSERT_TRUE(success);
-TEST_F(OptimizationGuideStoreTest, PurgeExpiredHostModelFeatures) {
- base::HistogramTester histogram_tester;
- size_t update_host_model_features_count = 5;
- MetadataSchemaState schema_state = MetadataSchemaState::kValid;
- base::Time update_time = base::Time().Now();
- SeedInitialData(schema_state, 0, base::Time().Now());
- CreateDatabase();
- InitializeStore(schema_state);
+ PurgeInactiveModels();
+ RunUntilIdle();
+ // The expired model should be removed.
+ EXPECT_FALSE(guide_store()->FindPredictionModelEntryKey(
+ proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, &entry_key));
+ EXPECT_FALSE(base::PathExists(old_file_path));
- std::unique_ptr<StoreUpdateData> update_data =
- guide_store()->CreateUpdateDataForHostModelFeatures(
- update_time, update_time -
- optimization_guide::features::
- StoredHostModelFeaturesFreshnessDuration());
- ASSERT_TRUE(update_data);
- SeedHostModelFeaturesUpdateData(update_data.get(),
- update_host_model_features_count);
- UpdateHostModelFeatures(std::move(update_data));
+ // Should not purge models that are still active.
+ EXPECT_TRUE(guide_store()->FindPredictionModelEntryKey(
+ proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION, &entry_key));
- for (size_t i = 0; i < update_host_model_features_count; ++i) {
- std::string host = GetHost(i);
- OptimizationGuideStore::EntryKey entry_key;
- EXPECT_TRUE(guide_store()->FindHostModelFeaturesEntryKey(host, &entry_key));
- }
+ histogram_tester.ExpectUniqueSample(
+ "OptimizationGuide.PredictionModelExpired.PainfulPageLoad", true, 1);
+ histogram_tester.ExpectTotalCount(
+ "OptimizationGuide.PredictionModelExpired.LanguageDetection", 0);
+}
- // Remove expired host model features from the opt. guide store.
- PurgeExpiredHostModelFeatures();
+struct ValidityTestCase {
+ std::string test_name;
+ bool keep_beyond_valid_duration;
+ bool initially_expired;
+ bool expect_kept;
+};
- for (size_t i = 0; i < update_host_model_features_count; ++i) {
- std::string host = GetHost(i);
- OptimizationGuideStore::EntryKey entry_key;
- EXPECT_FALSE(
- guide_store()->FindHostModelFeaturesEntryKey(host, &entry_key));
- }
-}
+class OptimizationGuideStoreValidityTest
+ : public OptimizationGuideStoreTest,
+ public ::testing::WithParamInterface<ValidityTestCase> {};
-TEST_F(OptimizationGuideStoreTest, PurgeInactiveModels) {
+TEST_P(OptimizationGuideStoreValidityTest, PurgeInactiveModels) {
+ const ValidityTestCase& test_case = GetParam();
base::HistogramTester histogram_tester;
MetadataSchemaState schema_state = MetadataSchemaState::kValid;
@@ -2664,22 +2464,30 @@ TEST_F(OptimizationGuideStoreTest, PurgeInactiveModels) {
CreateDatabase();
InitializeStore(schema_state);
- // Add an update with models that are "inactive".
+ // Add an update with one model according to ValidityTestCase settings.
base::Time update_time = base::Time().Now();
+ if (test_case.initially_expired) {
+ update_time -= optimization_guide::features::StoredModelsValidDuration();
+ } else {
+ update_time += optimization_guide::features::StoredModelsValidDuration();
+ }
std::unique_ptr<StoreUpdateData> update_data =
- guide_store()->CreateUpdateDataForPredictionModels(
- update_time -
- optimization_guide::features::StoredModelsInactiveDuration());
+ guide_store()->CreateUpdateDataForPredictionModels(update_time);
ASSERT_TRUE(update_data);
+ proto::ModelInfo info;
+ info.set_keep_beyond_valid_duration(test_case.keep_beyond_valid_duration);
+ base::FilePath old_file_path = temp_dir().AppendASCII("model_v1.tflite");
+ ASSERT_EQ(static_cast<int32_t>(3), base::WriteFile(old_file_path, "boo", 3));
SeedPredictionModelUpdateData(update_data.get(),
- proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD);
+ proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD,
+ old_file_path, info, update_time);
UpdatePredictionModels(std::move(update_data));
- // Add an update with models that are "active".
+ // Add an update with models that are "active" and should be unaffected.
std::unique_ptr<StoreUpdateData> update_data2 =
guide_store()->CreateUpdateDataForPredictionModels(
- update_time +
- optimization_guide::features::StoredModelsInactiveDuration());
+ base::Time().Now() +
+ optimization_guide::features::StoredModelsValidDuration());
ASSERT_TRUE(update_data2);
SeedPredictionModelUpdateData(update_data2.get(),
proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION);
@@ -2690,22 +2498,67 @@ TEST_F(OptimizationGuideStoreTest, PurgeInactiveModels) {
bool success = guide_store()->FindPredictionModelEntryKey(
proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, &entry_key);
ASSERT_TRUE(success);
+ EXPECT_TRUE(base::PathExists(old_file_path));
+
success = guide_store()->FindPredictionModelEntryKey(
proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION, &entry_key);
ASSERT_TRUE(success);
PurgeInactiveModels();
-
- EXPECT_FALSE(guide_store()->FindPredictionModelEntryKey(
- proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, &entry_key));
- // Should not purge models that are still active.
+ RunUntilIdle();
+ // Verify that the model file, entry key and histogram match expectations for
+ // PageLoad.
+ EXPECT_EQ(test_case.expect_kept,
+ guide_store()->FindPredictionModelEntryKey(
+ proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, &entry_key));
+ EXPECT_EQ(test_case.expect_kept, base::PathExists(old_file_path));
+
+ if (test_case.expect_kept) {
+ histogram_tester.ExpectTotalCount(
+ "OptimizationGuide.PredictionModelExpired.PainfulPageLoad", 0);
+ } else {
+ histogram_tester.ExpectTotalCount(
+ "OptimizationGuide.PredictionModelExpired.PainfulPageLoad", 1);
+ }
+ // Verify that the other model is not deleted.
EXPECT_TRUE(guide_store()->FindPredictionModelEntryKey(
proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION, &entry_key));
-
- histogram_tester.ExpectUniqueSample(
- "OptimizationGuide.PredictionModelExpired.PainfulPageLoad", true, 1);
histogram_tester.ExpectTotalCount(
"OptimizationGuide.PredictionModelExpired.LanguageDetection", 0);
}
+INSTANTIATE_TEST_SUITE_P(
+ OptimizationGuideStoreValidityTests,
+ OptimizationGuideStoreValidityTest,
+ testing::ValuesIn<ValidityTestCase>({
+ {
+ "KeepDespiteInvalidModel",
+ /*keep_beyond_valid_duration=*/true,
+ /*initially_expired=*/true,
+ /*expect_kept=*/true,
+ },
+ {
+ "KeepAndInitiallyValid",
+ /*keep_beyond_valid_duration=*/true,
+ /*initially_expired=*/false,
+ /*expect_kept=*/true,
+ },
+ {
+ "DeleteAndInitiallyValid",
+ /*keep_beyond_valid_duration=*/false,
+ /*initially_expired=*/false,
+ /*expect_kept=*/true,
+ },
+ // Only in this case should the model be removed.
+ {
+ "DeleteAndInvalidModel",
+ /*keep_beyond_valid_duration=*/false,
+ /*initially_expired=*/true,
+ /*expect_kept=*/false,
+ },
+ }),
+ [](const testing::TestParamInfo<
+ OptimizationGuideStoreValidityTest::ParamType>& info) {
+ return info.param.test_name;
+ });
} // namespace optimization_guide
diff --git a/chromium/components/optimization_guide/core/optimization_guide_switches.cc b/chromium/components/optimization_guide/core/optimization_guide_switches.cc
index 8c55155bf31..e2026ddc1cf 100644
--- a/chromium/components/optimization_guide/core/optimization_guide_switches.cc
+++ b/chromium/components/optimization_guide/core/optimization_guide_switches.cc
@@ -27,12 +27,6 @@ const char kHintsProtoOverride[] = "optimization_guide_hints_override";
// hosts.
const char kFetchHintsOverride[] = "optimization-guide-fetch-hints-override";
-// Overrides scheduling and time delays for fetching prediction models and host
-// model features. This causes a prediction model and host model features fetch
-// immediately on start up.
-const char kFetchModelsAndHostModelFeaturesOverrideTimer[] =
- "optimization-guide-fetch-models-and-features-override";
-
// Overrides the hints fetch scheduling and delay, causing a hints fetch
// immediately on start up using the TopHostProvider. This is meant for testing.
const char kFetchHintsOverrideTimer[] =
@@ -87,6 +81,14 @@ const char kModelOverride[] = "optimization-guide-model-override";
// Triggers validation of the model. Used for manual testing.
const char kModelValidate[] = "optimization-guide-model-validate";
+// Prevents any models from being executing when in annotating a batch
+// of visits. This is used for testing only.
+const char kStopHistoryVisitBatchAnnotateForTesting[] =
+ "stop-history-visit-batch-annotate";
+
+const char kPageContentAnnotationsLoggingEnabled[] =
+ "enable-page-content-annotations-logging";
+
bool IsHintComponentProcessingDisabled() {
return base::CommandLine::ForCurrentProcess()->HasSwitch(kHintsProtoOverride);
}
@@ -134,11 +136,6 @@ bool ShouldOverrideFetchHintsTimer() {
kFetchHintsOverrideTimer);
}
-bool ShouldOverrideFetchModelsAndFeaturesTimer() {
- return base::CommandLine::ForCurrentProcess()->HasSwitch(
- kFetchModelsAndHostModelFeaturesOverrideTimer);
-}
-
std::unique_ptr<optimization_guide::proto::Configuration>
ParseComponentConfigFromCommandLine() {
base::CommandLine* cmd_line = base::CommandLine::ForCurrentProcess();
@@ -191,7 +188,7 @@ bool ShouldValidateModel() {
}
absl::optional<std::string> GetModelOverride() {
-#if defined(OS_WIN)
+#if BUILDFLAG(IS_WIN)
// TODO(crbug/1227996): The parsing below is not supported on Windows because
// ':' is used as a delimiter, but this must be used in the absolute file path
// on Windows.
@@ -206,5 +203,17 @@ absl::optional<std::string> GetModelOverride() {
#endif
}
+bool StopHistoryVisitBatchAnnotateForTesting() {
+ base::CommandLine* command_line = base::CommandLine::ForCurrentProcess();
+ if (command_line->HasSwitch(kStopHistoryVisitBatchAnnotateForTesting))
+ return true;
+ return false;
+}
+
+bool ShouldLogPageContentAnnotationsInput() {
+ return base::CommandLine::ForCurrentProcess()->HasSwitch(
+ kPageContentAnnotationsLoggingEnabled);
+}
+
} // namespace switches
} // namespace optimization_guide
diff --git a/chromium/components/optimization_guide/core/optimization_guide_switches.h b/chromium/components/optimization_guide/core/optimization_guide_switches.h
index 188d06c574b..c7e80ceeabb 100644
--- a/chromium/components/optimization_guide/core/optimization_guide_switches.h
+++ b/chromium/components/optimization_guide/core/optimization_guide_switches.h
@@ -22,7 +22,6 @@ namespace switches {
extern const char kHintsProtoOverride[];
extern const char kFetchHintsOverride[];
extern const char kFetchHintsOverrideTimer[];
-extern const char kFetchModelsAndHostModelFeaturesOverrideTimer[];
extern const char kOptimizationGuideServiceGetHintsURL[];
extern const char kOptimizationGuideServiceGetModelsURL[];
extern const char kOptimizationGuideServiceAPIKey[];
@@ -34,6 +33,8 @@ extern const char kDisableModelDownloadVerificationForTesting[];
extern const char kModelOverride[];
extern const char kDebugLoggingEnabled[];
extern const char kModelValidate[];
+extern const char kStopHistoryVisitBatchAnnotateForTesting[];
+extern const char kPageContentAnnotationsLoggingEnabled[];
// Returns whether the hint component should be processed.
// Available hint components are only processed if a proto override isn't being
@@ -58,10 +59,6 @@ ParseHintsFetchOverrideFromCommandLine();
// Whether the hints fetcher timer should be overridden.
bool ShouldOverrideFetchHintsTimer();
-// Whether the prediction model and host model features fetcher timer should be
-// overridden.
-bool ShouldOverrideFetchModelsAndFeaturesTimer();
-
// Attempts to parse a base64 encoded Optimization Guide Configuration proto
// from the command line. If no proto is given or if it is encoded incorrectly,
// nullptr is returned.
@@ -92,6 +89,13 @@ absl::optional<std::string> GetModelOverride();
// Returns true if debug logs are enabled for the optimization guide.
bool IsDebugLogsEnabled();
+// Whether to prevent annotations from happening when in a batch. For testing
+// purposes only.
+bool StopHistoryVisitBatchAnnotateForTesting();
+
+// Returns true if page content annotations input should be logged.
+bool ShouldLogPageContentAnnotationsInput();
+
} // namespace switches
} // namespace optimization_guide
diff --git a/chromium/components/optimization_guide/core/optimization_guide_switches_unittest.cc b/chromium/components/optimization_guide/core/optimization_guide_switches_unittest.cc
index ab401708455..267633cb400 100644
--- a/chromium/components/optimization_guide/core/optimization_guide_switches_unittest.cc
+++ b/chromium/components/optimization_guide/core/optimization_guide_switches_unittest.cc
@@ -14,7 +14,7 @@
namespace optimization_guide {
namespace switches {
-#if !defined(OS_WIN)
+#if !BUILDFLAG(IS_WIN)
TEST(OptimizationGuideSwitchesTest, ParseHintsFetchOverrideFromCommandLine) {
base::CommandLine::ForCurrentProcess()->AppendSwitchASCII(kFetchHintsOverride,
diff --git a/chromium/components/optimization_guide/core/optimization_guide_test_util.cc b/chromium/components/optimization_guide/core/optimization_guide_test_util.cc
index f1e88b67197..e349570a811 100644
--- a/chromium/components/optimization_guide/core/optimization_guide_test_util.cc
+++ b/chromium/components/optimization_guide/core/optimization_guide_test_util.cc
@@ -9,7 +9,7 @@
namespace optimization_guide {
-#if defined(OS_WIN)
+#if BUILDFLAG(IS_WIN)
const char kTestAbsoluteFilePath[] = "C:\\absolute/file/path";
const char kTestRelativeFilePath[] = "relative/file/path";
#else
diff --git a/chromium/components/optimization_guide/core/optimization_guide_util.cc b/chromium/components/optimization_guide/core/optimization_guide_util.cc
index 392f806ac5a..15acd17ad35 100644
--- a/chromium/components/optimization_guide/core/optimization_guide_util.cc
+++ b/chromium/components/optimization_guide/core/optimization_guide_util.cc
@@ -19,50 +19,6 @@
namespace optimization_guide {
-// These names are persisted to histograms, so don't change them.
-std::string GetStringNameForOptimizationTarget(
- optimization_guide::proto::OptimizationTarget optimization_target) {
- switch (optimization_target) {
- case proto::OPTIMIZATION_TARGET_UNKNOWN:
- return "Unknown";
- case proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD:
- return "PainfulPageLoad";
- case proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION:
- return "LanguageDetection";
- case proto::OPTIMIZATION_TARGET_PAGE_TOPICS:
- return "PageTopics";
- case proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB:
- return "SegmentationNewTab";
- case proto::OPTIMIZATION_TARGET_SEGMENTATION_SHARE:
- return "SegmentationShare";
- case proto::OPTIMIZATION_TARGET_SEGMENTATION_VOICE:
- return "SegmentationVoice";
- case proto::OPTIMIZATION_TARGET_MODEL_VALIDATION:
- return "ModelValidation";
- case proto::OPTIMIZATION_TARGET_PAGE_ENTITIES:
- return "PageEntities";
- case proto::OPTIMIZATION_TARGET_NOTIFICATION_PERMISSION_PREDICTIONS:
- return "NotificationPermissions";
- case proto::OPTIMIZATION_TARGET_SEGMENTATION_DUMMY:
- return "SegmentationDummyFeature";
- case proto::OPTIMIZATION_TARGET_SEGMENTATION_CHROME_START_ANDROID:
- return "SegmentationChromeStartAndroid";
- case proto::OPTIMIZATION_TARGET_SEGMENTATION_QUERY_TILES:
- return "SegmentationQueryTiles";
- case proto::OPTIMIZATION_TARGET_PAGE_VISIBILITY:
- return "PageVisibility";
- case proto::OPTIMIZATION_TARGET_AUTOFILL_ASSISTANT:
- return "AutofillAssistant";
- case proto::OPTIMIZATION_TARGET_PAGE_TOPICS_V2:
- return "PageTopicsV2";
- // Whenever a new value is added, make sure to add it to the OptTarget
- // variant list in
- // //tools/metrics/histograms/metadata/optimization/histograms.xml.
- }
- NOTREACHED();
- return std::string();
-}
-
bool IsHostValidToFetchFromRemoteOptimizationGuide(const std::string& host) {
if (net::HostStringIsLocalhost(host))
return false;
@@ -108,29 +64,6 @@ GetActiveFieldTrialsAllowedForFetch() {
return filtered_active_field_trials;
}
-absl::optional<base::FilePath> StringToFilePath(const std::string& str_path) {
- if (str_path.empty())
- return absl::nullopt;
-
-#if defined(OS_WIN)
- return base::FilePath(base::UTF8ToWide(str_path));
-#else
- return base::FilePath(str_path);
-#endif
-}
-
-std::string FilePathToString(const base::FilePath& file_path) {
-#if defined(OS_WIN)
- return base::WideToUTF8(file_path.value());
-#else
- return file_path.value();
-#endif
-}
-
-base::FilePath GetBaseFileNameForModels() {
- return base::FilePath(FILE_PATH_LITERAL("model.tflite"));
-}
-
std::string GetStringForOptimizationGuideDecision(
OptimizationGuideDecision decision) {
switch (decision) {
@@ -149,7 +82,7 @@ absl::optional<
std::pair<std::string, absl::optional<optimization_guide::proto::Any>>>
GetModelOverrideForOptimizationTarget(
optimization_guide::proto::OptimizationTarget optimization_target) {
-#if defined(OS_WIN)
+#if BUILDFLAG(IS_WIN)
// TODO(crbug/1227996): The parsing below is not supported on Windows because
// ':' is used as a delimiter, but this must be used in the absolute file path
// on Windows.
diff --git a/chromium/components/optimization_guide/core/optimization_guide_util.h b/chromium/components/optimization_guide/core/optimization_guide_util.h
index f425dde56a1..90b70f34bc1 100644
--- a/chromium/components/optimization_guide/core/optimization_guide_util.h
+++ b/chromium/components/optimization_guide/core/optimization_guide_util.h
@@ -7,24 +7,28 @@
#include <string>
-#include "base/files/file_path.h"
#include "base/strings/string_split.h"
+#include "base/time/time.h"
#include "components/optimization_guide/core/optimization_guide_enums.h"
#include "components/optimization_guide/proto/common_types.pb.h"
#include "components/optimization_guide/proto/models.pb.h"
#include "third_party/abseil-cpp/absl/types/optional.h"
+#define OPTIMIZATION_GUIDE_LOG(optimization_guide_logger, message) \
+ do { \
+ if (optimization_guide_logger && \
+ optimization_guide_logger->ShouldEnableDebugLogs()) { \
+ optimization_guide_logger->OnLogMessageAdded( \
+ base::Time::Now(), __FILE__, __LINE__, message); \
+ } \
+ if (optimization_guide::switches::IsDebugLogsEnabled()) \
+ DVLOG(0) << message; \
+ } while (0)
+
namespace optimization_guide {
enum class OptimizationGuideDecision;
-// Returns the string than can be used to record histograms for the optimization
-// target. If adding a histogram to use the string or adding an optimization
-// target, update the OptimizationGuide.OptimizationTargets histogram suffixes
-// in histograms.xml.
-std::string GetStringNameForOptimizationTarget(
- proto::OptimizationTarget optimization_target);
-
// Returns false if the host is an IP address, localhosts, or an invalid
// host that is not supported by the remote optimization guide.
bool IsHostValidToFetchFromRemoteOptimizationGuide(const std::string& host);
@@ -34,18 +38,6 @@ bool IsHostValidToFetchFromRemoteOptimizationGuide(const std::string& host);
google::protobuf::RepeatedPtrField<proto::FieldTrial>
GetActiveFieldTrialsAllowedForFetch();
-// Returns the file path represented by the given string, handling platform
-// differences in the conversion. nullopt is only returned iff the passed string
-// is empty.
-absl::optional<base::FilePath> StringToFilePath(const std::string& str_path);
-
-// Returns a string representation of the given |file_path|, handling platform
-// differences in the conversion.
-std::string FilePathToString(const base::FilePath& file_path);
-
-// Returns the base file name to use for storing all prediction models.
-base::FilePath GetBaseFileNameForModels();
-
// Validates that the metadata stored in |any_metadata_| is of the same type
// and is parseable as |T|. Will return metadata if all checks pass.
template <class T,
diff --git a/chromium/components/optimization_guide/core/optimization_guide_util_unittest.cc b/chromium/components/optimization_guide/core/optimization_guide_util_unittest.cc
index 1d400833462..f3a7fee11ae 100644
--- a/chromium/components/optimization_guide/core/optimization_guide_util_unittest.cc
+++ b/chromium/components/optimization_guide/core/optimization_guide_util_unittest.cc
@@ -63,7 +63,7 @@ TEST(OptimizationGuideUtilTest, ParsedAnyMetadataTest) {
EXPECT_TRUE(parsed_subresource.preconnect_only());
}
-#if !defined(OS_WIN)
+#if !BUILDFLAG(IS_WIN)
TEST(OptimizationGuideUtilTest,
GetModelOverrideForOptimizationTargetSwitchNotSet) {
diff --git a/chromium/components/optimization_guide/core/optimization_hints_component_update_listener.cc b/chromium/components/optimization_guide/core/optimization_hints_component_update_listener.cc
index 4ea36ecd4cb..019c0d464ef 100644
--- a/chromium/components/optimization_guide/core/optimization_hints_component_update_listener.cc
+++ b/chromium/components/optimization_guide/core/optimization_hints_component_update_listener.cc
@@ -5,6 +5,7 @@
#include "components/optimization_guide/core/optimization_hints_component_update_listener.h"
#include "base/metrics/histogram_functions.h"
+#include "base/no_destructor.h"
namespace optimization_guide {
diff --git a/chromium/components/optimization_guide/core/page_content_annotation_job.cc b/chromium/components/optimization_guide/core/page_content_annotation_job.cc
index 4ac77c0091b..45104edb920 100644
--- a/chromium/components/optimization_guide/core/page_content_annotation_job.cc
+++ b/chromium/components/optimization_guide/core/page_content_annotation_job.cc
@@ -5,6 +5,7 @@
#include "components/optimization_guide/core/page_content_annotation_job.h"
#include "base/check_op.h"
+#include "base/metrics/histogram_functions.h"
namespace optimization_guide {
@@ -14,11 +15,35 @@ PageContentAnnotationJob::PageContentAnnotationJob(
AnnotationType type)
: on_complete_callback_(std::move(on_complete_callback)),
type_(type),
- inputs_(inputs.begin(), inputs.end()) {
+ inputs_(inputs.begin(), inputs.end()),
+ job_creation_time_(base::TimeTicks::Now()) {
DCHECK(!inputs_.empty());
}
-PageContentAnnotationJob::~PageContentAnnotationJob() = default;
+PageContentAnnotationJob::~PageContentAnnotationJob() {
+ if (!job_execution_start_time_)
+ return;
+
+ base::TimeDelta job_scheduling_wait_time =
+ *job_execution_start_time_ - job_creation_time_;
+ base::TimeDelta job_exec_time =
+ base::TimeTicks::Now() - *job_execution_start_time_;
+
+ base::UmaHistogramMediumTimes(
+ "OptimizationGuide.PageContentAnnotations.JobExecutionTime." +
+ AnnotationTypeToString(type()),
+ job_exec_time);
+
+ base::UmaHistogramMediumTimes(
+ "OptimizationGuide.PageContentAnnotations.JobScheduleTime." +
+ AnnotationTypeToString(type()),
+ job_scheduling_wait_time);
+
+ base::UmaHistogramBoolean(
+ "OptimizationGuide.PageContentAnnotations.BatchSuccess." +
+ AnnotationTypeToString(type()),
+ HadAnySuccess());
+}
void PageContentAnnotationJob::FillWithNullOutputs() {
while (auto input = GetNextInput()) {
@@ -59,6 +84,10 @@ size_t PageContentAnnotationJob::CountOfRemainingNonNullInputs() const {
}
absl::optional<std::string> PageContentAnnotationJob::GetNextInput() {
+ if (!job_execution_start_time_) {
+ job_execution_start_time_ = base::TimeTicks::Now();
+ }
+
if (inputs_.empty()) {
return absl::nullopt;
}
@@ -72,4 +101,20 @@ void PageContentAnnotationJob::PostNewResult(
results_.push_back(result);
}
+bool PageContentAnnotationJob::HadAnySuccess() const {
+ for (const BatchAnnotationResult& result : results_) {
+ if (result.type() == AnnotationType::kPageTopics && result.topics()) {
+ return true;
+ }
+ if (result.type() == AnnotationType::kPageEntities && result.entities()) {
+ return true;
+ }
+ if (result.type() == AnnotationType::kContentVisibility &&
+ result.visibility_score()) {
+ return true;
+ }
+ }
+ return false;
+}
+
} // namespace optimization_guide
diff --git a/chromium/components/optimization_guide/core/page_content_annotation_job.h b/chromium/components/optimization_guide/core/page_content_annotation_job.h
index 989736df1e1..9ba68090c00 100644
--- a/chromium/components/optimization_guide/core/page_content_annotation_job.h
+++ b/chromium/components/optimization_guide/core/page_content_annotation_job.h
@@ -10,6 +10,7 @@
#include <vector>
#include "base/callback.h"
+#include "base/time/time.h"
#include "components/optimization_guide/core/page_content_annotations_common.h"
#include "third_party/abseil-cpp/absl/types/optional.h"
@@ -21,8 +22,6 @@ namespace optimization_guide {
// container that matches the I/O of a single call to the PCA Service.
class PageContentAnnotationJob {
public:
- using WeightedCategories = std::vector<WeightedString>;
-
PageContentAnnotationJob(BatchAnnotationCallback on_complete_callback,
const std::vector<std::string>& inputs,
AnnotationType type);
@@ -46,6 +45,10 @@ class PageContentAnnotationJob {
// Posts a new result after an execution has completed.
void PostNewResult(const BatchAnnotationResult& result);
+ // Returns true if any element of |results_| was a successful execution. We
+ // expect that if one result is successful, many more will be as well.
+ bool HadAnySuccess() const;
+
AnnotationType type() const { return type_; }
PageContentAnnotationJob(const PageContentAnnotationJob&) = delete;
@@ -65,6 +68,12 @@ class PageContentAnnotationJob {
// Filled by |PostNewResult| with the complete annotations, specified by
// |type_|.
std::vector<BatchAnnotationResult> results_;
+
+ // The time the job was constructed.
+ const base::TimeTicks job_creation_time_;
+
+ // Set when |GetNextInput| is called for the first time.
+ absl::optional<base::TimeTicks> job_execution_start_time_;
};
} // namespace optimization_guide
diff --git a/chromium/components/optimization_guide/core/page_content_annotation_job_executor.cc b/chromium/components/optimization_guide/core/page_content_annotation_job_executor.cc
index 38f617cef94..779524c94d2 100644
--- a/chromium/components/optimization_guide/core/page_content_annotation_job_executor.cc
+++ b/chromium/components/optimization_guide/core/page_content_annotation_job_executor.cc
@@ -73,6 +73,8 @@ void PageContentAnnotationJobExecutor::OnJobExecutionComplete(
std::unique_ptr<PageContentAnnotationJob> job) {
job->OnComplete();
// Intentionally reset |job| here to make lifetime clearer and less bug-prone.
+ // Note that the job dtor also records some timing metrics which is better to
+ // do now rather than after the following callback.
job.reset();
std::move(on_job_complete_callback_from_caller).Run();
diff --git a/chromium/components/optimization_guide/core/page_content_annotation_job_executor_unittest.cc b/chromium/components/optimization_guide/core/page_content_annotation_job_executor_unittest.cc
index d16ffb6859f..5fe0a6a48a6 100644
--- a/chromium/components/optimization_guide/core/page_content_annotation_job_executor_unittest.cc
+++ b/chromium/components/optimization_guide/core/page_content_annotation_job_executor_unittest.cc
@@ -15,7 +15,7 @@
namespace optimization_guide {
namespace {
-const std::vector<WeightedString> kOutput{WeightedString("output", 1.0)};
+const std::vector<WeightedIdentifier> kOutput{WeightedIdentifier(1337, 1.0)};
}
class TestJobExecutor : public PageContentAnnotationJobExecutor {
diff --git a/chromium/components/optimization_guide/core/page_content_annotations_common.cc b/chromium/components/optimization_guide/core/page_content_annotations_common.cc
index 08bd362e404..62f637d7443 100644
--- a/chromium/components/optimization_guide/core/page_content_annotations_common.cc
+++ b/chromium/components/optimization_guide/core/page_content_annotations_common.cc
@@ -13,6 +13,9 @@
namespace optimization_guide {
+// Each of these string values is used in UMA histograms so please update the
+// variants there when any changes are made.
+// //tools/metrics/histograms/metadata/optimization/histograms.xml
std::string AnnotationTypeToString(AnnotationType type) {
switch (type) {
case AnnotationType::kUnknown:
@@ -26,26 +29,25 @@ std::string AnnotationTypeToString(AnnotationType type) {
}
}
-WeightedString::WeightedString(const std::string& value, double weight)
+WeightedIdentifier::WeightedIdentifier(int32_t value, double weight)
: value_(value), weight_(weight) {
DCHECK_GE(weight_, 0.0);
DCHECK_LE(weight_, 1.0);
}
-WeightedString::WeightedString(const WeightedString&) = default;
-WeightedString::~WeightedString() = default;
+WeightedIdentifier::WeightedIdentifier(const WeightedIdentifier&) = default;
+WeightedIdentifier::~WeightedIdentifier() = default;
-bool WeightedString::operator==(const WeightedString& other) const {
+bool WeightedIdentifier::operator==(const WeightedIdentifier& other) const {
constexpr double kWeightTolerance = 1e-6;
return this->value_ == other.value_ &&
abs(this->weight_ - other.weight_) <= kWeightTolerance;
}
-std::string WeightedString::ToString() const {
- return base::StringPrintf("WeightedString{\"%s\",%f}", value().c_str(),
- weight());
+std::string WeightedIdentifier::ToString() const {
+ return base::StringPrintf("WeightedIdentifier{%d,%f}", value(), weight());
}
-std::ostream& operator<<(std::ostream& stream, const WeightedString& ws) {
+std::ostream& operator<<(std::ostream& stream, const WeightedIdentifier& ws) {
stream << ws.ToString();
return stream;
}
@@ -58,11 +60,11 @@ BatchAnnotationResult::~BatchAnnotationResult() = default;
std::string BatchAnnotationResult::ToString() const {
std::string output = "nullopt";
if (topics_) {
- std::vector<std::string> all_weighted_strings;
- for (const WeightedString& ws : *topics_) {
- all_weighted_strings.push_back(ws.ToString());
+ std::vector<std::string> all_weighted_ids;
+ for (const WeightedIdentifier& wi : *topics_) {
+ all_weighted_ids.push_back(wi.ToString());
}
- output = "{" + base::JoinString(all_weighted_strings, ",") + "}";
+ output = "{" + base::JoinString(all_weighted_ids, ",") + "}";
} else if (entities_) {
std::vector<std::string> all_entities;
for (const ScoredEntityMetadata& md : *entities_) {
@@ -89,7 +91,7 @@ std::ostream& operator<<(std::ostream& stream,
// static
BatchAnnotationResult BatchAnnotationResult::CreatePageTopicsResult(
const std::string& input,
- absl::optional<std::vector<WeightedString>> topics) {
+ absl::optional<std::vector<WeightedIdentifier>> topics) {
BatchAnnotationResult result;
result.input_ = input;
result.topics_ = topics;
@@ -98,7 +100,7 @@ BatchAnnotationResult BatchAnnotationResult::CreatePageTopicsResult(
// Always sort the result (if present) by the given score.
if (result.topics_) {
std::sort(result.topics_->begin(), result.topics_->end(),
- [](const WeightedString& a, const WeightedString& b) {
+ [](const WeightedIdentifier& a, const WeightedIdentifier& b) {
return a.weight() < b.weight();
});
}
diff --git a/chromium/components/optimization_guide/core/page_content_annotations_common.h b/chromium/components/optimization_guide/core/page_content_annotations_common.h
index 5679064b8d0..4f77386485c 100644
--- a/chromium/components/optimization_guide/core/page_content_annotations_common.h
+++ b/chromium/components/optimization_guide/core/page_content_annotations_common.h
@@ -15,6 +15,10 @@
namespace optimization_guide {
// The type of annotation that is being done on the given input.
+//
+// Each of these is used in UMA histograms so please update the variants there
+// when any changes are made.
+// //tools/metrics/histograms/metadata/optimization/histograms.xml
enum class AnnotationType {
kUnknown,
@@ -33,25 +37,25 @@ enum class AnnotationType {
std::string AnnotationTypeToString(AnnotationType type);
-// A weighted string value.
-class WeightedString {
+// A weighted ID value.
+class WeightedIdentifier {
public:
- WeightedString(const std::string& value, double weight);
- WeightedString(const WeightedString&);
- ~WeightedString();
+ WeightedIdentifier(int32_t value, double weight);
+ WeightedIdentifier(const WeightedIdentifier&);
+ ~WeightedIdentifier();
- std::string value() const { return value_; }
+ int32_t value() const { return value_; }
double weight() const { return weight_; }
std::string ToString() const;
- bool operator==(const WeightedString& other) const;
+ bool operator==(const WeightedIdentifier& other) const;
friend std::ostream& operator<<(std::ostream& stream,
- const WeightedString& ws);
+ const WeightedIdentifier& ws);
private:
- std::string value_;
+ int32_t value_;
// In the range of [0.0, 1.0].
double weight_ = 0;
@@ -63,7 +67,7 @@ class BatchAnnotationResult {
// Creates a result for a page topics annotation.
static BatchAnnotationResult CreatePageTopicsResult(
const std::string& input,
- absl::optional<std::vector<WeightedString>> topics);
+ absl::optional<std::vector<WeightedIdentifier>> topics);
// Creates a result for a page entities annotation.
static BatchAnnotationResult CreatePageEntitiesResult(
@@ -84,7 +88,9 @@ class BatchAnnotationResult {
std::string input() const { return input_; }
AnnotationType type() const { return type_; }
- absl::optional<std::vector<WeightedString>> topics() const { return topics_; }
+ absl::optional<std::vector<WeightedIdentifier>> topics() const {
+ return topics_;
+ }
absl::optional<std::vector<ScoredEntityMetadata>> entities() const {
return entities_;
}
@@ -105,7 +111,7 @@ class BatchAnnotationResult {
// Output for page topics annotations, set only if the |type_| matches and the
// execution was successful.
- absl::optional<std::vector<WeightedString>> topics_;
+ absl::optional<std::vector<WeightedIdentifier>> topics_;
// Output for page entities annotations, set only if the |type_| matches and
// the execution was successful.
diff --git a/chromium/components/optimization_guide/core/page_entities_model_executor.h b/chromium/components/optimization_guide/core/page_entities_model_executor.h
index 919e7b76f76..5ea55c6e243 100644
--- a/chromium/components/optimization_guide/core/page_entities_model_executor.h
+++ b/chromium/components/optimization_guide/core/page_entities_model_executor.h
@@ -15,7 +15,7 @@
namespace optimization_guide {
-// TODO(crbug/1249632): Remove this entirely.
+// TODO(crbug/1278828): Remove this entirely.
class HumanReadablePageEntitiesModelExecutor {
public:
virtual ~HumanReadablePageEntitiesModelExecutor() = default;
diff --git a/chromium/components/optimization_guide/core/page_entities_model_executor_impl.cc b/chromium/components/optimization_guide/core/page_entities_model_executor_impl.cc
new file mode 100644
index 00000000000..d3a34832f6b
--- /dev/null
+++ b/chromium/components/optimization_guide/core/page_entities_model_executor_impl.cc
@@ -0,0 +1,230 @@
+// Copyright 2022 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "components/optimization_guide/core/page_entities_model_executor_impl.h"
+
+#include "base/metrics/histogram_functions.h"
+#include "base/threading/sequenced_task_runner_handle.h"
+#include "base/timer/elapsed_timer.h"
+#include "components/optimization_guide/core/entity_annotator_native_library.h"
+#include "components/optimization_guide/core/optimization_guide_features.h"
+#include "components/optimization_guide/core/optimization_guide_model_provider.h"
+#include "components/optimization_guide/proto/page_entities_model_metadata.pb.h"
+
+namespace optimization_guide {
+
+namespace {
+
+const char kPageEntitiesModelMetadataTypeUrl[] =
+ "type.googleapis.com/"
+ "google.internal.chrome.optimizationguide.v1.PageEntitiesModelMetadata";
+
+} // namespace
+
+EntityAnnotatorHolder::EntityAnnotatorHolder(
+ scoped_refptr<base::SequencedTaskRunner> background_task_runner,
+ scoped_refptr<base::SequencedTaskRunner> reply_task_runner)
+ : background_task_runner_(background_task_runner),
+ reply_task_runner_(reply_task_runner) {}
+
+EntityAnnotatorHolder::~EntityAnnotatorHolder() {
+ DCHECK(background_task_runner_->RunsTasksInCurrentSequence());
+
+ if (features::ShouldResetPageEntitiesModelOnShutdown()) {
+ ResetEntityAnnotator();
+ }
+}
+
+void EntityAnnotatorHolder::
+ InitializeEntityAnnotatorNativeLibraryOnBackgroundThread(
+ base::OnceCallback<void(int32_t)> init_callback) {
+ DCHECK(background_task_runner_->RunsTasksInCurrentSequence());
+
+ DCHECK(!entity_annotator_native_library_);
+ if (entity_annotator_native_library_) {
+ // We should only be initialized once but in case someone does something
+ // wrong in a non-debug build, we invoke the callback anyway.
+ reply_task_runner_->PostTask(
+ FROM_HERE,
+ base::BindOnce(
+ std::move(init_callback),
+ entity_annotator_native_library_->GetMaxSupportedFeatureFlag()));
+ return;
+ }
+
+ entity_annotator_native_library_ = EntityAnnotatorNativeLibrary::Create();
+ if (!entity_annotator_native_library_) {
+ reply_task_runner_->PostTask(FROM_HERE,
+ base::BindOnce(std::move(init_callback), -1));
+ return;
+ }
+
+ int32_t max_supported_feature_flag =
+ entity_annotator_native_library_->GetMaxSupportedFeatureFlag();
+ reply_task_runner_->PostTask(
+ FROM_HERE,
+ base::BindOnce(std::move(init_callback), max_supported_feature_flag));
+}
+
+void EntityAnnotatorHolder::ResetEntityAnnotator() {
+ DCHECK(background_task_runner_->RunsTasksInCurrentSequence());
+
+ if (entity_annotator_) {
+ DCHECK(entity_annotator_native_library_);
+ entity_annotator_native_library_->DeleteEntityAnnotator(entity_annotator_);
+
+ entity_annotator_ = nullptr;
+ }
+}
+
+void EntityAnnotatorHolder::CreateAndSetEntityAnnotatorOnBackgroundThread(
+ const ModelInfo& model_info) {
+ DCHECK(background_task_runner_->RunsTasksInCurrentSequence());
+
+ if (!entity_annotator_native_library_) {
+ return;
+ }
+
+ ResetEntityAnnotator();
+
+ entity_annotator_ =
+ entity_annotator_native_library_->CreateEntityAnnotator(model_info);
+}
+
+void EntityAnnotatorHolder::AnnotateEntitiesMetadataModelOnBackgroundThread(
+ const std::string& text,
+ PageEntitiesMetadataModelExecutedCallback callback) {
+ DCHECK(background_task_runner_->RunsTasksInCurrentSequence());
+ base::ElapsedThreadTimer annotate_timer;
+
+ absl::optional<std::vector<ScoredEntityMetadata>> scored_md;
+ if (entity_annotator_) {
+ DCHECK(entity_annotator_native_library_);
+ base::TimeTicks start_time = base::TimeTicks::Now();
+ scored_md =
+ entity_annotator_native_library_->AnnotateText(entity_annotator_, text);
+ // The max of the below histograms is 1 hour because we want to understand
+ // tail behavior and catch long running model executions.
+ base::UmaHistogramLongTimes(
+ "OptimizationGuide.PageContentAnnotationsService.ModelExecutionLatency."
+ "PageEntities",
+ base::TimeTicks::Now() - start_time);
+
+ base::UmaHistogramLongTimes(
+ "OptimizationGuide.PageContentAnnotationsService."
+ "ModelThreadExecutionLatency.PageEntities",
+ annotate_timer.Elapsed());
+ }
+ reply_task_runner_->PostTask(FROM_HERE,
+ base::BindOnce(std::move(callback), scored_md));
+}
+
+void EntityAnnotatorHolder::GetMetadataForEntityIdOnBackgroundThread(
+ const std::string& entity_id,
+ PageEntitiesModelExecutor::PageEntitiesModelEntityMetadataRetrievedCallback
+ callback) {
+ DCHECK(background_task_runner_->RunsTasksInCurrentSequence());
+
+ absl::optional<EntityMetadata> entity_metadata;
+ if (entity_annotator_) {
+ DCHECK(entity_annotator_native_library_);
+ entity_metadata =
+ entity_annotator_native_library_->GetEntityMetadataForEntityId(
+ entity_annotator_, entity_id);
+ }
+ reply_task_runner_->PostTask(
+ FROM_HERE,
+ base::BindOnce(std::move(callback), std::move(entity_metadata)));
+}
+
+base::WeakPtr<EntityAnnotatorHolder>
+EntityAnnotatorHolder::GetBackgroundWeakPtr() {
+ return background_weak_ptr_factory_.GetWeakPtr();
+}
+
+PageEntitiesModelExecutorImpl::PageEntitiesModelExecutorImpl(
+ OptimizationGuideModelProvider* optimization_guide_model_provider,
+ scoped_refptr<base::SequencedTaskRunner> background_task_runner)
+ : background_task_runner_(background_task_runner),
+ entity_annotator_holder_(std::make_unique<EntityAnnotatorHolder>(
+ background_task_runner_,
+ base::SequencedTaskRunnerHandle::Get())) {
+ background_task_runner_->PostTask(
+ FROM_HERE,
+ base::BindOnce(
+ &EntityAnnotatorHolder::
+ InitializeEntityAnnotatorNativeLibraryOnBackgroundThread,
+ entity_annotator_holder_->GetBackgroundWeakPtr(),
+ base::BindOnce(&PageEntitiesModelExecutorImpl::
+ OnEntityAnnotatorLibraryInitialized,
+ weak_ptr_factory_.GetWeakPtr(),
+ optimization_guide_model_provider)));
+}
+
+void PageEntitiesModelExecutorImpl::OnEntityAnnotatorLibraryInitialized(
+ OptimizationGuideModelProvider* optimization_guide_model_provider,
+ int32_t max_model_format_feature_flag) {
+ if (max_model_format_feature_flag <= 0) {
+ return;
+ }
+
+ proto::Any any_metadata;
+ any_metadata.set_type_url(kPageEntitiesModelMetadataTypeUrl);
+ proto::PageEntitiesModelMetadata model_metadata;
+ model_metadata.set_max_model_format_feature_flag(
+ max_model_format_feature_flag);
+ model_metadata.SerializeToString(any_metadata.mutable_value());
+ optimization_guide_model_provider->AddObserverForOptimizationTargetModel(
+ proto::OptimizationTarget::OPTIMIZATION_TARGET_PAGE_ENTITIES,
+ any_metadata, this);
+}
+
+PageEntitiesModelExecutorImpl::~PageEntitiesModelExecutorImpl() {
+ // |entity_annotator_holder_|'s WeakPtrs are used on the background thread,
+ // so that is also where the class must be destroyed.
+ background_task_runner_->DeleteSoon(FROM_HERE,
+ std::move(entity_annotator_holder_));
+}
+
+void PageEntitiesModelExecutorImpl::OnModelUpdated(
+ proto::OptimizationTarget optimization_target,
+ const ModelInfo& model_info) {
+ if (optimization_target != proto::OPTIMIZATION_TARGET_PAGE_ENTITIES)
+ return;
+
+ background_task_runner_->PostTask(
+ FROM_HERE,
+ base::BindOnce(
+ &EntityAnnotatorHolder::CreateAndSetEntityAnnotatorOnBackgroundThread,
+ entity_annotator_holder_->GetBackgroundWeakPtr(), model_info));
+}
+
+void PageEntitiesModelExecutorImpl::HumanReadableExecuteModelWithInput(
+ const std::string& text,
+ PageEntitiesMetadataModelExecutedCallback callback) {
+ if (text.empty()) {
+ std::move(callback).Run(absl::nullopt);
+ return;
+ }
+
+ background_task_runner_->PostTask(
+ FROM_HERE,
+ base::BindOnce(&EntityAnnotatorHolder::
+ AnnotateEntitiesMetadataModelOnBackgroundThread,
+ entity_annotator_holder_->GetBackgroundWeakPtr(), text,
+ std::move(callback)));
+}
+
+void PageEntitiesModelExecutorImpl::GetMetadataForEntityId(
+ const std::string& entity_id,
+ PageEntitiesModelEntityMetadataRetrievedCallback callback) {
+ background_task_runner_->PostTask(
+ FROM_HERE,
+ base::BindOnce(
+ &EntityAnnotatorHolder::GetMetadataForEntityIdOnBackgroundThread,
+ entity_annotator_holder_->GetBackgroundWeakPtr(), entity_id,
+ std::move(callback)));
+}
+
+} // namespace optimization_guide
diff --git a/chromium/components/optimization_guide/core/page_entities_model_executor_impl.h b/chromium/components/optimization_guide/core/page_entities_model_executor_impl.h
new file mode 100644
index 00000000000..d39944fc261
--- /dev/null
+++ b/chromium/components/optimization_guide/core/page_entities_model_executor_impl.h
@@ -0,0 +1,115 @@
+// Copyright 2022 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef COMPONENTS_OPTIMIZATION_GUIDE_CORE_PAGE_ENTITIES_MODEL_EXECUTOR_IMPL_H_
+#define COMPONENTS_OPTIMIZATION_GUIDE_CORE_PAGE_ENTITIES_MODEL_EXECUTOR_IMPL_H_
+
+#include "base/task/sequenced_task_runner.h"
+#include "base/task/task_traits.h"
+#include "base/task/thread_pool.h"
+#include "components/optimization_guide/core/entity_metadata.h"
+#include "components/optimization_guide/core/optimization_target_model_observer.h"
+#include "components/optimization_guide/core/page_entities_model_executor.h"
+
+namespace optimization_guide {
+
+class EntityAnnotatorNativeLibrary;
+class OptimizationGuideModelProvider;
+
+// An object used to hold an entity annotator on a background thread.
+class EntityAnnotatorHolder {
+ public:
+ EntityAnnotatorHolder(
+ scoped_refptr<base::SequencedTaskRunner> background_task_runner,
+ scoped_refptr<base::SequencedTaskRunner> reply_task_runner);
+ ~EntityAnnotatorHolder();
+
+ // Initializes the native library on a background thread. Will invoke
+ // |init_callback| on |reply_task_runner_| with the max version supported for
+ // the entity annotator on success. Otherwise, -1.
+ void InitializeEntityAnnotatorNativeLibraryOnBackgroundThread(
+ base::OnceCallback<void(int32_t)> init_callback);
+
+ // Creates an entity annotator on the background thread and sets it to
+ // |entity_annotator_|. Should be invoked on |background_task_runner_|.
+ void CreateAndSetEntityAnnotatorOnBackgroundThread(
+ const ModelInfo& model_info);
+
+ // Requests for |entity_annotator_| to execute its model for |text| and map
+ // the entities back to their metadata. Should be invoked on
+ // |background_task_runner_|.
+ using PageEntitiesMetadataModelExecutedCallback = base::OnceCallback<void(
+ const absl::optional<std::vector<ScoredEntityMetadata>>&)>;
+ void AnnotateEntitiesMetadataModelOnBackgroundThread(
+ const std::string& text,
+ PageEntitiesMetadataModelExecutedCallback callback);
+
+ // Returns entity metadata from |entity_annotator_| for |entity_id|.
+ // Should be invoked on |background_task_runner_|.
+ void GetMetadataForEntityIdOnBackgroundThread(
+ const std::string& entity_id,
+ PageEntitiesModelExecutor::
+ PageEntitiesModelEntityMetadataRetrievedCallback callback);
+
+ // Gets the weak ptr to |this| on the background thread.
+ base::WeakPtr<EntityAnnotatorHolder> GetBackgroundWeakPtr();
+
+ private:
+ void ResetEntityAnnotator();
+
+ scoped_refptr<base::SequencedTaskRunner> background_task_runner_;
+ scoped_refptr<base::SequencedTaskRunner> reply_task_runner_;
+
+ std::unique_ptr<EntityAnnotatorNativeLibrary>
+ entity_annotator_native_library_;
+ void* entity_annotator_ = nullptr;
+
+ base::WeakPtrFactory<EntityAnnotatorHolder> background_weak_ptr_factory_{
+ this};
+};
+
+// Manages the loading and execution of the page entities model.
+class PageEntitiesModelExecutorImpl : public OptimizationTargetModelObserver,
+ public PageEntitiesModelExecutor {
+ public:
+ PageEntitiesModelExecutorImpl(
+ OptimizationGuideModelProvider* model_provider,
+ scoped_refptr<base::SequencedTaskRunner> background_task_runner =
+ base::ThreadPool::CreateSequencedTaskRunner(
+ {base::MayBlock(), base::TaskPriority::BEST_EFFORT}));
+ ~PageEntitiesModelExecutorImpl() override;
+ PageEntitiesModelExecutorImpl(const PageEntitiesModelExecutorImpl&) = delete;
+ PageEntitiesModelExecutorImpl& operator=(
+ const PageEntitiesModelExecutorImpl&) = delete;
+
+ // PageEntitiesModelExecutor:
+ void GetMetadataForEntityId(
+ const std::string& entity_id,
+ PageEntitiesModelEntityMetadataRetrievedCallback callback) override;
+ void HumanReadableExecuteModelWithInput(
+ const std::string& text,
+ PageEntitiesMetadataModelExecutedCallback callback) override;
+
+ // OptimizationTargetModelObserver:
+ void OnModelUpdated(proto::OptimizationTarget optimization_target,
+ const ModelInfo& model_info) override;
+
+ private:
+ // Invoked on the UI thread when entity annotator library has been
+ // initialized.
+ void OnEntityAnnotatorLibraryInitialized(
+ OptimizationGuideModelProvider* model_provider,
+ int32_t max_model_format_feature_flag);
+
+ scoped_refptr<base::SequencedTaskRunner> background_task_runner_;
+
+ // The holder used to hold the annotator used to annotate entities.
+ std::unique_ptr<EntityAnnotatorHolder> entity_annotator_holder_;
+
+ base::WeakPtrFactory<PageEntitiesModelExecutorImpl> weak_ptr_factory_{this};
+};
+
+} // namespace optimization_guide
+
+#endif // COMPONENTS_OPTIMIZATION_GUIDE_CORE_PAGE_ENTITIES_MODEL_EXECUTOR_IMPL_H_
diff --git a/chromium/components/optimization_guide/core/page_entities_model_executor_impl_unittest.cc b/chromium/components/optimization_guide/core/page_entities_model_executor_impl_unittest.cc
new file mode 100644
index 00000000000..e1954454829
--- /dev/null
+++ b/chromium/components/optimization_guide/core/page_entities_model_executor_impl_unittest.cc
@@ -0,0 +1,268 @@
+// Copyright 2022 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "components/optimization_guide/core/page_entities_model_executor_impl.h"
+
+#include "base/observer_list.h"
+#include "base/path_service.h"
+#include "base/run_loop.h"
+#include "base/test/task_environment.h"
+#include "components/optimization_guide/core/model_util.h"
+#include "components/optimization_guide/core/optimization_guide_util.h"
+#include "components/optimization_guide/core/test_model_info_builder.h"
+#include "components/optimization_guide/core/test_optimization_guide_model_provider.h"
+#include "components/optimization_guide/proto/page_entities_model_metadata.pb.h"
+#include "testing/gmock/include/gmock/gmock.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace optimization_guide {
+namespace {
+
+using ::testing::ElementsAre;
+
+class ModelObserverTracker : public TestOptimizationGuideModelProvider {
+ public:
+ void AddObserverForOptimizationTargetModel(
+ proto::OptimizationTarget target,
+ const absl::optional<proto::Any>& model_metadata,
+ OptimizationTargetModelObserver* observer) override {
+ registered_model_metadata_.insert_or_assign(target, model_metadata);
+ registered_observers_.AddObserver(observer);
+ }
+
+ void RemoveObserverForOptimizationTargetModel(
+ proto::OptimizationTarget target,
+ OptimizationTargetModelObserver* observer) override {
+ registered_observers_.RemoveObserver(observer);
+ }
+
+ bool DidRegisterForTarget(
+ proto::OptimizationTarget target,
+ absl::optional<proto::Any>* out_model_metadata) const {
+ auto it = registered_model_metadata_.find(target);
+ if (it == registered_model_metadata_.end())
+ return false;
+ *out_model_metadata = registered_model_metadata_.at(target);
+ return true;
+ }
+
+ void PushModelInfoToObservers(const ModelInfo& model_info) {
+ for (auto& observer : registered_observers_) {
+ observer.OnModelUpdated(proto::OPTIMIZATION_TARGET_PAGE_ENTITIES,
+ model_info);
+ }
+ }
+
+ private:
+ base::flat_map<proto::OptimizationTarget, absl::optional<proto::Any>>
+ registered_model_metadata_;
+ base::ObserverList<OptimizationTargetModelObserver> registered_observers_;
+};
+
+class PageEntitiesModelExecutorImplTest : public testing::Test {
+ public:
+ void SetUp() override {
+ model_observer_tracker_ = std::make_unique<ModelObserverTracker>();
+ model_executor_ = std::make_unique<PageEntitiesModelExecutorImpl>(
+ model_observer_tracker_.get());
+
+ // Wait for PageEntitiesModelExecutor to set everything up.
+ task_environment_.RunUntilIdle();
+ }
+
+ void TearDown() override {
+ model_executor_.reset();
+ model_observer_tracker_.reset();
+
+ // Wait for PageEntitiesModelExecutor to clean everything up.
+ task_environment_.RunUntilIdle();
+ }
+
+ absl::optional<std::vector<ScoredEntityMetadata>> ExecuteHumanReadableModel(
+ const std::string& text) {
+ absl::optional<std::vector<ScoredEntityMetadata>> entity_metadata;
+
+ base::RunLoop run_loop;
+ model_executor_->HumanReadableExecuteModelWithInput(
+ text, base::BindOnce(
+ [](base::RunLoop* run_loop,
+ absl::optional<std::vector<ScoredEntityMetadata>>*
+ out_entity_metadata,
+ const absl::optional<std::vector<ScoredEntityMetadata>>&
+ entity_metadata) {
+ *out_entity_metadata = entity_metadata;
+ run_loop->Quit();
+ },
+ &run_loop, &entity_metadata));
+ run_loop.Run();
+
+ // Sort the result by score to make validating the output easier.
+ if (entity_metadata) {
+ std::sort(
+ entity_metadata->begin(), entity_metadata->end(),
+ [](const ScoredEntityMetadata& a, const ScoredEntityMetadata& b) {
+ return a.score > b.score;
+ });
+ }
+ return entity_metadata;
+ }
+
+ absl::optional<EntityMetadata> GetMetadataForEntityId(
+ const std::string& entity_id) {
+ absl::optional<EntityMetadata> entity_metadata;
+
+ base::RunLoop run_loop;
+ model_executor_->GetMetadataForEntityId(
+ entity_id,
+ base::BindOnce(
+ [](base::RunLoop* run_loop,
+ absl::optional<EntityMetadata>* out_entity_metadata,
+ const absl::optional<EntityMetadata>& entity_metadata) {
+ *out_entity_metadata = entity_metadata;
+ run_loop->Quit();
+ },
+ &run_loop, &entity_metadata));
+ run_loop.Run();
+
+ return entity_metadata;
+ }
+
+ ModelObserverTracker* model_observer_tracker() const {
+ return model_observer_tracker_.get();
+ }
+
+ base::FilePath GetModelTestDataDir() {
+ base::FilePath source_root_dir;
+ base::PathService::Get(base::DIR_SOURCE_ROOT, &source_root_dir);
+ return source_root_dir.AppendASCII("components")
+ .AppendASCII("optimization_guide")
+ .AppendASCII("internal")
+ .AppendASCII("testdata");
+ }
+
+ void PushModelInfoToObservers(const ModelInfo& model_info) {
+ model_observer_tracker_->PushModelInfoToObservers(model_info);
+ task_environment_.RunUntilIdle();
+ }
+
+ private:
+ base::test::TaskEnvironment task_environment_;
+ std::unique_ptr<ModelObserverTracker> model_observer_tracker_;
+ std::unique_ptr<PageEntitiesModelExecutorImpl> model_executor_;
+};
+
+TEST_F(PageEntitiesModelExecutorImplTest, CreateNoMetadata) {
+ std::unique_ptr<ModelInfo> model_info = TestModelInfoBuilder().Build();
+ ASSERT_TRUE(model_info);
+ PushModelInfoToObservers(*model_info);
+
+ // We expect that there will be no model to evaluate even for this input that
+ // has output in the test model.
+ EXPECT_EQ(ExecuteHumanReadableModel("Taylor Swift singer"), absl::nullopt);
+}
+
+TEST_F(PageEntitiesModelExecutorImplTest, CreateMetadataWrongType) {
+ proto::Any any;
+ any.set_type_url(any.GetTypeName());
+ proto::FieldTrial garbage;
+ garbage.SerializeToString(any.mutable_value());
+
+ proto::PredictionModel model;
+ model.mutable_model()->set_download_url(
+ FilePathToString(GetModelTestDataDir().AppendASCII("model.tflite")));
+ model.mutable_model_info()->set_version(123);
+ *model.mutable_model_info()->mutable_model_metadata() = any;
+ std::unique_ptr<ModelInfo> model_info = ModelInfo::Create(model);
+ ASSERT_TRUE(model_info);
+ PushModelInfoToObservers(*model_info);
+
+ // We expect that there will be no model to evaluate even for this input that
+ // has output in the test model.
+ EXPECT_EQ(ExecuteHumanReadableModel("Taylor Swift singer"), absl::nullopt);
+}
+
+TEST_F(PageEntitiesModelExecutorImplTest, CreateNoSlices) {
+ proto::Any any;
+ proto::PageEntitiesModelMetadata metadata;
+ any.set_type_url(metadata.GetTypeName());
+ metadata.SerializeToString(any.mutable_value());
+
+ proto::PredictionModel model;
+ model.mutable_model()->set_download_url(
+ FilePathToString(GetModelTestDataDir().AppendASCII("model.tflite")));
+ model.mutable_model_info()->set_version(123);
+ *model.mutable_model_info()->mutable_model_metadata() = any;
+ std::unique_ptr<ModelInfo> model_info = ModelInfo::Create(model);
+ ASSERT_TRUE(model_info);
+ PushModelInfoToObservers(*model_info);
+
+ // We expect that there will be no model to evaluate even for this input that
+ // has output in the test model.
+ EXPECT_EQ(ExecuteHumanReadableModel("Taylor Swift singer"), absl::nullopt);
+}
+
+TEST_F(PageEntitiesModelExecutorImplTest, CreateMissingFiles) {
+ proto::Any any;
+ proto::PageEntitiesModelMetadata metadata;
+ metadata.add_slice("global");
+ any.set_type_url(metadata.GetTypeName());
+ metadata.SerializeToString(any.mutable_value());
+
+ base::FilePath dir_path = GetModelTestDataDir();
+ base::flat_set<std::string> expected_additional_files = {
+ FilePathToString(dir_path.AppendASCII("model_metadata.pb")),
+ FilePathToString(dir_path.AppendASCII("word_embeddings")),
+ FilePathToString(dir_path.AppendASCII("global-entities_names")),
+ FilePathToString(dir_path.AppendASCII("global-entities_metadata")),
+ FilePathToString(dir_path.AppendASCII("global-entities_names_filter")),
+ FilePathToString(dir_path.AppendASCII("global-entities_prefixes_filter")),
+ };
+ // Remove one file for each iteration and make sure it fails.
+ for (const auto& missing_file_name : expected_additional_files) {
+ // Make a copy of the expected files and remove the one file from the set.
+ base::flat_set<std::string> additional_files = expected_additional_files;
+ additional_files.erase(missing_file_name);
+
+ proto::PredictionModel model;
+ model.mutable_model()->set_download_url(
+ FilePathToString(dir_path.AppendASCII("model.tflite")));
+ model.mutable_model_info()->set_version(123);
+ *model.mutable_model_info()->mutable_model_metadata() = any;
+ for (const auto& additional_file : additional_files) {
+ model.mutable_model_info()->add_additional_files()->set_file_path(
+ additional_file);
+ }
+ std::unique_ptr<ModelInfo> model_info = ModelInfo::Create(model);
+ ASSERT_TRUE(model_info);
+ PushModelInfoToObservers(*model_info);
+
+ // We expect that there will be no model to evaluate even for this input
+ // that has output in the test model.
+ EXPECT_EQ(ExecuteHumanReadableModel("Taylor Swift singer"), absl::nullopt);
+ }
+}
+
+TEST_F(PageEntitiesModelExecutorImplTest, GetMetadataForEntityIdNoModel) {
+ EXPECT_EQ(GetMetadataForEntityId("/m/0dl567"), absl::nullopt);
+}
+
+TEST_F(PageEntitiesModelExecutorImplTest, ExecuteHumanReadableModelNoModel) {
+ EXPECT_EQ(ExecuteHumanReadableModel("Taylor Swift singer"), absl::nullopt);
+}
+
+TEST_F(PageEntitiesModelExecutorImplTest,
+ SetsUpModelCorrectlyBasedOnFeatureParams) {
+ absl::optional<proto::Any> registered_model_metadata;
+ EXPECT_TRUE(model_observer_tracker()->DidRegisterForTarget(
+ proto::OPTIMIZATION_TARGET_PAGE_ENTITIES, &registered_model_metadata));
+ EXPECT_TRUE(registered_model_metadata.has_value());
+ absl::optional<proto::PageEntitiesModelMetadata>
+ page_entities_model_metadata =
+ ParsedAnyMetadata<proto::PageEntitiesModelMetadata>(
+ *registered_model_metadata);
+ EXPECT_TRUE(page_entities_model_metadata.has_value());
+}
+
+} // namespace
+} // namespace optimization_guide
diff --git a/chromium/components/optimization_guide/core/page_topics_model_executor.cc b/chromium/components/optimization_guide/core/page_topics_model_executor.cc
index b1c1707bc8e..dbc9a425889 100644
--- a/chromium/components/optimization_guide/core/page_topics_model_executor.cc
+++ b/chromium/components/optimization_guide/core/page_topics_model_executor.cc
@@ -17,7 +17,7 @@ namespace {
// The ID of the NONE category in the taxonomy. This node always exists.
// Semantically, the none category is attached to data for which we can say
// with certainty that no single label in the taxonomy is appropriate.
-const char kNoneCategoryId[] = "-2";
+const int32_t kNoneCategoryId = -2;
} // namespace
@@ -52,7 +52,7 @@ void PageTopicsModelExecutor::PostprocessCategoriesToBatchAnnotationResult(
const absl::optional<std::vector<tflite::task::core::Category>>& output) {
DCHECK_EQ(annotation_type, AnnotationType::kPageTopics);
- absl::optional<std::vector<WeightedString>> categories;
+ absl::optional<std::vector<WeightedIdentifier>> categories;
if (output) {
categories = ExtractCategoriesFromModelOutput(*output);
}
@@ -60,7 +60,7 @@ void PageTopicsModelExecutor::PostprocessCategoriesToBatchAnnotationResult(
BatchAnnotationResult::CreatePageTopicsResult(input, categories));
}
-absl::optional<std::vector<WeightedString>>
+absl::optional<std::vector<WeightedIdentifier>>
PageTopicsModelExecutor::ExtractCategoriesFromModelOutput(
const std::vector<tflite::task::core::Category>& model_output) const {
absl::optional<proto::PageTopicsModelMetadata> model_metadata =
@@ -79,7 +79,7 @@ PageTopicsModelExecutor::ExtractCategoriesFromModelOutput(
.category_name())
: absl::nullopt;
- std::vector<std::pair<std::string, float>> category_candidates;
+ std::vector<std::pair<int32_t, float>> category_candidates;
for (const auto& category : model_output) {
if (visibility_category_name &&
@@ -89,8 +89,8 @@ PageTopicsModelExecutor::ExtractCategoriesFromModelOutput(
// Assume everything else is for categories.
int category_id;
if (base::StringToInt(category.class_name, &category_id)) {
- category_candidates.emplace_back(std::make_pair(
- category.class_name, static_cast<float>(category.score)));
+ category_candidates.emplace_back(
+ std::make_pair(category_id, static_cast<float>(category.score)));
}
}
@@ -103,20 +103,19 @@ PageTopicsModelExecutor::ExtractCategoriesFromModelOutput(
model_metadata->output_postprocessing_params().category_params();
// Determine the categories with the highest weights.
- std::sort(category_candidates.begin(), category_candidates.end(),
- [](const std::pair<std::string, float>& a,
- const std::pair<std::string, float>& b) {
- return a.second > b.second;
- });
+ std::sort(
+ category_candidates.begin(), category_candidates.end(),
+ [](const std::pair<int32_t, float>& a,
+ const std::pair<int32_t, float>& b) { return a.second > b.second; });
size_t max_categories = static_cast<size_t>(category_params.max_categories());
float total_weight = 0.0;
float sum_positive_scores = 0.0;
absl::optional<std::pair<size_t, float>> none_idx_and_weight;
- std::vector<std::pair<std::string, float>> categories;
+ std::vector<std::pair<int32_t, float>> categories;
categories.reserve(max_categories);
for (size_t i = 0; i < category_candidates.size() && i < max_categories;
i++) {
- std::pair<std::string, float> candidate = category_candidates[i];
+ std::pair<int32_t, float> candidate = category_candidates[i];
categories.push_back(candidate);
total_weight += candidate.second;
@@ -132,7 +131,7 @@ PageTopicsModelExecutor::ExtractCategoriesFromModelOutput(
if (category_params.min_category_weight() > 0) {
categories.erase(
std::remove_if(categories.begin(), categories.end(),
- [&](const std::pair<std::string, float>& category) {
+ [&](const std::pair<int32_t, float>& category) {
return category.second <
category_params.min_category_weight();
}),
@@ -159,19 +158,19 @@ PageTopicsModelExecutor::ExtractCategoriesFromModelOutput(
categories.erase(
std::remove_if(
categories.begin(), categories.end(),
- [&](const std::pair<std::string, float>& category) {
+ [&](const std::pair<int32_t, float>& category) {
return (category.second / normalization_factor) <
category_params.min_normalized_weight_within_top_n();
}),
categories.end());
- std::vector<WeightedString> final_categories;
+ std::vector<WeightedIdentifier> final_categories;
final_categories.reserve(categories.size());
for (const auto& category : categories) {
// We expect the weight to be between 0 and 1.
DCHECK(category.second >= 0.0 && category.second <= 1.0);
final_categories.emplace_back(
- WeightedString(category.first, category.second));
+ WeightedIdentifier(category.first, category.second));
}
DCHECK_LE(final_categories.size(), max_categories);
diff --git a/chromium/components/optimization_guide/core/page_topics_model_executor.h b/chromium/components/optimization_guide/core/page_topics_model_executor.h
index fa396fafc4e..30e263872e5 100644
--- a/chromium/components/optimization_guide/core/page_topics_model_executor.h
+++ b/chromium/components/optimization_guide/core/page_topics_model_executor.h
@@ -42,7 +42,8 @@ class PageTopicsModelExecutor : public PageContentAnnotationJobExecutor,
// Extracts the scored categories from the output of the model.
// Public for testing.
- absl::optional<std::vector<WeightedString>> ExtractCategoriesFromModelOutput(
+ absl::optional<std::vector<WeightedIdentifier>>
+ ExtractCategoriesFromModelOutput(
const std::vector<tflite::task::core::Category>& model_output) const;
private:
diff --git a/chromium/components/optimization_guide/core/page_topics_model_executor_unittest.cc b/chromium/components/optimization_guide/core/page_topics_model_executor_unittest.cc
index 7fefca3bdc0..6ad49e17067 100644
--- a/chromium/components/optimization_guide/core/page_topics_model_executor_unittest.cc
+++ b/chromium/components/optimization_guide/core/page_topics_model_executor_unittest.cc
@@ -125,13 +125,13 @@ TEST_F(
{"0", 0.0001}, {"1", 0.1}, {"not an int", 0.9}, {"2", 0.2}, {"3", 0.3},
};
- absl::optional<std::vector<WeightedString>> categories =
+ absl::optional<std::vector<WeightedIdentifier>> categories =
model_executor()->ExtractCategoriesFromModelOutput(model_output);
ASSERT_TRUE(categories);
EXPECT_THAT(*categories,
- testing::UnorderedElementsAre(WeightedString("1", 0.1),
- WeightedString("2", 0.2),
- WeightedString("3", 0.3)));
+ testing::UnorderedElementsAre(WeightedIdentifier(1, 0.1),
+ WeightedIdentifier(2, 0.2),
+ WeightedIdentifier(3, 0.3)));
}
TEST_F(PageTopicsModelExecutorTest,
@@ -157,7 +157,7 @@ TEST_F(PageTopicsModelExecutorTest,
{"1", 0.2},
};
- absl::optional<std::vector<WeightedString>> categories =
+ absl::optional<std::vector<WeightedIdentifier>> categories =
model_executor()->ExtractCategoriesFromModelOutput(model_output);
EXPECT_FALSE(categories);
}
@@ -183,13 +183,13 @@ TEST_F(PageTopicsModelExecutorTest,
{"-2", 0.1}, {"0", 0.3}, {"1", 0.2}, {"2", 0.4}, {"3", 0.05},
};
- absl::optional<std::vector<WeightedString>> categories =
+ absl::optional<std::vector<WeightedIdentifier>> categories =
model_executor()->ExtractCategoriesFromModelOutput(model_output);
ASSERT_TRUE(categories);
EXPECT_THAT(*categories,
- testing::UnorderedElementsAre(WeightedString("0", 0.3),
- WeightedString("1", 0.2),
- WeightedString("2", 0.4)));
+ testing::UnorderedElementsAre(WeightedIdentifier(0, 0.3),
+ WeightedIdentifier(1, 0.2),
+ WeightedIdentifier(2, 0.4)));
}
TEST_F(PageTopicsModelExecutorTest,
@@ -216,13 +216,13 @@ TEST_F(PageTopicsModelExecutorTest,
{"3", 0.05},
};
- absl::optional<std::vector<WeightedString>> categories =
+ absl::optional<std::vector<WeightedIdentifier>> categories =
model_executor()->ExtractCategoriesFromModelOutput(model_output);
ASSERT_TRUE(categories);
EXPECT_THAT(*categories,
- testing::UnorderedElementsAre(WeightedString("0", 0.3),
- WeightedString("1", 0.25),
- WeightedString("2", 0.4)));
+ testing::UnorderedElementsAre(WeightedIdentifier(0, 0.3),
+ WeightedIdentifier(1, 0.25),
+ WeightedIdentifier(2, 0.4)));
}
TEST_F(PageTopicsModelExecutorTest,
@@ -260,10 +260,10 @@ TEST_F(PageTopicsModelExecutorTest,
&topics_result),
AnnotationType::kPageTopics, "input", model_output);
EXPECT_EQ(topics_result, BatchAnnotationResult::CreatePageTopicsResult(
- "input", std::vector<WeightedString>{
- WeightedString("0", 0.3),
- WeightedString("1", 0.25),
- WeightedString("2", 0.4),
+ "input", std::vector<WeightedIdentifier>{
+ WeightedIdentifier(0, 0.3),
+ WeightedIdentifier(1, 0.25),
+ WeightedIdentifier(2, 0.4),
}));
}
diff --git a/chromium/components/optimization_guide/core/prediction_model.cc b/chromium/components/optimization_guide/core/prediction_model.cc
deleted file mode 100644
index 1cf077bd71a..00000000000
--- a/chromium/components/optimization_guide/core/prediction_model.cc
+++ /dev/null
@@ -1,82 +0,0 @@
-// Copyright 2020 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#include "components/optimization_guide/core/prediction_model.h"
-
-#include <utility>
-
-#include "components/optimization_guide/core/decision_tree_prediction_model.h"
-
-namespace optimization_guide {
-
-// static
-std::unique_ptr<PredictionModel> PredictionModel::Create(
- const proto::PredictionModel& prediction_model) {
- // TODO(crbug/1009123): Add a histogram to record if the provided model is
- // constructed successfully or not.
- // TODO(crbug/1009123): Adding timing metrics around initialization due to
- // potential validation overhead.
- if (!prediction_model.has_model())
- return nullptr;
-
- if (!prediction_model.has_model_info())
- return nullptr;
-
- if (!prediction_model.model_info().has_version())
- return nullptr;
-
- // Enforce that only one ModelType is specified for the PredictionModel.
- if (prediction_model.model_info().supported_model_types_size() != 1) {
- return nullptr;
- }
-
- // Check that the client supports this type of model and is not an unknown
- // type.
- if (!proto::ModelType_IsValid(
- prediction_model.model_info().supported_model_types(0)) ||
- prediction_model.model_info().supported_model_types(0) ==
- proto::ModelType::MODEL_TYPE_UNKNOWN) {
- return nullptr;
- }
-
- std::unique_ptr<PredictionModel> model;
- // The Decision Tree model type is currently the only supported model type.
- if (prediction_model.model_info().supported_model_types(0) !=
- proto::ModelType::MODEL_TYPE_DECISION_TREE) {
- return nullptr;
- }
- model = std::make_unique<DecisionTreePredictionModel>(prediction_model);
-
- // Any constructed model must be validated for correctness according to its
- // model type before being returned.
- if (!model->ValidatePredictionModel())
- return nullptr;
-
- return model;
-}
-
-namespace {
-
-std::vector<std::string> ComputeModelFeatures(
- const proto::ModelInfo& model_info) {
- std::vector<std::string> features;
- features.reserve(model_info.supported_host_model_features_size());
- // Insert all the host model features for the owned |model_|.
- for (const auto& host_model_feature :
- model_info.supported_host_model_features()) {
- features.push_back(host_model_feature);
- }
- return features;
-}
-
-} // namespace
-
-PredictionModel::PredictionModel(const proto::PredictionModel& prediction_model)
- : model_(prediction_model.model()),
- model_features_(ComputeModelFeatures(prediction_model.model_info())),
- version_(prediction_model.model_info().version()) {}
-
-PredictionModel::~PredictionModel() = default;
-
-} // namespace optimization_guide
diff --git a/chromium/components/optimization_guide/core/prediction_model.h b/chromium/components/optimization_guide/core/prediction_model.h
deleted file mode 100644
index c4c3c024d0b..00000000000
--- a/chromium/components/optimization_guide/core/prediction_model.h
+++ /dev/null
@@ -1,70 +0,0 @@
-// Copyright 2020 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#ifndef COMPONENTS_OPTIMIZATION_GUIDE_CORE_PREDICTION_MODEL_H_
-#define COMPONENTS_OPTIMIZATION_GUIDE_CORE_PREDICTION_MODEL_H_
-
-#include <stdint.h>
-#include <memory>
-#include <string>
-
-#include "base/containers/flat_map.h"
-#include "base/containers/flat_set.h"
-#include "components/optimization_guide/core/optimization_guide_enums.h"
-#include "components/optimization_guide/proto/models.pb.h"
-
-namespace optimization_guide {
-
-// A PredictionModel supported by the optimization guide that makes an
-// OptimizationTargetDecision by evaluating a prediction model.
-class PredictionModel {
- public:
- PredictionModel(const PredictionModel&) = delete;
- PredictionModel& operator=(const PredictionModel&) = delete;
-
- virtual ~PredictionModel();
-
- // Creates an Prediction model of the correct ModelType specified in
- // |prediction_model|. The validation overhead of this factory can be high and
- // should should be called in the background.
- static std::unique_ptr<PredictionModel> Create(
- const proto::PredictionModel& prediction_model);
-
- // Returns the OptimizationTargetDecision by evaluating the |model_|
- // using the provided |model_features|. |prediction_score| will be populated
- // with the score output by the model.
- virtual OptimizationTargetDecision Predict(
- const base::flat_map<std::string, float>& model_features,
- double* prediction_score) = 0;
-
- // Provide the version of the |model_| by |this|.
- int64_t GetVersion() const { return version_; }
-
- // Provide the model features required for evaluation of the |model_| by
- // |this|.
- const base::flat_set<std::string>& GetModelFeatures() const {
- return model_features_;
- }
-
- protected:
- explicit PredictionModel(const proto::PredictionModel& prediction_model);
-
- // The in-memory model used for prediction.
- const proto::Model model_;
-
- private:
- // Determines if the |model_| is complete and can be successfully evaluated by
- // |this|.
- virtual bool ValidatePredictionModel() const = 0;
-
- // The set of features required by the |model_| to be evaluated.
- const base::flat_set<std::string> model_features_;
-
- // The version of the |model_|.
- const int64_t version_;
-};
-
-} // namespace optimization_guide
-
-#endif // COMPONENTS_OPTIMIZATION_GUIDE_CORE_PREDICTION_MODEL_H_
diff --git a/chromium/components/optimization_guide/core/prediction_model_fetcher_impl.cc b/chromium/components/optimization_guide/core/prediction_model_fetcher_impl.cc
index 79282f12248..1aba12cfa54 100644
--- a/chromium/components/optimization_guide/core/prediction_model_fetcher_impl.cc
+++ b/chromium/components/optimization_guide/core/prediction_model_fetcher_impl.cc
@@ -22,7 +22,6 @@
#include "net/http/http_response_headers.h"
#include "net/http/http_status_code.h"
#include "net/traffic_annotation/network_traffic_annotation.h"
-#include "services/network/public/cpp/network_connection_tracker.h"
#include "services/network/public/cpp/resource_request.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"
#include "services/network/public/cpp/simple_url_loader.h"
@@ -32,16 +31,14 @@ namespace optimization_guide {
PredictionModelFetcherImpl::PredictionModelFetcherImpl(
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
- const GURL& optimization_guide_service_get_models_url,
- network::NetworkConnectionTracker* network_connection_tracker)
+ const GURL& optimization_guide_service_get_models_url)
: optimization_guide_service_get_models_url_(
net::AppendOrReplaceQueryParameter(
optimization_guide_service_get_models_url,
"key",
optimization_guide::features::
GetOptimizationGuideServiceAPIKey())),
- url_loader_factory_(url_loader_factory),
- network_connection_tracker_(network_connection_tracker) {
+ url_loader_factory_(url_loader_factory) {
CHECK(optimization_guide_service_get_models_url_.SchemeIs(url::kHttpsScheme));
}
@@ -55,11 +52,6 @@ bool PredictionModelFetcherImpl::FetchOptimizationGuideServiceModels(
ModelsFetchedCallback models_fetched_callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
- if (network_connection_tracker_->IsOffline()) {
- std::move(models_fetched_callback).Run(absl::nullopt);
- return false;
- }
-
if (url_loader_)
return false;
diff --git a/chromium/components/optimization_guide/core/prediction_model_fetcher_impl.h b/chromium/components/optimization_guide/core/prediction_model_fetcher_impl.h
index 2d2557e8c15..969af65685a 100644
--- a/chromium/components/optimization_guide/core/prediction_model_fetcher_impl.h
+++ b/chromium/components/optimization_guide/core/prediction_model_fetcher_impl.h
@@ -19,7 +19,6 @@
#include "url/gurl.h"
namespace network {
-class NetworkConnectionTracker;
class SharedURLLoaderFactory;
class SimpleURLLoader;
} // namespace network
@@ -34,8 +33,7 @@ class PredictionModelFetcherImpl : public PredictionModelFetcher {
public:
PredictionModelFetcherImpl(
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
- const GURL& optimization_guide_service_get_models_url,
- network::NetworkConnectionTracker* network_connection_tracker);
+ const GURL& optimization_guide_service_get_models_url);
PredictionModelFetcherImpl(const PredictionModelFetcherImpl&) = delete;
PredictionModelFetcherImpl& operator=(const PredictionModelFetcherImpl&) =
@@ -82,10 +80,6 @@ class PredictionModelFetcherImpl : public PredictionModelFetcher {
// Used for creating a |url_loader_| when needed for request hints.
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory_;
- // Listens to changes around the network connection. Not owned. Guaranteed to
- // outlive |this|.
- raw_ptr<network::NetworkConnectionTracker> network_connection_tracker_;
-
SEQUENCE_CHECKER(sequence_checker_);
};
diff --git a/chromium/components/optimization_guide/core/prediction_model_fetcher_unittest.cc b/chromium/components/optimization_guide/core/prediction_model_fetcher_unittest.cc
index 1b59d314c01..a4ca4b85255 100644
--- a/chromium/components/optimization_guide/core/prediction_model_fetcher_unittest.cc
+++ b/chromium/components/optimization_guide/core/prediction_model_fetcher_unittest.cc
@@ -21,7 +21,6 @@
#include "net/base/url_util.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"
#include "services/network/public/cpp/weak_wrapper_shared_url_loader_factory.h"
-#include "services/network/test/test_network_connection_tracker.h"
#include "services/network/test/test_url_loader_factory.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/abseil-cpp/absl/types/optional.h"
@@ -37,11 +36,9 @@ class PredictionModelFetcherTest : public testing::Test {
: task_environment_(base::test::TaskEnvironment::MainThreadType::UI),
shared_url_loader_factory_(
base::MakeRefCounted<network::WeakWrapperSharedURLLoaderFactory>(
- &test_url_loader_factory_)),
- network_tracker_(network::TestNetworkConnectionTracker::GetInstance()) {
+ &test_url_loader_factory_)) {
prediction_model_fetcher_ = std::make_unique<PredictionModelFetcherImpl>(
- shared_url_loader_factory_, GURL(optimization_guide_service_url),
- network_tracker_);
+ shared_url_loader_factory_, GURL(optimization_guide_service_url));
}
PredictionModelFetcherTest(const PredictionModelFetcherTest&) = delete;
@@ -58,16 +55,6 @@ class PredictionModelFetcherTest : public testing::Test {
bool models_fetched() { return models_fetched_; }
- void SetConnectionOffline() {
- network_tracker_->SetConnectionType(
- network::mojom::ConnectionType::CONNECTION_NONE);
- }
-
- void SetConnectionOnline() {
- network_tracker_->SetConnectionType(
- network::mojom::ConnectionType::CONNECTION_4G);
- }
-
protected:
bool FetchModels(const std::vector<proto::ModelInfo> models_request_info,
const std::vector<proto::FieldTrial>& active_field_trials,
@@ -116,7 +103,6 @@ class PredictionModelFetcherTest : public testing::Test {
scoped_refptr<network::SharedURLLoaderFactory> shared_url_loader_factory_;
network::TestURLLoaderFactory test_url_loader_factory_;
- raw_ptr<network::TestNetworkConnectionTracker> network_tracker_;
};
TEST_F(PredictionModelFetcherTest, FetchOptimizationGuideServiceModels) {
@@ -173,26 +159,6 @@ TEST_F(PredictionModelFetcherTest, FetchReturnBadResponse) {
EXPECT_FALSE(models_fetched());
}
-TEST_F(PredictionModelFetcherTest, FetchAttemptWhenNetworkOffline) {
- SetConnectionOffline();
- std::string response_content;
- proto::ModelInfo model_info;
- model_info.set_optimization_target(
- proto::OptimizationTarget::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD);
- EXPECT_FALSE(FetchModels({model_info}, /*active_field_trials=*/{},
- proto::RequestContext::CONTEXT_BATCH_UPDATE_MODELS,
- "en-US"));
- EXPECT_FALSE(models_fetched());
-
- SetConnectionOnline();
- EXPECT_TRUE(FetchModels({model_info}, /*active_field_trials=*/{},
- proto::RequestContext::CONTEXT_BATCH_UPDATE_MODELS,
- "en-US"));
- VerifyHasPendingFetchRequests();
- EXPECT_TRUE(SimulateResponse(response_content, net::HTTP_OK));
- EXPECT_TRUE(models_fetched());
-}
-
TEST_F(PredictionModelFetcherTest, EmptyModelInfo) {
base::HistogramTester histogram_tester;
std::string response_content;
diff --git a/chromium/components/optimization_guide/core/prediction_model_unittest.cc b/chromium/components/optimization_guide/core/prediction_model_unittest.cc
deleted file mode 100644
index c928abdd5cf..00000000000
--- a/chromium/components/optimization_guide/core/prediction_model_unittest.cc
+++ /dev/null
@@ -1,134 +0,0 @@
-// Copyright 2020 The Chromium Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#include "components/optimization_guide/core/prediction_model.h"
-
-#include <utility>
-
-#include "components/optimization_guide/proto/models.pb.h"
-#include "testing/gtest/include/gtest/gtest.h"
-
-namespace optimization_guide {
-
-TEST(PredictionModelTest, ValidPredictionModel) {
- proto::PredictionModel prediction_model;
- prediction_model.mutable_model()->mutable_threshold()->set_value(5.0);
-
- proto::DecisionTree decision_tree_model = proto::DecisionTree();
- decision_tree_model.set_weight(2.0);
-
- proto::TreeNode* tree_node = decision_tree_model.add_nodes();
- tree_node->mutable_node_id()->set_value(0);
- tree_node->mutable_binary_node()->mutable_left_child_id()->set_value(1);
- tree_node->mutable_binary_node()->mutable_right_child_id()->set_value(2);
- tree_node->mutable_binary_node()
- ->mutable_inequality_left_child_test()
- ->mutable_feature_id()
- ->mutable_id()
- ->set_value("agg1");
- tree_node->mutable_binary_node()
- ->mutable_inequality_left_child_test()
- ->set_type(proto::InequalityTest::LESS_OR_EQUAL);
- tree_node->mutable_binary_node()
- ->mutable_inequality_left_child_test()
- ->mutable_threshold()
- ->set_float_value(1.0);
-
- tree_node = decision_tree_model.add_nodes();
- tree_node->mutable_node_id()->set_value(1);
- tree_node->mutable_leaf()->mutable_vector()->add_value()->set_double_value(
- 2.);
-
- tree_node = decision_tree_model.add_nodes();
- tree_node->mutable_node_id()->set_value(2);
- tree_node->mutable_leaf()->mutable_vector()->add_value()->set_double_value(
- 4.);
-
- *prediction_model.mutable_model()->mutable_decision_tree() =
- decision_tree_model;
-
- proto::ModelInfo* model_info = prediction_model.mutable_model_info();
- model_info->set_version(1);
- model_info->add_supported_model_types(
- proto::ModelType::MODEL_TYPE_DECISION_TREE);
- model_info->add_supported_host_model_features("agg1");
-
- std::unique_ptr<PredictionModel> model =
- PredictionModel::Create(prediction_model);
-
- EXPECT_EQ(1, model->GetVersion());
- EXPECT_EQ(1u, model->GetModelFeatures().size());
- EXPECT_TRUE(model->GetModelFeatures().count("agg1"));
-}
-
-TEST(PredictionModelTest, NoModel) {
- proto::PredictionModel prediction_model;
-
- std::unique_ptr<PredictionModel> model =
- PredictionModel::Create(prediction_model);
- EXPECT_FALSE(model);
-}
-
-TEST(PredictionModelTest, NoModelVersion) {
- proto::PredictionModel prediction_model;
-
- proto::DecisionTree* decision_tree_model =
- prediction_model.mutable_model()->mutable_decision_tree();
- decision_tree_model->set_weight(2.0);
-
- std::unique_ptr<PredictionModel> model =
- PredictionModel::Create(prediction_model);
- EXPECT_FALSE(model);
-}
-
-TEST(PredictionModelTest, NoModelType) {
- proto::PredictionModel prediction_model;
-
- proto::DecisionTree* decision_tree_model =
- prediction_model.mutable_model()->mutable_decision_tree();
- decision_tree_model->set_weight(2.0);
-
- proto::ModelInfo* model_info = prediction_model.mutable_model_info();
- model_info->set_version(1);
-
- std::unique_ptr<PredictionModel> model =
- PredictionModel::Create(std::move(prediction_model));
- EXPECT_FALSE(model);
-}
-
-TEST(PredictionModelTest, UnknownModelType) {
- proto::PredictionModel prediction_model;
-
- proto::DecisionTree* decision_tree_model =
- prediction_model.mutable_model()->mutable_decision_tree();
- decision_tree_model->set_weight(2.0);
-
- proto::ModelInfo* model_info = prediction_model.mutable_model_info();
- model_info->set_version(1);
- model_info->add_supported_model_types(proto::ModelType::MODEL_TYPE_UNKNOWN);
-
- std::unique_ptr<PredictionModel> model =
- PredictionModel::Create(prediction_model);
- EXPECT_FALSE(model);
-}
-
-TEST(PredictionModelTest, MultipleModelTypes) {
- proto::PredictionModel prediction_model;
-
- proto::DecisionTree* decision_tree_model =
- prediction_model.mutable_model()->mutable_decision_tree();
- decision_tree_model->set_weight(2.0);
-
- proto::ModelInfo* model_info = prediction_model.mutable_model_info();
- model_info->set_version(1);
- model_info->add_supported_model_types(
- proto::ModelType::MODEL_TYPE_DECISION_TREE);
- model_info->add_supported_model_types(proto::ModelType::MODEL_TYPE_UNKNOWN);
-
- std::unique_ptr<PredictionModel> model =
- PredictionModel::Create(prediction_model);
- EXPECT_FALSE(model);
-}
-
-} // namespace optimization_guide
diff --git a/chromium/components/optimization_guide/core/store_update_data.cc b/chromium/components/optimization_guide/core/store_update_data.cc
index 0afd12d3e80..d5c636ef16b 100644
--- a/chromium/components/optimization_guide/core/store_update_data.cc
+++ b/chromium/components/optimization_guide/core/store_update_data.cc
@@ -38,40 +38,6 @@ StoreUpdateData::CreatePredictionModelStoreUpdateData(base::Time expiry_time) {
return base::WrapUnique<StoreUpdateData>(new StoreUpdateData(expiry_time));
}
-// static
-std::unique_ptr<StoreUpdateData>
-StoreUpdateData::CreateHostModelFeaturesStoreUpdateData(
- base::Time host_model_features_update_time,
- base::Time expiry_time) {
- std::unique_ptr<StoreUpdateData> host_model_features_update_data(
- new StoreUpdateData(host_model_features_update_time, expiry_time));
- return host_model_features_update_data;
-}
-
-StoreUpdateData::StoreUpdateData(base::Time host_model_features_update_time,
- base::Time expiry_time)
- : update_time_(host_model_features_update_time),
- expiry_time_(expiry_time),
- entries_to_save_(std::make_unique<EntryVector>()) {
- entry_key_prefix_ =
- OptimizationGuideStore::GetHostModelFeaturesEntryKeyPrefix();
- proto::StoreEntry metadata_host_model_features_entry;
- metadata_host_model_features_entry.set_entry_type(
- static_cast<proto::StoreEntryType>(
- OptimizationGuideStore::StoreEntryType::kMetadata));
- metadata_host_model_features_entry.set_update_time_secs(
- host_model_features_update_time.ToDeltaSinceWindowsEpoch().InSeconds());
- entries_to_save_->emplace_back(
- OptimizationGuideStore::GetMetadataTypeEntryKey(
- OptimizationGuideStore::MetadataType::kHostModelFeatures),
- std::move(metadata_host_model_features_entry));
-
- // |this| may be modified on another thread after construction but all
- // future modifications, from that call forward, must be made on the same
- // thread.
- DETACH_FROM_SEQUENCE(sequence_checker_);
-}
-
StoreUpdateData::StoreUpdateData(base::Time expiry_time)
: expiry_time_(expiry_time),
entries_to_save_(std::make_unique<EntryVector>()) {
@@ -161,28 +127,6 @@ void StoreUpdateData::MoveHintIntoUpdateData(proto::Hint&& hint) {
std::move(entry_proto));
}
-void StoreUpdateData::CopyHostModelFeaturesIntoUpdateData(
- const proto::HostModelFeatures& host_model_features) {
- // All future modifications must be made by the same thread. Note, |this| may
- // have been constructed on another thread.
- DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
- DCHECK(!entry_key_prefix_.empty());
- DCHECK(expiry_time_);
-
- // To avoid any unnecessary copying, the host model feature data is moved into
- // proto::StoreEntry.
- OptimizationGuideStore::EntryKey host_model_features_entry_key =
- entry_key_prefix_ + host_model_features.host();
- proto::StoreEntry entry_proto;
- entry_proto.set_entry_type(static_cast<proto::StoreEntryType>(
- OptimizationGuideStore::StoreEntryType::kHostModelFeatures));
- entry_proto.set_expiry_time_secs(
- expiry_time_->ToDeltaSinceWindowsEpoch().InSeconds());
- entry_proto.mutable_host_model_features()->CopyFrom(host_model_features);
- entries_to_save_->emplace_back(std::move(host_model_features_entry_key),
- std::move(entry_proto));
-}
-
void StoreUpdateData::CopyPredictionModelIntoUpdateData(
const proto::PredictionModel& prediction_model) {
// All future modifications must be made by the same thread. Note, |this| may
@@ -200,8 +144,19 @@ void StoreUpdateData::CopyPredictionModelIntoUpdateData(
proto::StoreEntry entry_proto;
entry_proto.set_entry_type(static_cast<proto::StoreEntryType>(
OptimizationGuideStore::StoreEntryType::kPredictionModel));
+
+ base::TimeDelta expiry_duration;
+ if (prediction_model.model_info().has_valid_duration()) {
+ expiry_duration =
+ base::Seconds(prediction_model.model_info().valid_duration().seconds());
+ } else {
+ expiry_duration = features::StoredFetchedHintsFreshnessDuration();
+ }
+ expiry_time_ = base::Time::Now() + expiry_duration;
entry_proto.set_expiry_time_secs(
- expiry_time_->ToDeltaSinceWindowsEpoch().InSeconds());
+ expiry_time_.value().ToDeltaSinceWindowsEpoch().InSeconds());
+ entry_proto.set_keep_beyond_valid_duration(
+ prediction_model.model_info().keep_beyond_valid_duration());
entry_proto.mutable_prediction_model()->CopyFrom(prediction_model);
entries_to_save_->emplace_back(std::move(prediction_model_entry_key),
std::move(entry_proto));
diff --git a/chromium/components/optimization_guide/core/store_update_data.h b/chromium/components/optimization_guide/core/store_update_data.h
index 68a098a8ec3..05dbc9a872e 100644
--- a/chromium/components/optimization_guide/core/store_update_data.h
+++ b/chromium/components/optimization_guide/core/store_update_data.h
@@ -16,7 +16,6 @@
namespace optimization_guide {
namespace proto {
class Hint;
-class HostModelFeatures;
class PredictionModel;
class StoreEntry;
} // namespace proto
@@ -24,8 +23,7 @@ class StoreEntry;
using EntryVector =
leveldb_proto::ProtoDatabase<proto::StoreEntry>::KeyEntryVector;
-// Holds hint, prediction model, or host model features data for updating the
-// OptimizationGuideStore.
+// Holds hint or prediction model data for updating the OptimizationGuideStore.
class StoreUpdateData {
public:
StoreUpdateData(const StoreUpdateData&) = delete;
@@ -45,12 +43,6 @@ class StoreUpdateData {
static std::unique_ptr<StoreUpdateData> CreatePredictionModelStoreUpdateData(
base::Time expiry_time);
- // Creates an update data object for a host model features update.
- static std::unique_ptr<StoreUpdateData>
- CreateHostModelFeaturesStoreUpdateData(
- base::Time host_model_features_update_time,
- base::Time expiry_time);
-
// Returns the component version of a component hint update.
const absl::optional<base::Version> component_version() const {
return component_version_;
@@ -66,10 +58,6 @@ class StoreUpdateData {
// called, |hint| is no longer valid.
void MoveHintIntoUpdateData(proto::Hint&& hint);
- // Copies |host_model_features| into this update data.
- void CopyHostModelFeaturesIntoUpdateData(
- const proto::HostModelFeatures& host_model_features);
-
// Copies |prediction_model| into this update data.
void CopyPredictionModelIntoUpdateData(
const proto::PredictionModel& prediction_model);
@@ -81,8 +69,6 @@ class StoreUpdateData {
StoreUpdateData(absl::optional<base::Version> component_version,
absl::optional<base::Time> fetch_update_time,
absl::optional<base::Time> expiry_time);
- StoreUpdateData(base::Time host_model_features_update_time,
- base::Time expiry_time);
explicit StoreUpdateData(base::Time expiry_time);
// The component version of the update data for a component update.
diff --git a/chromium/components/optimization_guide/core/store_update_data_unittest.cc b/chromium/components/optimization_guide/core/store_update_data_unittest.cc
index 785033692bd..e36087ea00c 100644
--- a/chromium/components/optimization_guide/core/store_update_data_unittest.cc
+++ b/chromium/components/optimization_guide/core/store_update_data_unittest.cc
@@ -120,11 +120,13 @@ TEST(StoreUpdateDataTest, BuildPredictionModelUpdateData) {
model_info->set_version(1);
model_info->set_optimization_target(
proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD);
- model_info->add_supported_model_types(
- proto::ModelType::MODEL_TYPE_DECISION_TREE);
+ model_info->add_supported_model_engine_versions(
+ proto::ModelEngineVersion::MODEL_ENGINE_VERSION_DECISION_TREE);
+ model_info->set_keep_beyond_valid_duration(false);
- base::Time expected_expiry_time =
- base::Time::Now() + features::StoredModelsInactiveDuration();
+ model_info->mutable_valid_duration()->set_seconds(3);
+
+ base::Time expected_expiry_time = base::Time::Now() + base::Seconds(3);
std::unique_ptr<StoreUpdateData> prediction_model_update =
StoreUpdateData::CreatePredictionModelStoreUpdateData(
expected_expiry_time);
@@ -143,39 +145,14 @@ TEST(StoreUpdateDataTest, BuildPredictionModelUpdateData) {
found_prediction_model_entry = true;
EXPECT_EQ(expected_expiry_time.ToDeltaSinceWindowsEpoch().InSeconds(),
store_entry.expiry_time_secs());
+ EXPECT_EQ(store_entry.keep_beyond_valid_duration(),
+ model_info->keep_beyond_valid_duration());
break;
}
}
EXPECT_TRUE(found_prediction_model_entry);
}
-TEST(StoreUpdateDataTest, BuildHostModelFeaturesUpdateData) {
- // Verify creating a Prediction Model update data.
- base::Time host_model_features_update_time = base::Time::Now();
-
- proto::HostModelFeatures host_model_features;
- host_model_features.set_host("foo.com");
- proto::ModelFeature* model_feature = host_model_features.add_model_features();
- model_feature->set_feature_name("host_feat1");
- model_feature->set_double_value(2.0);
-
- std::unique_ptr<StoreUpdateData> host_model_features_update =
- StoreUpdateData::CreateHostModelFeaturesStoreUpdateData(
- host_model_features_update_time,
- host_model_features_update_time +
- optimization_guide::features::
- StoredHostModelFeaturesFreshnessDuration());
- host_model_features_update->CopyHostModelFeaturesIntoUpdateData(
- std::move(host_model_features));
- EXPECT_FALSE(host_model_features_update->component_version().has_value());
- EXPECT_TRUE(host_model_features_update->update_time().has_value());
- EXPECT_EQ(host_model_features_update_time,
- *host_model_features_update->update_time());
- // Verify there are 2 store entries, 1 for the metadata entry and 1 for the
- // added host model features entry.
- EXPECT_EQ(2ul, host_model_features_update->TakeUpdateEntries()->size());
-}
-
} // namespace
} // namespace optimization_guide
diff --git a/chromium/components/optimization_guide/core/test_model_executor.cc b/chromium/components/optimization_guide/core/test_model_executor.cc
index dbd8c4072b5..07c75aafcd7 100644
--- a/chromium/components/optimization_guide/core/test_model_executor.cc
+++ b/chromium/components/optimization_guide/core/test_model_executor.cc
@@ -7,13 +7,13 @@
namespace optimization_guide {
void TestModelExecutor::SendForExecution(
- ExecutionCallback ui_callback_on_complete,
+ ExecutionCallback callback_on_complete,
base::TimeTicks start_time,
const std::vector<float>& args) {
std::vector<float> results = std::vector<float>();
for (auto& arg : args)
results.push_back(arg);
- std::move(ui_callback_on_complete).Run(std::move(results));
+ std::move(callback_on_complete).Run(std::move(results));
}
} // namespace optimization_guide
diff --git a/chromium/components/optimization_guide/core/test_model_executor.h b/chromium/components/optimization_guide/core/test_model_executor.h
index 42dd3cfcff1..987942810e8 100644
--- a/chromium/components/optimization_guide/core/test_model_executor.h
+++ b/chromium/components/optimization_guide/core/test_model_executor.h
@@ -6,7 +6,6 @@
#define COMPONENTS_OPTIMIZATION_GUIDE_CORE_TEST_MODEL_EXECUTOR_H_
#include "components/optimization_guide/core/model_executor.h"
-#include "third_party/abseil-cpp/absl/status/status.h"
namespace optimization_guide {
@@ -16,7 +15,7 @@ class TestModelExecutor
TestModelExecutor() = default;
~TestModelExecutor() override = default;
- void InitializeAndMoveToBackgroundThread(
+ void InitializeAndMoveToExecutionThread(
proto::OptimizationTarget,
scoped_refptr<base::SequencedTaskRunner>,
scoped_refptr<base::SequencedTaskRunner>) override {}
@@ -29,7 +28,7 @@ class TestModelExecutor
using ExecutionCallback =
base::OnceCallback<void(const absl::optional<std::vector<float>>&)>;
- void SendForExecution(ExecutionCallback ui_callback_on_complete,
+ void SendForExecution(ExecutionCallback callback_on_complete,
base::TimeTicks start_time,
const std::vector<float>& args) override;
};
diff --git a/chromium/components/optimization_guide/core/test_model_info_builder.cc b/chromium/components/optimization_guide/core/test_model_info_builder.cc
index 9efa9d4892d..517462239da 100644
--- a/chromium/components/optimization_guide/core/test_model_info_builder.cc
+++ b/chromium/components/optimization_guide/core/test_model_info_builder.cc
@@ -4,8 +4,8 @@
#include "components/optimization_guide/core/test_model_info_builder.h"
+#include "components/optimization_guide/core/model_util.h"
#include "components/optimization_guide/core/optimization_guide_test_util.h"
-#include "components/optimization_guide/core/optimization_guide_util.h"
namespace optimization_guide {
diff --git a/chromium/components/optimization_guide/core/test_tflite_model_executor.cc b/chromium/components/optimization_guide/core/test_tflite_model_executor.cc
index 9b7ea803949..a310697de19 100644
--- a/chromium/components/optimization_guide/core/test_tflite_model_executor.cc
+++ b/chromium/components/optimization_guide/core/test_tflite_model_executor.cc
@@ -8,17 +8,19 @@
namespace optimization_guide {
-absl::Status TestTFLiteModelExecutor::Preprocess(
+bool TestTFLiteModelExecutor::Preprocess(
const std::vector<TfLiteTensor*>& input_tensors,
const std::vector<float>& input) {
- tflite::task::core::PopulateTensor<float>(input, input_tensors[0]);
- return absl::OkStatus();
+ return tflite::task::core::PopulateTensor<float>(input, input_tensors[0])
+ .ok();
}
-std::vector<float> TestTFLiteModelExecutor::Postprocess(
+absl::optional<std::vector<float>> TestTFLiteModelExecutor::Postprocess(
const std::vector<const TfLiteTensor*>& output_tensors) {
std::vector<float> data;
- tflite::task::core::PopulateVector<float>(output_tensors[0], &data);
+ absl::Status status =
+ tflite::task::core::PopulateVector<float>(output_tensors[0], &data);
+ DCHECK(status.ok());
return data;
}
diff --git a/chromium/components/optimization_guide/core/test_tflite_model_executor.h b/chromium/components/optimization_guide/core/test_tflite_model_executor.h
index ac33fd33603..39c99e4582a 100644
--- a/chromium/components/optimization_guide/core/test_tflite_model_executor.h
+++ b/chromium/components/optimization_guide/core/test_tflite_model_executor.h
@@ -16,10 +16,10 @@ class TestTFLiteModelExecutor
~TestTFLiteModelExecutor() override = default;
protected:
- absl::Status Preprocess(const std::vector<TfLiteTensor*>& input_tensors,
- const std::vector<float>& input) override;
+ bool Preprocess(const std::vector<TfLiteTensor*>& input_tensors,
+ const std::vector<float>& input) override;
- std::vector<float> Postprocess(
+ absl::optional<std::vector<float>> Postprocess(
const std::vector<const TfLiteTensor*>& output_tensors) override;
};
diff --git a/chromium/components/optimization_guide/core/tflite_model_executor.h b/chromium/components/optimization_guide/core/tflite_model_executor.h
index fe615dc6ee9..7b733101746 100644
--- a/chromium/components/optimization_guide/core/tflite_model_executor.h
+++ b/chromium/components/optimization_guide/core/tflite_model_executor.h
@@ -17,9 +17,9 @@
#include "base/time/time.h"
#include "base/trace_event/trace_event.h"
#include "components/optimization_guide/core/execution_status.h"
+#include "components/optimization_guide/core/model_enums.h"
#include "components/optimization_guide/core/model_executor.h"
-#include "components/optimization_guide/core/optimization_guide_enums.h"
-#include "components/optimization_guide/core/optimization_guide_util.h"
+#include "components/optimization_guide/core/model_util.h"
#include "third_party/abseil-cpp/absl/types/optional.h"
#include "third_party/tflite/src/tensorflow/lite/c/common.h"
#include "third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/base_task_api.h"
@@ -34,8 +34,7 @@ class ScopedExecutionStatusResultRecorder {
public:
explicit ScopedExecutionStatusResultRecorder(
proto::OptimizationTarget optimization_target)
- : optimization_target_(optimization_target),
- start_time_(base::TimeTicks::Now()) {}
+ : optimization_target_(optimization_target) {}
~ScopedExecutionStatusResultRecorder() {
base::UmaHistogramEnumeration(
@@ -43,12 +42,6 @@ class ScopedExecutionStatusResultRecorder {
optimization_guide::GetStringNameForOptimizationTarget(
optimization_target_),
status_);
-
- base::UmaHistogramTimes(
- "OptimizationGuide.ModelExecutor.ModelLoadingDuration." +
- optimization_guide::GetStringNameForOptimizationTarget(
- optimization_target_),
- base::TimeTicks::Now() - start_time_);
}
ExecutionStatus* mutable_status() { return &status_; }
@@ -61,9 +54,6 @@ class ScopedExecutionStatusResultRecorder {
// The OptimizationTarget of the model being executed.
const proto::OptimizationTarget optimization_target_;
- // The time at which this instance was constructed.
- const base::TimeTicks start_time_;
-
ExecutionStatus status_ = ExecutionStatus::kUnknown;
};
@@ -88,26 +78,26 @@ class TFLiteModelExecutor : public ModelExecutor<OutputType, InputTypes...> {
}
// Should be called on the same sequence as the ctor, but once called |this|
- // must only be used from a background thread/sequence.
- void InitializeAndMoveToBackgroundThread(
+ // must only be used from the |execution_task_runner| thread/sequence.
+ void InitializeAndMoveToExecutionThread(
proto::OptimizationTarget optimization_target,
- scoped_refptr<base::SequencedTaskRunner> background_task_runner,
+ scoped_refptr<base::SequencedTaskRunner> execution_task_runner,
scoped_refptr<base::SequencedTaskRunner> reply_task_runner) override {
- DCHECK(!background_task_runner_);
+ DCHECK(!execution_task_runner_);
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK_NE(optimization_target,
proto::OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN);
DETACH_FROM_SEQUENCE(sequence_checker_);
optimization_target_ = optimization_target;
- background_task_runner_ = background_task_runner;
+ execution_task_runner_ = execution_task_runner;
reply_task_runner_ = reply_task_runner;
}
// Called when a model file is available to load. Depending on feature flags,
// the model may or may not be immediately loaded.
void UpdateModelFile(const base::FilePath& file_path) override {
- DCHECK(background_task_runner_->RunsTasksInCurrentSequence());
+ DCHECK(execution_task_runner_->RunsTasksInCurrentSequence());
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
UnloadModel();
@@ -130,7 +120,7 @@ class TFLiteModelExecutor : public ModelExecutor<OutputType, InputTypes...> {
// called. False is the default behavior (see class comment).
void SetShouldUnloadModelOnComplete(
bool should_unload_model_on_complete) override {
- DCHECK(background_task_runner_->RunsTasksInCurrentSequence());
+ DCHECK(execution_task_runner_->RunsTasksInCurrentSequence());
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
should_unload_model_on_complete_ = should_unload_model_on_complete;
}
@@ -142,21 +132,21 @@ class TFLiteModelExecutor : public ModelExecutor<OutputType, InputTypes...> {
"OptimizationTarget",
optimization_guide::GetStringNameForOptimizationTarget(
optimization_target_));
- DCHECK(background_task_runner_->RunsTasksInCurrentSequence());
+ DCHECK(execution_task_runner_->RunsTasksInCurrentSequence());
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
loaded_model_.reset();
model_fb_.reset();
}
- // Starts the execution of the model. When complete, |ui_callback_on_complete|
- // will be run on the UI thread with the output of the model.
+ // Starts the execution of the model. When complete, |callback_on_complete|
+ // will be run via |reply_task_runner_| with the output of the model.
using ExecutionCallback =
base::OnceCallback<void(const absl::optional<OutputType>&)>;
- void SendForExecution(ExecutionCallback ui_callback_on_complete,
+ void SendForExecution(ExecutionCallback callback_on_complete,
base::TimeTicks start_time,
InputTypes... args) override {
- DCHECK(background_task_runner_->RunsTasksInCurrentSequence());
+ DCHECK(execution_task_runner_->RunsTasksInCurrentSequence());
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK(reply_task_runner_);
@@ -175,7 +165,7 @@ class TFLiteModelExecutor : public ModelExecutor<OutputType, InputTypes...> {
if (!loaded_model_ && !LoadModelFile(status_recorder.mutable_status())) {
reply_task_runner_->PostTask(
FROM_HERE,
- base::BindOnce(std::move(ui_callback_on_complete), absl::nullopt));
+ base::BindOnce(std::move(callback_on_complete), absl::nullopt));
// Some error status is expected, and derived classes should have set the
// status.
DCHECK_NE(status_recorder.status(), ExecutionStatus::kUnknown);
@@ -213,19 +203,13 @@ class TFLiteModelExecutor : public ModelExecutor<OutputType, InputTypes...> {
base::TimeTicks::Now() - execute_start_time);
}
- DCHECK(ui_callback_on_complete);
+ DCHECK(callback_on_complete);
reply_task_runner_->PostTask(
- FROM_HERE, base::BindOnce(std::move(ui_callback_on_complete), output));
+ FROM_HERE, base::BindOnce(std::move(callback_on_complete), output));
OnExecutionComplete();
}
- // IMPORTANT: These WeakPointers must only be dereferenced on the background
- // thread.
- base::WeakPtr<TFLiteModelExecutor> GetBackgroundWeakPtr() {
- return background_weak_ptr_factory_.GetWeakPtr();
- }
-
TFLiteModelExecutor(const TFLiteModelExecutor&) = delete;
TFLiteModelExecutor& operator=(const TFLiteModelExecutor&) = delete;
@@ -252,7 +236,7 @@ class TFLiteModelExecutor : public ModelExecutor<OutputType, InputTypes...> {
"OptimizationTarget",
optimization_guide::GetStringNameForOptimizationTarget(
optimization_target_));
- DCHECK(background_task_runner_->RunsTasksInCurrentSequence());
+ DCHECK(execution_task_runner_->RunsTasksInCurrentSequence());
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
UnloadModel();
@@ -267,6 +251,8 @@ class TFLiteModelExecutor : public ModelExecutor<OutputType, InputTypes...> {
return false;
}
+ base::TimeTicks loading_start_time = base::TimeTicks::Now();
+
std::unique_ptr<base::MemoryMappedFile> model_fb =
std::make_unique<base::MemoryMappedFile>();
if (!model_fb->Initialize(*model_file_path_)) {
@@ -277,11 +263,28 @@ class TFLiteModelExecutor : public ModelExecutor<OutputType, InputTypes...> {
loaded_model_ = BuildModelExecutionTask(model_fb_.get(), out_status);
+ if (!!loaded_model_) {
+ // We only want to record successful loading times.
+ base::UmaHistogramTimes(
+ "OptimizationGuide.ModelExecutor.ModelLoadingDuration2." +
+ optimization_guide::GetStringNameForOptimizationTarget(
+ optimization_target_),
+ base::TimeTicks::Now() - loading_start_time);
+ }
+
+ // Local histogram used in integration testing.
+ base::BooleanHistogram::FactoryGet(
+ "OptimizationGuide.ModelExecutor.ModelLoadedSuccessfully." +
+ optimization_guide::GetStringNameForOptimizationTarget(
+ optimization_target_),
+ base::Histogram::kNoFlags)
+ ->Add(!!loaded_model_);
+
return !!loaded_model_;
}
void OnExecutionComplete() {
- DCHECK(background_task_runner_->RunsTasksInCurrentSequence());
+ DCHECK(execution_task_runner_->RunsTasksInCurrentSequence());
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (should_unload_model_on_complete_) {
UnloadModel();
@@ -293,7 +296,7 @@ class TFLiteModelExecutor : public ModelExecutor<OutputType, InputTypes...> {
bool should_unload_model_on_complete_ = true;
- scoped_refptr<base::SequencedTaskRunner> background_task_runner_;
+ scoped_refptr<base::SequencedTaskRunner> execution_task_runner_;
scoped_refptr<base::SequencedTaskRunner> reply_task_runner_;
@@ -320,8 +323,6 @@ class TFLiteModelExecutor : public ModelExecutor<OutputType, InputTypes...> {
GUARDED_BY_CONTEXT(sequence_checker_);
SEQUENCE_CHECKER(sequence_checker_);
-
- base::WeakPtrFactory<TFLiteModelExecutor> background_weak_ptr_factory_{this};
};
} // namespace optimization_guide
diff --git a/chromium/components/optimization_guide/core/tflite_model_executor_unittest.cc b/chromium/components/optimization_guide/core/tflite_model_executor_unittest.cc
index 1c519e51639..c0f5f76f91e 100644
--- a/chromium/components/optimization_guide/core/tflite_model_executor_unittest.cc
+++ b/chromium/components/optimization_guide/core/tflite_model_executor_unittest.cc
@@ -187,6 +187,11 @@ TEST_F(TFLiteModelExecutorTest, ExecuteWithLoadedModel) {
optimization_guide::GetStringNameForOptimizationTarget(
proto::OptimizationTarget::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD),
ExecutionStatus::kSuccess, 1);
+ histogram_tester.ExpectUniqueSample(
+ "OptimizationGuide.ModelExecutor.ModelLoadedSuccessfully." +
+ optimization_guide::GetStringNameForOptimizationTarget(
+ proto::OptimizationTarget::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD),
+ true, 1);
}
TEST_F(TFLiteModelExecutorTest, ExecuteTwiceWithLoadedModel) {
@@ -228,6 +233,11 @@ TEST_F(TFLiteModelExecutorTest, ExecuteTwiceWithLoadedModel) {
optimization_guide::GetStringNameForOptimizationTarget(
proto::OptimizationTarget::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD),
ExecutionStatus::kSuccess, 1);
+ histogram_tester.ExpectUniqueSample(
+ "OptimizationGuide.ModelExecutor.ModelLoadedSuccessfully." +
+ optimization_guide::GetStringNameForOptimizationTarget(
+ proto::OptimizationTarget::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD),
+ true, 1);
// Second run.
run_loop = std::make_unique<base::RunLoop>();
@@ -253,6 +263,11 @@ TEST_F(TFLiteModelExecutorTest, ExecuteTwiceWithLoadedModel) {
optimization_guide::GetStringNameForOptimizationTarget(
proto::OptimizationTarget::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD),
ExecutionStatus::kSuccess, 2);
+ histogram_tester.ExpectUniqueSample(
+ "OptimizationGuide.ModelExecutor.ModelLoadedSuccessfully." +
+ optimization_guide::GetStringNameForOptimizationTarget(
+ proto::OptimizationTarget::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD),
+ true, 2);
histogram_tester.ExpectTotalCount(
"OptimizationGuide.ModelExecutor.TaskExecutionLatency." +
diff --git a/chromium/components/optimization_guide/core/tflite_op_resolver.cc b/chromium/components/optimization_guide/core/tflite_op_resolver.cc
index 1cd89433031..8674eb8d700 100644
--- a/chromium/components/optimization_guide/core/tflite_op_resolver.cc
+++ b/chromium/components/optimization_guide/core/tflite_op_resolver.cc
@@ -371,6 +371,10 @@ TFLiteOpResolver::TFLiteOpResolver() {
tflite::ops::builtin::Register_BATCH_MATMUL(),
/* min_version = */ 1,
/* max_version = */ 4);
+ AddBuiltin(tflite::BuiltinOperator_GELU,
+ tflite::ops::builtin::Register_GELU(),
+ /* min_version = */ 1,
+ /* max_version = */ 2);
}
} // namespace optimization_guide
diff --git a/chromium/components/optimization_guide/features.gni b/chromium/components/optimization_guide/features.gni
index 26d60df3f0d..2ed17e3165e 100644
--- a/chromium/components/optimization_guide/features.gni
+++ b/chromium/components/optimization_guide/features.gni
@@ -13,6 +13,18 @@ declare_args() {
# You can set the variable 'build_with_internal_optimization_guide' to true
# even in a developer build in args.gn. Setting this variable explicitly to true will
# cause your build to fail if the internal files are missing.
+ #
+ # If changing the value of this, you MUST also update the following files depending on the
+ # platform:
+ # ChromeOS: //lib/chrome_util.py in the Chromite repo (ex: https://crrev.com/c/3437291)
+ # Linux: Internal archive files. //chrome/installer/linux/common/installer.include handles the
+ # relevant files not being present.
+ # Mac: //chrome/installer/mac/signing/parts.py
+ # Windows: //chrome/installer/mini_installer/chrome.release and internal archive files
+ #
+ # The library this pulls in depends on open-source LevelDB which is not supported for Fuchsia.
+ # Android and iOS should just work but are not included in the set we release for, so we do
+ # not needlessly increase the binary.
build_with_internal_optimization_guide =
- is_chrome_branded && !is_android && !is_ios
+ is_chrome_branded && !is_android && !is_ios && !is_fuchsia
}
diff --git a/chromium/components/optimization_guide/optimization_guide_internals/resources/BUILD.gn b/chromium/components/optimization_guide/optimization_guide_internals/resources/BUILD.gn
new file mode 100644
index 00000000000..22359bed596
--- /dev/null
+++ b/chromium/components/optimization_guide/optimization_guide_internals/resources/BUILD.gn
@@ -0,0 +1,71 @@
+# Copyright 2022 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+import("//tools/grit/grit_rule.gni")
+import("//tools/polymer/html_to_js.gni")
+import("//tools/typescript/ts_library.gni")
+import("//ui/webui/resources/tools/generate_grd.gni")
+
+tsc_folder = "tsc"
+
+grit("resources") {
+ # These arguments are needed since the grd is generated at build time.
+ enable_input_discovery_for_gn_analyze = false
+ source = "$target_gen_dir/resources.grd"
+ deps = [ ":build_grd" ]
+
+ outputs = [
+ "grit/optimization_guide_internals_resources.h",
+ "grit/optimization_guide_internals_resources_map.cc",
+ "grit/optimization_guide_internals_resources_map.h",
+ "optimization_guide_internals_resources.pak",
+ ]
+
+ output_dir = "$root_gen_dir/components"
+}
+
+generate_grd("build_grd") {
+ grd_prefix = "optimization_guide_internals"
+ out_grd = "$target_gen_dir/resources.grd"
+ deps = [ ":build_ts" ]
+ manifest_files = [ "$target_gen_dir/tsconfig.manifest" ]
+ input_files = [ "optimization_guide_internals.html" ]
+ input_files_base_dir = rebase_path(".", "//")
+}
+
+html_to_js("web_components") {
+ js_files = [ "optimization_guide_internals.ts" ]
+}
+
+copy("copy_proxy") {
+ sources = [ "optimization_guide_internals_browser_proxy.ts" ]
+ outputs = [ "$target_gen_dir/{{source_file_part}}" ]
+}
+
+copy("copy_mojo") {
+ deps = [ "//components/optimization_guide/optimization_guide_internals/webui:mojo_bindings_webui_js" ]
+ mojom_folder = "$root_gen_dir/mojom-webui/components/optimization_guide/optimization_guide_internals/webui/"
+ sources = [ "$mojom_folder/optimization_guide_internals.mojom-webui.js" ]
+ outputs = [ "$target_gen_dir/{{source_file_part}}" ]
+}
+
+ts_library("build_ts") {
+ root_dir = "$target_gen_dir"
+ out_dir = "$target_gen_dir/$tsc_folder"
+ tsconfig_base = "tsconfig_base.json"
+ in_files = [
+ "optimization_guide_internals.ts",
+ "optimization_guide_internals_browser_proxy.ts",
+ "optimization_guide_internals.mojom-webui.js",
+ ]
+ deps = [
+ "//ui/webui/resources:library",
+ "//ui/webui/resources/js/browser_command:build_ts",
+ ]
+ extra_deps = [
+ ":copy_mojo",
+ ":copy_proxy",
+ ":web_components",
+ ]
+}
diff --git a/chromium/components/optimization_guide/optimization_guide_internals/resources/OWNERS b/chromium/components/optimization_guide/optimization_guide_internals/resources/OWNERS
new file mode 100644
index 00000000000..1ee724d06f4
--- /dev/null
+++ b/chromium/components/optimization_guide/optimization_guide_internals/resources/OWNERS
@@ -0,0 +1 @@
+file://components/optimization_guide/OWNERS
diff --git a/chromium/components/optimization_guide/optimization_guide_internals/resources/optimization_guide_internals.html b/chromium/components/optimization_guide/optimization_guide_internals/resources/optimization_guide_internals.html
new file mode 100644
index 00000000000..8237d8a8f29
--- /dev/null
+++ b/chromium/components/optimization_guide/optimization_guide_internals/resources/optimization_guide_internals.html
@@ -0,0 +1,29 @@
+<!doctype html>
+<html lang="en" dir="ltr">
+ <head>
+ <style>
+ .segment {
+ border: 1px outset black;
+ margin: 2px 2px 2px 2px;
+ }
+ </style>
+ <meta charset="utf-8">
+ <title>Optimization Guide Internals</title>
+ <meta name="viewport" content="width=device-width">
+ <link rel="stylesheet" href="chrome://resources/css/text_defaults.css">
+ </head>
+ <body>
+ <h1>Optimization Guide Internals - Debug Logs</h1>
+ <button id="log-messages-dump">Dump</button>
+ <table id="log-message-container">
+ <thead>
+ <tr>
+ <th>Time</th>
+ <th>Source Location</th>
+ <th>Log Message</th>
+ </tr>
+ </thead>
+ </table>
+ <script type="module" src="optimization_guide_internals.js"></script>
+ </body>
+</html> \ No newline at end of file
diff --git a/chromium/components/optimization_guide/optimization_guide_internals/resources/optimization_guide_internals.ts b/chromium/components/optimization_guide/optimization_guide_internals/resources/optimization_guide_internals.ts
new file mode 100644
index 00000000000..df91b5bcf79
--- /dev/null
+++ b/chromium/components/optimization_guide/optimization_guide_internals/resources/optimization_guide_internals.ts
@@ -0,0 +1,80 @@
+// Copyright 2022 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+import {$} from 'chrome://resources/js/util.m.js';
+import {Time} from 'chrome://resources/mojo/mojo/public/mojom/base/time.mojom-webui.js';
+
+import {OptimizationGuideInternalsBrowserProxy} from './optimization_guide_internals_browser_proxy.js';
+
+// Contains all the log events received when the internals page is open.
+const logMessages:
+ {eventTime: string, sourceLocation: string, message: string}[] = [];
+
+/**
+ * Converts a mojo time to a JS time.
+ * @param {!mojoBase.mojom.Time} mojoTime
+ * @return {!Date}
+ */
+function convertMojoTimeToJS(mojoTime: Time) {
+ // The JS Date() is based off of the number of milliseconds since the
+ // UNIX epoch (1970-01-01 00::00:00 UTC), while |internalValue| of the
+ // base::Time (represented in mojom.Time) represents the number of
+ // microseconds since the Windows FILETIME epoch (1601-01-01 00:00:00 UTC).
+ // This computes the final JS time by computing the epoch delta and the
+ // conversion from microseconds to milliseconds.
+ const windowsEpoch = Date.UTC(1601, 0, 1, 0, 0, 0, 0);
+ const unixEpoch = Date.UTC(1970, 0, 1, 0, 0, 0, 0);
+ // |epochDeltaInMs| equals to base::Time::kTimeTToMicrosecondsOffset.
+ const epochDeltaInMs = unixEpoch - windowsEpoch;
+ const timeInMs = Number(mojoTime.internalValue) / 1000;
+
+ return new Date(timeInMs - epochDeltaInMs);
+}
+
+/**
+ * The callback to button#log-messages-dump to save the logs to a file.
+ */
+function onLogMessagesDump() {
+ const data = JSON.stringify(logMessages);
+ const blob = new Blob([data], {'type': 'text/json'});
+ const url = URL.createObjectURL(blob);
+ const filename = 'optimization_guide_internals_logs_dump.json';
+
+ const a = document.createElement('a');
+ a.setAttribute('href', url);
+ a.setAttribute('download', filename);
+
+ const event = document.createEvent('MouseEvent');
+ event.initMouseEvent(
+ 'click', true, true, window, 0, 0, 0, 0, 0, false, false, false, false, 0,
+ null);
+ a.dispatchEvent(event);
+}
+
+function getProxy(): OptimizationGuideInternalsBrowserProxy {
+ return OptimizationGuideInternalsBrowserProxy.getInstance();
+}
+
+
+function initialize() {
+ const logMessageContainer = $('log-message-container') as HTMLTableElement;
+
+ $('log-messages-dump').addEventListener('click', onLogMessagesDump);
+
+ getProxy().getCallbackRouter().onLogMessageAdded.addListener(
+ (eventTime: Time, sourceFile: string, sourceLine: number,
+ message: string) => {
+ const eventTimeStr = convertMojoTimeToJS(eventTime).toISOString();
+ const sourceLocation = `${sourceFile}(${sourceLine})`;
+ logMessages.push({eventTime: eventTimeStr, sourceLocation, message});
+ if (logMessageContainer) {
+ const logmessage = logMessageContainer.insertRow();
+ logmessage.insertCell().innerHTML = eventTimeStr;
+ logmessage.insertCell().innerHTML = sourceLocation;
+ logmessage.insertCell().innerHTML = message;
+ }
+ });
+}
+
+document.addEventListener('DOMContentLoaded', initialize); \ No newline at end of file
diff --git a/chromium/components/optimization_guide/optimization_guide_internals/resources/optimization_guide_internals_browser_proxy.ts b/chromium/components/optimization_guide/optimization_guide_internals/resources/optimization_guide_internals_browser_proxy.ts
new file mode 100644
index 00000000000..b7636abbe61
--- /dev/null
+++ b/chromium/components/optimization_guide/optimization_guide_internals/resources/optimization_guide_internals_browser_proxy.ts
@@ -0,0 +1,26 @@
+// Copyright 2022 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+import {PageCallbackRouter, PageHandlerFactory} from './optimization_guide_internals.mojom-webui.js';
+
+export class OptimizationGuideInternalsBrowserProxy {
+ private callbackRouter: PageCallbackRouter;
+
+ constructor() {
+ this.callbackRouter = new PageCallbackRouter();
+ const factory = PageHandlerFactory.getRemote();
+ factory.createPageHandler(this.callbackRouter.$.bindNewPipeAndPassRemote());
+ }
+
+ static getInstance(): OptimizationGuideInternalsBrowserProxy {
+ return instance ||
+ (instance = new OptimizationGuideInternalsBrowserProxy());
+ }
+
+ getCallbackRouter(): PageCallbackRouter {
+ return this.callbackRouter;
+ }
+}
+
+let instance: OptimizationGuideInternalsBrowserProxy|null = null; \ No newline at end of file
diff --git a/chromium/components/optimization_guide/optimization_guide_internals/resources/tsconfig_base.json b/chromium/components/optimization_guide/optimization_guide_internals/resources/tsconfig_base.json
new file mode 100644
index 00000000000..cbf406ef813
--- /dev/null
+++ b/chromium/components/optimization_guide/optimization_guide_internals/resources/tsconfig_base.json
@@ -0,0 +1,6 @@
+{
+ "extends": "../../../../tools/typescript/tsconfig_base.json",
+ "compilerOptions": {
+ "allowJs": true
+ }
+} \ No newline at end of file
diff --git a/chromium/components/optimization_guide/optimization_guide_internals/webui/BUILD.gn b/chromium/components/optimization_guide/optimization_guide_internals/webui/BUILD.gn
new file mode 100644
index 00000000000..f5952e8ff2f
--- /dev/null
+++ b/chromium/components/optimization_guide/optimization_guide_internals/webui/BUILD.gn
@@ -0,0 +1,30 @@
+# Copyright 2022 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+import("//mojo/public/tools/bindings/mojom.gni")
+
+static_library("webui") {
+ sources = [
+ "optimization_guide_internals_page_handler_impl.cc",
+ "optimization_guide_internals_page_handler_impl.h",
+ "optimization_guide_internals_ui.cc",
+ "optimization_guide_internals_ui.h",
+ "url_constants.cc",
+ "url_constants.h",
+ ]
+ deps = [
+ "//base",
+ "//components/optimization_guide/core",
+ "//components/optimization_guide/optimization_guide_internals/resources:resources",
+ "//components/optimization_guide/optimization_guide_internals/webui:mojo_bindings",
+ "//third_party/abseil-cpp:absl",
+ "//ui/base",
+ "//ui/webui",
+ ]
+}
+mojom("mojo_bindings") {
+ sources = [ "optimization_guide_internals.mojom" ]
+ webui_module_path = "/"
+ public_deps = [ "//mojo/public/mojom/base" ]
+}
diff --git a/chromium/components/optimization_guide/optimization_guide_internals/webui/DEPS b/chromium/components/optimization_guide/optimization_guide_internals/webui/DEPS
new file mode 100644
index 00000000000..930de395723
--- /dev/null
+++ b/chromium/components/optimization_guide/optimization_guide_internals/webui/DEPS
@@ -0,0 +1,7 @@
+include_rules = [
+ "+mojo/public/cpp/bindings",
+ "+ui/base/webui/resource_path.h",
+ "+ui/webui/mojo_web_ui_controller.h",
+ "+components/grit/optimization_guide_internals_resources.h",
+ "+components/grit/optimization_guide_internals_resources_map.h",
+]
diff --git a/chromium/components/optimization_guide/optimization_guide_internals/webui/OWNERS b/chromium/components/optimization_guide/optimization_guide_internals/webui/OWNERS
new file mode 100644
index 00000000000..5245dc478f1
--- /dev/null
+++ b/chromium/components/optimization_guide/optimization_guide_internals/webui/OWNERS
@@ -0,0 +1,4 @@
+file://components/optimization_guide/OWNERS
+
+per-file *.mojom=set noparent
+per-file *.mojom=file://ipc/SECURITY_OWNERS \ No newline at end of file
diff --git a/chromium/components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals.mojom b/chromium/components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals.mojom
new file mode 100644
index 00000000000..b4eec080fa0
--- /dev/null
+++ b/chromium/components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals.mojom
@@ -0,0 +1,24 @@
+// Copyright 2022 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+
+module optimization_guide_internals.mojom;
+
+import "mojo/public/mojom/base/time.mojom";
+
+// Used by the WebUI page to bootstrap bidirectional communication.
+interface PageHandlerFactory {
+ // The WebUI calls this method when the page is first initialized.
+ CreatePageHandler(pending_remote<Page> page);
+};
+
+// Renderer-side handler for internal page to process the updates from
+// the OptimizationGuide service.
+interface Page {
+ // Notifies the page of a log event from the OptimizationGuide service.
+ OnLogMessageAdded(mojo_base.mojom.Time event_time,
+ string source_file,
+ int64 source_line,
+ string message);
+};
diff --git a/chromium/components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals_page_handler_impl.cc b/chromium/components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals_page_handler_impl.cc
new file mode 100644
index 00000000000..40ebfe372f0
--- /dev/null
+++ b/chromium/components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals_page_handler_impl.cc
@@ -0,0 +1,31 @@
+// Copyright 2022 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals_page_handler_impl.h"
+
+#include "base/time/time.h"
+
+OptimizationGuideInternalsPageHandlerImpl::
+ OptimizationGuideInternalsPageHandlerImpl(
+ mojo::PendingRemote<optimization_guide_internals::mojom::Page> page,
+ OptimizationGuideLogger* optimization_guide_logger)
+ : page_(std::move(page)),
+ optimization_guide_logger_(optimization_guide_logger) {
+ if (optimization_guide_logger_)
+ optimization_guide_logger_->AddObserver(this);
+}
+
+OptimizationGuideInternalsPageHandlerImpl::
+ ~OptimizationGuideInternalsPageHandlerImpl() {
+ if (optimization_guide_logger_)
+ optimization_guide_logger_->RemoveObserver(this);
+}
+
+void OptimizationGuideInternalsPageHandlerImpl::OnLogMessageAdded(
+ base::Time event_time,
+ const std::string& source_file,
+ int source_line,
+ const std::string& message) {
+ page_->OnLogMessageAdded(event_time, source_file, source_line, message);
+}
diff --git a/chromium/components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals_page_handler_impl.h b/chromium/components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals_page_handler_impl.h
new file mode 100644
index 00000000000..6816e3b68ec
--- /dev/null
+++ b/chromium/components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals_page_handler_impl.h
@@ -0,0 +1,44 @@
+// Copyright 2022 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef COMPONENTS_OPTIMIZATION_GUIDE_OPTIMIZATION_GUIDE_INTERNALS_WEBUI_OPTIMIZATION_GUIDE_INTERNALS_PAGE_HANDLER_IMPL_H_
+#define COMPONENTS_OPTIMIZATION_GUIDE_OPTIMIZATION_GUIDE_INTERNALS_WEBUI_OPTIMIZATION_GUIDE_INTERNALS_PAGE_HANDLER_IMPL_H_
+
+#include <string>
+
+#include "components/optimization_guide/core/optimization_guide_logger.h"
+#include "components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals.mojom.h"
+#include "mojo/public/cpp/bindings/remote.h"
+
+// Handler for the internals page to receive and forward the log messages.
+class OptimizationGuideInternalsPageHandlerImpl
+ : public OptimizationGuideLogger::Observer {
+ public:
+ OptimizationGuideInternalsPageHandlerImpl(
+ mojo::PendingRemote<optimization_guide_internals::mojom::Page> page,
+ OptimizationGuideLogger* optimization_guide_logger);
+ ~OptimizationGuideInternalsPageHandlerImpl() override;
+
+ OptimizationGuideInternalsPageHandlerImpl(
+ const OptimizationGuideInternalsPageHandlerImpl&) = delete;
+ OptimizationGuideInternalsPageHandlerImpl& operator=(
+ const OptimizationGuideInternalsPageHandlerImpl&) = delete;
+
+ private:
+ // optimization_guide::OptimizationGuideLogger::Observer overrides.
+ void OnLogMessageAdded(base::Time event_time,
+ const std::string& source_file,
+ int source_line,
+ const std::string& message) override;
+
+ mojo::Remote<optimization_guide_internals::mojom::Page> page_;
+
+ // Logger to receive the debug logs from the optimization guide service. Not
+ // owned. Guaranteed to outlive |this|, since the logger is owned by the
+ // optimization guide keyed service, while |this| is part of
+ // RenderFrameHostImpl::WebUIImpl.
+ raw_ptr<OptimizationGuideLogger> optimization_guide_logger_;
+};
+
+#endif // COMPONENTS_OPTIMIZATION_GUIDE_OPTIMIZATION_GUIDE_INTERNALS_WEBUI_OPTIMIZATION_GUIDE_INTERNALS_PAGE_HANDLER_IMPL_H_
diff --git a/chromium/components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals_ui.cc b/chromium/components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals_ui.cc
new file mode 100644
index 00000000000..e69b01edbff
--- /dev/null
+++ b/chromium/components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals_ui.cc
@@ -0,0 +1,38 @@
+// Copyright 2022 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals_ui.h"
+
+#include "components/grit/optimization_guide_internals_resources.h"
+#include "components/grit/optimization_guide_internals_resources_map.h"
+#include "components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals_page_handler_impl.h"
+
+OptimizationGuideInternalsUI::OptimizationGuideInternalsUI(
+ content::WebUI* web_ui,
+ OptimizationGuideLogger* optimization_guide_logger,
+ SetupWebUIDataSourceCallback set_up_data_source_callback)
+ : MojoWebUIController(web_ui, /*enable_chrome_send=*/true),
+ optimization_guide_logger_(optimization_guide_logger) {
+ std::move(set_up_data_source_callback)
+ .Run(base::make_span(kOptimizationGuideInternalsResources,
+ kOptimizationGuideInternalsResourcesSize),
+ IDR_OPTIMIZATION_GUIDE_INTERNALS_OPTIMIZATION_GUIDE_INTERNALS_HTML);
+}
+
+OptimizationGuideInternalsUI::~OptimizationGuideInternalsUI() = default;
+
+void OptimizationGuideInternalsUI::BindInterface(
+ mojo::PendingReceiver<
+ optimization_guide_internals::mojom::PageHandlerFactory> receiver) {
+ optimization_guide_internals_page_factory_receiver_.Bind(std::move(receiver));
+}
+
+void OptimizationGuideInternalsUI::CreatePageHandler(
+ mojo::PendingRemote<optimization_guide_internals::mojom::Page> page) {
+ optimization_guide_internals_page_handler_ =
+ std::make_unique<OptimizationGuideInternalsPageHandlerImpl>(
+ std::move(page), optimization_guide_logger_);
+}
+
+WEB_UI_CONTROLLER_TYPE_IMPL(OptimizationGuideInternalsUI)
diff --git a/chromium/components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals_ui.h b/chromium/components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals_ui.h
new file mode 100644
index 00000000000..4421de745c8
--- /dev/null
+++ b/chromium/components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals_ui.h
@@ -0,0 +1,59 @@
+// Copyright 2022 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef COMPONENTS_OPTIMIZATION_GUIDE_OPTIMIZATION_GUIDE_INTERNALS_WEBUI_OPTIMIZATION_GUIDE_INTERNALS_UI_H_
+#define COMPONENTS_OPTIMIZATION_GUIDE_OPTIMIZATION_GUIDE_INTERNALS_WEBUI_OPTIMIZATION_GUIDE_INTERNALS_UI_H_
+
+#include "components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals.mojom.h"
+#include "mojo/public/cpp/bindings/pending_receiver.h"
+#include "ui/base/webui/resource_path.h"
+#include "ui/webui/mojo_web_ui_controller.h"
+
+class OptimizationGuideLogger;
+class OptimizationGuideInternalsPageHandlerImpl;
+
+// The WebUI controller for chrome://optimization-guide-internals.
+class OptimizationGuideInternalsUI
+ : public ui::MojoWebUIController,
+ public optimization_guide_internals::mojom::PageHandlerFactory {
+ public:
+ using SetupWebUIDataSourceCallback =
+ base::OnceCallback<void(base::span<const webui::ResourcePath> resources,
+ int default_resource)>;
+
+ explicit OptimizationGuideInternalsUI(
+ content::WebUI* web_ui,
+ OptimizationGuideLogger* optimization_guide_logger,
+ SetupWebUIDataSourceCallback set_up_data_source_callback);
+ ~OptimizationGuideInternalsUI() override;
+
+ OptimizationGuideInternalsUI(const OptimizationGuideInternalsUI&) = delete;
+ OptimizationGuideInternalsUI& operator=(const OptimizationGuideInternalsUI&) =
+ delete;
+
+ void BindInterface(
+ mojo::PendingReceiver<
+ optimization_guide_internals::mojom::PageHandlerFactory> receiver);
+
+ private:
+ // optimization_guide_internals::mojom::PageHandlerFactory impls.
+ void CreatePageHandler(
+ mojo::PendingRemote<optimization_guide_internals::mojom::Page> page)
+ override;
+
+ std::unique_ptr<OptimizationGuideInternalsPageHandlerImpl>
+ optimization_guide_internals_page_handler_;
+ mojo::Receiver<optimization_guide_internals::mojom::PageHandlerFactory>
+ optimization_guide_internals_page_factory_receiver_{this};
+
+ // Logger to receive the debug logs from the optimization guide service. Not
+ // owned. Guaranteed to outlive |this|, since the logger is owned by the
+ // optimization guide keyed service, while |this| is part of
+ // RenderFrameHostImpl::WebUIImpl.
+ raw_ptr<OptimizationGuideLogger> optimization_guide_logger_;
+
+ WEB_UI_CONTROLLER_TYPE_DECL();
+};
+
+#endif // COMPONENTS_OPTIMIZATION_GUIDE_OPTIMIZATION_GUIDE_INTERNALS_WEBUI_OPTIMIZATION_GUIDE_INTERNALS_UI_H_
diff --git a/chromium/components/optimization_guide/optimization_guide_internals/webui/url_constants.cc b/chromium/components/optimization_guide/optimization_guide_internals/webui/url_constants.cc
new file mode 100644
index 00000000000..3fc092ccddb
--- /dev/null
+++ b/chromium/components/optimization_guide/optimization_guide_internals/webui/url_constants.cc
@@ -0,0 +1,12 @@
+// Copyright 2022 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "components/optimization_guide/optimization_guide_internals/webui/url_constants.h"
+
+namespace optimization_guide_internals {
+
+const char kChromeUIOptimizationGuideInternalsHost[] =
+ "optimization-guide-internals";
+
+}
diff --git a/chromium/components/optimization_guide/optimization_guide_internals/webui/url_constants.h b/chromium/components/optimization_guide/optimization_guide_internals/webui/url_constants.h
new file mode 100644
index 00000000000..14a21360f89
--- /dev/null
+++ b/chromium/components/optimization_guide/optimization_guide_internals/webui/url_constants.h
@@ -0,0 +1,15 @@
+// Copyright 2022 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef COMPONENTS_OPTIMIZATION_GUIDE_OPTIMIZATION_GUIDE_INTERNALS_WEBUI_URL_CONSTANTS_H_
+#define COMPONENTS_OPTIMIZATION_GUIDE_OPTIMIZATION_GUIDE_INTERNALS_WEBUI_URL_CONSTANTS_H_
+
+namespace optimization_guide_internals {
+
+// The host of the optimization guide internals page URL.
+extern const char kChromeUIOptimizationGuideInternalsHost[];
+
+} // namespace optimization_guide_internals
+
+#endif // COMPONENTS_OPTIMIZATION_GUIDE_OPTIMIZATION_GUIDE_INTERNALS_WEBUI_URL_CONSTANTS_H_
diff --git a/chromium/components/optimization_guide/proto/BUILD.gn b/chromium/components/optimization_guide/proto/BUILD.gn
index 288e9fd2c25..2253201609c 100644
--- a/chromium/components/optimization_guide/proto/BUILD.gn
+++ b/chromium/components/optimization_guide/proto/BUILD.gn
@@ -17,6 +17,7 @@ proto_library("optimization_guide_proto") {
"loading_predictor_metadata.proto",
"models.proto",
"page_entities_metadata.proto",
+ "page_entities_model_metadata.proto",
"page_topics_model_metadata.proto",
"performance_hints_metadata.proto",
"public_image_metadata.proto",
diff --git a/chromium/components/optimization_guide/proto/hint_cache.proto b/chromium/components/optimization_guide/proto/hint_cache.proto
index a27430d3e53..e89dec22aa3 100644
--- a/chromium/components/optimization_guide/proto/hint_cache.proto
+++ b/chromium/components/optimization_guide/proto/hint_cache.proto
@@ -59,4 +59,6 @@ message StoreEntry {
optional PredictionModel prediction_model = 6;
// The actual HostModelFeature data.
optional HostModelFeatures host_model_features = 7;
+ // Whether to delete a model once expiry_time_secs is past.
+ optional bool keep_beyond_valid_duration = 8;
}
diff --git a/chromium/components/optimization_guide/proto/models.proto b/chromium/components/optimization_guide/proto/models.proto
index c69cb127eec..a916421b570 100644
--- a/chromium/components/optimization_guide/proto/models.proto
+++ b/chromium/components/optimization_guide/proto/models.proto
@@ -200,7 +200,7 @@ message AdditionalModelFile {
// Metadata for a prediction model for a specific optimization target.
//
-// Next ID: 8
+// Next ID: 10
message ModelInfo {
reserved 3;
@@ -208,8 +208,9 @@ message ModelInfo {
optional OptimizationTarget optimization_target = 1;
// The version of the model, which is specific to the optimization target.
optional int64 version = 2;
- // The set of model types the requesting client can use to make predictions.
- repeated ModelType supported_model_types = 4;
+ // The set of model engine versions the requesting client can use to do model
+ // inference.
+ repeated ModelEngineVersion supported_model_engine_versions = 4;
// The set of host model features that are referenced by the model.
//
// Note that this should only be populated if part of the response.
@@ -222,6 +223,11 @@ message ModelInfo {
// This does not need to be sent to the server in the request for an update to
// this model. The server will ignore this if sent.
repeated AdditionalModelFile additional_files = 7;
+ // How long the model will remain valid in client storage. If
+ // |keep_beyond_valid_duration| is true, will be ignored.
+ optional Duration valid_duration = 8;
+ // Whether to delete the model once valid_duration has passed.
+ optional bool keep_beyond_valid_duration = 9;
// Mechanism used for model owners to attach metadata to the request or
// response.
//
@@ -267,31 +273,38 @@ enum OptimizationTarget {
// Target for determining topics present on a page.
// TODO(crbug/1266504): Remove PAGE_TOPICS in favor of this target.
OPTIMIZATION_TARGET_PAGE_TOPICS_V2 = 15;
+ // Target for segmentation: Determine users with low engagement with chrome.
+ OPTIMIZATION_TARGET_SEGMENTATION_CHROME_LOW_USER_ENGAGEMENT = 16;
}
-// The types of models that can be evaluated.
+// The model engine versions that can be used to do model inference.
//
// Please only update these enums when a new major version of TFLite rolls.
//
// For example: v1.2.3
// ^
// Change when this number increments.
-enum ModelType {
- MODEL_TYPE_UNKNOWN = 0;
+enum ModelEngineVersion {
+ MODEL_ENGINE_VERSION_UNKNOWN = 0;
// A decision tree.
- MODEL_TYPE_DECISION_TREE = 1;
+ MODEL_ENGINE_VERSION_DECISION_TREE = 1;
// A model using only operations that are supported by TensorflowLite 2.3.0.
- MODEL_TYPE_TFLITE_2_3_0 = 2;
+ MODEL_ENGINE_VERSION_TFLITE_2_3_0 = 2;
// A model using only operations that are supported by TensorflowLite 2.3.0
// with updated FULLY_CONNECTED and BATCH_MUL versions for quantized models.
- MODEL_TYPE_TFLITE_2_3_0_1 = 3;
+ MODEL_ENGINE_VERSION_TFLITE_2_3_0_1 = 3;
// TensorflowLite version 2.4.2, and a bit more up to internal rev number
// 381280669.
- MODEL_TYPE_TFLITE_2_4 = 4;
+ MODEL_ENGINE_VERSION_TFLITE_2_4 = 4;
// TensorflowLite version 2.7.*. This is where regular ~HEAD rolls started.
- MODEL_TYPE_TFLITE_2_7 = 5;
+ MODEL_ENGINE_VERSION_TFLITE_2_7 = 5;
// A model using only operations that are supported by TensorflowLite 2.8.0.
- MODEL_TYPE_TFLITE_2_8 = 6;
+ MODEL_ENGINE_VERSION_TFLITE_2_8 = 6;
+ // A model using only operations that are supported by TensorflowLite 2.9.0.
+ MODEL_ENGINE_VERSION_TFLITE_2_9 = 7;
+ // A model using only operations that are supported by TensorflowLite 2.9.0.
+ // This adds GELU to the supported ops in Optimiziation Guide.
+ MODEL_ENGINE_VERSION_TFLITE_2_9_0_1 = 8;
}
// A set of model features and the host that it applies to.
diff --git a/chromium/components/optimization_guide/proto/page_entities_metadata.proto b/chromium/components/optimization_guide/proto/page_entities_metadata.proto
index 96a8479f179..1b5fa334a18 100644
--- a/chromium/components/optimization_guide/proto/page_entities_metadata.proto
+++ b/chromium/components/optimization_guide/proto/page_entities_metadata.proto
@@ -24,7 +24,18 @@ message Entity {
// entities on the page.
//
// It is only populated for the PAGE_ENTITIES optimization type.
+//
+// Note that the meaning of metadata here is in relation to a page load.
message PageEntitiesMetadata {
// A set of entities that are expected to be present on the page.
repeated Entity entities = 1;
}
+
+// The metadata associated with an |Entity|.
+//
+// Each |Entity| has some attached metadata about it which may be stored on
+// device for later lookup. Notably, this includes it's human-readable name as
+// opposed to the opaque entity_id.
+message EntityMetadataStorage {
+ optional string entity_name = 1;
+} \ No newline at end of file
diff --git a/chromium/components/optimization_guide/proto/page_entities_model_metadata.proto b/chromium/components/optimization_guide/proto/page_entities_model_metadata.proto
new file mode 100644
index 00000000000..cf239e37352
--- /dev/null
+++ b/chromium/components/optimization_guide/proto/page_entities_model_metadata.proto
@@ -0,0 +1,25 @@
+// Copyright 2022 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+syntax = "proto2";
+
+option optimize_for = LITE_RUNTIME;
+option java_package = "org.chromium.components.optimization_guide.proto";
+option java_outer_classname = "PageEntitiesModelMetadataProto";
+
+package optimization_guide.proto;
+
+message PageEntitiesModelMetadata {
+ // The maximum model format feature flag that is supported.
+ //
+ // If sent from the server, this is the maximum model format feature flag the
+ // returned model supports. If sent from the client, this is the maximum
+ // model format feature flag the client knows how to evaluate.
+ optional int32 max_model_format_feature_flag = 1;
+
+ // The slices to load into the entity annotator.
+ //
+ // Will only be populated by the server.
+ repeated string slice = 2;
+} \ No newline at end of file
diff --git a/chromium/components/optimization_guide/proto/page_topics_model_metadata.proto b/chromium/components/optimization_guide/proto/page_topics_model_metadata.proto
index 59e4a0441be..42516c9358a 100644
--- a/chromium/components/optimization_guide/proto/page_topics_model_metadata.proto
+++ b/chromium/components/optimization_guide/proto/page_topics_model_metadata.proto
@@ -50,7 +50,24 @@ message PageTopicsOutputPostprocessingParams {
optional PageTopicsCategoryPostprocessingParams category_params = 2;
}
+message Topic {
+ // The user-visible string of the taxonomy topic.
+ optional string topic_name = 1;
+ // The id of the topic.
+ optional int64 topic_id = 2;
+}
+
+message TopicTaxonomy {
+ // The version of this specific taxonomy, which is separate from the model
+ // version.
+ optional int64 version = 1;
+ // The topics supported by this taxonomy.
+ repeated Topic topics = 2;
+}
+
message PageTopicsModelMetadata {
+ reserved 4;
+
// The version of the model sent by the server, and thus, will only be
// populated by the server.
optional int64 version = 1;
@@ -64,4 +81,6 @@ message PageTopicsModelMetadata {
// populated by the server.
optional PageTopicsOutputPostprocessingParams output_postprocessing_params =
3;
+ // The taxonomy used by this model.
+ optional TopicTaxonomy topic_taxonomy = 5;
}