diff options
Diffstat (limited to 'chromium/components/optimization_guide')
120 files changed, 4186 insertions, 2950 deletions
diff --git a/chromium/components/optimization_guide/DEPS b/chromium/components/optimization_guide/DEPS index c22edf8e983..ba9c08904b4 100644 --- a/chromium/components/optimization_guide/DEPS +++ b/chromium/components/optimization_guide/DEPS @@ -1,6 +1,4 @@ include_rules = [ - "+components/data_reduction_proxy/core/browser", - "+components/data_reduction_proxy/core/common", "+components/leveldb_proto", "+components/prefs", "+components/sync_preferences", diff --git a/chromium/components/optimization_guide/content/browser/BUILD.gn b/chromium/components/optimization_guide/content/browser/BUILD.gn index 0d26cc4b455..09a22e1ea29 100644 --- a/chromium/components/optimization_guide/content/browser/BUILD.gn +++ b/chromium/components/optimization_guide/content/browser/BUILD.gn @@ -54,10 +54,6 @@ static_library("browser") { "//third_party/tflite_support", "//third_party/tflite_support:tflite_support_proto", ] - - if (build_with_internal_optimization_guide) { - deps = [ "//components/optimization_guide/internal" ] - } } } @@ -83,7 +79,9 @@ source_set("unit_tests") { "page_text_dump_result_unittest.cc", "page_text_observer_unittest.cc", ] - if (build_with_tflite_lib) { + + # crbug.com/1279884 Flaky on CrOS + if (!is_chromeos && build_with_tflite_lib) { sources += [ "page_content_annotations_model_manager_unittest.cc" ] } deps = [ diff --git a/chromium/components/optimization_guide/content/browser/page_content_annotations_model_manager.cc b/chromium/components/optimization_guide/content/browser/page_content_annotations_model_manager.cc index 0deb1b9d4aa..52339bd11a7 100644 --- a/chromium/components/optimization_guide/content/browser/page_content_annotations_model_manager.cc +++ b/chromium/components/optimization_guide/content/browser/page_content_annotations_model_manager.cc @@ -4,6 +4,7 @@ #include "components/optimization_guide/content/browser/page_content_annotations_model_manager.h" +#include "base/metrics/histogram_functions.h" #include "base/metrics/histogram_macros_local.h" #include "base/strings/string_number_conversions.h" #include "base/task/sequenced_task_runner.h" @@ -17,7 +18,7 @@ #include "content/public/browser/browser_thread.h" #if BUILDFLAG(BUILD_WITH_INTERNAL_OPTIMIZATION_GUIDE) -#include "components/optimization_guide/internal/page_entities_model_executor_impl.h" +#include "components/optimization_guide/core/page_entities_model_executor_impl.h" #endif namespace optimization_guide { @@ -47,34 +48,23 @@ GetOrCreateCurrentContentModelAnnotations( return std::make_unique<history::VisitContentModelAnnotations>(); } -void PretendToExecuteJob(base::OnceClosure callback, - std::unique_ptr<PageContentAnnotationJob> job) { - while (absl::optional<std::string> input = job->GetNextInput()) { - job->PostNewResult( - BatchAnnotationResult::CreatePageTopicsResult(*input, absl::nullopt)); - } - // Note to future self: The ordering of these callbacks being run will be - // important once actually being run on an executor. - job->OnComplete(); - std::move(callback).Run(); -} - } // namespace PageContentAnnotationsModelManager::PageContentAnnotationsModelManager( const std::string& application_locale, - OptimizationGuideModelProvider* optimization_guide_model_provider) { - for (auto opt_target : - features::GetPageContentModelsToExecute(application_locale)) { - if (opt_target == proto::OPTIMIZATION_TARGET_PAGE_TOPICS) { - SetUpPageTopicsModel(optimization_guide_model_provider); - ordered_models_to_execute_.push_back(opt_target); - } else if (opt_target == proto::OPTIMIZATION_TARGET_PAGE_ENTITIES) { - SetUpPageEntitiesModel(optimization_guide_model_provider); - ordered_models_to_execute_.push_back(opt_target); - } else { - // TODO(crbug/1228790): Add histogram for if this happens. - } + OptimizationGuideModelProvider* optimization_guide_model_provider) + : optimization_guide_model_provider_(optimization_guide_model_provider) { + if (features::ShouldExecutePageVisibilityModelOnPageContent( + application_locale)) { + SetUpPageTopicsModel(optimization_guide_model_provider); + ordered_models_to_execute_.push_back( + proto::OPTIMIZATION_TARGET_PAGE_TOPICS); + } + if (features::ShouldExecutePageEntitiesModelOnPageContent( + application_locale)) { + SetUpPageEntitiesModel(optimization_guide_model_provider); + ordered_models_to_execute_.push_back( + proto::OPTIMIZATION_TARGET_PAGE_ENTITIES); } } @@ -234,6 +224,9 @@ void PageContentAnnotationsModelManager::SetUpPageTopicsV2Model( if (!features::PageTopicsBatchAnnotationsEnabled()) return; + if (on_demand_page_topics_model_executor_) + return; + on_demand_page_topics_model_executor_ = std::make_unique<PageTopicsModelExecutor>( optimization_guide_model_provider, @@ -247,6 +240,9 @@ void PageContentAnnotationsModelManager::SetUpPageVisibilityModel( if (!features::PageVisibilityBatchAnnotationsEnabled()) return; + if (page_visibility_model_executor_) + return; + page_visibility_model_executor_ = std::make_unique<PageVisibilityModelExecutor>( optimization_guide_model_provider, @@ -477,24 +473,32 @@ void PageContentAnnotationsModelManager:: out_content_annotations->categories = final_categories; } -void PageContentAnnotationsModelManager::NotifyWhenModelAvailable( +void PageContentAnnotationsModelManager::RequestAndNotifyWhenModelAvailable( AnnotationType type, base::OnceCallback<void(bool)> callback) { - if (type == AnnotationType::kPageTopics && - on_demand_page_topics_model_executor_) { - on_demand_page_topics_model_executor_->AddOnModelUpdatedCallback( - base::BindOnce(std::move(callback), true)); - return; + if (type == AnnotationType::kPageTopics) { + // No-op if the executor is already setup. + SetUpPageTopicsV2Model(optimization_guide_model_provider_); + + if (on_demand_page_topics_model_executor_) { + on_demand_page_topics_model_executor_->AddOnModelUpdatedCallback( + base::BindOnce(std::move(callback), true)); + return; + } } - if (type == AnnotationType::kContentVisibility && - page_visibility_model_executor_) { - page_visibility_model_executor_->AddOnModelUpdatedCallback( - base::BindOnce(std::move(callback), true)); - return; + if (type == AnnotationType::kContentVisibility) { + // No-op if the executor is already setup. + SetUpPageVisibilityModel(optimization_guide_model_provider_); + + if (page_visibility_model_executor_) { + page_visibility_model_executor_->AddOnModelUpdatedCallback( + base::BindOnce(std::move(callback), true)); + return; + } } - // TODO(crbug/1249632): Add support for page entities. + // TODO(crbug/1278828): Add support for page entities. std::move(callback).Run(false); } @@ -506,6 +510,11 @@ PageContentAnnotationsModelManager::GetModelInfoForType( on_demand_page_topics_model_executor_) { return on_demand_page_topics_model_executor_->GetModelInfo(); } + if (type == AnnotationType::kContentVisibility && + page_visibility_model_executor_) { + return page_visibility_model_executor_->GetModelInfo(); + } + // TODO(crbug/1278828): Add support for page entities. return absl::nullopt; } @@ -513,6 +522,11 @@ void PageContentAnnotationsModelManager::Annotate( BatchAnnotationCallback callback, const std::vector<std::string>& inputs, AnnotationType annotation_type) { + base::UmaHistogramCounts100( + "OptimizationGuide.PageContentAnnotations.BatchRequestedSize." + + AnnotationTypeToString(annotation_type), + inputs.size()); + std::unique_ptr<PageContentAnnotationJob> job = std::make_unique<PageContentAnnotationJob>(std::move(callback), inputs, annotation_type); @@ -592,11 +606,15 @@ void PageContentAnnotationsModelManager::MaybeStartNextAnnotationJob() { return; } - // TODO(crbug/1249632): Actually run the model instead. - content::GetUIThreadTaskRunner({})->PostTask( - FROM_HERE, - base::BindOnce(&PretendToExecuteJob, std::move(on_job_complete_callback), - std::move(job))); + // TODO(crbug/1278828): Add support for page entities. + if (job->type() == AnnotationType::kPageEntities) { + job->FillWithNullOutputs(); + job->OnComplete(); + job.reset(); + std::move(on_job_complete_callback).Run(); + return; + } + NOTREACHED(); } } // namespace optimization_guide diff --git a/chromium/components/optimization_guide/content/browser/page_content_annotations_model_manager.h b/chromium/components/optimization_guide/content/browser/page_content_annotations_model_manager.h index 44ce32c15d3..dc43f852d23 100644 --- a/chromium/components/optimization_guide/content/browser/page_content_annotations_model_manager.h +++ b/chromium/components/optimization_guide/content/browser/page_content_annotations_model_manager.h @@ -5,6 +5,7 @@ #ifndef COMPONENTS_OPTIMIZATION_GUIDE_CONTENT_BROWSER_PAGE_CONTENT_ANNOTATIONS_MODEL_MANAGER_H_ #define COMPONENTS_OPTIMIZATION_GUIDE_CONTENT_BROWSER_PAGE_CONTENT_ANNOTATIONS_MODEL_MANAGER_H_ +#include "base/memory/raw_ptr.h" #include "components/history/core/browser/url_row.h" #include "components/optimization_guide/content/browser/page_content_annotator.h" #include "components/optimization_guide/core/bert_model_handler.h" @@ -49,6 +50,7 @@ class PageContentAnnotationsModelManager : public PageContentAnnotator { // This will execute all supported models of the PageContentAnnotationsService // feature and is only used by the History service code path. See the below // |Annotate| for the publicly available Annotation code path. + // TODO(crbug/1278833): Remove this. void Annotate(const std::string& text, PageContentAnnotatedCallback callback); // PageContentAnnotator: @@ -56,16 +58,16 @@ class PageContentAnnotationsModelManager : public PageContentAnnotator { const std::vector<std::string>& inputs, AnnotationType annotation_type) override; - // Runs |callback| with true when the model that powers |BatchAnnotate| for - // the given annotation type is ready to execute. If the model is ready now, - // the callback is run immediately. If the model will never become ready, due - // to feature flags for example, the callback run with false. - void NotifyWhenModelAvailable(AnnotationType type, - base::OnceCallback<void(bool)> callback); + // Requests that the given model for |type| be loaded in the background and + // then runs |callback| with true when the model is ready to execute. If the + // model is ready now, the callback is run immediately. If the model file will + // never be available, the callback is run with false. + void RequestAndNotifyWhenModelAvailable( + AnnotationType type, + base::OnceCallback<void(bool)> callback); // Returns the model info associated with the given AnnotationType, if it is // available and loaded. - // TODO(crbug/1249632): Add support for more than just page topics. absl::optional<ModelInfo> GetModelInfoForType(AnnotationType type) const; // Returns the version of the page topics model that is currently being used @@ -92,7 +94,7 @@ class PageContentAnnotationsModelManager : public PageContentAnnotator { // All publicly posted jobs will have this priority level. kNormal = 1, - // TODO(crbug/1249632): Add a kHigh value for internal jobs. + // TODO(crbug/1278833): Add a kHigh value for internal jobs. // Always keep this last and as the highest priority + 1. This value is // passed to the priority queue ctor as "how many level of priorities are @@ -215,7 +217,7 @@ class PageContentAnnotationsModelManager : public PageContentAnnotator { // Runs the next job in |job_queue_| if there is any. void MaybeStartNextAnnotationJob(); - // Called when a job finishes executing. + // Called when a |job| finishes executing, just before it is deleted. void OnJobExecutionComplete(); // The model executor responsible for executing the page topics model. @@ -254,6 +256,9 @@ class PageContentAnnotationsModelManager : public PageContentAnnotator { // The current state of the running job, if any. JobExecutionState job_state_ = JobExecutionState::kIdle; + // The model provider, not owned. + raw_ptr<OptimizationGuideModelProvider> optimization_guide_model_provider_; + base::WeakPtrFactory<PageContentAnnotationsModelManager> weak_ptr_factory_{ this}; }; diff --git a/chromium/components/optimization_guide/content/browser/page_content_annotations_model_manager_unittest.cc b/chromium/components/optimization_guide/content/browser/page_content_annotations_model_manager_unittest.cc index caae58be379..10237a060e6 100644 --- a/chromium/components/optimization_guide/content/browser/page_content_annotations_model_manager_unittest.cc +++ b/chromium/components/optimization_guide/content/browser/page_content_annotations_model_manager_unittest.cc @@ -11,6 +11,7 @@ #include "base/test/metrics/histogram_tester.h" #include "base/test/scoped_feature_list.h" #include "base/test/scoped_run_loop_timeout.h" +#include "build/build_config.h" #include "components/optimization_guide/core/execution_status.h" #include "components/optimization_guide/core/optimization_guide_features.h" #include "components/optimization_guide/core/page_entities_model_executor.h" @@ -117,9 +118,8 @@ class FakePageEntitiesModelExecutor : public PageEntitiesModelExecutor { class PageContentAnnotationsModelManagerTest : public testing::Test { public: PageContentAnnotationsModelManagerTest() { - scoped_feature_list_.InitAndEnableFeatureWithParameters( - features::kPageContentAnnotations, - {{"models_to_execute_v2", "OPTIMIZATION_TARGET_PAGE_TOPICS"}}); + scoped_feature_list_.InitAndEnableFeature( + features::kPageVisibilityPageContentAnnotations); } ~PageContentAnnotationsModelManagerTest() override = default; @@ -172,7 +172,8 @@ class PageContentAnnotationsModelManagerTest : public testing::Test { } void SetupPageTopicsV2ModelExecutor() { - model_manager()->SetUpPageTopicsV2Model(model_observer_tracker()); + model_manager()->RequestAndNotifyWhenModelAvailable( + AnnotationType::kPageTopics, base::DoNothing()); // If the feature flag is disabled, the executor won't have been created so // skip everything else. if (!model_manager()->on_demand_page_topics_model_executor_) @@ -204,7 +205,8 @@ class PageContentAnnotationsModelManagerTest : public testing::Test { void SendPageVisibilityModelToExecutor( const absl::optional<proto::Any>& model_metadata) { - model_manager()->SetUpPageVisibilityModel(model_observer_tracker()); + model_manager()->RequestAndNotifyWhenModelAvailable( + AnnotationType::kContentVisibility, base::DoNothing()); // If the feature flag is disabled, the executor won't have been created so // skip everything else. if (!model_manager()->page_visibility_model_executor_) @@ -562,7 +564,13 @@ TEST_F(PageContentAnnotationsModelManagerTest, EXPECT_FALSE(GetMetadataForEntityId("someid").has_value()); } -TEST_F(PageContentAnnotationsModelManagerTest, BatchAnnotate_PageTopics) { +// TODO(crbug.com/1286473): Flaky on Chrome OS. +#if BUILDFLAG(IS_CHROMEOS) +#define MAYBE_BatchAnnotate_PageTopics DISABLED_BatchAnnotate_PageTopics +#else +#define MAYBE_BatchAnnotate_PageTopics BatchAnnotate_PageTopics +#endif +TEST_F(PageContentAnnotationsModelManagerTest, MAYBE_BatchAnnotate_PageTopics) { SetupPageTopicsV2ModelExecutor(); // Running the actual model can take a while. @@ -591,6 +599,18 @@ TEST_F(PageContentAnnotationsModelManagerTest, BatchAnnotate_PageTopics) { "OptimizationGuide.ModelExecutor.ExecutionStatus.PageTopicsV2", ExecutionStatus::kSuccess, 1); + histogram_tester.ExpectUniqueSample( + "OptimizationGuide.PageContentAnnotations.BatchRequestedSize.PageTopics", + 1, 1); + histogram_tester.ExpectUniqueSample( + "OptimizationGuide.PageContentAnnotations.BatchSuccess.PageTopics", true, + 1); + histogram_tester.ExpectTotalCount( + "OptimizationGuide.PageContentAnnotations.JobExecutionTime.PageTopics", + 1); + histogram_tester.ExpectTotalCount( + "OptimizationGuide.PageContentAnnotations.JobScheduleTime.PageTopics", 1); + EXPECT_TRUE(model_observer_tracker()->DidRegisterForTarget( proto::OptimizationTarget::OPTIMIZATION_TARGET_PAGE_TOPICS_V2, nullptr)); @@ -628,6 +648,17 @@ TEST_F(PageContentAnnotationsModelManagerTest, base::RunLoop().RunUntilIdle(); histogram_tester.ExpectTotalCount( "OptimizationGuide.ModelExecutor.ExecutionStatus.PageTopicsV2", 0); + histogram_tester.ExpectUniqueSample( + "OptimizationGuide.PageContentAnnotations.BatchRequestedSize.PageTopics", + 1, 1); + histogram_tester.ExpectUniqueSample( + "OptimizationGuide.PageContentAnnotations.BatchSuccess.PageTopics", false, + 1); + histogram_tester.ExpectTotalCount( + "OptimizationGuide.PageContentAnnotations.JobExecutionTime.PageTopics", + 1); + histogram_tester.ExpectTotalCount( + "OptimizationGuide.PageContentAnnotations.JobScheduleTime.PageTopics", 1); EXPECT_FALSE(model_observer_tracker()->DidRegisterForTarget( proto::OptimizationTarget::OPTIMIZATION_TARGET_PAGE_TOPICS_V2, nullptr)); @@ -641,6 +672,7 @@ TEST_F(PageContentAnnotationsModelManagerTest, } TEST_F(PageContentAnnotationsModelManagerTest, BatchAnnotate_PageEntities) { + base::HistogramTester histogram_tester; base::RunLoop run_loop; std::vector<BatchAnnotationResult> result; BatchAnnotationCallback callback = base::BindOnce( @@ -654,10 +686,22 @@ TEST_F(PageContentAnnotationsModelManagerTest, BatchAnnotate_PageEntities) { model_manager()->Annotate(std::move(callback), {"input"}, AnnotationType::kPageEntities); + histogram_tester.ExpectUniqueSample( + "OptimizationGuide.PageContentAnnotations.BatchRequestedSize." + "PageEntities", + 1, 1); + histogram_tester.ExpectUniqueSample( + "OptimizationGuide.PageContentAnnotations.BatchSuccess.PageEntities", + false, 1); + histogram_tester.ExpectTotalCount( + "OptimizationGuide.PageContentAnnotations.JobExecutionTime.PageEntities", + 1); + histogram_tester.ExpectTotalCount( + "OptimizationGuide.PageContentAnnotations.JobScheduleTime.PageEntities", + 1); + run_loop.Run(); - // TODO(crbug/1249632): Check the corresponding output once the model is being - // run. ASSERT_EQ(result.size(), 1U); EXPECT_EQ(result[0].input(), "input"); EXPECT_EQ(result[0].topics(), absl::nullopt); @@ -665,7 +709,15 @@ TEST_F(PageContentAnnotationsModelManagerTest, BatchAnnotate_PageEntities) { EXPECT_EQ(result[0].visibility_score(), absl::nullopt); } -TEST_F(PageContentAnnotationsModelManagerTest, BatchAnnotate_PageVisibility) { +// TODO(crbug.com/1286473): Flaky on Chrome OS. +#if BUILDFLAG(IS_CHROMEOS) +#define MAYBE_BatchAnnotate_PageVisibility DISABLED_BatchAnnotate_PageVisibility +#else +#define MAYBE_BatchAnnotate_PageVisibility BatchAnnotate_PageVisibility +#endif +TEST_F(PageContentAnnotationsModelManagerTest, + MAYBE_BatchAnnotate_PageVisibility) { + base::HistogramTester histogram_tester; proto::Any any_metadata; any_metadata.set_type_url( "type.googleapis.com/com.foo.PageTopicsModelMetadata"); @@ -697,6 +749,21 @@ TEST_F(PageContentAnnotationsModelManagerTest, BatchAnnotate_PageVisibility) { EXPECT_TRUE(model_observer_tracker()->DidRegisterForTarget( proto::OptimizationTarget::OPTIMIZATION_TARGET_PAGE_VISIBILITY, nullptr)); + histogram_tester.ExpectUniqueSample( + "OptimizationGuide.PageContentAnnotations.BatchRequestedSize." + "ContentVisibility", + 1, 1); + histogram_tester.ExpectUniqueSample( + "OptimizationGuide.PageContentAnnotations.BatchSuccess.ContentVisibility", + true, 1); + histogram_tester.ExpectTotalCount( + "OptimizationGuide.PageContentAnnotations.JobExecutionTime." + "ContentVisibility", + 1); + histogram_tester.ExpectTotalCount( + "OptimizationGuide.PageContentAnnotations.JobScheduleTime." + "ContentVisibility", + 1); ASSERT_EQ(result.size(), 1U); EXPECT_EQ(result[0].input(), "input"); @@ -707,6 +774,7 @@ TEST_F(PageContentAnnotationsModelManagerTest, BatchAnnotate_PageVisibility) { TEST_F(PageContentAnnotationsModelManagerTest, BatchAnnotate_PageVisibilityDisabled) { + base::HistogramTester histogram_tester; base::test::ScopedFeatureList scoped_feature_list; scoped_feature_list.InitAndDisableFeature( features::kPageVisibilityBatchAnnotations); @@ -739,6 +807,21 @@ TEST_F(PageContentAnnotationsModelManagerTest, EXPECT_FALSE(model_observer_tracker()->DidRegisterForTarget( proto::OptimizationTarget::OPTIMIZATION_TARGET_PAGE_VISIBILITY, nullptr)); + histogram_tester.ExpectUniqueSample( + "OptimizationGuide.PageContentAnnotations.BatchRequestedSize." + "ContentVisibility", + 1, 1); + histogram_tester.ExpectUniqueSample( + "OptimizationGuide.PageContentAnnotations.BatchSuccess.ContentVisibility", + false, 1); + histogram_tester.ExpectTotalCount( + "OptimizationGuide.PageContentAnnotations.JobExecutionTime." + "ContentVisibility", + 1); + histogram_tester.ExpectTotalCount( + "OptimizationGuide.PageContentAnnotations.JobScheduleTime." + "ContentVisibility", + 1); ASSERT_EQ(result.size(), 1U); EXPECT_EQ(result[0].input(), "input"); @@ -747,7 +830,14 @@ TEST_F(PageContentAnnotationsModelManagerTest, EXPECT_EQ(result[0].visibility_score(), absl::nullopt); } -TEST_F(PageContentAnnotationsModelManagerTest, BatchAnnotate_CalledTwice) { +// TODO(crbug.com/1286473): Flaky on Chrome OS. +#if BUILDFLAG(IS_CHROMEOS) +#define MAYBE_BatchAnnotate_CalledTwice DISABLED_BatchAnnotate_CalledTwice +#else +#define MAYBE_BatchAnnotate_CalledTwice BatchAnnotate_CalledTwice +#endif +TEST_F(PageContentAnnotationsModelManagerTest, + MAYBE_BatchAnnotate_CalledTwice) { SetupPageTopicsV2ModelExecutor(); base::HistogramTester histogram_tester; @@ -790,6 +880,18 @@ TEST_F(PageContentAnnotationsModelManagerTest, BatchAnnotate_CalledTwice) { EXPECT_TRUE(model_observer_tracker()->DidRegisterForTarget( proto::OptimizationTarget::OPTIMIZATION_TARGET_PAGE_TOPICS_V2, nullptr)); + histogram_tester.ExpectUniqueSample( + "OptimizationGuide.PageContentAnnotations.BatchRequestedSize.PageTopics", + 1, 2); + histogram_tester.ExpectUniqueSample( + "OptimizationGuide.PageContentAnnotations.BatchSuccess.PageTopics", true, + 2); + histogram_tester.ExpectTotalCount( + "OptimizationGuide.PageContentAnnotations.JobExecutionTime.PageTopics", + 2); + histogram_tester.ExpectTotalCount( + "OptimizationGuide.PageContentAnnotations.JobScheduleTime.PageTopics", 2); + // The model should have only been loaded once and then used for both jobs. histogram_tester.ExpectUniqueSample( "OptimizationGuide.ModelExecutor.ModelAvailableToLoad.PageTopicsV2", true, @@ -816,6 +918,8 @@ TEST_F(PageContentAnnotationsModelManagerTest, GetModelInfoForType) { model_manager()->GetModelInfoForType(AnnotationType::kContentVisibility)); SetupPageTopicsV2ModelExecutor(); + EXPECT_TRUE( + model_manager()->GetModelInfoForType(AnnotationType::kPageTopics)); proto::Any any_metadata; any_metadata.set_type_url( @@ -829,59 +933,61 @@ TEST_F(PageContentAnnotationsModelManagerTest, GetModelInfoForType) { SendPageVisibilityModelToExecutor(any_metadata); EXPECT_TRUE( - model_manager()->GetModelInfoForType(AnnotationType::kPageTopics)); - EXPECT_FALSE( model_manager()->GetModelInfoForType(AnnotationType::kContentVisibility)); } TEST_F(PageContentAnnotationsModelManagerTest, - NotifyWhenModelAvailable_NotAvailable) { - absl::optional<bool> topics_callback_success; - absl::optional<bool> visibility_callback_success; + NotifyWhenModelAvailable_TopicsOnly) { + SetupPageTopicsV2ModelExecutor(); - model_manager()->NotifyWhenModelAvailable( + base::RunLoop topics_run_loop; + bool topics_callback_success = false; + + model_manager()->RequestAndNotifyWhenModelAvailable( AnnotationType::kPageTopics, - base::BindOnce([](absl::optional<bool>* out_success, - bool success) { *out_success = success; }, - &topics_callback_success)); - model_manager()->NotifyWhenModelAvailable( - AnnotationType::kContentVisibility, - base::BindOnce([](absl::optional<bool>* out_success, - bool success) { *out_success = success; }, - &visibility_callback_success)); - - ASSERT_TRUE(topics_callback_success); - ASSERT_TRUE(visibility_callback_success); - EXPECT_FALSE(*topics_callback_success); - EXPECT_FALSE(*visibility_callback_success); + base::BindOnce( + [](base::RunLoop* run_loop, bool* out_success, bool success) { + *out_success = success; + run_loop->Quit(); + }, + &topics_run_loop, &topics_callback_success)); + + topics_run_loop.Run(); + + EXPECT_TRUE(topics_callback_success); } TEST_F(PageContentAnnotationsModelManagerTest, - NotifyWhenModelAvailable_TopicsOnly) { - SetupPageTopicsV2ModelExecutor(); + NotifyWhenModelAvailable_VisibilityOnly) { + proto::Any any_metadata; + any_metadata.set_type_url( + "type.googleapis.com/com.foo.PageTopicsModelMetadata"); + proto::PageTopicsModelMetadata page_topics_model_metadata; + page_topics_model_metadata.set_version(123); + page_topics_model_metadata.mutable_output_postprocessing_params() + ->mutable_visibility_params() + ->set_category_name("DO NOT EVALUATE"); + page_topics_model_metadata.SerializeToString(any_metadata.mutable_value()); + SendPageVisibilityModelToExecutor(any_metadata); - absl::optional<bool> topics_callback_success; - absl::optional<bool> visibility_callback_success; + base::RunLoop visibility_run_loop; + bool visibility_callback_success = false; - model_manager()->NotifyWhenModelAvailable( - AnnotationType::kPageTopics, - base::BindOnce([](absl::optional<bool>* out_success, - bool success) { *out_success = success; }, - &topics_callback_success)); - model_manager()->NotifyWhenModelAvailable( + model_manager()->RequestAndNotifyWhenModelAvailable( AnnotationType::kContentVisibility, - base::BindOnce([](absl::optional<bool>* out_success, - bool success) { *out_success = success; }, - &visibility_callback_success)); - - ASSERT_TRUE(topics_callback_success); - ASSERT_TRUE(visibility_callback_success); - EXPECT_TRUE(*topics_callback_success); - EXPECT_FALSE(*visibility_callback_success); + base::BindOnce( + [](base::RunLoop* run_loop, bool* out_success, bool success) { + *out_success = success; + run_loop->Quit(); + }, + &visibility_run_loop, &visibility_callback_success)); + + visibility_run_loop.Run(); + + EXPECT_TRUE(visibility_callback_success); } -TEST_F(PageContentAnnotationsModelManagerTest, - NotifyWhenModelAvailable_VisibilityOnly) { +TEST_F(PageContentAnnotationsModelManagerTest, NotifyWhenModelAvailable_Both) { proto::Any any_metadata; any_metadata.set_type_url( "type.googleapis.com/com.foo.PageTopicsModelMetadata"); @@ -893,33 +999,44 @@ TEST_F(PageContentAnnotationsModelManagerTest, page_topics_model_metadata.SerializeToString(any_metadata.mutable_value()); SendPageVisibilityModelToExecutor(any_metadata); - absl::optional<bool> topics_callback_success; - absl::optional<bool> visibility_callback_success; + SetupPageTopicsV2ModelExecutor(); + + base::RunLoop topics_run_loop; + base::RunLoop visibility_run_loop; + bool topics_callback_success = false; + bool visibility_callback_success = false; - model_manager()->NotifyWhenModelAvailable( + model_manager()->RequestAndNotifyWhenModelAvailable( AnnotationType::kPageTopics, - base::BindOnce([](absl::optional<bool>* out_success, - bool success) { *out_success = success; }, - &topics_callback_success)); - model_manager()->NotifyWhenModelAvailable( + base::BindOnce( + [](base::RunLoop* run_loop, bool* out_success, bool success) { + *out_success = success; + run_loop->Quit(); + }, + &topics_run_loop, &topics_callback_success)); + model_manager()->RequestAndNotifyWhenModelAvailable( AnnotationType::kContentVisibility, - base::BindOnce([](absl::optional<bool>* out_success, - bool success) { *out_success = success; }, - &visibility_callback_success)); - - ASSERT_TRUE(topics_callback_success); - ASSERT_TRUE(visibility_callback_success); - EXPECT_FALSE(*topics_callback_success); - EXPECT_TRUE(*visibility_callback_success); + base::BindOnce( + [](base::RunLoop* run_loop, bool* out_success, bool success) { + *out_success = success; + run_loop->Quit(); + }, + &visibility_run_loop, &visibility_callback_success)); + + topics_run_loop.Run(); + visibility_run_loop.Run(); + + EXPECT_TRUE(topics_callback_success); + EXPECT_TRUE(visibility_callback_success); } class PageContentAnnotationsModelManagerEntitiesOnlyTest : public PageContentAnnotationsModelManagerTest { public: PageContentAnnotationsModelManagerEntitiesOnlyTest() { - scoped_feature_list_.InitAndEnableFeatureWithParameters( - features::kPageContentAnnotations, - {{"models_to_execute_v2", "OPTIMIZATION_TARGET_PAGE_ENTITIES"}}); + scoped_feature_list_.InitWithFeatures( + {features::kPageEntitiesPageContentAnnotations}, + {features::kPageVisibilityPageContentAnnotations}); } private: @@ -999,11 +1116,10 @@ class PageContentAnnotationsModelManagerMultipleModelsTest : public PageContentAnnotationsModelManagerTest { public: PageContentAnnotationsModelManagerMultipleModelsTest() { - scoped_feature_list_.InitAndEnableFeatureWithParameters( - features::kPageContentAnnotations, - {{"models_to_execute_v2", - "OPTIMIZATION_TARGET_PAGE_ENTITIES,OPTIMIZATION_TARGET_PAGE_" - "TOPICS"}}); + scoped_feature_list_.InitWithFeatures( + {features::kPageEntitiesPageContentAnnotations, + features::kPageVisibilityPageContentAnnotations}, + {}); } private: diff --git a/chromium/components/optimization_guide/content/browser/page_content_annotations_service.cc b/chromium/components/optimization_guide/content/browser/page_content_annotations_service.cc index 8d1421253f2..a88b7181f85 100644 --- a/chromium/components/optimization_guide/content/browser/page_content_annotations_service.cc +++ b/chromium/components/optimization_guide/content/browser/page_content_annotations_service.cc @@ -4,14 +4,22 @@ #include "components/optimization_guide/content/browser/page_content_annotations_service.h" +#include "base/callback_helpers.h" #include "base/metrics/histogram_functions.h" +#include "base/metrics/histogram_macros_local.h" +#include "base/rand_util.h" #include "base/strings/string_util.h" #include "base/strings/utf_string_conversions.h" +#include "base/time/default_tick_clock.h" +#include "base/timer/timer.h" #include "components/history/core/browser/history_service.h" +#include "components/leveldb_proto/public/proto_database_provider.h" +#include "components/optimization_guide/core/local_page_entities_metadata_provider.h" #include "components/optimization_guide/core/noisy_metrics_recorder.h" #include "components/optimization_guide/core/optimization_guide_enums.h" #include "components/optimization_guide/core/optimization_guide_features.h" #include "components/optimization_guide/core/optimization_guide_model_provider.h" +#include "components/optimization_guide/core/optimization_guide_switches.h" #include "content/public/browser/navigation_entry.h" #include "content/public/browser/web_contents.h" #include "services/metrics/public/cpp/metrics_utils.h" @@ -71,14 +79,27 @@ void MaybeRecordVisibilityUKM( } #endif /* BUILDFLAG(BUILD_WITH_TFLITE_LIB) */ +const char kDummyTextBlob[] = + "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod " + "tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim " + "veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea " + "commodo consequat. Duis aute irure dolor in reprehenderit in voluptate " + "velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint " + "occaecat cupidatat non proident, sunt in culpa qui officia deserunt " + "mollit anim id est laborum"; + } // namespace PageContentAnnotationsService::PageContentAnnotationsService( const std::string& application_locale, OptimizationGuideModelProvider* optimization_guide_model_provider, - history::HistoryService* history_service) + history::HistoryService* history_service, + leveldb_proto::ProtoDatabaseProvider* database_provider, + const base::FilePath& database_dir, + scoped_refptr<base::SequencedTaskRunner> background_task_runner) : last_annotated_history_visits_( - features::MaxContentAnnotationRequestsCached()) { + features::MaxContentAnnotationRequestsCached()), + annotated_text_cache_(features::MaxVisitAnnotationCacheSize()) { DCHECK(optimization_guide_model_provider); DCHECK(history_service); history_service_ = history_service; @@ -87,12 +108,28 @@ PageContentAnnotationsService::PageContentAnnotationsService( application_locale, optimization_guide_model_provider); annotator_ = model_manager_.get(); #endif + + if (features::UseLocalPageEntitiesMetadataProvider()) { + local_page_entities_metadata_provider_ = + std::make_unique<LocalPageEntitiesMetadataProvider>(); + local_page_entities_metadata_provider_->Initialize( + database_provider, database_dir, background_task_runner); + } + + if (features::BatchAnnotationsValidationEnabled()) { + validation_timer_ = std::make_unique<base::OneShotTimer>( + base::DefaultTickClock::GetInstance()); + validation_timer_->Start( + FROM_HERE, features::BatchAnnotationValidationStartupDelay(), + base::BindRepeating( + &PageContentAnnotationsService::RunBatchAnnotationValidation, + weak_ptr_factory_.GetWeakPtr())); + } } PageContentAnnotationsService::~PageContentAnnotationsService() = default; -void PageContentAnnotationsService::Annotate(const HistoryVisit& visit, - const std::string& text) { +void PageContentAnnotationsService::Annotate(const HistoryVisit& visit) { if (last_annotated_history_visits_.Peek(visit) != last_annotated_history_visits_.end()) { // We have already been requested to annotate this visit, so don't submit @@ -102,13 +139,88 @@ void PageContentAnnotationsService::Annotate(const HistoryVisit& visit, last_annotated_history_visits_.Put(visit, true); #if BUILDFLAG(BUILD_WITH_TFLITE_LIB) - model_manager_->Annotate( - text, - base::BindOnce(&PageContentAnnotationsService::OnPageContentAnnotated, - weak_ptr_factory_.GetWeakPtr(), visit)); + if (!visit.text_to_annotate) + return; + // Used for testing. + LOCAL_HISTOGRAM_BOOLEAN( + "PageContentAnnotations.AnnotateVisit.AnnotationRequested", true); + + auto it = annotated_text_cache_.Peek(*visit.text_to_annotate); + if (it != annotated_text_cache_.end()) { + // We have annotations the text for this visit, so return that immediately + // rather than re-executing the model. + // + // TODO(crbug.com/1291275): If the model was updated, the cached value could + // be stale so we should invalidate the cache on model updates. + OnPageContentAnnotated(visit, it->second); + base::UmaHistogramBoolean( + "OptimizationGuide.PageContentAnnotations.AnnotateVisitResultCached", + true); + return; + } + visits_to_annotate_.emplace_back(visit); + base::UmaHistogramBoolean( + "OptimizationGuide.PageContentAnnotations.AnnotateVisitResultCached", + false); + if (visits_to_annotate_.size() >= features::AnnotateVisitBatchSize()) { + if (current_visit_annotation_batch_.empty()) { + // Used for testing. + LOCAL_HISTOGRAM_BOOLEAN( + "PageContentAnnotations.AnnotateVisit.BatchAnnotationStarted", true); + current_visit_annotation_batch_ = std::move(visits_to_annotate_); + AnnotateVisitBatch(); + return; + } + // The queue is full and an batch annotation is actively being done so + // we will remove the "oldest" visit. + visits_to_annotate_.erase(visits_to_annotate_.begin()); + // Used for testing. + LOCAL_HISTOGRAM_BOOLEAN( + "PageContentAnnotations.AnnotateVisit.QueueFullVisitDropped", true); + } + // Used for testing. + LOCAL_HISTOGRAM_BOOLEAN( + "PageContentAnnotations.AnnotateVisit.AnnotationRequestQueued", true); #endif } +#if BUILDFLAG(BUILD_WITH_TFLITE_LIB) +void PageContentAnnotationsService::AnnotateVisitBatch() { + DCHECK(!current_visit_annotation_batch_.empty()); + + if (switches::StopHistoryVisitBatchAnnotateForTesting()) { + // Code beyond this is tested in multiple places. This just ensures the + // calls up to this point can be more easily configured. + return; + } + + if (current_visit_annotation_batch_.empty()) { + return; + } + auto visit = current_visit_annotation_batch_.back(); + DCHECK(visit.text_to_annotate); + if (visit.text_to_annotate) { + model_manager_->Annotate( + *(visit.text_to_annotate), + base::BindOnce(&PageContentAnnotationsService::OnBatchVisitAnnotated, + weak_ptr_factory_.GetWeakPtr(), visit)); + } +} + +void PageContentAnnotationsService::OnBatchVisitAnnotated( + const HistoryVisit& visit, + const absl::optional<history::VisitContentModelAnnotations>& + content_annotations) { + OnPageContentAnnotated(visit, content_annotations); + DCHECK_EQ(visit.navigation_id, + current_visit_annotation_batch_.back().navigation_id); + current_visit_annotation_batch_.pop_back(); + if (!current_visit_annotation_batch_.empty()) { + AnnotateVisitBatch(); + } +} +#endif + void PageContentAnnotationsService::OverridePageContentAnnotatorForTesting( PageContentAnnotator* annotator) { annotator_ = annotator; @@ -135,17 +247,27 @@ absl::optional<ModelInfo> PageContentAnnotationsService::GetModelInfoForType( #endif } -void PageContentAnnotationsService::NotifyWhenModelAvailable( +void PageContentAnnotationsService::RequestAndNotifyWhenModelAvailable( AnnotationType type, base::OnceCallback<void(bool)> callback) { #if BUILDFLAG(BUILD_WITH_TFLITE_LIB) DCHECK(model_manager_); - model_manager_->NotifyWhenModelAvailable(type, std::move(callback)); + model_manager_->RequestAndNotifyWhenModelAvailable(type, std::move(callback)); #else std::move(callback).Run(false); #endif } +void PageContentAnnotationsService::PersistSearchMetadata( + const HistoryVisit& visit, + const SearchMetadata& search_metadata) { + QueryURL(visit, + base::BindOnce(&history::HistoryService::AddSearchMetadataForVisit, + history_service_->AsWeakPtr(), + search_metadata.normalized_url, + search_metadata.search_terms)); +} + void PageContentAnnotationsService::ExtractRelatedSearches( const HistoryVisit& visit, content::WebContents* web_contents) { @@ -166,6 +288,10 @@ void PageContentAnnotationsService::OnPageContentAnnotated( if (!content_annotations) return; + if (annotated_text_cache_.Peek(*visit.text_to_annotate) == + annotated_text_cache_.end()) { + annotated_text_cache_.Put(*visit.text_to_annotate, *content_annotations); + } MaybeRecordVisibilityUKM(visit, content_annotations); if (!features::ShouldWriteContentAnnotationsToHistoryService()) @@ -258,6 +384,13 @@ void PageContentAnnotationsService::OnURLQueried( void PageContentAnnotationsService::GetMetadataForEntityId( const std::string& entity_id, EntityMetadataRetrievedCallback callback) { + if (features::UseLocalPageEntitiesMetadataProvider()) { + DCHECK(local_page_entities_metadata_provider_); + local_page_entities_metadata_provider_->GetMetadataForEntityId( + entity_id, std::move(callback)); + return; + } + #if BUILDFLAG(BUILD_WITH_TFLITE_LIB) model_manager_->GetMetadataForEntityId(entity_id, std::move(callback)); #else @@ -277,14 +410,51 @@ void PageContentAnnotationsService::PersistRemotePageEntities( history_service_->AsWeakPtr(), annotations)); } +void PageContentAnnotationsService::RunBatchAnnotationValidation() { + DCHECK(features::BatchAnnotationsValidationEnabled()); + DCHECK(validation_timer_); + validation_timer_.reset(); + + std::vector<std::string> dummy_inputs; + dummy_inputs.reserve(features::BatchAnnotationsValidationBatchSize()); + for (size_t i = 0; i < features::BatchAnnotationsValidationBatchSize(); i++) { + // Pick a random substring of the dummy blob so that we can't do any caching + // or deduping. + size_t half_length = std::strlen(kDummyTextBlob) / 2; + size_t rand_start = base::RandInt(0, half_length - 1); + dummy_inputs.emplace_back( + std::string(kDummyTextBlob + rand_start, half_length)); + } + + LOCAL_HISTOGRAM_COUNTS_100( + "OptimizationGuide.PageContentAnnotationsService.ValidationRun", + dummy_inputs.size()); + + BatchAnnotate(base::DoNothing(), dummy_inputs, + AnnotationType::kContentVisibility); +} + // static HistoryVisit PageContentAnnotationsService::CreateHistoryVisitFromWebContents( content::WebContents* web_contents, int64_t navigation_id) { - HistoryVisit visit = { + HistoryVisit visit( web_contents->GetController().GetLastCommittedEntry()->GetTimestamp(), - web_contents->GetLastCommittedURL(), navigation_id}; + web_contents->GetLastCommittedURL(), navigation_id); return visit; } +HistoryVisit::HistoryVisit() = default; + +HistoryVisit::HistoryVisit(base::Time nav_entry_timestamp, + GURL url, + int64_t navigation_id) { + this->nav_entry_timestamp = nav_entry_timestamp; + this->url = url; + this->navigation_id = navigation_id; +} + +HistoryVisit::~HistoryVisit() = default; +HistoryVisit::HistoryVisit(const HistoryVisit&) = default; + } // namespace optimization_guide diff --git a/chromium/components/optimization_guide/content/browser/page_content_annotations_service.h b/chromium/components/optimization_guide/content/browser/page_content_annotations_service.h index 4e361811f85..7931d37b30e 100644 --- a/chromium/components/optimization_guide/content/browser/page_content_annotations_service.h +++ b/chromium/components/optimization_guide/content/browser/page_content_annotations_service.h @@ -9,12 +9,14 @@ #include "base/callback_forward.h" #include "base/containers/lru_cache.h" +#include "base/files/file_path.h" #include "base/hash/hash.h" #include "base/memory/raw_ptr.h" #include "base/memory/weak_ptr.h" #include "base/strings/strcat.h" #include "base/strings/string_number_conversions.h" #include "base/task/cancelable_task_tracker.h" +#include "base/task/sequenced_task_runner.h" #include "components/continuous_search/browser/search_result_extractor_client.h" #include "components/continuous_search/browser/search_result_extractor_client_status.h" #include "components/continuous_search/common/public/mojom/continuous_search.mojom.h" @@ -26,8 +28,13 @@ #include "components/optimization_guide/core/model_info.h" #include "components/optimization_guide/core/page_content_annotations_common.h" #include "components/optimization_guide/machine_learning_tflite_buildflags.h" +#include "third_party/abseil-cpp/absl/types/optional.h" #include "url/gurl.h" +namespace base { +class OneShotTimer; +} // namespace base + namespace content { class WebContents; } // namespace content @@ -36,8 +43,13 @@ namespace history { class HistoryService; } // namespace history +namespace leveldb_proto { +class ProtoDatabaseProvider; +} // namespace leveldb_proto + namespace optimization_guide { +class LocalPageEntitiesMetadataProvider; class OptimizationGuideModelProvider; class PageContentAnnotationsModelManager; class PageContentAnnotationsServiceBrowserTest; @@ -45,9 +57,15 @@ class PageContentAnnotationsWebContentsObserver; // The information used by HistoryService to identify a visit to a URL. struct HistoryVisit { + HistoryVisit(); + HistoryVisit(base::Time nav_entry_timestamp, GURL url, int64_t navigation_id); + ~HistoryVisit(); + HistoryVisit(const HistoryVisit&); + base::Time nav_entry_timestamp; GURL url; - int64_t navigation_id; + int64_t navigation_id = 0; + absl::optional<std::string> text_to_annotate; struct Comp { bool operator()(const HistoryVisit& lhs, const HistoryVisit& rhs) const { @@ -58,6 +76,12 @@ struct HistoryVisit { }; }; +// The information about a search visit to store in HistoryService. +struct SearchMetadata { + GURL normalized_url; + std::u16string search_terms; +}; + // A KeyedService that annotates page content. class PageContentAnnotationsService : public KeyedService, public EntityMetadataProvider { @@ -65,38 +89,44 @@ class PageContentAnnotationsService : public KeyedService, PageContentAnnotationsService( const std::string& application_locale, OptimizationGuideModelProvider* optimization_guide_model_provider, - history::HistoryService* history_service); + history::HistoryService* history_service, + leveldb_proto::ProtoDatabaseProvider* database_provider, + const base::FilePath& database_dir, + scoped_refptr<base::SequencedTaskRunner> background_task_runner); ~PageContentAnnotationsService() override; PageContentAnnotationsService(const PageContentAnnotationsService&) = delete; PageContentAnnotationsService& operator=( const PageContentAnnotationsService&) = delete; // This is the main entry point for page content annotations by external - // callers. + // callers. Callers must call |RequestAndNotifyWhenModelAvailable| as close to + // session start as possible to allow time for the model file to be + // downloaded. void BatchAnnotate(BatchAnnotationCallback callback, const std::vector<std::string>& inputs, AnnotationType annotation_type); - // Overrides the PageContentAnnotator for testing. See - // test_page_content_annotator.h for an implementation designed for testing. - void OverridePageContentAnnotatorForTesting(PageContentAnnotator* annotator); + // Requests that the given model for |type| be loaded in the background and + // then runs |callback| with true when the model is ready to execute. If the + // model is ready now, the callback is run immediately. If the model file will + // never be available, the callback is run with false. + void RequestAndNotifyWhenModelAvailable( + AnnotationType type, + base::OnceCallback<void(bool)> callback); // Returns the model info for the given annotation type, if the model file is // available. absl::optional<ModelInfo> GetModelInfoForType(AnnotationType type) const; - // Runs |callback| with true when the model that powers |BatchAnnotate| for - // the given annotation type is ready to execute. If the model is ready now, - // the callback is run immediately. If the model file will never be available, - // the callback is run with false. - void NotifyWhenModelAvailable(AnnotationType type, - base::OnceCallback<void(bool)> callback); - // EntityMetadataProvider: void GetMetadataForEntityId( const std::string& entity_id, EntityMetadataRetrievedCallback callback) override; + // Overrides the PageContentAnnotator for testing. See + // test_page_content_annotator.h for an implementation designed for testing. + void OverridePageContentAnnotatorForTesting(PageContentAnnotator* annotator); + private: #if BUILDFLAG(BUILD_WITH_TFLITE_LIB) // Callback invoked when |visit| has been annotated. @@ -105,7 +135,20 @@ class PageContentAnnotationsService : public KeyedService, const absl::optional<history::VisitContentModelAnnotations>& content_annotations); + // Runs the page annotation models available to |model_manager_| on all the + // visits within |current_visit_annotation_batch_|. + void AnnotateVisitBatch(); + + // Callback run after the annotations for a |visit| of a batch has been + // determined. |current_visit_annotation_batch_| is updated to remove + // the annotated visit and will trigger the next visit to be annotated. + void OnBatchVisitAnnotated( + const HistoryVisit& visit, + const absl::optional<history::VisitContentModelAnnotations>& + content_annotations); + std::unique_ptr<PageContentAnnotationsModelManager> model_manager_; + #endif // The annotator to use for requests to |BatchAnnotate|. In prod, this is @@ -123,13 +166,17 @@ class PageContentAnnotationsService : public KeyedService, friend class PageContentAnnotationsWebContentsObserver; friend class PageContentAnnotationsServiceBrowserTest; // Virtualized for testing. - virtual void Annotate(const HistoryVisit& visit, const std::string& text); + virtual void Annotate(const HistoryVisit& visit); // Creates a HistoryVisit based on the current state of |web_contents|. static HistoryVisit CreateHistoryVisitFromWebContents( content::WebContents* web_contents, int64_t navigation_id); + // Persist |search_metadata| for |visit| in |history_service_|. + virtual void PersistSearchMetadata(const HistoryVisit& visit, + const SearchMetadata& search_metadata); + // Requests |search_result_extractor_client_| to extract related searches from // the Google SRP DOM associated with |web_contents|. // @@ -165,6 +212,17 @@ class PageContentAnnotationsService : public KeyedService, PersistAnnotationsCallback callback, history::QueryURLResult url_result); + // Runs a batch annotation validation, that is calls |BatchAnnotate| with + // dummy input and discards the output. + void RunBatchAnnotationValidation(); + + // A metadata-only provider for page entities (as opposed to |model_manager_| + // which does both entity model execution and metadata providing) that uses a + // local database to provide the metadata for a given entity id. This is only + // non-null and initialized when its feature flag is enabled. + std::unique_ptr<LocalPageEntitiesMetadataProvider> + local_page_entities_metadata_provider_; + // The history service to write content annotations to. Not owned. Guaranteed // to outlive |this|. raw_ptr<history::HistoryService> history_service_; @@ -180,6 +238,23 @@ class PageContentAnnotationsService : public KeyedService, base::LRUCache<HistoryVisit, bool, HistoryVisit::Comp> last_annotated_history_visits_; + // A LRU cache of the annotation results for visits. If the text of the visit + // is in the cache, the cached model annotations will be used. + base::HashingLRUCache<std::string, history::VisitContentModelAnnotations> + annotated_text_cache_; + + // The set of visits to be annotated, this is added to by Annotate requests + // from the web content observer. These will be annotated when the set is full + // and annotations can be scheduled with minimal impact to browsing. + std::vector<HistoryVisit> visits_to_annotate_; + + // The batch of visits being annotated. If this is empty, it is assumed that + // no visits are actively be annotated and a new batch can be started. + std::vector<HistoryVisit> current_visit_annotation_batch_; + + // Is only ever set when the feature is enabled. + std::unique_ptr<base::OneShotTimer> validation_timer_; + base::WeakPtrFactory<PageContentAnnotationsService> weak_ptr_factory_{this}; }; diff --git a/chromium/components/optimization_guide/content/browser/page_content_annotations_web_contents_observer.cc b/chromium/components/optimization_guide/content/browser/page_content_annotations_web_contents_observer.cc index 8eea9f2744e..d91df2d5163 100644 --- a/chromium/components/optimization_guide/content/browser/page_content_annotations_web_contents_observer.cc +++ b/chromium/components/optimization_guide/content/browser/page_content_annotations_web_contents_observer.cc @@ -11,6 +11,7 @@ #include "components/optimization_guide/content/browser/optimization_guide_decider.h" #include "components/optimization_guide/content/browser/page_content_annotations_service.h" #include "components/optimization_guide/core/optimization_guide_features.h" +#include "components/optimization_guide/core/optimization_guide_switches.h" #include "components/optimization_guide/proto/page_entities_metadata.pb.h" #include "components/search_engines/template_url_service.h" #include "content/public/browser/navigation_entry.h" @@ -21,24 +22,38 @@ namespace optimization_guide { namespace { -// Returns the search query if |url| is a valid Search URL according to +// Returns search metadata if |url| is a valid Search URL according to // |template_url_service|. -absl::optional<std::u16string> ExtractSearchTerms( +absl::optional<SearchMetadata> ExtractSearchMetadata( const TemplateURLService* template_url_service, const GURL& url) { - const TemplateURL* default_search_provider = - template_url_service->GetDefaultSearchProvider(); + if (!template_url_service) + return absl::nullopt; + + const TemplateURL* template_url = + template_url_service->GetTemplateURLForHost(url.host()); const SearchTermsData& search_terms_data = template_url_service->search_terms_data(); std::u16string search_terms; - if (default_search_provider && - default_search_provider->ExtractSearchTermsFromURL(url, search_terms_data, - &search_terms) && - !search_terms.empty()) { - return search_terms; - } - return absl::nullopt; + bool is_valid_search_url = template_url && + template_url->ExtractSearchTermsFromURL( + url, search_terms_data, &search_terms) && + !search_terms.empty(); + if (!is_valid_search_url) + return absl::nullopt; + + const std::u16string& normalized_search_query = + base::i18n::ToLower(base::CollapseWhitespace(search_terms, false)); + TemplateURLRef::SearchTermsArgs search_terms_args(normalized_search_query); + const TemplateURLRef& search_url_ref = template_url->url_ref(); + if (!search_url_ref.SupportsReplacement(search_terms_data)) + return absl::nullopt; + + return SearchMetadata{ + GURL(search_url_ref.ReplaceSearchTerms(search_terms_args, + search_terms_data)), + base::i18n::ToLower(base::CollapseWhitespace(search_terms, false))}; } // Data scoped to a single page. PageData has the same lifetime as the page's @@ -147,16 +162,24 @@ void PageContentAnnotationsWebContentsObserver::DidFinishNavigation( web_contents()); } - absl::optional<std::u16string> search_terms = - ExtractSearchTerms(template_url_service_, navigation_handle->GetURL()); - if (search_terms) { + absl::optional<SearchMetadata> search_metadata = ExtractSearchMetadata( + template_url_service_, navigation_handle->GetURL()); + if (search_metadata) { if (page_data) { page_data->set_annotation_was_requested(); } - const std::u16string& normalized_search_query = - base::i18n::ToLower(base::CollapseWhitespace(*search_terms, false)); - page_content_annotations_service_->Annotate( - history_visit, base::UTF16ToUTF8(normalized_search_query)); + history_visit.text_to_annotate = + base::UTF16ToUTF8(search_metadata->search_terms); + page_content_annotations_service_->Annotate(history_visit); + page_content_annotations_service_->PersistSearchMetadata( + history_visit, *search_metadata); + + if (switches::ShouldLogPageContentAnnotationsInput()) { + LOG(ERROR) << "Annotating search terms: \n" + << "URL: " << navigation_handle->GetURL() << "\n" + << "Text: " << *(history_visit.text_to_annotate); + } + return; } } @@ -168,8 +191,15 @@ void PageContentAnnotationsWebContentsObserver::DidFinishNavigation( page_data->set_annotation_was_requested(); } // Annotate the title instead. - page_content_annotations_service_->Annotate( - history_visit, base::UTF16ToUTF8(web_contents()->GetTitle())); + history_visit.text_to_annotate = + base::UTF16ToUTF8(web_contents()->GetTitle()); + page_content_annotations_service_->Annotate(history_visit); + + if (switches::ShouldLogPageContentAnnotationsInput()) { + LOG(ERROR) << "Annotating same document navigation: \n" + << "URL: " << navigation_handle->GetURL() << "\n" + << "Text: " << *(history_visit.text_to_annotate); + } } } @@ -191,8 +221,15 @@ void PageContentAnnotationsWebContentsObserver::TitleWasSet( optimization_guide::HistoryVisit history_visit = optimization_guide:: PageContentAnnotationsService::CreateHistoryVisitFromWebContents( web_contents(), page_data->navigation_id()); - page_content_annotations_service_->Annotate( - history_visit, base::UTF16ToUTF8(entry->GetTitleForDisplay())); + history_visit.text_to_annotate = + base::UTF16ToUTF8(entry->GetTitleForDisplay()); + page_content_annotations_service_->Annotate(history_visit); + + if (switches::ShouldLogPageContentAnnotationsInput()) { + LOG(ERROR) << "Annotating main frame navigation: \n" + << "URL: " << entry->GetURL() << "\n" + << "Text: " << *(history_visit.text_to_annotate); + } } std::unique_ptr<PageTextObserver::ConsumerTextDumpRequest> @@ -230,7 +267,7 @@ PageContentAnnotationsWebContentsObserver::MaybeRequestFrameTextDump( } void PageContentAnnotationsWebContentsObserver::OnTextDumpReceived( - const HistoryVisit& visit, + HistoryVisit visit, const PageTextDumpResult& result) { DCHECK(!features::ShouldAnnotateTitleInsteadOfPageContent()); @@ -241,12 +278,22 @@ void PageContentAnnotationsWebContentsObserver::OnTextDumpReceived( // If the page had AMP frames, then only use that content. Otherwise, use the // mainframe. if (result.GetAMPTextContent()) { - page_content_annotations_service_->Annotate(visit, - *result.GetAMPTextContent()); + visit.text_to_annotate = *result.GetAMPTextContent(); + page_content_annotations_service_->Annotate(visit); + if (switches::ShouldLogPageContentAnnotationsInput()) { + LOG(ERROR) << "Annotating AMP text content: \n" + << "URL: " << visit.url << "\n" + << "Text: " << *(visit.text_to_annotate); + } return; } - page_content_annotations_service_->Annotate( - visit, *result.GetMainFrameTextContent()); + visit.text_to_annotate = *result.GetMainFrameTextContent(); + page_content_annotations_service_->Annotate(visit); + if (switches::ShouldLogPageContentAnnotationsInput()) { + LOG(ERROR) << "Annotating main frame text content: \n" + << "URL: " << visit.url << "\n" + << "Text: " << *(visit.text_to_annotate); + } } void PageContentAnnotationsWebContentsObserver::OnRemotePageEntitiesReceived( diff --git a/chromium/components/optimization_guide/content/browser/page_content_annotations_web_contents_observer.h b/chromium/components/optimization_guide/content/browser/page_content_annotations_web_contents_observer.h index c0417821386..de084693c39 100644 --- a/chromium/components/optimization_guide/content/browser/page_content_annotations_web_contents_observer.h +++ b/chromium/components/optimization_guide/content/browser/page_content_annotations_web_contents_observer.h @@ -62,8 +62,7 @@ class PageContentAnnotationsWebContentsObserver content::NavigationHandle* navigation_handle) override; // Callback invoked when a text dump has been received for the |visit|. - void OnTextDumpReceived(const HistoryVisit& visit, - const PageTextDumpResult& result); + void OnTextDumpReceived(HistoryVisit visit, const PageTextDumpResult& result); // Callback invoked when the page entities have been received from // |optimization_guide_decider_| for |visit|. diff --git a/chromium/components/optimization_guide/content/browser/page_content_annotations_web_contents_observer_unittest.cc b/chromium/components/optimization_guide/content/browser/page_content_annotations_web_contents_observer_unittest.cc index c08c20f7276..49397fca3e7 100644 --- a/chromium/components/optimization_guide/content/browser/page_content_annotations_web_contents_observer_unittest.cc +++ b/chromium/components/optimization_guide/content/browser/page_content_annotations_web_contents_observer_unittest.cc @@ -62,11 +62,14 @@ class FakePageContentAnnotationsService : public PageContentAnnotationsService { history::HistoryService* history_service) : PageContentAnnotationsService("en-US", optimization_guide_model_provider, - history_service) {} + history_service, + nullptr, + base::FilePath(), + nullptr) {} ~FakePageContentAnnotationsService() override = default; - void Annotate(const HistoryVisit& visit, const std::string& text) override { - last_annotation_request_.emplace(std::make_pair(visit, text)); + void Annotate(const HistoryVisit& visit) override { + last_annotation_request_.emplace(visit); } void ExtractRelatedSearches(const HistoryVisit& visit, @@ -75,8 +78,7 @@ class FakePageContentAnnotationsService : public PageContentAnnotationsService { std::make_pair(visit, web_contents)); } - absl::optional<std::pair<HistoryVisit, std::string>> last_annotation_request() - const { + absl::optional<HistoryVisit> last_annotation_request() const { return last_annotation_request_; } @@ -103,14 +105,24 @@ class FakePageContentAnnotationsService : public PageContentAnnotationsService { return last_entities_persistence_request_; } + void PersistSearchMetadata(const HistoryVisit& visit, + const SearchMetadata& search_metadata) override { + last_search_metadata_ = search_metadata; + } + + absl::optional<SearchMetadata> last_search_metadata_persisted() const { + return last_search_metadata_; + } + private: - absl::optional<std::pair<HistoryVisit, std::string>> last_annotation_request_; + absl::optional<HistoryVisit> last_annotation_request_; absl::optional<std::pair<HistoryVisit, content::WebContents*>> last_related_searches_extraction_request_; absl::optional< std::pair<HistoryVisit, std::vector<history::VisitContentModelAnnotations::Category>>> last_entities_persistence_request_; + absl::optional<SearchMetadata> last_search_metadata_; }; class FakeOptimizationGuideDecider : public TestOptimizationGuideDecider { @@ -322,11 +334,11 @@ TEST_F(PageContentAnnotationsWebContentsObserverTest, result.AddFrameTextDumpResult(frame_result); std::move(request->callback).Run(std::move(result)); - absl::optional<std::pair<HistoryVisit, std::string>> last_annotation_request = + absl::optional<HistoryVisit> last_annotation_request = service()->last_annotation_request(); EXPECT_TRUE(last_annotation_request.has_value()); - EXPECT_EQ(last_annotation_request->first.url, GURL("http://test.com")); - EXPECT_EQ(last_annotation_request->second, "some text"); + EXPECT_EQ(last_annotation_request->url, GURL("http://test.com")); + EXPECT_EQ(last_annotation_request->text_to_annotate, "some text"); service()->ClearLastAnnotationRequest(); @@ -355,11 +367,11 @@ TEST_F(PageContentAnnotationsWebContentsObserverTest, navigation_simulator->CommitSameDocument(); // The title should be what is requested to be annotated. - absl::optional<std::pair<HistoryVisit, std::string>> last_annotation_request = + absl::optional<HistoryVisit> last_annotation_request = service()->last_annotation_request(); EXPECT_TRUE(last_annotation_request.has_value()); - EXPECT_EQ(last_annotation_request->first.url, url2); - EXPECT_EQ(last_annotation_request->second, "Title"); + EXPECT_EQ(last_annotation_request->url, url2); + EXPECT_EQ(last_annotation_request->text_to_annotate, "Title"); } TEST_F(PageContentAnnotationsWebContentsObserverTest, @@ -369,12 +381,19 @@ TEST_F(PageContentAnnotationsWebContentsObserverTest, web_contents(), GURL("http://default-engine.com/search?q=a")); // The search query should be what is requested to be annotated. - absl::optional<std::pair<HistoryVisit, std::string>> last_annotation_request = + absl::optional<HistoryVisit> last_annotation_request = service()->last_annotation_request(); ASSERT_TRUE(last_annotation_request.has_value()); - EXPECT_EQ(last_annotation_request->first.url, + EXPECT_EQ(last_annotation_request->url, + GURL("http://default-engine.com/search?q=a")); + EXPECT_EQ(last_annotation_request->text_to_annotate, "a"); + + absl::optional<SearchMetadata> last_search_metadata_persisted = + service()->last_search_metadata_persisted(); + ASSERT_TRUE(last_search_metadata_persisted.has_value()); + EXPECT_EQ(last_search_metadata_persisted->normalized_url, GURL("http://default-engine.com/search?q=a")); - EXPECT_EQ(last_annotation_request->second, "a"); + EXPECT_EQ(last_search_metadata_persisted->search_terms, u"a"); } TEST_F(PageContentAnnotationsWebContentsObserverTest, @@ -460,11 +479,11 @@ TEST_F(PageContentAnnotationsWebContentsObserverAnnotateTitleTest, navigation_simulator->CommitSameDocument(); // The title should be what is requested to be annotated. - absl::optional<std::pair<HistoryVisit, std::string>> last_annotation_request = + absl::optional<HistoryVisit> last_annotation_request = service()->last_annotation_request(); EXPECT_TRUE(last_annotation_request.has_value()); - EXPECT_EQ(last_annotation_request->first.url, url2); - EXPECT_EQ(last_annotation_request->second, "Title"); + EXPECT_EQ(last_annotation_request->url, url2); + EXPECT_EQ(last_annotation_request->text_to_annotate, "Title"); service()->ClearLastAnnotationRequest(); @@ -489,12 +508,11 @@ TEST_F(PageContentAnnotationsWebContentsObserverAnnotateTitleTest, title); // The title should be what is requested to be annotated. - absl::optional<std::pair<HistoryVisit, std::string>> last_annotation_request = + absl::optional<HistoryVisit> last_annotation_request = service()->last_annotation_request(); EXPECT_TRUE(last_annotation_request.has_value()); - EXPECT_EQ(last_annotation_request->first.url, - GURL("http://www.foo.com/someurl")); - EXPECT_EQ(last_annotation_request->second, "Title"); + EXPECT_EQ(last_annotation_request->url, GURL("http://www.foo.com/someurl")); + EXPECT_EQ(last_annotation_request->text_to_annotate, "Title"); service()->ClearLastAnnotationRequest(); diff --git a/chromium/components/optimization_guide/content/browser/test_page_content_annotator.cc b/chromium/components/optimization_guide/content/browser/test_page_content_annotator.cc index 5df206b2994..f1c12bed097 100644 --- a/chromium/components/optimization_guide/content/browser/test_page_content_annotator.cc +++ b/chromium/components/optimization_guide/content/browser/test_page_content_annotator.cc @@ -17,7 +17,7 @@ void TestPageContentAnnotator::Annotate(BatchAnnotationCallback callback, if (annotation_type == AnnotationType::kPageTopics) { for (const std::string& input : inputs) { auto it = topics_by_input_.find(input); - absl::optional<std::vector<WeightedString>> output; + absl::optional<std::vector<WeightedIdentifier>> output; if (it != topics_by_input_.end()) { output = it->second; } @@ -54,7 +54,7 @@ void TestPageContentAnnotator::Annotate(BatchAnnotationCallback callback, } void TestPageContentAnnotator::UsePageTopics( - const base::flat_map<std::string, std::vector<WeightedString>>& + const base::flat_map<std::string, std::vector<WeightedIdentifier>>& topics_by_input) { topics_by_input_ = topics_by_input; } diff --git a/chromium/components/optimization_guide/content/browser/test_page_content_annotator.h b/chromium/components/optimization_guide/content/browser/test_page_content_annotator.h index dc52ac14b03..a38e1ccd5c4 100644 --- a/chromium/components/optimization_guide/content/browser/test_page_content_annotator.h +++ b/chromium/components/optimization_guide/content/browser/test_page_content_annotator.h @@ -23,7 +23,7 @@ class TestPageContentAnnotator : public PageContentAnnotator { // The given page topics are used for the matching BatchAnnotationResults by // input string. If the input is not found, the output is left as nullopt. void UsePageTopics( - const base::flat_map<std::string, std::vector<WeightedString>>& + const base::flat_map<std::string, std::vector<WeightedIdentifier>>& topics_by_input); // The given page entities are used for the matching BatchAnnotationResults by @@ -43,7 +43,7 @@ class TestPageContentAnnotator : public PageContentAnnotator { AnnotationType annotation_type) override; private: - base::flat_map<std::string, std::vector<WeightedString>> topics_by_input_; + base::flat_map<std::string, std::vector<WeightedIdentifier>> topics_by_input_; base::flat_map<std::string, std::vector<ScoredEntityMetadata>> entities_by_input_; base::flat_map<std::string, double> visibility_scores_for_input_; diff --git a/chromium/components/optimization_guide/core/BUILD.gn b/chromium/components/optimization_guide/core/BUILD.gn index e66d20fc9ba..90eaa54e934 100644 --- a/chromium/components/optimization_guide/core/BUILD.gn +++ b/chromium/components/optimization_guide/core/BUILD.gn @@ -1,5 +1,5 @@ -# Copyright 2017 The Chromium Authors. All rights reserved. -# Use of this source code is governed by a BSD-style license that can be +# Copyright 2017 The Chromium Authors.All rights reserved. +# Use of this source code is governed by a BSD - style license that can be # found in the LICENSE file. if (is_android) { @@ -32,14 +32,66 @@ static_library("entities") { ] } +static_library("model_executor") { + sources = [ + "execution_status.cc", + "execution_status.h", + "model_enums.h", + "model_executor.h", + "model_info.cc", + "model_info.h", + "model_util.cc", + "model_util.h", + ] + if (build_with_tflite_lib) { + sources += [ + "base_model_executor.h", + "base_model_executor_helpers.h", + "bert_model_executor.cc", + "bert_model_executor.h", + "tflite_model_executor.h", + ] + } + + public_deps = [ + "//components/optimization_guide:machine_learning_tflite_buildflags", + "//third_party/re2", + ] + if (build_with_tflite_lib) { + public_deps += [ + "//components/optimization_guide/core:machine_learning", + "//third_party/abseil-cpp:absl", + "//third_party/tflite", + "//third_party/tflite:tflite_public_headers", + "//third_party/tflite_support", + "//third_party/tflite_support:tflite_support_proto", + ] + } + deps = [ + "//base", + "//components/optimization_guide/proto:optimization_guide_proto", + "//net", + "//url", + ] +} + +if (build_with_tflite_lib) { + static_library("machine_learning") { + sources = [ + "tflite_op_resolver.cc", + "tflite_op_resolver.h", + ] + deps = [ + "//third_party/tflite", + "//third_party/tflite:tflite_public_headers", + ] + } +} + static_library("core") { sources = [ "command_line_top_host_provider.cc", "command_line_top_host_provider.h", - "decision_tree_prediction_model.cc", - "decision_tree_prediction_model.h", - "execution_status.cc", - "execution_status.h", "hint_cache.cc", "hint_cache.h", "hints_component_info.h", @@ -54,12 +106,11 @@ static_library("core") { "hints_processing_util.cc", "hints_processing_util.h", "insertion_ordered_set.h", + "local_page_entities_metadata_provider.cc", + "local_page_entities_metadata_provider.h", "memory_hint.cc", "memory_hint.h", - "model_executor.h", "model_handler.h", - "model_info.cc", - "model_info.h", "noisy_metrics_recorder.cc", "noisy_metrics_recorder.h", "optimization_filter.cc", @@ -70,6 +121,8 @@ static_library("core") { "optimization_guide_enums.h", "optimization_guide_features.cc", "optimization_guide_features.h", + "optimization_guide_logger.cc", + "optimization_guide_logger.h", "optimization_guide_model_provider.h", "optimization_guide_navigation_data.cc", "optimization_guide_navigation_data.h", @@ -93,8 +146,6 @@ static_library("core") { "page_content_annotation_job.h", "page_content_annotations_common.cc", "page_content_annotations_common.h", - "prediction_model.cc", - "prediction_model.h", "prediction_model_fetcher.h", "prediction_model_fetcher_impl.cc", "prediction_model_fetcher_impl.h", @@ -108,10 +159,6 @@ static_library("core") { ] if (build_with_tflite_lib) { sources += [ - "base_model_executor.h", - "base_model_executor_helpers.h", - "bert_model_executor.cc", - "bert_model_executor.h", "bert_model_handler.cc", "bert_model_handler.h", "model_validator.cc", @@ -123,12 +170,20 @@ static_library("core") { "page_topics_model_executor.h", "page_visibility_model_executor.cc", "page_visibility_model_executor.h", - "tflite_model_executor.h", ] + if (build_with_internal_optimization_guide) { + sources += [ + "entity_annotator_native_library.cc", + "entity_annotator_native_library.h", + "page_entities_model_executor_impl.cc", + "page_entities_model_executor_impl.h", + ] + } } public_deps = [ ":entities", + ":model_executor", "//components/optimization_guide:machine_learning_tflite_buildflags", "//third_party/re2", ] @@ -146,7 +201,6 @@ static_library("core") { deps = [ ":bloomfilter", "//base", - "//components/data_reduction_proxy/core/browser", "//components/leveldb_proto", "//components/optimization_guide/proto:optimization_guide_proto", "//components/prefs", @@ -161,17 +215,9 @@ static_library("core") { "//ui/base:base", "//url:url", ] -} - -if (build_with_tflite_lib) { - static_library("machine_learning") { - sources = [ - "tflite_op_resolver.cc", - "tflite_op_resolver.h", - ] - deps = [ - "//third_party/tflite", - "//third_party/tflite:tflite_public_headers", + if (build_with_tflite_lib && build_with_internal_optimization_guide) { + data_deps = [ + "//components/optimization_guide/internal:optimization_guide_internal", ] } } @@ -244,13 +290,13 @@ source_set("unit_tests") { "batch_entity_metadata_task_unittest.cc", "bloom_filter_unittest.cc", "command_line_top_host_provider_unittest.cc", - "decision_tree_prediction_model_unittest.cc", "hint_cache_unittest.cc", "hints_component_util_unittest.cc", "hints_fetcher_unittest.cc", "hints_manager_unittest.cc", "hints_processing_util_unittest.cc", "insertion_ordered_set_unittest.cc", + "local_page_entities_metadata_provider_unittest.cc", "model_handler_unittest.cc", "noisy_metrics_recorder_unittest.cc", "optimization_filter_unittest.cc", @@ -264,7 +310,6 @@ source_set("unit_tests") { "optimization_metadata_unittest.cc", "page_content_annotation_job_unittest.cc", "prediction_model_fetcher_unittest.cc", - "prediction_model_unittest.cc", "store_update_data_unittest.cc", "url_pattern_with_wildcards_unittest.cc", ] @@ -277,6 +322,12 @@ source_set("unit_tests") { "page_visibility_model_executor_unittest.cc", "tflite_model_executor_unittest.cc", ] + if (build_with_internal_optimization_guide) { + sources += [ + "entity_annotator_native_library_unittest.cc", + "page_entities_model_executor_impl_unittest.cc", + ] + } } deps = [ @@ -286,8 +337,6 @@ source_set("unit_tests") { ":test_support", "//base", "//base/test:test_support", - "//components/data_reduction_proxy/core/browser", - "//components/data_reduction_proxy/core/common", "//components/leveldb_proto:test_support", "//components/optimization_guide/proto:optimization_guide_proto", "//components/prefs:test_support", @@ -319,3 +368,17 @@ if (is_android) { visibility = [ "//chrome/browser/optimization_guide/android:*" ] } } + +if (is_mac && build_with_internal_optimization_guide) { + # We need to copy the optimization guide shared library so that the + # bundle_data dependencies have a "copy" target type.Otherwise for + # "shared_library" target types it will try to link things into + # Chromium Framework when we want to keep it separate instead. + copy("optimization_guide_internal_library_copy") { + sources = [ "$root_out_dir/liboptimization_guide_internal.dylib" ] + outputs = [ "$root_out_dir/og_intermediates/{{source_file_part}}" ] + deps = [ + "//components/optimization_guide/internal:optimization_guide_internal", + ] + } +} diff --git a/chromium/components/optimization_guide/core/DEPS b/chromium/components/optimization_guide/core/DEPS index 3ab01bf2788..74dc5dfb565 100644 --- a/chromium/components/optimization_guide/core/DEPS +++ b/chromium/components/optimization_guide/core/DEPS @@ -1,7 +1,5 @@ include_rules = [ "+components/ukm/test_ukm_recorder.h", "+services/metrics/public/cpp", - "+third_party/abseil-cpp/absl/types/optional.h", - "+third_party/abseil-cpp/absl/status/status.h", "+ui/base/l10n", ] diff --git a/chromium/components/optimization_guide/core/base_model_executor.h b/chromium/components/optimization_guide/core/base_model_executor.h index 76bd8b96a46..78e064b4256 100644 --- a/chromium/components/optimization_guide/core/base_model_executor.h +++ b/chromium/components/optimization_guide/core/base_model_executor.h @@ -69,9 +69,9 @@ class BaseModelExecutor : public TFLiteModelExecutor<OutputType, InputTypes...>, } // InferenceDelegate: - absl::Status Preprocess(const std::vector<TfLiteTensor*>& input_tensors, - InputTypes... input) override = 0; - OutputType Postprocess( + bool Preprocess(const std::vector<TfLiteTensor*>& input_tensors, + InputTypes... input) override = 0; + absl::optional<OutputType> Postprocess( const std::vector<const TfLiteTensor*>& output_tensors) override = 0; }; diff --git a/chromium/components/optimization_guide/core/base_model_executor_helpers.h b/chromium/components/optimization_guide/core/base_model_executor_helpers.h index 0b9717eee2c..986ce04589d 100644 --- a/chromium/components/optimization_guide/core/base_model_executor_helpers.h +++ b/chromium/components/optimization_guide/core/base_model_executor_helpers.h @@ -11,6 +11,7 @@ #include "base/check.h" #include "base/memory/raw_ptr.h" #include "components/optimization_guide/core/execution_status.h" +#include "third_party/abseil-cpp/absl/types/optional.h" #include "third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/base_task_api.h" namespace optimization_guide { @@ -18,13 +19,13 @@ namespace optimization_guide { template <class OutputType, class... InputTypes> class InferenceDelegate { public: - // Preprocesses |args| into |input_tensors|. - virtual absl::Status Preprocess( - const std::vector<TfLiteTensor*>& input_tensors, - InputTypes... args) = 0; + // Preprocesses |args| into |input_tensors|. Returns true on success. + virtual bool Preprocess(const std::vector<TfLiteTensor*>& input_tensors, + InputTypes... args) = 0; - // Postprocesses |output_tensors| into the desired |OutputType|. - virtual OutputType Postprocess( + // Postprocesses |output_tensors| into the desired |OutputType|, returning + // absl::nullopt on error. + virtual absl::optional<OutputType> Postprocess( const std::vector<const TfLiteTensor*>& output_tensors) = 0; }; @@ -59,12 +60,24 @@ class GenericModelExecutionTask // BaseTaskApi: absl::Status Preprocess(const std::vector<TfLiteTensor*>& input_tensors, InputTypes... args) override { - return delegate_->Preprocess(input_tensors, args...); + bool success = delegate_->Preprocess(input_tensors, args...); + if (success) { + return absl::OkStatus(); + } + return absl::InternalError( + "error during preprocessing. See stderr for more information if " + "available"); } tflite::support::StatusOr<OutputType> Postprocess( const std::vector<const TfLiteTensor*>& output_tensors, InputTypes... api_inputs) override { - return delegate_->Postprocess(output_tensors); + absl::optional<OutputType> output = delegate_->Postprocess(output_tensors); + if (!output) { + return absl::InternalError( + "error during postprocessing. See stderr for more infomation if " + "available"); + } + return *output; } private: diff --git a/chromium/components/optimization_guide/core/bert_model_executor.cc b/chromium/components/optimization_guide/core/bert_model_executor.cc index ab2b7306726..fcd4df0e96d 100644 --- a/chromium/components/optimization_guide/core/bert_model_executor.cc +++ b/chromium/components/optimization_guide/core/bert_model_executor.cc @@ -5,9 +5,9 @@ #include "components/optimization_guide/core/bert_model_executor.h" #include "base/trace_event/trace_event.h" -#include "components/optimization_guide/core/optimization_guide_util.h" +#include "components/optimization_guide/core/model_util.h" #include "components/optimization_guide/core/tflite_op_resolver.h" -#include "third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h" +#include "third_party/tflite_support/src/tensorflow_lite_support/cc/task/text/bert_nl_classifier.h" namespace optimization_guide { @@ -28,18 +28,21 @@ BertModelExecutor::Execute(ModelExecutionTask* execution_task, GetStringNameForOptimizationTarget(optimization_target_), "input_length", input.size()); *out_status = ExecutionStatus::kSuccess; - return static_cast<tflite::task::text::nlclassifier::BertNLClassifier*>( - execution_task) + return static_cast<tflite::task::text::BertNLClassifier*>(execution_task) ->Classify(input); } std::unique_ptr<BertModelExecutor::ModelExecutionTask> BertModelExecutor::BuildModelExecutionTask(base::MemoryMappedFile* model_file, ExecutionStatus* out_status) { + tflite::task::text::BertNLClassifierOptions options; + *options.mutable_base_options() + ->mutable_model_file() + ->mutable_file_content() = std::string( + reinterpret_cast<const char*>(model_file->data()), model_file->length()); auto maybe_nl_classifier = - tflite::task::text::nlclassifier::BertNLClassifier::CreateFromBuffer( - reinterpret_cast<const char*>(model_file->data()), - model_file->length(), std::make_unique<TFLiteOpResolver>()); + tflite::task::text::BertNLClassifier::CreateFromOptions( + std::move(options), std::make_unique<TFLiteOpResolver>()); if (maybe_nl_classifier.ok()) return std::move(maybe_nl_classifier.value()); *out_status = ExecutionStatus::kErrorModelFileNotValid; diff --git a/chromium/components/optimization_guide/core/bloom_filter_unittest.cc b/chromium/components/optimization_guide/core/bloom_filter_unittest.cc index a2551a522fe..e0cf1884b67 100644 --- a/chromium/components/optimization_guide/core/bloom_filter_unittest.cc +++ b/chromium/components/optimization_guide/core/bloom_filter_unittest.cc @@ -103,7 +103,7 @@ TEST(BloomFilterTest, EverythingMatches) { } // Disable this test in configurations that don't print CHECK failures. -#if !defined(OS_IOS) && !(defined(OFFICIAL_BUILD) && defined(NDEBUG)) +#if !BUILDFLAG(IS_IOS) && !(defined(OFFICIAL_BUILD) && defined(NDEBUG)) TEST(BloomFilterTest, ByteVectorTooSmall) { std::string data(1023, 0xff); EXPECT_DEATH( diff --git a/chromium/components/optimization_guide/core/decision_tree_prediction_model.cc b/chromium/components/optimization_guide/core/decision_tree_prediction_model.cc deleted file mode 100644 index f8f5906c072..00000000000 --- a/chromium/components/optimization_guide/core/decision_tree_prediction_model.cc +++ /dev/null @@ -1,237 +0,0 @@ -// Copyright 2020 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "components/optimization_guide/core/decision_tree_prediction_model.h" - -#include <utility> - -namespace optimization_guide { - -DecisionTreePredictionModel::DecisionTreePredictionModel( - const proto::PredictionModel& prediction_model) - : PredictionModel(prediction_model) {} - -DecisionTreePredictionModel::~DecisionTreePredictionModel() = default; - -bool DecisionTreePredictionModel::ValidatePredictionModel() const { - // Only the top-level ensemble or decision tree must have a threshold. Any - // submodels of an ensemble will have model weights but no threshold. - if (!model_.has_threshold()) - return false; - return ValidateModel(model_); -} - -bool DecisionTreePredictionModel::ValidateModel( - const proto::Model& model) const { - if (model.has_ensemble()) { - return ValidateEnsembleModel(model.ensemble()); - } - if (model.has_decision_tree()) { - return ValidateDecisionTree(model.decision_tree()); - } - return false; -} - -bool DecisionTreePredictionModel::ValidateEnsembleModel( - const proto::Ensemble& ensemble) const { - if (ensemble.members_size() == 0) - return false; - - for (const auto& member : ensemble.members()) { - if (!ValidateModel(member.submodel())) { - return false; - } - } - return true; -} - -bool DecisionTreePredictionModel::ValidateDecisionTree( - const proto::DecisionTree& tree) const { - if (tree.nodes_size() == 0) - return false; - return ValidateTreeNode(tree, tree.nodes(0), 0); -} - -bool DecisionTreePredictionModel::ValidateLeaf(const proto::Leaf& leaf) const { - return leaf.has_vector() && leaf.vector().value_size() == 1 && - leaf.vector().value(0).has_double_value(); -} - -bool DecisionTreePredictionModel::ValidateInequalityTest( - const proto::InequalityTest& inequality_test) const { - if (!inequality_test.has_threshold()) - return false; - if (!inequality_test.threshold().has_float_value()) - return false; - if (!inequality_test.has_feature_id()) - return false; - if (!inequality_test.feature_id().has_id()) - return false; - if (!inequality_test.has_type()) - return false; - return true; -} - -bool DecisionTreePredictionModel::ValidateTreeNode( - const proto::DecisionTree& tree, - const proto::TreeNode& node, - int node_index) const { - if (node.has_leaf()) - return ValidateLeaf(node.leaf()); - - if (!node.has_binary_node()) - return false; - - proto::BinaryNode binary_node = node.binary_node(); - if (!binary_node.has_inequality_left_child_test()) - return false; - - if (!ValidateInequalityTest(binary_node.inequality_left_child_test())) - return false; - - if (!binary_node.left_child_id().has_value()) - return false; - if (!binary_node.right_child_id().has_value()) - return false; - - if (binary_node.left_child_id().value() >= tree.nodes_size()) - return false; - if (binary_node.right_child_id().value() >= tree.nodes_size()) - return false; - - // Assure that no parent has an child index less than itself in order to - // prevent loops. - if (node_index >= binary_node.left_child_id().value()) - return false; - if (node_index >= binary_node.right_child_id().value()) - return false; - - if (!ValidateTreeNode(tree, tree.nodes(binary_node.left_child_id().value()), - binary_node.left_child_id().value())) { - return false; - } - if (!ValidateTreeNode(tree, tree.nodes(binary_node.right_child_id().value()), - binary_node.right_child_id().value())) { - return false; - } - return true; -} - -OptimizationTargetDecision DecisionTreePredictionModel::Predict( - const base::flat_map<std::string, float>& model_features, - double* prediction_score) { - *prediction_score = 0.0; - // TODO(mcrouse): Add metrics to record if the model evaluation fails. - if (!EvaluateModel(model_, model_features, prediction_score)) - return OptimizationTargetDecision::kUnknown; - if (*prediction_score > model_.threshold().value()) - return OptimizationTargetDecision::kPageLoadMatches; - return OptimizationTargetDecision::kPageLoadDoesNotMatch; -} - -bool DecisionTreePredictionModel::TraverseTree( - const proto::DecisionTree& tree, - const proto::TreeNode& node, - const base::flat_map<std::string, float>& model_features, - double* result) { - if (node.has_leaf()) { - *result = node.leaf().vector().value(0).double_value(); - return true; - } - - proto::BinaryNode binary_node = node.binary_node(); - float threshold = - binary_node.inequality_left_child_test().threshold().float_value(); - std::string feature_name = - binary_node.inequality_left_child_test().feature_id().id().value(); - auto it = model_features.find(feature_name); - if (it == model_features.end()) - return false; - switch (binary_node.inequality_left_child_test().type()) { - case proto::InequalityTest::LESS_OR_EQUAL: - if (it->second <= threshold) - return TraverseTree(tree, - tree.nodes(binary_node.left_child_id().value()), - model_features, result); - return TraverseTree(tree, - tree.nodes(binary_node.right_child_id().value()), - model_features, result); - case proto::InequalityTest::LESS_THAN: - if (it->second < threshold) - return TraverseTree(tree, - tree.nodes(binary_node.left_child_id().value()), - model_features, result); - return TraverseTree(tree, - tree.nodes(binary_node.right_child_id().value()), - model_features, result); - case proto::InequalityTest::GREATER_OR_EQUAL: - if (it->second >= threshold) - return TraverseTree(tree, - tree.nodes(binary_node.left_child_id().value()), - model_features, result); - return TraverseTree(tree, - tree.nodes(binary_node.right_child_id().value()), - model_features, result); - case proto::InequalityTest::GREATER_THAN: - if (it->second > threshold) - return TraverseTree(tree, - tree.nodes(binary_node.left_child_id().value()), - model_features, result); - return TraverseTree(tree, - tree.nodes(binary_node.right_child_id().value()), - model_features, result); - default: - return false; - } -} - -bool DecisionTreePredictionModel::EvaluateDecisionTree( - const proto::DecisionTree& tree, - const base::flat_map<std::string, float>& model_features, - double* result) { - if (TraverseTree(tree, tree.nodes(0), model_features, result)) { - *result *= tree.weight(); - return true; - } - return false; -} - -bool DecisionTreePredictionModel::EvaluateEnsembleModel( - const proto::Ensemble& ensemble, - const base::flat_map<std::string, float>& model_features, - double* result) { - if (ensemble.members_size() == 0) - return false; - - double score = 0.0; - for (const auto& member : ensemble.members()) { - if (!EvaluateModel(member.submodel(), model_features, &score)) { - *result = 0.0; - return false; - } - - *result += score; - } - *result = *result / ensemble.members_size(); - return true; -} - -bool DecisionTreePredictionModel::EvaluateModel( - const proto::Model& model, - const base::flat_map<std::string, float>& model_features, - double* result) { - DCHECK(result); - // Clear the result value. - *result = 0.0; - - if (model.has_ensemble()) { - return EvaluateEnsembleModel(model.ensemble(), model_features, result); - } - if (model.has_decision_tree()) { - return EvaluateDecisionTree(model.decision_tree(), model_features, result); - } - return false; -} - -} // namespace optimization_guide diff --git a/chromium/components/optimization_guide/core/decision_tree_prediction_model.h b/chromium/components/optimization_guide/core/decision_tree_prediction_model.h deleted file mode 100644 index a31ef93d1cd..00000000000 --- a/chromium/components/optimization_guide/core/decision_tree_prediction_model.h +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright 2020 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef COMPONENTS_OPTIMIZATION_GUIDE_CORE_DECISION_TREE_PREDICTION_MODEL_H_ -#define COMPONENTS_OPTIMIZATION_GUIDE_CORE_DECISION_TREE_PREDICTION_MODEL_H_ - -#include <memory> -#include <string> - -#include "base/containers/flat_map.h" -#include "components/optimization_guide/core/prediction_model.h" -#include "components/optimization_guide/proto/models.pb.h" - -namespace optimization_guide { - -// A concrete PredictionModel capable of evaluating the decision tree model type -// supported by the optimization guide. -class DecisionTreePredictionModel : public PredictionModel { - public: - explicit DecisionTreePredictionModel( - const proto::PredictionModel& prediction_model); - - DecisionTreePredictionModel(const DecisionTreePredictionModel&) = delete; - DecisionTreePredictionModel& operator=(const DecisionTreePredictionModel&) = - delete; - - ~DecisionTreePredictionModel() override; - - // PredictionModel implementation: - OptimizationTargetDecision Predict( - const base::flat_map<std::string, float>& model_features, - double* prediction_score) override; - - private: - // Evaluates the provided model, either an ensemble or decision tree model, - // with the |model_features| and stores the output in |result|. Returns false - // if evaluation fails. - bool EvaluateModel(const proto::Model& model, - const base::flat_map<std::string, float>& model_features, - double* result); - - // Evaluates the decision tree model with the |model_features| and - // stores the output in |result|. Returns false if the evaluation fails. - bool EvaluateDecisionTree( - const proto::DecisionTree& tree, - const base::flat_map<std::string, float>& model_features, - double* result); - - // Evaluates an ensemble model with the |model_features| and - // stores the output in |result|. Returns false if the evaluation fails. - bool EvaluateEnsembleModel( - const proto::Ensemble& ensemble, - const base::flat_map<std::string, float>& model_features, - double* result); - - // Performs a depth first traversal the |tree| based on |model_features| - // and stores the value of the leaf in |result|. Returns false if the - // traversal or node evaluation fails. - bool TraverseTree(const proto::DecisionTree& tree, - const proto::TreeNode& node, - const base::flat_map<std::string, float>& model_features, - double* result); - - // PredictionModel implementation: - bool ValidatePredictionModel() const override; - - // Validates a model or submodel of an ensemble. Returns - // false if the model is invalid. - bool ValidateModel(const proto::Model& model) const; - - // Validates an ensemble model. Returns false if the ensemble - // if invalid. - bool ValidateEnsembleModel(const proto::Ensemble& ensemble) const; - - // Validates a decision tree model. Returns false if the - // decision tree model is invalid. - bool ValidateDecisionTree(const proto::DecisionTree& tree) const; - - // Validates a leaf. Returns false if the leaf is invalid. - bool ValidateLeaf(const proto::Leaf& leaf) const; - - // Validates an inequality test. Returns false if the - // inequality test is invalid. - bool ValidateInequalityTest( - const proto::InequalityTest& inequality_test) const; - - // Validates each node of a decision tree by traversing every - // node of the |tree|. Returns false if any part of the tree is invalid. - bool ValidateTreeNode(const proto::DecisionTree& tree, - const proto::TreeNode& node, - int node_index) const; -}; - -} // namespace optimization_guide - -#endif // COMPONENTS_OPTIMIZATION_GUIDE_CORE_DECISION_TREE_PREDICTION_MODEL_H_ diff --git a/chromium/components/optimization_guide/core/decision_tree_prediction_model_unittest.cc b/chromium/components/optimization_guide/core/decision_tree_prediction_model_unittest.cc deleted file mode 100644 index 4a479d11038..00000000000 --- a/chromium/components/optimization_guide/core/decision_tree_prediction_model_unittest.cc +++ /dev/null @@ -1,434 +0,0 @@ -// Copyright 2020 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "components/optimization_guide/core/decision_tree_prediction_model.h" - -#include <memory> -#include <utility> - -#include "base/containers/flat_map.h" -#include "base/containers/flat_set.h" -#include "components/optimization_guide/core/prediction_model.h" -#include "components/optimization_guide/proto/models.pb.h" -#include "testing/gtest/include/gtest/gtest.h" - -namespace optimization_guide { - -proto::PredictionModel GetValidDecisionTreePredictionModel() { - proto::PredictionModel prediction_model; - prediction_model.mutable_model()->mutable_threshold()->set_value(5.0); - - proto::DecisionTree decision_tree_model = proto::DecisionTree(); - decision_tree_model.set_weight(2.0); - - proto::TreeNode* tree_node = decision_tree_model.add_nodes(); - tree_node->mutable_node_id()->set_value(0); - tree_node->mutable_binary_node()->mutable_left_child_id()->set_value(1); - tree_node->mutable_binary_node()->mutable_right_child_id()->set_value(2); - tree_node->mutable_binary_node() - ->mutable_inequality_left_child_test() - ->mutable_feature_id() - ->mutable_id() - ->set_value("agg1"); - tree_node->mutable_binary_node() - ->mutable_inequality_left_child_test() - ->set_type(proto::InequalityTest::LESS_OR_EQUAL); - tree_node->mutable_binary_node() - ->mutable_inequality_left_child_test() - ->mutable_threshold() - ->set_float_value(1.0); - - tree_node = decision_tree_model.add_nodes(); - tree_node->mutable_node_id()->set_value(1); - tree_node->mutable_leaf()->mutable_vector()->add_value()->set_double_value( - 2.); - - tree_node = decision_tree_model.add_nodes(); - tree_node->mutable_node_id()->set_value(2); - tree_node->mutable_leaf()->mutable_vector()->add_value()->set_double_value( - 4.); - - *prediction_model.mutable_model()->mutable_decision_tree() = - decision_tree_model; - return prediction_model; -} - -proto::PredictionModel GetValidEnsemblePredictionModel() { - proto::PredictionModel prediction_model; - prediction_model.mutable_model()->mutable_threshold()->set_value(5.0); - proto::Ensemble ensemble = proto::Ensemble(); - *ensemble.add_members()->mutable_submodel() = - *GetValidDecisionTreePredictionModel().mutable_model(); - - *ensemble.add_members()->mutable_submodel() = - *GetValidDecisionTreePredictionModel().mutable_model(); - - *prediction_model.mutable_model()->mutable_ensemble() = ensemble; - return prediction_model; -} - -TEST(DecisionTreePredictionModel, ValidDecisionTreeModel) { - proto::PredictionModel prediction_model = - GetValidDecisionTreePredictionModel(); - - proto::ModelInfo* model_info = prediction_model.mutable_model_info(); - model_info->set_version(1); - model_info->add_supported_model_types( - proto::ModelType::MODEL_TYPE_DECISION_TREE); - model_info->add_supported_host_model_features("agg1"); - - std::unique_ptr<PredictionModel> model = - PredictionModel::Create(prediction_model); - EXPECT_TRUE(model); - - double prediction_score; - EXPECT_EQ(OptimizationTargetDecision::kPageLoadDoesNotMatch, - model->Predict({{"agg1", 1.0}}, &prediction_score)); - EXPECT_EQ(4., prediction_score); - EXPECT_EQ(OptimizationTargetDecision::kPageLoadMatches, - model->Predict({{"agg1", 2.0}}, &prediction_score)); - EXPECT_EQ(8., prediction_score); -} - -TEST(DecisionTreePredictionModel, InequalityLessThan) { - proto::PredictionModel prediction_model = - GetValidDecisionTreePredictionModel(); - - prediction_model.mutable_model() - ->mutable_decision_tree() - ->mutable_nodes(0) - ->mutable_binary_node() - ->mutable_inequality_left_child_test() - ->set_type(proto::InequalityTest::LESS_THAN); - - proto::ModelInfo* model_info = prediction_model.mutable_model_info(); - model_info->set_version(1); - model_info->add_supported_model_types( - proto::ModelType::MODEL_TYPE_DECISION_TREE); - model_info->add_supported_host_model_features("agg1"); - - std::unique_ptr<PredictionModel> model = - PredictionModel::Create(std::move(prediction_model)); - EXPECT_TRUE(model); - - double prediction_score; - EXPECT_EQ(OptimizationTargetDecision::kPageLoadDoesNotMatch, - model->Predict({{"agg1", 0.5}}, &prediction_score)); - EXPECT_EQ(4., prediction_score); - EXPECT_EQ(OptimizationTargetDecision::kPageLoadMatches, - model->Predict({{"agg1", 2.0}}, &prediction_score)); - EXPECT_EQ(8., prediction_score); -} - -TEST(DecisionTreePredictionModel, InequalityGreaterOrEqual) { - proto::PredictionModel prediction_model = - GetValidDecisionTreePredictionModel(); - - prediction_model.mutable_model() - ->mutable_decision_tree() - ->mutable_nodes(0) - ->mutable_binary_node() - ->mutable_inequality_left_child_test() - ->set_type(proto::InequalityTest::GREATER_OR_EQUAL); - - proto::ModelInfo* model_info = prediction_model.mutable_model_info(); - model_info->set_version(1); - model_info->add_supported_model_types( - proto::ModelType::MODEL_TYPE_DECISION_TREE); - model_info->add_supported_host_model_features("agg1"); - - std::unique_ptr<PredictionModel> model = - PredictionModel::Create(prediction_model); - EXPECT_TRUE(model); - - double prediction_score; - EXPECT_EQ(OptimizationTargetDecision::kPageLoadMatches, - model->Predict({{"agg1", 0.5}}, &prediction_score)); - EXPECT_EQ(8., prediction_score); - EXPECT_EQ(OptimizationTargetDecision::kPageLoadDoesNotMatch, - model->Predict({{"agg1", 1.0}}, &prediction_score)); - EXPECT_EQ(4., prediction_score); -} - -TEST(DecisionTreePredictionModel, InequalityGreaterThan) { - proto::PredictionModel prediction_model = - GetValidDecisionTreePredictionModel(); - - prediction_model.mutable_model() - ->mutable_decision_tree() - ->mutable_nodes(0) - ->mutable_binary_node() - ->mutable_inequality_left_child_test() - ->set_type(proto::InequalityTest::GREATER_THAN); - - proto::ModelInfo* model_info = prediction_model.mutable_model_info(); - model_info->set_version(1); - model_info->add_supported_model_types( - proto::ModelType::MODEL_TYPE_DECISION_TREE); - model_info->add_supported_host_model_features("agg1"); - - std::unique_ptr<PredictionModel> model = - PredictionModel::Create(std::move(prediction_model)); - EXPECT_TRUE(model); - - double prediction_score; - EXPECT_EQ(OptimizationTargetDecision::kPageLoadMatches, - model->Predict({{"agg1", 0.5}}, &prediction_score)); - EXPECT_EQ(8., prediction_score); - EXPECT_EQ(OptimizationTargetDecision::kPageLoadDoesNotMatch, - model->Predict({{"agg1", 2.0}}, &prediction_score)); - EXPECT_EQ(4., prediction_score); -} - -TEST(DecisionTreePredictionModel, MissingInequalityTest) { - proto::PredictionModel prediction_model = - GetValidDecisionTreePredictionModel(); - - prediction_model.mutable_model() - ->mutable_decision_tree() - ->mutable_nodes(0) - ->mutable_binary_node() - ->mutable_inequality_left_child_test() - ->Clear(); - - proto::ModelInfo* model_info = prediction_model.mutable_model_info(); - model_info->set_version(1); - model_info->add_supported_model_types( - proto::ModelType::MODEL_TYPE_DECISION_TREE); - model_info->add_supported_host_model_features("agg1"); - - std::unique_ptr<PredictionModel> model = - PredictionModel::Create(std::move(prediction_model)); - EXPECT_FALSE(model); -} - -TEST(DecisionTreePredictionModel, NoDecisionTreeThreshold) { - proto::PredictionModel prediction_model = - GetValidDecisionTreePredictionModel(); - - prediction_model.mutable_model()->clear_threshold(); - - proto::ModelInfo* model_info = prediction_model.mutable_model_info(); - model_info->set_version(1); - model_info->add_supported_model_types( - proto::ModelType::MODEL_TYPE_DECISION_TREE); - model_info->add_supported_host_model_features("agg1"); - - std::unique_ptr<PredictionModel> model = - PredictionModel::Create(prediction_model); - EXPECT_FALSE(model); -} - -TEST(DecisionTreePredictionModel, EmptyTree) { - proto::PredictionModel prediction_model = - GetValidDecisionTreePredictionModel(); - - prediction_model.mutable_model()->mutable_decision_tree()->clear_nodes(); - - proto::ModelInfo* model_info = prediction_model.mutable_model_info(); - model_info->set_version(1); - model_info->add_supported_model_types( - proto::ModelType::MODEL_TYPE_DECISION_TREE); - model_info->add_supported_host_model_features("agg1"); - - std::unique_ptr<PredictionModel> model = - PredictionModel::Create(std::move(prediction_model)); - EXPECT_FALSE(model); -} - -TEST(DecisionTreePredictionModel, ModelFeatureNotInFeatureMap) { - proto::PredictionModel prediction_model = - GetValidDecisionTreePredictionModel(); - - prediction_model.mutable_model()->mutable_decision_tree()->clear_nodes(); - - proto::ModelInfo* model_info = prediction_model.mutable_model_info(); - model_info->set_version(1); - model_info->add_supported_model_types( - proto::ModelType::MODEL_TYPE_DECISION_TREE); - model_info->add_supported_host_model_features("agg1"); - - std::unique_ptr<PredictionModel> model = - PredictionModel::Create(prediction_model); - EXPECT_FALSE(model); -} - -TEST(DecisionTreePredictionModel, DecisionTreeMissingLeaf) { - proto::PredictionModel prediction_model = - GetValidDecisionTreePredictionModel(); - - prediction_model.mutable_model() - ->mutable_decision_tree() - ->mutable_nodes(1) - ->mutable_leaf() - ->Clear(); - - proto::ModelInfo* model_info = prediction_model.mutable_model_info(); - model_info->set_version(1); - model_info->add_supported_model_types( - proto::ModelType::MODEL_TYPE_DECISION_TREE); - model_info->add_supported_host_model_features("agg1"); - - std::unique_ptr<PredictionModel> model = - PredictionModel::Create(prediction_model); - EXPECT_FALSE(model); -} - -TEST(DecisionTreePredictionModel, DecisionTreeLeftChildIndexInvalid) { - proto::PredictionModel prediction_model = - GetValidDecisionTreePredictionModel(); - - prediction_model.mutable_model() - ->mutable_decision_tree() - ->mutable_nodes(0) - ->mutable_binary_node() - ->mutable_left_child_id() - ->set_value(3); - - proto::ModelInfo* model_info = prediction_model.mutable_model_info(); - model_info->set_version(1); - model_info->add_supported_model_types( - proto::ModelType::MODEL_TYPE_DECISION_TREE); - model_info->add_supported_host_model_features("agg1"); - - std::unique_ptr<PredictionModel> model = - PredictionModel::Create(std::move(prediction_model)); - EXPECT_FALSE(model); -} - -TEST(DecisionTreePredictionModel, DecisionTreeRightChildIndexInvalid) { - proto::PredictionModel prediction_model = - GetValidDecisionTreePredictionModel(); - - prediction_model.mutable_model() - ->mutable_decision_tree() - ->mutable_nodes(0) - ->mutable_binary_node() - ->mutable_right_child_id() - ->set_value(3); - - proto::ModelInfo* model_info = prediction_model.mutable_model_info(); - model_info->set_version(1); - model_info->add_supported_model_types( - proto::ModelType::MODEL_TYPE_DECISION_TREE); - model_info->add_supported_host_model_features("agg1"); - - std::unique_ptr<PredictionModel> model = - PredictionModel::Create(prediction_model); - EXPECT_FALSE(model); -} - -TEST(DecisionTreePredictionModel, DecisionTreeWithLoopOnLeftChild) { - proto::PredictionModel prediction_model = - GetValidDecisionTreePredictionModel(); - - proto::TreeNode* tree_node = - prediction_model.mutable_model()->mutable_decision_tree()->mutable_nodes( - 1); - - tree_node->mutable_node_id()->set_value(0); - tree_node->mutable_binary_node() - ->mutable_inequality_left_child_test() - ->mutable_feature_id() - ->mutable_id() - ->set_value("agg1"); - tree_node->mutable_binary_node() - ->mutable_inequality_left_child_test() - ->set_type(proto::InequalityTest::LESS_OR_EQUAL); - tree_node->mutable_binary_node() - ->mutable_inequality_left_child_test() - ->mutable_threshold() - ->set_float_value(1.0); - - tree_node->mutable_binary_node()->mutable_left_child_id()->set_value(0); - tree_node->mutable_binary_node()->mutable_right_child_id()->set_value(2); - - proto::ModelInfo* model_info = prediction_model.mutable_model_info(); - model_info->set_version(1); - model_info->add_supported_model_types( - proto::ModelType::MODEL_TYPE_DECISION_TREE); - model_info->add_supported_host_model_features("agg1"); - - std::unique_ptr<PredictionModel> model = - PredictionModel::Create(prediction_model); - EXPECT_FALSE(model); -} - -TEST(DecisionTreePredictionModel, DecisionTreeWithLoopOnRightChild) { - proto::PredictionModel prediction_model = - GetValidDecisionTreePredictionModel(); - - proto::TreeNode* tree_node = - prediction_model.mutable_model()->mutable_decision_tree()->mutable_nodes( - 1); - - tree_node->mutable_node_id()->set_value(0); - tree_node->mutable_binary_node() - ->mutable_inequality_left_child_test() - ->mutable_feature_id() - ->mutable_id() - ->set_value("agg1"); - tree_node->mutable_binary_node() - ->mutable_inequality_left_child_test() - ->set_type(proto::InequalityTest::LESS_OR_EQUAL); - tree_node->mutable_binary_node() - ->mutable_inequality_left_child_test() - ->mutable_threshold() - ->set_float_value(1.0); - - tree_node->mutable_binary_node()->mutable_left_child_id()->set_value(2); - tree_node->mutable_binary_node()->mutable_right_child_id()->set_value(0); - - proto::ModelInfo* model_info = prediction_model.mutable_model_info(); - model_info->set_version(1); - model_info->add_supported_model_types( - proto::ModelType::MODEL_TYPE_DECISION_TREE); - model_info->add_supported_host_model_features("agg1"); - - std::unique_ptr<PredictionModel> model = - PredictionModel::Create(prediction_model); - EXPECT_FALSE(model); -} - -TEST(DecisionTreePredictionModel, ValidEnsembleModel) { - proto::PredictionModel prediction_model = GetValidEnsemblePredictionModel(); - - proto::ModelInfo* model_info = prediction_model.mutable_model_info(); - model_info->set_version(1); - model_info->add_supported_model_types( - proto::ModelType::MODEL_TYPE_DECISION_TREE); - model_info->add_supported_host_model_features("agg1"); - - std::unique_ptr<PredictionModel> model = - PredictionModel::Create(prediction_model); - EXPECT_TRUE(model); - - double prediction_score; - EXPECT_EQ(OptimizationTargetDecision::kPageLoadDoesNotMatch, - model->Predict({{"agg1", 1.0}}, &prediction_score)); - EXPECT_EQ(4., prediction_score); - EXPECT_EQ(OptimizationTargetDecision::kPageLoadMatches, - model->Predict({{"agg1", 2.0}}, &prediction_score)); - EXPECT_EQ(8., prediction_score); -} - -TEST(DecisionTreePredictionModel, EnsembleWithNoMembers) { - proto::PredictionModel prediction_model = GetValidEnsemblePredictionModel(); - prediction_model.mutable_model() - ->mutable_ensemble() - ->mutable_members() - ->Clear(); - - proto::ModelInfo* model_info = prediction_model.mutable_model_info(); - model_info->set_version(1); - model_info->add_supported_model_types( - proto::ModelType::MODEL_TYPE_DECISION_TREE); - model_info->add_supported_host_model_features("agg1"); - - std::unique_ptr<PredictionModel> model = - PredictionModel::Create(prediction_model); - EXPECT_FALSE(model); -} - -} // namespace optimization_guide diff --git a/chromium/components/optimization_guide/core/entity_annotator_native_library.cc b/chromium/components/optimization_guide/core/entity_annotator_native_library.cc new file mode 100644 index 00000000000..1382ef228fc --- /dev/null +++ b/chromium/components/optimization_guide/core/entity_annotator_native_library.cc @@ -0,0 +1,445 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "components/optimization_guide/core/entity_annotator_native_library.h" + +#include "base/base_paths.h" +#include "base/compiler_specific.h" +#include "base/logging.h" +#include "base/memory/ptr_util.h" +#include "base/path_service.h" +#include "build/build_config.h" +#include "components/optimization_guide/core/model_util.h" +#include "components/optimization_guide/core/optimization_guide_util.h" +#include "components/optimization_guide/proto/page_entities_model_metadata.pb.h" + +#if BUILDFLAG(IS_MAC) +#include "base/mac/bundle_locations.h" +#include "base/mac/foundation_util.h" +#endif + +// IMPORTANT: All functions in this file that call dlsym()'ed +// functions should be annotated with DISABLE_CFI_ICALL. + +namespace optimization_guide { + +namespace { + +const char kModelMetadataBaseName[] = "model_metadata.pb"; +const char kWordEmbeddingsBaseName[] = "word_embeddings"; +const char kNameTableBaseName[] = "entities_names"; +const char kMetadataTableBaseName[] = "entities_metadata"; +const char kNameFilterBaseName[] = "entities_names_filter"; +const char kPrefixFilterBaseName[] = "entities_prefixes_filter"; + +// Sets |field_to_set| with the full file path of |base_name|'s entry in +// |base_to_full_file_path|. Returns whether |base_name| is in +// |base_to_full_file_path|. +absl::optional<std::string> GetFilePathFromMap( + const std::string& base_name, + const base::flat_map<std::string, base::FilePath>& base_to_full_file_path) { + auto it = base_to_full_file_path.find(base_name); + return it == base_to_full_file_path.end() + ? absl::nullopt + : absl::make_optional(FilePathToString(it->second)); +} + +// Returns the expected base name for |slice|. Will be of the form +// |slice|-|base_name|. +std::string GetSliceBaseName(const std::string& slice, + const std::string& base_name) { + return slice + "-" + base_name; +} + +} // namespace + +EntityAnnotatorNativeLibrary::EntityAnnotatorNativeLibrary( + base::NativeLibrary native_library) + : native_library_(std::move(native_library)) { + LoadFunctions(); +} +EntityAnnotatorNativeLibrary::~EntityAnnotatorNativeLibrary() = default; + +// static +std::unique_ptr<EntityAnnotatorNativeLibrary> +EntityAnnotatorNativeLibrary::Create() { + base::FilePath base_dir; +#if BUILDFLAG(IS_MAC) + if (base::mac::AmIBundled()) { + base_dir = base::mac::FrameworkBundlePath().Append("Libraries"); + } else { +#endif + if (!base::PathService::Get(base::DIR_MODULE, &base_dir)) { + LOG(ERROR) << "Error getting app dir"; + return nullptr; + } +#if BUILDFLAG(IS_MAC) + } +#endif + + base::NativeLibraryLoadError error; + base::NativeLibrary native_library = base::LoadNativeLibrary( + base_dir.AppendASCII( + base::GetNativeLibraryName("optimization_guide_internal")), + &error); + if (!native_library) { + LOG(ERROR) << "Failed to initialize optimization guide internal: " + << error.ToString(); + return nullptr; + } + + std::unique_ptr<EntityAnnotatorNativeLibrary> + entity_annotator_native_library = + base::WrapUnique<EntityAnnotatorNativeLibrary>( + new EntityAnnotatorNativeLibrary(std::move(native_library))); + if (entity_annotator_native_library->IsValid()) { + return entity_annotator_native_library; + } + LOG(ERROR) << "Could not find all required functions for optimization guide " + "internal library"; + return nullptr; +} + +DISABLE_CFI_ICALL +void EntityAnnotatorNativeLibrary::LoadFunctions() { + get_max_supported_feature_flag_func_ = + reinterpret_cast<GetMaxSupportedFeatureFlagFunc>( + base::GetFunctionPointerFromNativeLibrary( + native_library_, + "OptimizationGuideEntityAnnotatorGetMaxSupportedFeatureFlag")); + + create_from_options_func_ = reinterpret_cast<CreateFromOptionsFunc>( + base::GetFunctionPointerFromNativeLibrary( + native_library_, + "OptimizationGuideEntityAnnotatorCreateFromOptions")); + get_creation_error_func_ = reinterpret_cast<GetCreationErrorFunc>( + base::GetFunctionPointerFromNativeLibrary( + native_library_, "OptimizationGuideEntityAnnotatorGetCreationError")); + delete_func_ = + reinterpret_cast<DeleteFunc>(base::GetFunctionPointerFromNativeLibrary( + native_library_, "OptimizationGuideEntityAnnotatorDelete")); + + annotate_job_create_func_ = reinterpret_cast<AnnotateJobCreateFunc>( + base::GetFunctionPointerFromNativeLibrary( + native_library_, + "OptimizationGuideEntityAnnotatorAnnotateJobCreate")); + annotate_job_delete_func_ = reinterpret_cast<AnnotateJobDeleteFunc>( + base::GetFunctionPointerFromNativeLibrary( + native_library_, + "OptimizationGuideEntityAnnotatorAnnotateJobDelete")); + run_annotate_job_func_ = reinterpret_cast<RunAnnotateJobFunc>( + base::GetFunctionPointerFromNativeLibrary( + native_library_, "OptimizationGuideEntityAnnotatorRunAnnotateJob")); + annotate_get_output_metadata_at_index_func_ = reinterpret_cast< + AnnotateGetOutputMetadataAtIndexFunc>( + base::GetFunctionPointerFromNativeLibrary( + native_library_, + "OptimizationGuideEntityAnnotatorAnnotateGetOutputMetadataAtIndex")); + annotate_get_output_metadata_score_at_index_func_ = + reinterpret_cast<AnnotateGetOutputMetadataScoreAtIndexFunc>( + base::GetFunctionPointerFromNativeLibrary( + native_library_, + "OptimizationGuideEntityAnnotatorAnnotateGetOutputMetadataScoreAt" + "Index")); + + entity_metadata_job_create_func_ = + reinterpret_cast<EntityMetadataJobCreateFunc>( + base::GetFunctionPointerFromNativeLibrary( + native_library_, + "OptimizationGuideEntityAnnotatorEntityMetadataJobCreate")); + entity_metadata_job_delete_func_ = + reinterpret_cast<EntityMetadataJobDeleteFunc>( + base::GetFunctionPointerFromNativeLibrary( + native_library_, + "OptimizationGuideEntityAnnotatorEntityMetadataJobDelete")); + run_entity_metadata_job_func_ = reinterpret_cast<RunEntityMetadataJobFunc>( + base::GetFunctionPointerFromNativeLibrary( + native_library_, + "OptimizationGuideEntityAnnotatorRunEntityMetadataJob")); + + options_create_func_ = reinterpret_cast<OptionsCreateFunc>( + base::GetFunctionPointerFromNativeLibrary( + native_library_, "OptimizationGuideEntityAnnotatorOptionsCreate")); + options_set_model_file_path_func_ = + reinterpret_cast<OptionsSetModelFilePathFunc>( + base::GetFunctionPointerFromNativeLibrary( + native_library_, + "OptimizationGuideEntityAnnotatorOptionsSetModelFilePath")); + options_set_model_metadata_file_path_func_ = reinterpret_cast< + OptionsSetModelMetadataFilePathFunc>( + base::GetFunctionPointerFromNativeLibrary( + native_library_, + "OptimizationGuideEntityAnnotatorOptionsSetModelMetadataFilePath")); + options_set_word_embeddings_file_path_func_ = reinterpret_cast< + OptionsSetWordEmbeddingsFilePathFunc>( + base::GetFunctionPointerFromNativeLibrary( + native_library_, + "OptimizationGuideEntityAnnotatorOptionsSetWordEmbeddingsFilePath")); + options_add_model_slice_func_ = reinterpret_cast<OptionsAddModelSliceFunc>( + base::GetFunctionPointerFromNativeLibrary( + native_library_, + "OptimizationGuideEntityAnnotatorOptionsAddModelSlice")); + options_delete_func_ = reinterpret_cast<OptionsDeleteFunc>( + base::GetFunctionPointerFromNativeLibrary( + native_library_, "OptimizationGuideEntityAnnotatorOptionsDelete")); + + entity_metadata_get_entity_id_func_ = + reinterpret_cast<EntityMetadataGetEntityIdFunc>( + base::GetFunctionPointerFromNativeLibrary( + native_library_, "OptimizationGuideEntityMetadataGetEntityID")); + entity_metadata_get_human_readable_name_func_ = + reinterpret_cast<EntityMetadataGetHumanReadableNameFunc>( + base::GetFunctionPointerFromNativeLibrary( + native_library_, + "OptimizationGuideEntityMetadataGetHumanReadableName")); + entity_metadata_get_human_readable_categories_count_func_ = reinterpret_cast< + EntityMetadataGetHumanReadableCategoriesCountFunc>( + base::GetFunctionPointerFromNativeLibrary( + native_library_, + "OptimizationGuideEntityMetadataGetHumanReadableCategoriesCount")); + entity_metadata_get_human_readable_category_name_at_index_func_ = + reinterpret_cast<EntityMetadataGetHumanReadableCategoryNameAtIndexFunc>( + base::GetFunctionPointerFromNativeLibrary( + native_library_, + "OptimizationGuideEntityMetadataGetHumanReadableCategoryNameAtInd" + "ex")); + entity_metadata_get_human_readable_category_score_at_index_func_ = + reinterpret_cast<EntityMetadataGetHumanReadableCategoryScoreAtIndexFunc>( + base::GetFunctionPointerFromNativeLibrary( + native_library_, + "OptimizationGuideEntityMetadataGetHumanReadableCategoryScoreAtIn" + "dex")); +} + +DISABLE_CFI_ICALL +bool EntityAnnotatorNativeLibrary::IsValid() const { + return get_max_supported_feature_flag_func_ && create_from_options_func_ && + get_creation_error_func_ && delete_func_ && + annotate_job_create_func_ && annotate_job_delete_func_ && + run_annotate_job_func_ && + annotate_get_output_metadata_at_index_func_ && + annotate_get_output_metadata_score_at_index_func_ && + entity_metadata_job_create_func_ && entity_metadata_job_delete_func_ && + run_entity_metadata_job_func_ && options_create_func_ && + options_set_model_file_path_func_ && + options_set_model_metadata_file_path_func_ && + options_set_word_embeddings_file_path_func_ && + options_add_model_slice_func_ && options_delete_func_ && + entity_metadata_get_entity_id_func_ && + entity_metadata_get_human_readable_name_func_ && + entity_metadata_get_human_readable_categories_count_func_ && + entity_metadata_get_human_readable_category_name_at_index_func_ && + entity_metadata_get_human_readable_category_score_at_index_func_; +} + +DISABLE_CFI_ICALL +int32_t EntityAnnotatorNativeLibrary::GetMaxSupportedFeatureFlag() { + DCHECK(IsValid()); + if (!IsValid()) { + return -1; + } + + return get_max_supported_feature_flag_func_(); +} + +DISABLE_CFI_ICALL +void* EntityAnnotatorNativeLibrary::CreateEntityAnnotator( + const ModelInfo& model_info) { + DCHECK(IsValid()); + if (!IsValid()) { + return nullptr; + } + + void* options = options_create_func_(); + if (!PopulateEntityAnnotatorOptionsFromModelInfo(options, model_info)) { + options_delete_func_(options); + return nullptr; + } + + void* entity_annotator = create_from_options_func_(options); + const char* creation_error = get_creation_error_func_(entity_annotator); + if (creation_error) { + LOG(ERROR) << "Failed to create entity annotator: " << creation_error; + DeleteEntityAnnotator(entity_annotator); + entity_annotator = nullptr; + } + options_delete_func_(options); + return entity_annotator; +} + +DISABLE_CFI_ICALL +bool EntityAnnotatorNativeLibrary::PopulateEntityAnnotatorOptionsFromModelInfo( + void* options, + const ModelInfo& model_info) { + // We don't know which files are intended for use if we don't have model + // metadata, so return early. + if (!model_info.GetModelMetadata()) { + return false; + } + + // // Validate the model metadata. + absl::optional<proto::PageEntitiesModelMetadata> entities_model_metadata = + ParsedAnyMetadata<proto::PageEntitiesModelMetadata>( + model_info.GetModelMetadata().value()); + if (!entities_model_metadata) { + return false; + } + if (entities_model_metadata->slice_size() == 0) { + return false; + } + + // Build the entity annotator options. + options_set_model_file_path_func_( + options, FilePathToString(model_info.GetModelFilePath()).c_str()); + + // Attach the additional files required by the model. + base::flat_map<std::string, base::FilePath> base_to_full_file_path; + for (const auto& model_file : model_info.GetAdditionalFiles()) { + base_to_full_file_path.insert( + {FilePathToString(model_file.BaseName()), model_file}); + } + absl::optional<std::string> model_metadata_file_path = + GetFilePathFromMap(kModelMetadataBaseName, base_to_full_file_path); + if (!model_metadata_file_path) { + return false; + } + options_set_model_metadata_file_path_func_(options, + model_metadata_file_path->c_str()); + absl::optional<std::string> word_embeddings_file_path = + GetFilePathFromMap(kWordEmbeddingsBaseName, base_to_full_file_path); + if (!word_embeddings_file_path) { + return false; + } + options_set_word_embeddings_file_path_func_( + options, word_embeddings_file_path->c_str()); + + base::flat_set<std::string> slices(entities_model_metadata->slice().begin(), + entities_model_metadata->slice().end()); + for (const auto& slice_id : slices) { + absl::optional<std::string> name_filter_path = + GetFilePathFromMap(GetSliceBaseName(slice_id, kNameFilterBaseName), + base_to_full_file_path); + if (!name_filter_path) { + return false; + } + absl::optional<std::string> name_table_path = GetFilePathFromMap( + GetSliceBaseName(slice_id, kNameTableBaseName), base_to_full_file_path); + if (!name_table_path) { + return false; + } + absl::optional<std::string> prefix_filter_path = + GetFilePathFromMap(GetSliceBaseName(slice_id, kPrefixFilterBaseName), + base_to_full_file_path); + if (!prefix_filter_path) { + return false; + } + absl::optional<std::string> metadata_table_path = + GetFilePathFromMap(GetSliceBaseName(slice_id, kMetadataTableBaseName), + base_to_full_file_path); + if (!metadata_table_path) { + return false; + } + options_add_model_slice_func_( + options, slice_id.c_str(), name_filter_path->c_str(), + name_table_path->c_str(), prefix_filter_path->c_str(), + metadata_table_path->c_str()); + } + + return true; +} + +DISABLE_CFI_ICALL +void EntityAnnotatorNativeLibrary::DeleteEntityAnnotator( + void* entity_annotator) { + DCHECK(IsValid()); + if (!IsValid()) { + return; + } + + delete_func_(reinterpret_cast<void*>(entity_annotator)); +} + +DISABLE_CFI_ICALL +absl::optional<std::vector<ScoredEntityMetadata>> +EntityAnnotatorNativeLibrary::AnnotateText(void* annotator, + const std::string& text) { + DCHECK(IsValid()); + if (!IsValid()) { + return absl::nullopt; + } + + if (!annotator) { + return absl::nullopt; + } + + void* job = annotate_job_create_func_(reinterpret_cast<void*>(annotator)); + int32_t output_metadata_count = run_annotate_job_func_(job, text.c_str()); + if (output_metadata_count <= 0) { + return absl::nullopt; + } + std::vector<ScoredEntityMetadata> scored_md; + scored_md.reserve(output_metadata_count); + for (int32_t i = 0; i < output_metadata_count; i++) { + ScoredEntityMetadata md; + md.score = annotate_get_output_metadata_score_at_index_func_(job, i); + md.metadata = GetEntityMetadataFromOptimizationGuideEntityMetadata( + annotate_get_output_metadata_at_index_func_(job, i)); + scored_md.emplace_back(md); + } + annotate_job_delete_func_(job); + return scored_md; +} + +DISABLE_CFI_ICALL +absl::optional<EntityMetadata> +EntityAnnotatorNativeLibrary::GetEntityMetadataForEntityId( + void* annotator, + const std::string& entity_id) { + DCHECK(IsValid()); + if (!IsValid()) { + return absl::nullopt; + } + if (!annotator) { + return absl::nullopt; + } + + void* job = + entity_metadata_job_create_func_(reinterpret_cast<void*>(annotator)); + const void* entity_metadata = + run_entity_metadata_job_func_(job, entity_id.c_str()); + if (!entity_metadata) { + return absl::nullopt; + } + EntityMetadata md = + GetEntityMetadataFromOptimizationGuideEntityMetadata(entity_metadata); + entity_metadata_job_delete_func_(job); + return md; +} + +DISABLE_CFI_ICALL +EntityMetadata EntityAnnotatorNativeLibrary:: + GetEntityMetadataFromOptimizationGuideEntityMetadata( + const void* og_entity_metadata) { + EntityMetadata entity_metadata; + entity_metadata.entity_id = + entity_metadata_get_entity_id_func_(og_entity_metadata); + entity_metadata.human_readable_name = + entity_metadata_get_human_readable_name_func_(og_entity_metadata); + + int32_t human_readable_categories_count = + entity_metadata_get_human_readable_categories_count_func_( + og_entity_metadata); + for (int32_t i = 0; i < human_readable_categories_count; i++) { + std::string category_name = + entity_metadata_get_human_readable_category_name_at_index_func_( + og_entity_metadata, i); + float category_score = + entity_metadata_get_human_readable_category_score_at_index_func_( + og_entity_metadata, i); + entity_metadata.human_readable_categories[category_name] = category_score; + } + return entity_metadata; +} + +} // namespace optimization_guide
\ No newline at end of file diff --git a/chromium/components/optimization_guide/core/entity_annotator_native_library.h b/chromium/components/optimization_guide/core/entity_annotator_native_library.h new file mode 100644 index 00000000000..a1fc4973269 --- /dev/null +++ b/chromium/components/optimization_guide/core/entity_annotator_native_library.h @@ -0,0 +1,143 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef COMPONENTS_OPTIMIZATION_GUIDE_CORE_ENTITY_ANNOTATOR_NATIVE_LIBRARY_H_ +#define COMPONENTS_OPTIMIZATION_GUIDE_CORE_ENTITY_ANNOTATOR_NATIVE_LIBRARY_H_ + +#include <memory> +#include <vector> + +#include "base/native_library.h" +#include "components/optimization_guide/core/entity_metadata.h" +#include "components/optimization_guide/core/model_info.h" + +namespace optimization_guide { + +// Handles interactions with the native library that contains logic for the +// entity annotator. +class EntityAnnotatorNativeLibrary { + public: + // Creates an EntityAnnotatorNativeLibrary, which loads a native library and + // relevant functions required. Will return nullptr if fails. + static std::unique_ptr<EntityAnnotatorNativeLibrary> Create(); + + EntityAnnotatorNativeLibrary(const EntityAnnotatorNativeLibrary&) = delete; + EntityAnnotatorNativeLibrary& operator=(const EntityAnnotatorNativeLibrary&) = + delete; + ~EntityAnnotatorNativeLibrary(); + + // Returns whether this instance is valid (i.e. all necessary functions have + // been loaded.) + bool IsValid() const; + + // Gets the max supported feature from this native library. + int32_t GetMaxSupportedFeatureFlag(); + + // Creates an entity annotator from |model_info|. + void* CreateEntityAnnotator(const ModelInfo& model_info); + + // Deletes |entity_annotator|. + void DeleteEntityAnnotator(void* entity_annotator); + + // Uses |annotator| to annotate entities present in |text|. + absl::optional<std::vector<ScoredEntityMetadata>> AnnotateText( + void* annotator, + const std::string& text); + + // Returns entity metadata from |annotator| for |entity_id|. + absl::optional<EntityMetadata> GetEntityMetadataForEntityId( + void* annotator, + const std::string& entity_id); + + private: + EntityAnnotatorNativeLibrary(base::NativeLibrary native_library); + + // Loads the functions exposed by the native library. + void LoadFunctions(); + + // Populates |options| based on |model_info|. Returns false if |model_info| + // cannot construct a valid options object. + bool PopulateEntityAnnotatorOptionsFromModelInfo(void* options, + const ModelInfo& model_info); + + // Returns an entity metadata from the C-API representation. + EntityMetadata GetEntityMetadataFromOptimizationGuideEntityMetadata( + const void* og_entity_metadata); + + base::NativeLibrary native_library_; + + // Functions exposed by native library. + using GetMaxSupportedFeatureFlagFunc = int32_t (*)(); + GetMaxSupportedFeatureFlagFunc get_max_supported_feature_flag_func_ = nullptr; + + using CreateFromOptionsFunc = void* (*)(const void*); + CreateFromOptionsFunc create_from_options_func_ = nullptr; + using GetCreationErrorFunc = const char* (*)(const void*); + GetCreationErrorFunc get_creation_error_func_ = nullptr; + using DeleteFunc = void (*)(void*); + DeleteFunc delete_func_ = nullptr; + + using AnnotateJobCreateFunc = void* (*)(void*); + AnnotateJobCreateFunc annotate_job_create_func_ = nullptr; + using AnnotateJobDeleteFunc = void (*)(void*); + AnnotateJobDeleteFunc annotate_job_delete_func_ = nullptr; + using RunAnnotateJobFunc = int32_t (*)(void*, const char*); + RunAnnotateJobFunc run_annotate_job_func_ = nullptr; + using AnnotateGetOutputMetadataAtIndexFunc = const void* (*)(void*, int32_t); + AnnotateGetOutputMetadataAtIndexFunc + annotate_get_output_metadata_at_index_func_ = nullptr; + using AnnotateGetOutputMetadataScoreAtIndexFunc = float (*)(void*, int32_t); + AnnotateGetOutputMetadataScoreAtIndexFunc + annotate_get_output_metadata_score_at_index_func_ = nullptr; + + using EntityMetadataJobCreateFunc = void* (*)(void*); + EntityMetadataJobCreateFunc entity_metadata_job_create_func_ = nullptr; + using EntityMetadataJobDeleteFunc = void (*)(void*); + EntityMetadataJobDeleteFunc entity_metadata_job_delete_func_ = nullptr; + using RunEntityMetadataJobFunc = const void* (*)(void*, const char*); + RunEntityMetadataJobFunc run_entity_metadata_job_func_ = nullptr; + + using OptionsCreateFunc = void* (*)(); + OptionsCreateFunc options_create_func_ = nullptr; + using OptionsSetModelFilePathFunc = void (*)(void*, const char*); + OptionsSetModelFilePathFunc options_set_model_file_path_func_ = nullptr; + using OptionsSetModelMetadataFilePathFunc = void (*)(void*, const char*); + OptionsSetModelMetadataFilePathFunc + options_set_model_metadata_file_path_func_ = nullptr; + using OptionsSetWordEmbeddingsFilePathFunc = void (*)(void*, const char*); + OptionsSetWordEmbeddingsFilePathFunc + options_set_word_embeddings_file_path_func_ = nullptr; + using OptionsAddModelSliceFunc = void (*)(void*, + const char*, + const char*, + const char*, + const char*, + const char*); + OptionsAddModelSliceFunc options_add_model_slice_func_ = nullptr; + using OptionsDeleteFunc = void (*)(void*); + OptionsDeleteFunc options_delete_func_ = nullptr; + + using EntityMetadataGetEntityIdFunc = const char* (*)(const void*); + EntityMetadataGetEntityIdFunc entity_metadata_get_entity_id_func_ = nullptr; + using EntityMetadataGetHumanReadableNameFunc = const char* (*)(const void*); + EntityMetadataGetHumanReadableNameFunc + entity_metadata_get_human_readable_name_func_ = nullptr; + using EntityMetadataGetHumanReadableCategoriesCountFunc = + int32_t (*)(const void*); + EntityMetadataGetHumanReadableCategoriesCountFunc + entity_metadata_get_human_readable_categories_count_func_ = nullptr; + using EntityMetadataGetHumanReadableCategoryNameAtIndexFunc = + const char* (*)(const void*, int32_t); + EntityMetadataGetHumanReadableCategoryNameAtIndexFunc + entity_metadata_get_human_readable_category_name_at_index_func_ = nullptr; + using EntityMetadataGetHumanReadableCategoryScoreAtIndexFunc = + float (*)(const void*, int32_t); + EntityMetadataGetHumanReadableCategoryScoreAtIndexFunc + entity_metadata_get_human_readable_category_score_at_index_func_ = + nullptr; +}; + +} // namespace optimization_guide + +#endif // COMPONENTS_OPTIMIZATION_GUIDE_CORE_ENTITY_ANNOTATOR_NATIVE_LIBRARY_H_
\ No newline at end of file diff --git a/chromium/components/optimization_guide/core/entity_annotator_native_library_unittest.cc b/chromium/components/optimization_guide/core/entity_annotator_native_library_unittest.cc new file mode 100644 index 00000000000..3155e9bf10c --- /dev/null +++ b/chromium/components/optimization_guide/core/entity_annotator_native_library_unittest.cc @@ -0,0 +1,22 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "components/optimization_guide/core/entity_annotator_native_library.h" + +#include "testing/gtest/include/gtest/gtest.h" + +namespace optimization_guide { +namespace { + +using EntityAnnotatorNativeLibraryTest = ::testing::Test; + +TEST_F(EntityAnnotatorNativeLibraryTest, CanCreateValidLibrary) { + std::unique_ptr<EntityAnnotatorNativeLibrary> lib = + EntityAnnotatorNativeLibrary::Create(); + ASSERT_TRUE(lib); + EXPECT_TRUE(lib->IsValid()); +} + +} // namespace +} // namespace optimization_guide diff --git a/chromium/components/optimization_guide/core/entity_metadata.cc b/chromium/components/optimization_guide/core/entity_metadata.cc index 6ca1930d1fd..2f8301579f3 100644 --- a/chromium/components/optimization_guide/core/entity_metadata.cc +++ b/chromium/components/optimization_guide/core/entity_metadata.cc @@ -4,6 +4,7 @@ #include "components/optimization_guide/core/entity_metadata.h" +#include <ostream> #include <string> #include <vector> diff --git a/chromium/components/optimization_guide/core/entity_metadata_provider.h b/chromium/components/optimization_guide/core/entity_metadata_provider.h index 54ec81d7e01..14f7def09a8 100644 --- a/chromium/components/optimization_guide/core/entity_metadata_provider.h +++ b/chromium/components/optimization_guide/core/entity_metadata_provider.h @@ -6,6 +6,7 @@ #define COMPONENTS_OPTIMIZATION_GUIDE_CORE_ENTITY_METADATA_PROVIDER_H_ #include "components/optimization_guide/core/entity_metadata.h" +#include "third_party/abseil-cpp/absl/types/optional.h" namespace optimization_guide { diff --git a/chromium/components/optimization_guide/core/hints_fetcher.cc b/chromium/components/optimization_guide/core/hints_fetcher.cc index 3909ae5c3f8..7d58512dc42 100644 --- a/chromium/components/optimization_guide/core/hints_fetcher.cc +++ b/chromium/components/optimization_guide/core/hints_fetcher.cc @@ -11,9 +11,11 @@ #include "base/feature_list.h" #include "base/metrics/histogram_functions.h" #include "base/metrics/histogram_macros.h" +#include "base/strings/strcat.h" #include "base/time/default_clock.h" #include "components/optimization_guide/core/hints_processing_util.h" #include "components/optimization_guide/core/optimization_guide_features.h" +#include "components/optimization_guide/core/optimization_guide_logger.h" #include "components/optimization_guide/core/optimization_guide_prefs.h" #include "components/optimization_guide/core/optimization_guide_switches.h" #include "components/optimization_guide/core/optimization_guide_util.h" @@ -27,7 +29,6 @@ #include "net/http/http_response_headers.h" #include "net/http/http_status_code.h" #include "net/traffic_annotation/network_traffic_annotation.h" -#include "services/network/public/cpp/network_connection_tracker.h" #include "services/network/public/cpp/resource_request.h" #include "services/network/public/cpp/shared_url_loader_factory.h" #include "services/network/public/cpp/simple_url_loader.h" @@ -62,23 +63,9 @@ std::string GetStringNameForRequestContext( return std::string(); } -// Returns the subset of URLs from |urls| for which the URL is considered -// valid and can be included in a hints fetch. -std::vector<GURL> GetValidURLsForFetching(const std::vector<GURL>& urls) { - std::vector<GURL> valid_urls; - for (const auto& url : urls) { - if (valid_urls.size() >= - features::MaxUrlsForOptimizationGuideServiceHintsFetch()) { - break; - } - if (IsValidURLForURLKeyedHint(url)) - valid_urls.push_back(url); - } - return valid_urls; -} - void RecordRequestStatusHistogram(proto::RequestContext request_context, HintsFetcherRequestStatus status) { + DCHECK_NE(status, HintsFetcherRequestStatus::kDeprecatedNetworkOffline); base::UmaHistogramEnumeration( "OptimizationGuide.HintsFetcher.RequestStatus." + GetStringNameForRequestContext(request_context), @@ -91,14 +78,14 @@ HintsFetcher::HintsFetcher( scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory, const GURL& optimization_guide_service_url, PrefService* pref_service, - network::NetworkConnectionTracker* network_connection_tracker) + OptimizationGuideLogger* optimization_guide_logger) : optimization_guide_service_url_(net::AppendOrReplaceQueryParameter( optimization_guide_service_url, "key", features::GetOptimizationGuideServiceAPIKey())), pref_service_(pref_service), - network_connection_tracker_(network_connection_tracker), - time_clock_(base::DefaultClock::GetInstance()) { + time_clock_(base::DefaultClock::GetInstance()), + optimization_guide_logger_(optimization_guide_logger) { url_loader_factory_ = std::move(url_loader_factory); // Allow non-https scheme only when it is overridden in command line. This is // needed for iOS EG2 tests which don't support HTTPS embedded test servers @@ -190,16 +177,13 @@ bool HintsFetcher::FetchOptimizationGuideServiceHints( HintsFetchedCallback hints_fetched_callback) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK_GT(optimization_types.size(), 0u); - - if (network_connection_tracker_->IsOffline()) { - RecordRequestStatusHistogram(request_context, - HintsFetcherRequestStatus::kNetworkOffline); - std::move(hints_fetched_callback).Run(absl::nullopt); - return false; - } + request_context_ = request_context; if (active_url_loader_) { - RecordRequestStatusHistogram(request_context, + OPTIMIZATION_GUIDE_LOG( + optimization_guide_logger_, + "No hints fetched: HintsFetcher busy in another fetch"); + RecordRequestStatusHistogram(request_context_, HintsFetcherRequestStatus::kFetcherBusy); std::move(hints_fetched_callback).Run(absl::nullopt); return false; @@ -207,10 +191,12 @@ bool HintsFetcher::FetchOptimizationGuideServiceHints( std::vector<std::string> filtered_hosts = GetSizeLimitedHostsDueForHintsRefresh(hosts); - std::vector<GURL> valid_urls = GetValidURLsForFetching(urls); + std::vector<GURL> valid_urls = GetSizeLimitedURLsForFetching(urls); if (filtered_hosts.empty() && valid_urls.empty()) { + OPTIMIZATION_GUIDE_LOG(optimization_guide_logger_, + "No hints fetched: No hosts/URLs"); RecordRequestStatusHistogram( - request_context, HintsFetcherRequestStatus::kNoHostsOrURLsToFetch); + request_context_, HintsFetcherRequestStatus::kNoHostsOrURLsToFetch); std::move(hints_fetched_callback).Run(absl::nullopt); return false; } @@ -221,15 +207,16 @@ bool HintsFetcher::FetchOptimizationGuideServiceHints( valid_urls.size()); if (optimization_types.empty()) { + OPTIMIZATION_GUIDE_LOG(optimization_guide_logger_, + "No hints fetched: No supported optimization types"); RecordRequestStatusHistogram( - request_context, + request_context_, HintsFetcherRequestStatus::kNoSupportedOptimizationTypes); std::move(hints_fetched_callback).Run(absl::nullopt); return false; } hints_fetch_start_time_ = base::TimeTicks::Now(); - request_context_ = request_context; proto::GetHintsRequest get_hints_request; get_hints_request.add_supported_key_representations(proto::HOST); @@ -445,6 +432,38 @@ void HintsFetcher::OnURLLoadComplete( HandleResponse(response_body ? *response_body : "", net_error, response_code); } +// Returns the subset of URLs from |urls| for which the URL is considered +// valid and can be included in a hints fetch. +std::vector<GURL> HintsFetcher::GetSizeLimitedURLsForFetching( + const std::vector<GURL>& urls) const { + std::vector<GURL> valid_urls; + for (size_t i = 0; i < urls.size(); i++) { + if (valid_urls.size() >= + features::MaxUrlsForOptimizationGuideServiceHintsFetch()) { + base::UmaHistogramCounts100( + "OptimizationGuide.HintsFetcher.GetHintsRequest.DroppedUrls." + + GetStringNameForRequestContext(request_context_), + urls.size() - i); + OPTIMIZATION_GUIDE_LOG( + optimization_guide_logger_, + base::StrCat({"Skipped adding URL due to limit, context:", + GetStringNameForRequestContext(request_context_), + " URL:", urls[i].possibly_invalid_spec()})); + break; + } + if (IsValidURLForURLKeyedHint(urls[i])) { + valid_urls.push_back(urls[i]); + } else { + OPTIMIZATION_GUIDE_LOG( + optimization_guide_logger_, + base::StrCat({"Skipped adding invalid URL, context:", + GetStringNameForRequestContext(request_context_), + " URL:", urls[i].possibly_invalid_spec()})); + } + } + return valid_urls; +} + std::vector<std::string> HintsFetcher::GetSizeLimitedHostsDueForHintsRefresh( const std::vector<std::string>& hosts) const { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); @@ -455,7 +474,17 @@ std::vector<std::string> HintsFetcher::GetSizeLimitedHostsDueForHintsRefresh( std::vector<std::string> target_hosts; target_hosts.reserve(hosts.size()); - for (const auto& host : hosts) { + for (size_t i = 0; i < hosts.size(); i++) { + if (target_hosts.size() >= + features::MaxHostsForOptimizationGuideServiceHintsFetch()) { + base::UmaHistogramCounts100( + "OptimizationGuide.HintsFetcher.GetHintsRequest.DroppedHosts." + + GetStringNameForRequestContext(request_context_), + hosts.size() - i); + break; + } + + std::string host = hosts[i]; // Skip over localhosts, IP addresses, and invalid hosts. if (net::HostStringIsLocalhost(host)) continue; @@ -463,6 +492,9 @@ std::vector<std::string> HintsFetcher::GetSizeLimitedHostsDueForHintsRefresh( std::string canonicalized_host(net::CanonicalizeHost(host, &host_info)); if (host_info.IsIPAddress() || !net::IsCanonicalizedHostCompliant(canonicalized_host)) { + OPTIMIZATION_GUIDE_LOG( + optimization_guide_logger_, + base::StrCat({"Skipped adding invalid host:", host})); continue; } @@ -477,12 +509,12 @@ std::vector<std::string> HintsFetcher::GetSizeLimitedHostsDueForHintsRefresh( (host_valid_time - features::GetHostHintsFetchRefreshDuration() <= time_clock_->Now()); } - if (host_hints_due_for_refresh) + if (host_hints_due_for_refresh) { target_hosts.push_back(host); - - if (target_hosts.size() >= - features::MaxHostsForOptimizationGuideServiceHintsFetch()) { - break; + } else { + OPTIMIZATION_GUIDE_LOG( + optimization_guide_logger_, + base::StrCat({"Skipped refreshing hints for host:", host})); } } DCHECK_GE(features::MaxHostsForOptimizationGuideServiceHintsFetch(), diff --git a/chromium/components/optimization_guide/core/hints_fetcher.h b/chromium/components/optimization_guide/core/hints_fetcher.h index 95c2000513c..0ef0a13dea7 100644 --- a/chromium/components/optimization_guide/core/hints_fetcher.h +++ b/chromium/components/optimization_guide/core/hints_fetcher.h @@ -20,10 +20,10 @@ #include "third_party/abseil-cpp/absl/types/optional.h" #include "url/gurl.h" +class OptimizationGuideLogger; class PrefService; namespace network { -class NetworkConnectionTracker; class SharedURLLoaderFactory; class SimpleURLLoader; } // namespace network @@ -41,8 +41,8 @@ enum class HintsFetcherRequestStatus { kSuccess, // Fetch request was sent but no response received. kResponseError, - // Fetch request not sent because of offline network status. - kNetworkOffline, + // DEPRECATED: Fetch request not sent because of offline network status. + kDeprecatedNetworkOffline, // Fetch request not sent because fetcher was busy with another request. kFetcherBusy, // Fetch request not sent because the host and URL lists were empty. @@ -71,7 +71,7 @@ class HintsFetcher { scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory, const GURL& optimization_guide_service_url, PrefService* pref_service, - network::NetworkConnectionTracker* network_connection_tracker); + OptimizationGuideLogger* optimization_guide_logger); HintsFetcher(const HintsFetcher&) = delete; HintsFetcher& operator=(const HintsFetcher&) = delete; @@ -142,6 +142,11 @@ class HintsFetcher { // in the pref. void UpdateHostsSuccessfullyFetched(base::TimeDelta valid_duration); + // Returns the subset of URLs from |urls| for which the URL is considered + // valid and can be included in a hints fetch. + std::vector<GURL> GetSizeLimitedURLsForFetching( + const std::vector<GURL>& urls) const; + // Returns the subset of hosts from |hosts| for which the hints should be // refreshed. The count of returned hosts is limited to // features::MaxHostsForOptimizationGuideServiceHintsFetch(). @@ -165,10 +170,6 @@ class HintsFetcher { // A reference to the PrefService for this profile. Not owned. raw_ptr<PrefService> pref_service_ = nullptr; - // Listens to changes around the network connection. Not owned. Guaranteed to - // outlive |this|. - raw_ptr<network::NetworkConnectionTracker> network_connection_tracker_; - // Holds the hosts being requested by the hints fetcher. std::vector<std::string> hosts_fetched_; @@ -182,6 +183,9 @@ class HintsFetcher { // retrieving hints from the remote Optimization Guide Service. base::TimeTicks hints_fetch_start_time_; + // Owned by OptimizationGuideKeyedService and outlives |this|. + raw_ptr<OptimizationGuideLogger> optimization_guide_logger_; + SEQUENCE_CHECKER(sequence_checker_); }; diff --git a/chromium/components/optimization_guide/core/hints_fetcher_factory.cc b/chromium/components/optimization_guide/core/hints_fetcher_factory.cc index cd63684f63a..b6e454fa336 100644 --- a/chromium/components/optimization_guide/core/hints_fetcher_factory.cc +++ b/chromium/components/optimization_guide/core/hints_fetcher_factory.cc @@ -5,7 +5,6 @@ #include "components/optimization_guide/core/hints_fetcher.h" #include "components/prefs/pref_service.h" -#include "services/network/public/cpp/network_connection_tracker.h" #include "services/network/public/cpp/shared_url_loader_factory.h" namespace optimization_guide { @@ -13,19 +12,18 @@ namespace optimization_guide { HintsFetcherFactory::HintsFetcherFactory( scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory, const GURL& optimization_guide_service_url, - PrefService* pref_service, - network::NetworkConnectionTracker* network_connection_tracker) + PrefService* pref_service) : url_loader_factory_(url_loader_factory), optimization_guide_service_url_(optimization_guide_service_url), - pref_service_(pref_service), - network_connection_tracker_(network_connection_tracker) {} + pref_service_(pref_service) {} HintsFetcherFactory::~HintsFetcherFactory() = default; -std::unique_ptr<HintsFetcher> HintsFetcherFactory::BuildInstance() { +std::unique_ptr<HintsFetcher> HintsFetcherFactory::BuildInstance( + OptimizationGuideLogger* optimization_guide_logger) { return std::make_unique<HintsFetcher>( url_loader_factory_, optimization_guide_service_url_, pref_service_, - network_connection_tracker_); + optimization_guide_logger); } void HintsFetcherFactory::OverrideOptimizationGuideServiceUrlForTesting( diff --git a/chromium/components/optimization_guide/core/hints_fetcher_factory.h b/chromium/components/optimization_guide/core/hints_fetcher_factory.h index 67eb35c4c97..f8c8a7ffd3c 100644 --- a/chromium/components/optimization_guide/core/hints_fetcher_factory.h +++ b/chromium/components/optimization_guide/core/hints_fetcher_factory.h @@ -11,10 +11,10 @@ #include "base/memory/scoped_refptr.h" #include "url/gurl.h" +class OptimizationGuideLogger; class PrefService; namespace network { -class NetworkConnectionTracker; class SharedURLLoaderFactory; } // namespace network @@ -29,15 +29,15 @@ class HintsFetcherFactory { HintsFetcherFactory( scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory, const GURL& optimization_guide_service_url, - PrefService* pref_service, - network::NetworkConnectionTracker* network_connection_tracker); + PrefService* pref_service); HintsFetcherFactory(const HintsFetcherFactory&) = delete; HintsFetcherFactory& operator=(const HintsFetcherFactory&) = delete; virtual ~HintsFetcherFactory(); // Creates a new instance of HintsFetcher. Virtualized for testing so that the // testing code can override this to provide a mocked instance. - virtual std::unique_ptr<HintsFetcher> BuildInstance(); + virtual std::unique_ptr<HintsFetcher> BuildInstance( + OptimizationGuideLogger* optimization_guide_logger); // Override the optimization guide hints server URL. Used for testing. void OverrideOptimizationGuideServiceUrlForTesting( @@ -53,10 +53,6 @@ class HintsFetcherFactory { // A reference to the PrefService for this profile. Not owned. raw_ptr<PrefService> pref_service_ = nullptr; - - // A reference to the object that listens for changes in network connection. - // Not owned. Guaranteed to outlive |this|. - raw_ptr<network::NetworkConnectionTracker> network_connection_tracker_; }; } // namespace optimization_guide diff --git a/chromium/components/optimization_guide/core/hints_fetcher_unittest.cc b/chromium/components/optimization_guide/core/hints_fetcher_unittest.cc index 58e9c1470e2..b7c2c5e647b 100644 --- a/chromium/components/optimization_guide/core/hints_fetcher_unittest.cc +++ b/chromium/components/optimization_guide/core/hints_fetcher_unittest.cc @@ -25,7 +25,6 @@ #include "net/base/url_util.h" #include "services/network/public/cpp/shared_url_loader_factory.h" #include "services/network/public/cpp/weak_wrapper_shared_url_loader_factory.h" -#include "services/network/test/test_network_connection_tracker.h" #include "services/network/test/test_url_loader_factory.h" #include "testing/gtest/include/gtest/gtest.h" #include "third_party/abseil-cpp/absl/types/optional.h" @@ -56,8 +55,7 @@ class HintsFetcherTest : public testing::Test, hints_fetcher_ = std::make_unique<HintsFetcher>( shared_url_loader_factory_, GURL(optimization_guide_service_url), - pref_service_.get(), - network::TestNetworkConnectionTracker::GetInstance()); + pref_service_.get(), /*optimization_guide_logger=*/nullptr); hints_fetcher_->SetTimeClockForTesting(task_environment_.GetMockClock()); } @@ -72,19 +70,7 @@ class HintsFetcherTest : public testing::Test, hints_fetched_ = true; } - bool hints_fetched() { return hints_fetched_; } - - void SetConnectionOffline() { - network_tracker_ = network::TestNetworkConnectionTracker::GetInstance(); - network_tracker_->SetConnectionType( - network::mojom::ConnectionType::CONNECTION_NONE); - } - - void SetConnectionOnline() { - network_tracker_ = network::TestNetworkConnectionTracker::GetInstance(); - network_tracker_->SetConnectionType( - network::mojom::ConnectionType::CONNECTION_4G); - } + bool hints_fetched() const { return hints_fetched_; } // Updates the pref so that hints for each of the host in |hosts| are set to // expire at |host_invalid_time|. @@ -177,7 +163,6 @@ class HintsFetcherTest : public testing::Test, std::unique_ptr<TestingPrefServiceSimple> pref_service_; scoped_refptr<network::SharedURLLoaderFactory> shared_url_loader_factory_; network::TestURLLoaderFactory test_url_loader_factory_; - raw_ptr<network::TestNetworkConnectionTracker> network_tracker_; std::string last_request_body_; }; @@ -355,35 +340,6 @@ TEST_P(HintsFetcherTest, FetchReturnBadResponse) { HintsFetcherRequestStatus::kResponseError, 1); } -TEST_P(HintsFetcherTest, FetchAttemptWhenNetworkOffline) { - base::HistogramTester histogram_tester; - - SetConnectionOffline(); - std::string response_content; - EXPECT_FALSE(FetchHints({"foo.com"}, {} /* urls */)); - EXPECT_FALSE(hints_fetched()); - - // Make sure histograms are recorded correctly on bad response. - histogram_tester.ExpectTotalCount( - "OptimizationGuide.HintsFetcher.GetHintsRequest.FetchLatency", 0); - histogram_tester.ExpectUniqueSample( - "OptimizationGuide.HintsFetcher.RequestStatus.BatchUpdateActiveTabs", - HintsFetcherRequestStatus::kNetworkOffline, 1); - - SetConnectionOnline(); - EXPECT_TRUE(FetchHints({"foo.com"}, {} /* urls */)); - VerifyHasPendingFetchRequests(); - EXPECT_TRUE(SimulateResponse(response_content, net::HTTP_OK)); - EXPECT_TRUE(hints_fetched()); - - histogram_tester.ExpectTotalCount( - "OptimizationGuide.HintsFetcher.GetHintsRequest.FetchLatency", 1); - histogram_tester.ExpectTotalCount( - "OptimizationGuide.HintsFetcher.GetHintsRequest.FetchLatency." - "BatchUpdateActiveTabs", - 1); -} - TEST_P(HintsFetcherTest, HintsFetchSuccessfulHostsRecorded) { std::vector<std::string> hosts{"host1.com", "host2.com"}; std::string response_content; @@ -606,6 +562,8 @@ TEST_P(HintsFetcherTest, HintsFetcherSuccessfullyFetchedHostsFull) { } TEST_P(HintsFetcherTest, MaxHostsForOptimizationGuideServiceHintsFetch) { + base::HistogramTester histogram_tester; + std::string response_content; std::vector<std::string> all_hosts; @@ -640,6 +598,12 @@ TEST_P(HintsFetcherTest, MaxHostsForOptimizationGuideServiceHintsFetch) { EXPECT_TRUE( WasHostCoveredByFetch("host" + base::NumberToString(i) + ".com")); } + + // extra1.com and extra2.com should have been considered "dropped". + histogram_tester.ExpectUniqueSample( + "OptimizationGuide.HintsFetcher.GetHintsRequest.DroppedHosts." + "BatchUpdateActiveTabs", + 2, 1); } TEST_P(HintsFetcherTest, MaxUrlsForOptimizationGuideServiceHintsFetch) { @@ -677,6 +641,12 @@ TEST_P(HintsFetcherTest, MaxUrlsForOptimizationGuideServiceHintsFetch) { EXPECT_EQ(last_request.urls(i).url(), "https://url" + base::NumberToString(i) + ".com/"); } + + // notfetched.com and notfetched-2.com should have been considered "dropped". + histogram_tester.ExpectUniqueSample( + "OptimizationGuide.HintsFetcher.GetHintsRequest.DroppedUrls." + "BatchUpdateActiveTabs", + 2, 1); } TEST_P(HintsFetcherTest, OnlyURLsToFetch) { @@ -697,6 +667,11 @@ TEST_P(HintsFetcherTest, OnlyURLsToFetch) { histogram_tester.ExpectUniqueSample( "OptimizationGuide.HintsFetcher.RequestStatus.BatchUpdateActiveTabs", static_cast<int>(HintsFetcherRequestStatus::kSuccess), 1); + // Nothing was dropped so this shouldn't be recorded. + histogram_tester.ExpectTotalCount( + "OptimizationGuide.HintsFetcher.GetHintsRequest.DroppedHosts", 0); + histogram_tester.ExpectTotalCount( + "OptimizationGuide.HintsFetcher.GetHintsRequest.DroppedUrls", 0); } TEST_P(HintsFetcherTest, NoHostsOrURLsToFetch) { diff --git a/chromium/components/optimization_guide/core/hints_manager.cc b/chromium/components/optimization_guide/core/hints_manager.cc index 7ae4f3a4fe0..3270962e174 100644 --- a/chromium/components/optimization_guide/core/hints_manager.cc +++ b/chromium/components/optimization_guide/core/hints_manager.cc @@ -17,6 +17,8 @@ #include "base/metrics/histogram_macros_local.h" #include "base/notreached.h" #include "base/rand_util.h" +#include "base/strings/strcat.h" +#include "base/strings/string_number_conversions.h" #include "base/task/post_task.h" #include "base/task/sequenced_task_runner.h" #include "base/task/task_runner_util.h" @@ -33,6 +35,7 @@ #include "components/optimization_guide/core/optimization_guide_constants.h" #include "components/optimization_guide/core/optimization_guide_enums.h" #include "components/optimization_guide/core/optimization_guide_features.h" +#include "components/optimization_guide/core/optimization_guide_logger.h" #include "components/optimization_guide/core/optimization_guide_navigation_data.h" #include "components/optimization_guide/core/optimization_guide_permissions_util.h" #include "components/optimization_guide/core/optimization_guide_prefs.h" @@ -50,7 +53,6 @@ #include "services/metrics/public/cpp/ukm_recorder.h" #include "services/metrics/public/cpp/ukm_source.h" #include "services/metrics/public/cpp/ukm_source_id.h" -#include "services/network/public/cpp/network_connection_tracker.h" #include "services/network/public/cpp/shared_url_loader_factory.h" namespace optimization_guide { @@ -204,23 +206,31 @@ bool ShouldIgnoreNewlyRegisteredOptimizationType( class ScopedCanApplyOptimizationLogger { public: - ScopedCanApplyOptimizationLogger(proto::OptimizationType opt_type, GURL url) + ScopedCanApplyOptimizationLogger( + proto::OptimizationType opt_type, + GURL url, + OptimizationGuideLogger* optimization_guide_logger) : decision_(OptimizationGuideDecision::kUnknown), type_decision_(OptimizationTypeDecision::kUnknown), opt_type_(opt_type), has_metadata_(false), - url_(url) {} + url_(url), + optimization_guide_logger_(optimization_guide_logger) {} ~ScopedCanApplyOptimizationLogger() { if (!switches::IsDebugLogsEnabled()) return; DCHECK_NE(type_decision_, OptimizationTypeDecision::kUnknown); - DVLOG(0) << "OptimizationGuide: CanApplyOptimization: " - << GetStringNameForOptimizationType(opt_type_) - << "\nqueried on: " << url_ << "\nDecision: " - << GetStringForOptimizationGuideDecision(decision_) - << "\nTypeDecision: " << static_cast<int>(type_decision_) - << "\nHas Metadata: " << has_metadata_; + OPTIMIZATION_GUIDE_LOG( + optimization_guide_logger_, + base::StrCat( + {"OptimizationGuide: CanApplyOptimization: ", + GetStringNameForOptimizationType(opt_type_), + "\nqueried on: ", url_.possibly_invalid_spec(), + "\nDecision: ", GetStringForOptimizationGuideDecision(decision_), + "\nTypeDecision: ", + base::NumberToString(static_cast<int>(type_decision_)), + "\nHas Metadata: ", (has_metadata_ ? "True" : "False")})); } void set_has_metadata() { has_metadata_ = true; } @@ -238,6 +248,9 @@ class ScopedCanApplyOptimizationLogger { proto::OptimizationType opt_type_; bool has_metadata_; GURL url_; + + // Not owned. Guaranteed to outlive |this| scoped object. + raw_ptr<OptimizationGuideLogger> optimization_guide_logger_; }; // Reads component file and parses it into a Configuration proto. Should not be @@ -264,23 +277,33 @@ void MaybeLogGetHintRequestInfo( const base::flat_set<proto::OptimizationType>& registered_optimization_types, const std::vector<GURL>& urls_to_fetch, - const std::vector<std::string>& hosts_to_fetch) { + const std::vector<std::string>& hosts_to_fetch, + OptimizationGuideLogger* optimization_guide_logger) { if (!switches::IsDebugLogsEnabled()) return; - DVLOG(0) << "OptimizationGuide: Starting fetch for request context " - << proto::RequestContext_Name(request_context); - DVLOG(0) << "OptimizationGuide: Registered Optimization Types: "; + OPTIMIZATION_GUIDE_LOG( + optimization_guide_logger, + base::StrCat({"OptimizationGuide: Starting fetch for request context ", + proto::RequestContext_Name(request_context)})); + OPTIMIZATION_GUIDE_LOG(optimization_guide_logger, + "OptimizationGuide: Registered Optimization Types: "); for (const auto& optimization_type : registered_optimization_types) { - DVLOG(0) << "OptimizationGuide: Optimization Type: " - << proto::OptimizationType_Name(optimization_type); + OPTIMIZATION_GUIDE_LOG( + optimization_guide_logger, + base::StrCat({"OptimizationGuide: Optimization Type: ", + proto::OptimizationType_Name(optimization_type)})); } - DVLOG(0) << "OptimizationGuide: URLs and Hosts: "; + OPTIMIZATION_GUIDE_LOG(optimization_guide_logger, + "OptimizationGuide: URLs and Hosts: "); for (const auto& url : urls_to_fetch) { - DVLOG(0) << "OptimizationGuide: URL: " << url; + OPTIMIZATION_GUIDE_LOG(optimization_guide_logger, + base::StrCat({"OptimizationGuide: URL: ", + url.possibly_invalid_spec()})); } for (const auto& host : hosts_to_fetch) { - DVLOG(0) << "OptimizationGuide: Host: " << host; + OPTIMIZATION_GUIDE_LOG(optimization_guide_logger, + base::StrCat({"OptimizationGuide: Host: ", host})); } } @@ -294,8 +317,8 @@ HintsManager::HintsManager( TopHostProvider* top_host_provider, TabUrlProvider* tab_url_provider, scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory, - network::NetworkConnectionTracker* network_connection_tracker, - std::unique_ptr<PushNotificationManager> push_notification_manager) + std::unique_ptr<PushNotificationManager> push_notification_manager, + OptimizationGuideLogger* optimization_guide_logger) : is_off_the_record_(is_off_the_record), application_locale_(application_locale), pref_service_(pref_service), @@ -308,11 +331,11 @@ HintsManager::HintsManager( hints_fetcher_factory_(std::make_unique<HintsFetcherFactory>( url_loader_factory, features::GetOptimizationGuideServiceGetHintsURL(), - pref_service, - network_connection_tracker)), + pref_service)), top_host_provider_(top_host_provider), tab_url_provider_(tab_url_provider), push_notification_manager_(std::move(push_notification_manager)), + optimization_guide_logger_(optimization_guide_logger), clock_(base::DefaultClock::GetInstance()), background_task_runner_(base::ThreadPool::CreateSequencedTaskRunner( {base::MayBlock(), base::TaskPriority::BEST_EFFORT})) { @@ -687,14 +710,14 @@ void HintsManager::FetchHintsForActiveTabs() { top_hosts.insert(top_hosts.begin(), url.host()); } } - MaybeLogGetHintRequestInfo(proto::CONTEXT_BATCH_UPDATE_ACTIVE_TABS, - registered_optimization_types_, - active_tab_urls_to_refresh, top_hosts); + MaybeLogGetHintRequestInfo( + proto::CONTEXT_BATCH_UPDATE_ACTIVE_TABS, registered_optimization_types_, + active_tab_urls_to_refresh, top_hosts, optimization_guide_logger_); if (!active_tabs_batch_update_hints_fetcher_) { DCHECK(hints_fetcher_factory_); active_tabs_batch_update_hints_fetcher_ = - hints_fetcher_factory_->BuildInstance(); + hints_fetcher_factory_->BuildInstance(optimization_guide_logger_); } active_tabs_batch_update_hints_fetcher_->FetchOptimizationGuideServiceHints( top_hosts, active_tab_urls_to_refresh, registered_optimization_types_, @@ -710,8 +733,12 @@ void HintsManager::OnHintsForActiveTabsFetched( const base::flat_set<GURL>& urls_fetched, absl::optional<std::unique_ptr<proto::GetHintsResponse>> get_hints_response) { - if (!get_hints_response) + if (!get_hints_response) { + if (switches::IsDebugLogsEnabled()) { + DVLOG(0) << "OptimizationGuide: OnHintsForActiveTabsFetched failed"; + } return; + } hint_cache_->UpdateFetchedHints( std::move(*get_hints_response), @@ -719,8 +746,11 @@ void HintsManager::OnHintsForActiveTabsFetched( hosts_fetched, urls_fetched, base::BindOnce(&HintsManager::OnFetchedActiveTabsHintsStored, weak_ptr_factory_.GetWeakPtr())); - if (switches::IsDebugLogsEnabled()) - DVLOG(0) << "OptimizationGuide: OnHintsForActiveTabsFetched complete"; + if (switches::IsDebugLogsEnabled()) { + OPTIMIZATION_GUIDE_LOG( + optimization_guide_logger_, + "OptimizationGuide: OnHintsForActiveTabsFetched complete"); + } } void HintsManager::OnPageNavigationHintsFetched( @@ -735,6 +765,10 @@ void HintsManager::OnPageNavigationHintsFetched( } if (!get_hints_response.has_value() || !get_hints_response.value()) { + if (switches::IsDebugLogsEnabled()) { + DVLOG(0) << "OptimizationGuide: OnPageNavigationHintsFetched failed"; + } + if (navigation_url) { PrepareToInvokeRegisteredCallbacks(*navigation_url); } @@ -748,6 +782,10 @@ void HintsManager::OnPageNavigationHintsFetched( base::BindOnce(&HintsManager::OnFetchedPageNavigationHintsStored, weak_ptr_factory_.GetWeakPtr(), navigation_data_weak_ptr, navigation_url, page_navigation_hosts_requested)); + + if (switches::IsDebugLogsEnabled()) { + DVLOG(0) << "OptimizationGuide: OnPageNavigationHintsFetched complete"; + } } void HintsManager::OnFetchedActiveTabsHintsStored() { @@ -844,7 +882,8 @@ void HintsManager::FetchHintsForURLs(const std::vector<GURL>& urls, return; MaybeLogGetHintRequestInfo(request_context, registered_optimization_types_, - target_urls.vector(), target_hosts.vector()); + target_urls.vector(), target_hosts.vector(), + optimization_guide_logger_); std::pair<int32_t, HintsFetcher*> request_id_and_fetcher = CreateAndTrackBatchUpdateHintsFetcher(); @@ -898,8 +937,10 @@ void HintsManager::RegisterOptimizationTypes( registered_optimization_types_.insert(optimization_type); if (switches::IsDebugLogsEnabled()) { - DVLOG(0) << "OptimizationGuide: Registered new OptimizationType: " - << proto::OptimizationType_Name(optimization_type); + OPTIMIZATION_GUIDE_LOG( + optimization_guide_logger_, + base::StrCat({"OptimizationGuide: Registered new OptimizationType: ", + proto::OptimizationType_Name(optimization_type)})); } absl::optional<double> value = previously_registered_opt_types->FindBoolKey( @@ -1030,7 +1071,8 @@ void HintsManager::CanApplyOptimizationOnDemand( } MaybeLogGetHintRequestInfo(request_context, registered_optimization_types_, - urls_to_fetch.vector(), hosts_to_fetch.vector()); + urls_to_fetch.vector(), hosts_to_fetch.vector(), + optimization_guide_logger_); // Fetch the data for the entries we don't have all information for. std::pair<int32_t, HintsFetcher*> request_id_and_fetcher = @@ -1058,6 +1100,10 @@ void HintsManager::OnBatchUpdateHintsFetched( CleanUpBatchUpdateHintsFetcher(request_id); if (!get_hints_response.has_value() || !get_hints_response.value()) { + if (switches::IsDebugLogsEnabled()) { + DVLOG(0) << "OptimizationGuide: OnBatchUpdateHintsFetched for " + << proto::RequestContext_Name(request_context) << " failed"; + } OnBatchUpdateHintsStored(urls_with_pending_callback, optimization_types, callback); return; @@ -1074,8 +1120,11 @@ void HintsManager::OnBatchUpdateHintsFetched( optimization_types, callback)); if (switches::IsDebugLogsEnabled()) { - DVLOG(0) << "OptimizationGuide: OnBatchUpdateHintsFetched for " - << proto::RequestContext_Name(request_context) << " complete"; + OPTIMIZATION_GUIDE_LOG( + optimization_guide_logger_, + base::StrCat({"OptimizationGuide: OnBatchUpdateHintsFetched for ", + proto::RequestContext_Name(request_context), + " complete"})); } } @@ -1100,7 +1149,7 @@ std::pair<int32_t, HintsFetcher*> HintsManager::CreateAndTrackBatchUpdateHintsFetcher() { DCHECK(hints_fetcher_factory_); std::unique_ptr<HintsFetcher> hints_fetcher = - hints_fetcher_factory_->BuildInstance(); + hints_fetcher_factory_->BuildInstance(optimization_guide_logger_); HintsFetcher* hints_fetcher_ptr = hints_fetcher.get(); batch_update_hints_fetchers_.Put(batch_update_hints_fetcher_request_id_++, std::move(hints_fetcher)); @@ -1161,8 +1210,8 @@ OptimizationTypeDecision HintsManager::CanApplyOptimization( OptimizationMetadata* optimization_metadata) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); - ScopedCanApplyOptimizationLogger scoped_logger(optimization_type, - navigation_url); + ScopedCanApplyOptimizationLogger scoped_logger( + optimization_type, navigation_url, optimization_guide_logger_); // Clear out optimization metadata if provided. if (optimization_metadata) *optimization_metadata = {}; @@ -1432,14 +1481,15 @@ void HintsManager::MaybeFetchHintsForNavigation( DCHECK(hints_fetcher_factory_); auto it = page_navigation_hints_fetchers_.Put( - url, hints_fetcher_factory_->BuildInstance()); + url, hints_fetcher_factory_->BuildInstance(optimization_guide_logger_)); UMA_HISTOGRAM_COUNTS_100( "OptimizationGuide.HintsManager.ConcurrentPageNavigationFetches", page_navigation_hints_fetchers_.size()); MaybeLogGetHintRequestInfo(proto::CONTEXT_PAGE_NAVIGATION, - registered_optimization_types_, urls, hosts); + registered_optimization_types_, urls, hosts, + optimization_guide_logger_); bool fetch_attempted = it->second->FetchOptimizationGuideServiceHints( hosts, urls, registered_optimization_types_, proto::CONTEXT_PAGE_NAVIGATION, application_locale_, diff --git a/chromium/components/optimization_guide/core/hints_manager.h b/chromium/components/optimization_guide/core/hints_manager.h index 59b0782d191..f6e39c1cbb1 100644 --- a/chromium/components/optimization_guide/core/hints_manager.h +++ b/chromium/components/optimization_guide/core/hints_manager.h @@ -27,21 +27,21 @@ #include "components/optimization_guide/proto/hints.pb.h" #include "third_party/abseil-cpp/absl/types/optional.h" +class OptimizationGuideLogger; class OptimizationGuideNavigationData; class OptimizationGuideTestAppInterfaceWrapper; class PrefService; namespace network { class SharedURLLoaderFactory; -class NetworkConnectionTracker; } // namespace network namespace optimization_guide { class HintCache; class HintsFetcherFactory; class OptimizationFilter; -class OptimizationMetadata; class OptimizationGuideStore; +class OptimizationMetadata; enum class OptimizationTypeDecision; class StoreUpdateData; class TabUrlProvider; @@ -58,8 +58,8 @@ class HintsManager : public OptimizationHintsComponentObserver, TopHostProvider* top_host_provider, TabUrlProvider* tab_url_provider, scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory, - network::NetworkConnectionTracker* network_connection_tracker, - std::unique_ptr<PushNotificationManager> push_notification_manager); + std::unique_ptr<PushNotificationManager> push_notification_manager, + OptimizationGuideLogger* optimization_guide_logger); ~HintsManager() override; @@ -473,6 +473,11 @@ class HintsManager : public OptimizationHintsComponentObserver, // what to do through the implemented Delegate above. std::unique_ptr<PushNotificationManager> push_notification_manager_; + // The logger that plumbs the debug logs to the optimization guide + // internals page. Not owned. Guaranteed to outlive |this|, since the logger + // and |this| are owned by the optimization guide keyed service. + raw_ptr<OptimizationGuideLogger> optimization_guide_logger_; + // The clock used to schedule fetching from the remote Optimization Guide // Service. raw_ptr<const base::Clock> clock_; diff --git a/chromium/components/optimization_guide/core/hints_manager_unittest.cc b/chromium/components/optimization_guide/core/hints_manager_unittest.cc index 5f4fd3b8383..2b2ddb9d398 100644 --- a/chromium/components/optimization_guide/core/hints_manager_unittest.cc +++ b/chromium/components/optimization_guide/core/hints_manager_unittest.cc @@ -15,7 +15,6 @@ #include "base/test/metrics/histogram_tester.h" #include "base/test/scoped_feature_list.h" #include "base/test/task_environment.h" -#include "components/data_reduction_proxy/core/common/data_reduction_proxy_pref_names.h" #include "components/optimization_guide/core/bloom_filter.h" #include "components/optimization_guide/core/hint_cache.h" #include "components/optimization_guide/core/hints_component_util.h" @@ -38,10 +37,8 @@ #include "components/variations/scoped_variations_ids_provider.h" #include "services/metrics/public/cpp/ukm_builders.h" #include "services/metrics/public/cpp/ukm_source.h" -#include "services/network/public/cpp/network_connection_tracker.h" #include "services/network/public/cpp/shared_url_loader_factory.h" #include "services/network/public/cpp/weak_wrapper_shared_url_loader_factory.h" -#include "services/network/test/test_network_connection_tracker.h" #include "services/network/test/test_url_loader_factory.h" #include "testing/gtest/include/gtest/gtest.h" @@ -190,12 +187,12 @@ class TestHintsFetcher : public HintsFetcher { scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory, GURL optimization_guide_service_url, PrefService* pref_service, - network::NetworkConnectionTracker* network_connection_tracker, - const std::vector<HintsFetcherEndState>& fetch_states) + const std::vector<HintsFetcherEndState>& fetch_states, + OptimizationGuideLogger* optimization_guide_logger) : HintsFetcher(url_loader_factory, optimization_guide_service_url, pref_service, - network_connection_tracker), + optimization_guide_logger), fetch_states_(fetch_states) { DCHECK(!fetch_states_.empty()); } @@ -266,18 +263,17 @@ class TestHintsFetcherFactory : public HintsFetcherFactory { scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory, GURL optimization_guide_service_url, PrefService* pref_service, - const std::vector<HintsFetcherEndState>& fetch_states, - network::NetworkConnectionTracker* network_connection_tracker) + const std::vector<HintsFetcherEndState>& fetch_states) : HintsFetcherFactory(url_loader_factory, optimization_guide_service_url, - pref_service, - network_connection_tracker), + pref_service), fetch_states_(fetch_states) {} - std::unique_ptr<HintsFetcher> BuildInstance() override { + std::unique_ptr<HintsFetcher> BuildInstance( + OptimizationGuideLogger* optimization_guide_logger) override { return std::make_unique<TestHintsFetcher>( url_loader_factory_, optimization_guide_service_url_, pref_service_, - network_connection_tracker_, fetch_states_); + fetch_states_, optimization_guide_logger); } private: @@ -313,8 +309,6 @@ class HintsManagerTest : public ProtoDatabaseProviderTestBase { pref_service_ = std::make_unique<sync_preferences::TestingPrefServiceSyncable>(); prefs::RegisterProfilePrefs(pref_service_->registry()); - pref_service_->registry()->RegisterBooleanPref( - data_reduction_proxy::prefs::kDataSaverEnabled, false); unified_consent::UnifiedConsentService::RegisterPrefs( pref_service_->registry()); @@ -332,8 +326,8 @@ class HintsManagerTest : public ProtoDatabaseProviderTestBase { /*is_off_the_record=*/false, /*application_locale=*/"en-US", pref_service(), hint_store_->AsWeakPtr(), top_host_provider, tab_url_provider_.get(), url_loader_factory_, - network::TestNetworkConnectionTracker::GetInstance(), - /*push_notification_manager=*/nullptr); + /*push_notification_manager=*/nullptr, + /*optimization_guide_logger=*/nullptr); hints_manager_->SetClockForTesting(task_environment_.GetMockClock()); // Run until hint cache is initialized and the HintsManager is ready to @@ -414,7 +408,7 @@ class HintsManagerTest : public ProtoDatabaseProviderTestBase { const std::vector<HintsFetcherEndState>& fetch_states) { return std::make_unique<TestHintsFetcherFactory>( url_loader_factory_, GURL("https://hintsserver.com"), pref_service(), - fetch_states, network::TestNetworkConnectionTracker::GetInstance()); + fetch_states); } void MoveClockForwardBy(base::TimeDelta time_delta) { @@ -441,16 +435,6 @@ class HintsManagerTest : public ProtoDatabaseProviderTestBase { std::move(callback)); } - void SetConnectionOffline() { - network::TestNetworkConnectionTracker::GetInstance()->SetConnectionType( - network::mojom::ConnectionType::CONNECTION_NONE); - } - - void SetConnectionOnline() { - network::TestNetworkConnectionTracker::GetInstance()->SetConnectionType( - network::mojom::ConnectionType::CONNECTION_4G); - } - HintsManager* hints_manager() const { return hints_manager_.get(); } int32_t num_batch_update_hints_fetches_initiated() const { @@ -1402,8 +1386,6 @@ TEST_F(HintsManagerTest, CanApplyOptimizationAndPopulatesAnyMetadata) { TEST_F(HintsManagerTest, CanApplyOptimizationNoMatchingPageHint) { InitializeWithDefaultConfig("1.0.0.0"); - // Set to online so fetch is activated. - SetConnectionOnline(); auto navigation_data = CreateTestNavigationData(GURL("https://somedomain.org/nomatch"), {}); base::RunLoop run_loop; @@ -2074,8 +2056,6 @@ TEST_F(HintsManagerFetchingTest, /*is_allowlist=*/true, &config); ProcessHints(config, "1.0.0.0"); - // Set to online so fetch is activated. - SetConnectionOnline(); auto navigation_data = CreateTestNavigationData(url_without_hints(), {proto::LITE_PAGE_REDIRECT}); base::HistogramTester histogram_tester; @@ -2096,8 +2076,6 @@ TEST_F(HintsManagerFetchingTest, HintsFetchedAtNavigationTime) { hints_manager()->RegisterOptimizationTypes({proto::DEFER_ALL_SCRIPT}); InitializeWithDefaultConfig("1.0.0.0"); - // Set to online so fetch is activated. - SetConnectionOnline(); auto navigation_data = CreateTestNavigationData(url_without_hints(), {proto::DEFER_ALL_SCRIPT}); base::HistogramTester histogram_tester; @@ -2120,8 +2098,6 @@ TEST_F(HintsManagerFetchingTest, hints_manager()->RegisterOptimizationTypes({proto::DEFER_ALL_SCRIPT}); InitializeWithDefaultConfig("1.0.0.0"); - // Set to online so fetch is activated. - SetConnectionOnline(); auto navigation_data = CreateTestNavigationData(url_without_hints(), {proto::DEFER_ALL_SCRIPT}); hints_manager()->SetHintsFetcherFactoryForTesting( @@ -2149,8 +2125,6 @@ TEST_F(HintsManagerFetchingTest, BuildTestHintsFetcherFactory( {HintsFetcherEndState::kFetchSuccessWithURLHints})); - // Set to online so fetch is activated. - SetConnectionOnline(); auto navigation_data = CreateTestNavigationData(url_with_hints(), {proto::DEFER_ALL_SCRIPT}); base::HistogramTester histogram_tester; @@ -2188,8 +2162,6 @@ TEST_F(HintsManagerFetchingTest, switches::kDisableCheckingUserPermissionsForTesting); hints_manager()->RegisterOptimizationTypes({proto::DEFER_ALL_SCRIPT}); - // Set to online so fetch is activated. - SetConnectionOnline(); auto navigation_data = CreateTestNavigationData(example_url, {proto::DEFER_ALL_SCRIPT}); base::HistogramTester histogram_tester; @@ -2221,9 +2193,6 @@ TEST_F(HintsManagerFetchingTest, URLHintsNotFetchedAtNavigationTime) { BuildTestHintsFetcherFactory( {HintsFetcherEndState::kFetchSuccessWithURLHints})); - // Set to online so fetch is activated. - SetConnectionOnline(); - { base::HistogramTester histogram_tester; auto navigation_data = @@ -2282,9 +2251,6 @@ TEST_F(HintsManagerFetchingTest, URLWithNoHintsNotRefetchedAtNavigationTime) { BuildTestHintsFetcherFactory( {HintsFetcherEndState::kFetchSuccessWithHostHints})); - // Set to online so fetch is activated. - SetConnectionOnline(); - base::HistogramTester histogram_tester; { auto navigation_data = CreateTestNavigationData(url_without_hints(), @@ -2330,8 +2296,6 @@ TEST_F(HintsManagerFetchingTest, CanApplyOptimizationCalledMidFetch) { hints_manager()->RegisterOptimizationTypes({proto::DEFER_ALL_SCRIPT}); InitializeWithDefaultConfig("1.0.0.0"); - // Set to online so fetch is activated. - SetConnectionOnline(); auto navigation_data = CreateTestNavigationData(url_without_hints(), {proto::DEFER_ALL_SCRIPT}); CallOnNavigationStartOrRedirect(navigation_data.get(), base::DoNothing()); @@ -2355,8 +2319,6 @@ TEST_F(HintsManagerFetchingTest, BuildTestHintsFetcherFactory( {HintsFetcherEndState::kFetchSuccessWithNoHints})); - // Set to online so fetch is activated. - SetConnectionOnline(); auto navigation_data = CreateTestNavigationData(url_without_hints(), {proto::DEFER_ALL_SCRIPT}); CallOnNavigationStartOrRedirect(navigation_data.get(), base::DoNothing()); @@ -2381,8 +2343,6 @@ TEST_F(HintsManagerFetchingTest, hints_manager()->SetHintsFetcherFactoryForTesting( BuildTestHintsFetcherFactory({HintsFetcherEndState::kFetchFailed})); - // Set to online so fetch is activated. - SetConnectionOnline(); auto navigation_data = CreateTestNavigationData(url_without_hints(), {proto::DEFER_ALL_SCRIPT}); CallOnNavigationStartOrRedirect(navigation_data.get(), base::DoNothing()); @@ -2408,8 +2368,6 @@ TEST_F(HintsManagerFetchingTest, BuildTestHintsFetcherFactory( {HintsFetcherEndState::kFetchSuccessWithURLHints})); - // Set to online so fetch is activated. - SetConnectionOnline(); auto navigation_data = CreateTestNavigationData(url_with_url_keyed_hint(), {proto::DEFER_ALL_SCRIPT}); // Make sure URL-keyed hint is fetched and processed. @@ -2437,9 +2395,6 @@ TEST_F(HintsManagerFetchingTest, InitializeWithDefaultConfig("1.0.0.0"); - // Set to online so fetch is activated. - SetConnectionOnline(); - // Make sure both URL-Keyed and host-keyed hints are processed and cached. hints_manager()->SetHintsFetcherFactoryForTesting( BuildTestHintsFetcherFactory( @@ -2467,9 +2422,6 @@ TEST_F(HintsManagerFetchingTest, InitializeWithDefaultConfig("1.0.0.0"); - // Set to online so fetch is activated. - SetConnectionOnline(); - // Make sure both URL-Keyed and host-keyed hints are processed and cached. hints_manager()->SetHintsFetcherFactoryForTesting( BuildTestHintsFetcherFactory( @@ -2497,9 +2449,6 @@ TEST_F(HintsManagerFetchingTest, InitializeWithDefaultConfig("1.0.0.0"); - // Set to online so fetch is activated. - SetConnectionOnline(); - hints_manager()->SetHintsFetcherFactoryForTesting( BuildTestHintsFetcherFactory( {HintsFetcherEndState::kFetchSuccessWithNoHints})); @@ -2528,9 +2477,6 @@ TEST_F(HintsManagerFetchingTest, InitializeWithDefaultConfig("1.0.0.0"); - // Set to online so fetch is activated. - SetConnectionOnline(); - // Attempt to fetch a hint but call CanApplyOptimization right away to // simulate being mid-fetch. auto navigation_data = CreateTestNavigationData( @@ -2557,9 +2503,6 @@ TEST_F(HintsManagerFetchingTest, InitializeWithDefaultConfig("1.0.0.0"); - // Set to online so fetch is activated. - SetConnectionOnline(); - // Attempt to fetch a hint but initiate the next navigation right away to // simulate being mid-fetch. auto navigation_data = @@ -2608,8 +2551,6 @@ TEST_F(HintsManagerFetchingTest, InitializeWithDefaultConfig("1.0.0.0"); - // Set to online so fetch is activated. - SetConnectionOnline(); hints_manager()->SetHintsFetcherFactoryForTesting( BuildTestHintsFetcherFactory( {HintsFetcherEndState::kFetchSuccessWithNoHints})); @@ -2650,8 +2591,6 @@ TEST_F(HintsManagerFetchingTest, InitializeWithDefaultConfig("1.0.0.0"); - // Set to online so fetch is activated. - SetConnectionOnline(); hints_manager()->SetHintsFetcherFactoryForTesting( BuildTestHintsFetcherFactory( {HintsFetcherEndState::kFetchSuccessWithURLHints})); @@ -2691,9 +2630,6 @@ TEST_F(HintsManagerFetchingTest, hints_manager()->RegisterOptimizationTypes({proto::COMPRESS_PUBLIC_IMAGES}); InitializeWithDefaultConfig("1.0.0.0"); - // Set to online so fetch is activated. - SetConnectionOnline(); - hints_manager()->SetHintsFetcherFactoryForTesting( BuildTestHintsFetcherFactory( {HintsFetcherEndState::kFetchSuccessWithURLHints})); @@ -2723,9 +2659,6 @@ TEST_F(HintsManagerFetchingTest, hints_manager()->RegisterOptimizationTypes({proto::COMPRESS_PUBLIC_IMAGES}); InitializeWithDefaultConfig("1.0.0.0"); - // Set to online so fetch is activated. - SetConnectionOnline(); - hints_manager()->SetHintsFetcherFactoryForTesting( BuildTestHintsFetcherFactory( {HintsFetcherEndState::kFetchSuccessWithURLHints})); @@ -2763,9 +2696,6 @@ TEST_F( hints_manager()->RegisterOptimizationTypes({proto::RESOURCE_LOADING}); InitializeWithDefaultConfig("1.0.0.0"); - // Set to online so fetch is activated. - SetConnectionOnline(); - hints_manager()->SetHintsFetcherFactoryForTesting( BuildTestHintsFetcherFactory( {HintsFetcherEndState::kFetchSuccessWithURLHints})); @@ -2795,9 +2725,6 @@ TEST_F(HintsManagerFetchingTest, InitializeWithDefaultConfig("1.0.0.0"); - // Set to online so fetch is activated. - SetConnectionOnline(); - hints_manager()->SetHintsFetcherFactoryForTesting( BuildTestHintsFetcherFactory({HintsFetcherEndState::kFetchFailed})); auto navigation_data = CreateTestNavigationData( @@ -2826,9 +2753,6 @@ TEST_F(HintsManagerFetchingTest, InitializeWithDefaultConfig("1.0.0.0"); - // Set to online so fetch is activated. - SetConnectionOnline(); - hints_manager()->SetHintsFetcherFactoryForTesting( BuildTestHintsFetcherFactory( {HintsFetcherEndState::kFetchSuccessWithURLHints})); @@ -2861,9 +2785,6 @@ TEST_F(HintsManagerFetchingTest, InitializeWithDefaultConfig("1.0.0.0"); - // Set to online so fetch is activated. - SetConnectionOnline(); - hints_manager()->SetHintsFetcherFactoryForTesting( BuildTestHintsFetcherFactory( {HintsFetcherEndState::kFetchSuccessWithURLHints})); @@ -2895,9 +2816,6 @@ TEST_F(HintsManagerFetchingTest, InitializeWithDefaultConfig("1.0.0.0"); - // Set to online so fetch is activated. - SetConnectionOnline(); - hints_manager()->SetHintsFetcherFactoryForTesting( BuildTestHintsFetcherFactory( {HintsFetcherEndState::kFetchSuccessWithNoHints})); @@ -2927,9 +2845,6 @@ TEST_F(HintsManagerFetchingTest, InitializeWithDefaultConfig("1.0.0.0"); - // Set to offline so fetch is NOT activated. - SetConnectionOffline(); - GURL url_that_redirected("https://urlthatredirected.com"); auto navigation_data_redirect = CreateTestNavigationData( url_that_redirected, {proto::COMPRESS_PUBLIC_IMAGES}); @@ -2958,9 +2873,6 @@ TEST_F(HintsManagerFetchingTest, InitializeWithDefaultConfig("1.0.0.0"); - // Set to offline so fetch is NOT activated. - SetConnectionOffline(); - hints_manager()->SetHintsFetcherFactoryForTesting( BuildTestHintsFetcherFactory( {HintsFetcherEndState::kFetchSuccessWithNoHints})); @@ -3053,9 +2965,6 @@ TEST_F(HintsManagerFetchingTest, InitializeWithDefaultConfig("1.0.0.0"); - // Set to online so fetch is activated. - SetConnectionOnline(); - hints_manager()->SetHintsFetcherFactoryForTesting( BuildTestHintsFetcherFactory( {HintsFetcherEndState::kFetchSuccessWithURLHints})); @@ -3099,9 +3008,6 @@ TEST_F(HintsManagerFetchingTest, NewOptTypeRegisteredClearsHintCache) { GURL url("https://host.com/fetched_hint_host"); - // Set to online so fetch is activated. - SetConnectionOnline(); - hints_manager()->SetHintsFetcherFactoryForTesting( BuildTestHintsFetcherFactory( {HintsFetcherEndState::kFetchSuccessWithHostHints})); @@ -3129,10 +3035,6 @@ TEST_F(HintsManagerFetchingTest, NewOptTypeRegisteredClearsHintCache) { base::RunLoop run_loop; - // Set to offline so fetch is NOT activated, so the cache state is known and - // empty. - SetConnectionOffline(); - base::HistogramTester histogram_tester; navigation_data = CreateTestNavigationData(url, {proto::DEFER_ALL_SCRIPT}); @@ -3159,9 +3061,6 @@ TEST_F(HintsManagerFetchingTest, hints_manager()->RegisterOptimizationTypes({proto::COMPRESS_PUBLIC_IMAGES}); InitializeWithDefaultConfig("1.0.0.0"); - // Set to online so fetch is activated. - SetConnectionOnline(); - hints_manager()->SetHintsFetcherFactoryForTesting( BuildTestHintsFetcherFactory( {HintsFetcherEndState::kFetchSuccessWithURLHints})); @@ -3196,9 +3095,6 @@ TEST_F(HintsManagerFetchingTest, BatchUpdateCalledMoreThanMaxConcurrent) { hints_manager()->RegisterOptimizationTypes({proto::COMPRESS_PUBLIC_IMAGES}); InitializeWithDefaultConfig("1.0.0.0"); - // Set to online so fetch is activated. - SetConnectionOnline(); - hints_manager()->SetHintsFetcherFactoryForTesting( BuildTestHintsFetcherFactory( {HintsFetcherEndState::kFetchSuccessWithURLHints})); @@ -3244,9 +3140,6 @@ TEST_F( {proto::NOSCRIPT, proto::COMPRESS_PUBLIC_IMAGES}); InitializeWithDefaultConfig("1.0.0.0"); - // Set to online so fetch is activated. - SetConnectionOnline(); - hints_manager()->SetHintsFetcherFactoryForTesting( BuildTestHintsFetcherFactory( {HintsFetcherEndState::kFetchSuccessWithURLHints})); @@ -3285,9 +3178,6 @@ TEST_F(HintsManagerFetchingTest, {proto::NOSCRIPT, proto::COMPRESS_PUBLIC_IMAGES}); InitializeWithDefaultConfig("1.0.0.0"); - // Set to online so fetch is activated. - SetConnectionOnline(); - hints_manager()->SetHintsFetcherFactoryForTesting( BuildTestHintsFetcherFactory({HintsFetcherEndState::kFetchFailed})); std::unique_ptr<base::RunLoop> run_loop = std::make_unique<base::RunLoop>(); diff --git a/chromium/components/optimization_guide/core/local_page_entities_metadata_provider.cc b/chromium/components/optimization_guide/core/local_page_entities_metadata_provider.cc new file mode 100644 index 00000000000..18a32f9a573 --- /dev/null +++ b/chromium/components/optimization_guide/core/local_page_entities_metadata_provider.cc @@ -0,0 +1,93 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "components/optimization_guide/core/local_page_entities_metadata_provider.h" + +#include "components/optimization_guide/core/entity_metadata.h" + +namespace optimization_guide { + +namespace { + +// The amount of data to build up in memory before converting to a sorted on- +// disk file. +constexpr size_t kDatabaseWriteBufferSizeBytes = 128 * 1024; + +} // namespace + +LocalPageEntitiesMetadataProvider::LocalPageEntitiesMetadataProvider() = + default; +LocalPageEntitiesMetadataProvider::~LocalPageEntitiesMetadataProvider() = + default; + +void LocalPageEntitiesMetadataProvider::Initialize( + leveldb_proto::ProtoDatabaseProvider* database_provider, + const base::FilePath& database_dir, + scoped_refptr<base::SequencedTaskRunner> background_task_runner) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + + background_task_runner_ = std::move(background_task_runner); + database_ = database_provider->GetDB<proto::EntityMetadataStorage>( + leveldb_proto::ProtoDbType::PAGE_ENTITY_METADATA_STORE, database_dir, + background_task_runner_); + + leveldb_env::Options options = leveldb_proto::CreateSimpleOptions(); + options.write_buffer_size = kDatabaseWriteBufferSizeBytes; + database_->Init( + options, + base::BindOnce(&LocalPageEntitiesMetadataProvider::OnDatabaseInitialized, + weak_ptr_factory_.GetWeakPtr())); +} + +void LocalPageEntitiesMetadataProvider::InitializeForTesting( + std::unique_ptr<leveldb_proto::ProtoDatabase<proto::EntityMetadataStorage>> + database, + scoped_refptr<base::SequencedTaskRunner> background_task_runner) { + database_ = std::move(database); + background_task_runner_ = std::move(background_task_runner); +} + +void LocalPageEntitiesMetadataProvider::OnDatabaseInitialized( + leveldb_proto::Enums::InitStatus status) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + if (status != leveldb_proto::Enums::InitStatus::kOK) { + database_.reset(); + return; + } +} + +void LocalPageEntitiesMetadataProvider::GetMetadataForEntityId( + const std::string& entity_id, + EntityMetadataRetrievedCallback callback) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + + if (!database_) { + std::move(callback).Run(absl::nullopt); + return; + } + + database_->GetEntry( + entity_id, base::BindOnce(&LocalPageEntitiesMetadataProvider::OnGotEntry, + weak_ptr_factory_.GetWeakPtr(), entity_id, + std::move(callback))); +} + +void LocalPageEntitiesMetadataProvider::OnGotEntry( + const std::string& entity_id, + EntityMetadataRetrievedCallback callback, + bool success, + std::unique_ptr<proto::EntityMetadataStorage> entry) { + if (!success || !entry) { + std::move(callback).Run(absl::nullopt); + return; + } + + EntityMetadata md; + md.entity_id = entity_id; + md.human_readable_name = entry->entity_name(); + + std::move(callback).Run(md); +} + +} // namespace optimization_guide
\ No newline at end of file diff --git a/chromium/components/optimization_guide/core/local_page_entities_metadata_provider.h b/chromium/components/optimization_guide/core/local_page_entities_metadata_provider.h new file mode 100644 index 00000000000..11070c8961a --- /dev/null +++ b/chromium/components/optimization_guide/core/local_page_entities_metadata_provider.h @@ -0,0 +1,67 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef COMPONENTS_OPTIMIZATION_GUIDE_CORE_LOCAL_PAGE_ENTITIES_METADATA_PROVIDER_H_ +#define COMPONENTS_OPTIMIZATION_GUIDE_CORE_LOCAL_PAGE_ENTITIES_METADATA_PROVIDER_H_ + +#include "base/callback.h" +#include "base/memory/weak_ptr.h" +#include "base/sequence_checker.h" +#include "components/leveldb_proto/public/proto_database.h" +#include "components/leveldb_proto/public/proto_database_provider.h" +#include "components/optimization_guide/core/entity_metadata_provider.h" +#include "components/optimization_guide/proto/page_entities_metadata.pb.h" +#include "third_party/abseil-cpp/absl/types/optional.h" + +namespace optimization_guide { + +// Provides EntityMetadata given an entity id by looking up entries in a local +// database on-disk. +class LocalPageEntitiesMetadataProvider : public EntityMetadataProvider { + public: + LocalPageEntitiesMetadataProvider(); + ~LocalPageEntitiesMetadataProvider() override; + LocalPageEntitiesMetadataProvider(const LocalPageEntitiesMetadataProvider&) = + delete; + LocalPageEntitiesMetadataProvider& operator=( + const LocalPageEntitiesMetadataProvider&) = delete; + + // Initializes this class, setting |database_| and |background_task_runner_|. + void Initialize( + leveldb_proto::ProtoDatabaseProvider* database_provider, + const base::FilePath& database_dir, + scoped_refptr<base::SequencedTaskRunner> background_task_runner); + + // Directly sets |database_| and |background_task_runner_| for tests. + void InitializeForTesting( + std::unique_ptr< + leveldb_proto::ProtoDatabase<proto::EntityMetadataStorage>> database, + scoped_refptr<base::SequencedTaskRunner> background_task_runner); + + // EntityMetadataProvider: + void GetMetadataForEntityId( + const std::string& entity_id, + EntityMetadataRetrievedCallback callback) override; + + private: + void OnDatabaseInitialized(leveldb_proto::Enums::InitStatus status); + void OnGotEntry(const std::string& entity_id, + EntityMetadataRetrievedCallback callback, + bool success, + std::unique_ptr<proto::EntityMetadataStorage> entry); + + std::unique_ptr<leveldb_proto::ProtoDatabase<proto::EntityMetadataStorage>> + database_; + + scoped_refptr<base::SequencedTaskRunner> background_task_runner_; + + SEQUENCE_CHECKER(sequence_checker_); + + base::WeakPtrFactory<LocalPageEntitiesMetadataProvider> weak_ptr_factory_{ + this}; +}; + +} // namespace optimization_guide + +#endif // COMPONENTS_OPTIMIZATION_GUIDE_CORE_LOCAL_PAGE_ENTITIES_METADATA_PROVIDER_H_
\ No newline at end of file diff --git a/chromium/components/optimization_guide/core/local_page_entities_metadata_provider_unittest.cc b/chromium/components/optimization_guide/core/local_page_entities_metadata_provider_unittest.cc new file mode 100644 index 00000000000..7a24cbe2430 --- /dev/null +++ b/chromium/components/optimization_guide/core/local_page_entities_metadata_provider_unittest.cc @@ -0,0 +1,134 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "components/optimization_guide/core/local_page_entities_metadata_provider.h" + +#include "base/test/task_environment.h" +#include "components/leveldb_proto/testing/fake_db.h" +#include "components/optimization_guide/core/optimization_guide_features.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace optimization_guide { + +class LocalPageEntitiesMetadataProviderTest : public testing::Test { + public: + LocalPageEntitiesMetadataProviderTest() = default; + ~LocalPageEntitiesMetadataProviderTest() override = default; + + void SetUp() override { + auto db = std::make_unique< + leveldb_proto::test::FakeDB<proto::EntityMetadataStorage>>(&db_store_); + db_ = db.get(); + + provider_ = std::make_unique<LocalPageEntitiesMetadataProvider>(); + provider_->InitializeForTesting( + std::move(db), task_environment_.GetMainThreadTaskRunner()); + } + + LocalPageEntitiesMetadataProvider* provider() { return provider_.get(); } + + leveldb_proto::test::FakeDB<proto::EntityMetadataStorage>* db() { + return db_; + } + + std::map<std::string, proto::EntityMetadataStorage>* store() { + return &db_store_; + } + + private: + base::test::TaskEnvironment task_environment_; + std::unique_ptr<LocalPageEntitiesMetadataProvider> provider_; + leveldb_proto::test::FakeDB<proto::EntityMetadataStorage>* db_; + std::map<std::string, proto::EntityMetadataStorage> db_store_; +}; + +TEST_F(LocalPageEntitiesMetadataProviderTest, NonInitReturnsNullOpt) { + LocalPageEntitiesMetadataProvider provider; + + absl::optional<EntityMetadata> md; + bool callback_ran = false; + provider.GetMetadataForEntityId( + "entity_id", + base::BindOnce( + [](bool* callback_ran_flag, absl::optional<EntityMetadata>* md_out, + const absl::optional<EntityMetadata>& md_in) { + *callback_ran_flag = true; + *md_out = md_in; + }, + &callback_ran, &md)); + + ASSERT_TRUE(callback_ran); + EXPECT_EQ(absl::nullopt, md); +} + +TEST_F(LocalPageEntitiesMetadataProviderTest, EmptyStoreReturnsNullOpt) { + absl::optional<EntityMetadata> md; + bool callback_ran = false; + provider()->GetMetadataForEntityId( + "entity_id", + base::BindOnce( + [](bool* callback_ran_flag, absl::optional<EntityMetadata>* md_out, + const absl::optional<EntityMetadata>& md_in) { + *callback_ran_flag = true; + *md_out = md_in; + }, + &callback_ran, &md)); + + db()->GetCallback(/*success=*/true); + + ASSERT_TRUE(callback_ran); + EXPECT_EQ(absl::nullopt, md); +} + +TEST_F(LocalPageEntitiesMetadataProviderTest, PopulatedSuccess) { + proto::EntityMetadataStorage stored_proto; + stored_proto.set_entity_name("chip"); + store()->emplace("chocolate", stored_proto); + + EntityMetadata want_md; + want_md.entity_id = "chocolate"; + want_md.human_readable_name = "chip"; + + absl::optional<EntityMetadata> md; + bool callback_ran = false; + provider()->GetMetadataForEntityId( + "chocolate", + base::BindOnce( + [](bool* callback_ran_flag, absl::optional<EntityMetadata>* md_out, + const absl::optional<EntityMetadata>& md_in) { + *callback_ran_flag = true; + *md_out = md_in; + }, + &callback_ran, &md)); + + db()->GetCallback(/*success=*/true); + + ASSERT_TRUE(callback_ran); + EXPECT_EQ(absl::make_optional(want_md), md); +} + +TEST_F(LocalPageEntitiesMetadataProviderTest, PopulatedFailure) { + proto::EntityMetadataStorage stored_proto; + stored_proto.set_entity_name("chip"); + store()->emplace("chocolate", stored_proto); + + absl::optional<EntityMetadata> md; + bool callback_ran = false; + provider()->GetMetadataForEntityId( + "chocolate", + base::BindOnce( + [](bool* callback_ran_flag, absl::optional<EntityMetadata>* md_out, + const absl::optional<EntityMetadata>& md_in) { + *callback_ran_flag = true; + *md_out = md_in; + }, + &callback_ran, &md)); + + db()->GetCallback(/*success=*/false); + + ASSERT_TRUE(callback_ran); + EXPECT_EQ(absl::nullopt, md); +} + +} // namespace optimization_guide
\ No newline at end of file diff --git a/chromium/components/optimization_guide/core/model_enums.h b/chromium/components/optimization_guide/core/model_enums.h new file mode 100644 index 00000000000..1e6c495c753 --- /dev/null +++ b/chromium/components/optimization_guide/core/model_enums.h @@ -0,0 +1,57 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef COMPONENTS_OPTIMIZATION_GUIDE_CORE_MODEL_ENUMS_H_ +#define COMPONENTS_OPTIMIZATION_GUIDE_CORE_MODEL_ENUMS_H_ + +namespace optimization_guide { + +// The types of decisions that can be made for an optimization target. +// +// Keep in sync with OptimizationGuideOptimizationTargetDecision in enums.xml. +enum class OptimizationTargetDecision { + kUnknown = 0, + // The page load does not match the optimization target. + kPageLoadDoesNotMatch = 1, + // The page load matches the optimization target. + kPageLoadMatches = 2, + // The model needed to make the target decision was not available on the + // client. + kModelNotAvailableOnClient = 3, + // The page load is part of a model prediction holdback where all decisions + // will return |OptimizationGuideDecision::kFalse| in an attempt to not taint + // the data for understanding the production recall of the model. + kModelPredictionHoldback = 4, + // The OptimizationGuideDecider was not initialized yet. + kDeciderNotInitialized = 5, + + // Add new values above this line. + kMaxValue = kDeciderNotInitialized, +}; + +// The statuses for a prediction model in the prediction manager when requested +// to be evaluated. +// +// Keep in sync with OptimizationGuidePredictionManagerModelStatus in enums.xml. +enum class PredictionManagerModelStatus { + kUnknown = 0, + // The model is loaded and available for use. + kModelAvailable = 1, + // The store is initialized but does not contain a model for the optimization + // target. + kStoreAvailableNoModelForTarget = 2, + // The store is initialized and contains a model for the optimization target + // but it is not loaded in memory. + kStoreAvailableModelNotLoaded = 3, + // The store is not initialized and it is unknown if it contains a model for + // the optimization target. + kStoreUnavailableModelUnknown = 4, + + // Add new values above this line. + kMaxValue = kStoreUnavailableModelUnknown, +}; + +} // namespace optimization_guide + +#endif // COMPONENTS_OPTIMIZATION_GUIDE_CORE_OPTIMIZATION_GUIDE_ENUMS_H_ diff --git a/chromium/components/optimization_guide/core/model_executor.h b/chromium/components/optimization_guide/core/model_executor.h index 542c604c9db..305b6a9e24b 100644 --- a/chromium/components/optimization_guide/core/model_executor.h +++ b/chromium/components/optimization_guide/core/model_executor.h @@ -17,7 +17,7 @@ namespace optimization_guide { // This class handles the execution, loading, unloading, and associated metrics -// of machine learning models in Optimization Guide on a background thread. This +// of machine learning models in Optimization Guide on a specified thread. This // class is meant to be used and owned by an instance of |ModelHandler|. A // ModelExecutor must be passed to a ModelHandler's constructor, this design // allows the implementer of a ModelExecutor to define how the model is built @@ -25,21 +25,21 @@ namespace optimization_guide { // base_model_executor_helpers.h in this directory for helpful derived classes. // // Lifetime: This class can be constructed on any thread but cannot do anything -// useful until |InitializeAndMoveToBackgroundThread| is called. After that +// useful until |InitializeAndMoveToExecutionThread| is called. After that // method is called, all subsequent calls to this class must be made through the -// |background_task_runner| that was passed to initialize. Furthermore, all -// WeakPointers of this class must only be dereferenced on the background thread -// as well. This in turn means that this class must be destroyed on the -// background thread as well. +// |execution_task_runner| that was passed to initialize. Furthermore, all +// WeakPointers of this class must only be dereferenced on the +// |execution_task_runner| thread as well. This in turn means that this class +// must be destroyed on the |execution_task_runner| thread as well. template <class OutputType, class... InputTypes> class ModelExecutor { public: ModelExecutor() = default; virtual ~ModelExecutor() = default; - virtual void InitializeAndMoveToBackgroundThread( + virtual void InitializeAndMoveToExecutionThread( proto::OptimizationTarget optimization_target, - scoped_refptr<base::SequencedTaskRunner> background_task_runner, + scoped_refptr<base::SequencedTaskRunner> execution_task_runner, scoped_refptr<base::SequencedTaskRunner> reply_task_runner) = 0; virtual void UpdateModelFile(const base::FilePath& file_path) = 0; @@ -53,21 +53,21 @@ class ModelExecutor { using ExecutionCallback = base::OnceCallback<void(const absl::optional<OutputType>&)>; - virtual void SendForExecution(ExecutionCallback ui_callback_on_complete, + virtual void SendForExecution(ExecutionCallback callback_on_complete, base::TimeTicks start_time, InputTypes... args) = 0; - // IMPORTANT: These WeakPointers must only be dereferenced on the background - // thread. - base::WeakPtr<ModelExecutor> GetBackgroundWeakPtr() { - return background_weak_ptr_factory_.GetWeakPtr(); + // IMPORTANT: These WeakPointers must only be dereferenced on the + // |execution_task_runner| thread. + base::WeakPtr<ModelExecutor> GetWeakPtrForExecutionThread() { + return weak_ptr_factory_.GetWeakPtr(); } ModelExecutor(const ModelExecutor&) = delete; ModelExecutor& operator=(const ModelExecutor&) = delete; private: - base::WeakPtrFactory<ModelExecutor> background_weak_ptr_factory_{this}; + base::WeakPtrFactory<ModelExecutor> weak_ptr_factory_{this}; }; } // namespace optimization_guide diff --git a/chromium/components/optimization_guide/core/model_handler.h b/chromium/components/optimization_guide/core/model_handler.h index 010734bab35..ff1c332a4a7 100644 --- a/chromium/components/optimization_guide/core/model_handler.h +++ b/chromium/components/optimization_guide/core/model_handler.h @@ -16,6 +16,7 @@ #include "base/threading/sequenced_task_runner_handle.h" #include "base/time/time.h" #include "components/optimization_guide/core/model_executor.h" +#include "components/optimization_guide/core/model_util.h" #include "components/optimization_guide/core/optimization_guide_model_provider.h" #include "components/optimization_guide/core/optimization_guide_util.h" #include "components/optimization_guide/core/optimization_target_model_observer.h" @@ -24,32 +25,33 @@ namespace optimization_guide { -// This class owns and handles the execution of models on the UI thread. Derived -// classes must provide an implementation of |ModelExecutor| -// (see above) which is then owned by |this|. The passed executor will be called -// and destroyed on a background thread, which is all handled by this class. +// This class owns and handles the execution of models on the UI thread. +// Derived classes must provide an implementation of |ModelExecutor| +// which is then owned by |this|. The passed executor will be called +// and destroyed on the thread specified by |model_executor_task_runner|, +// which is all handled by this class. template <class OutputType, class... InputTypes> class ModelHandler : public OptimizationTargetModelObserver { public: - ModelHandler(OptimizationGuideModelProvider* model_provider, - scoped_refptr<base::SequencedTaskRunner> background_task_runner, - std::unique_ptr<ModelExecutor<OutputType, InputTypes...>> - background_executor, - proto::OptimizationTarget optimization_target, - const absl::optional<proto::Any>& model_metadata) + ModelHandler( + OptimizationGuideModelProvider* model_provider, + scoped_refptr<base::SequencedTaskRunner> model_executor_task_runner, + std::unique_ptr<ModelExecutor<OutputType, InputTypes...>> model_executor, + proto::OptimizationTarget optimization_target, + const absl::optional<proto::Any>& model_metadata) : model_provider_(model_provider), optimization_target_(optimization_target), - background_executor_(std::move(background_executor)), - background_task_runner_(background_task_runner) { + model_executor_(std::move(model_executor)), + model_executor_task_runner_(model_executor_task_runner) { DCHECK(model_provider_); - DCHECK(background_executor_); + DCHECK(model_executor_); DCHECK_NE(optimization_target_, proto::OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN); model_provider_->AddObserverForOptimizationTargetModel( optimization_target_, model_metadata, this); - background_executor_->InitializeAndMoveToBackgroundThread( - optimization_target_, background_task_runner_, + model_executor_->InitializeAndMoveToExecutionThread( + optimization_target_, model_executor_task_runner_, base::SequencedTaskRunnerHandle::Get()); } ~ModelHandler() override { @@ -58,10 +60,10 @@ class ModelHandler : public OptimizationTargetModelObserver { model_provider_->RemoveObserverForOptimizationTargetModel( optimization_target_, this); - // |background_executor_|'s WeakPtrs are used on the background thread, so + // |model_executor_|'s WeakPtrs are used on the model thread, so // that is also where the class must be destroyed. - background_task_runner_->DeleteSoon(FROM_HERE, - std::move(background_executor_)); + model_executor_task_runner_->DeleteSoon(FROM_HERE, + std::move(model_executor_)); } ModelHandler(const ModelHandler&) = delete; ModelHandler& operator=(const ModelHandler&) = delete; @@ -79,32 +81,33 @@ class ModelHandler : public OptimizationTargetModelObserver { ExecutionCallback on_complete_callback = base::BindOnce(&ModelHandler::OnExecutionCompleted, std::move(callback), optimization_target_, now); - background_task_runner_->PostTask( + model_executor_task_runner_->PostTask( FROM_HERE, base::BindOnce( &ModelExecutor<OutputType, InputTypes...>::SendForExecution, - background_executor_->GetBackgroundWeakPtr(), + model_executor_->GetWeakPtrForExecutionThread(), std::move(on_complete_callback), now, input...)); } void SetShouldUnloadModelOnComplete(bool should_auto_unload) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); - background_task_runner_->PostTask( + model_executor_task_runner_->PostTask( FROM_HERE, base::BindOnce( &ModelExecutor<OutputType, InputTypes...>::SetShouldUnloadModelOnComplete, - background_executor_->GetBackgroundWeakPtr(), should_auto_unload)); + model_executor_->GetWeakPtrForExecutionThread(), + should_auto_unload)); } // Requests that the model executor unload the model from memory, if it is // currently loaded. void UnloadModel() { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); - background_task_runner_->PostTask( + model_executor_task_runner_->PostTask( FROM_HERE, base::BindOnce(&ModelExecutor<OutputType, InputTypes...>::UnloadModel, - background_executor_->GetBackgroundWeakPtr())); + model_executor_->GetWeakPtrForExecutionThread())); } // OptimizationTargetModelObserver: @@ -118,16 +121,16 @@ class ModelHandler : public OptimizationTargetModelObserver { model_info_ = model_info; model_available_ = true; - background_task_runner_->PostTask( + model_executor_task_runner_->PostTask( FROM_HERE, base::BindOnce( &ModelExecutor<OutputType, InputTypes...>::UpdateModelFile, - background_executor_->GetBackgroundWeakPtr(), + model_executor_->GetWeakPtrForExecutionThread(), model_info.GetModelFilePath())); // Run any observing callbacks after the model file is posted to the - // background thread so that any model execution requests are posted to the - // background thread after the model update. + // model executor thread so that any model execution requests are posted to + // the model executor thread after the model update. on_model_updated_callbacks_.Notify(); } @@ -168,7 +171,7 @@ class ModelHandler : public OptimizationTargetModelObserver { } private: - // This is called by |background_executor_|. This method does not have to be + // This is called by |model_executor_|. This method does not have to be // static, but because it is stateless we've made it static so that we don't // have to have this class support WeakPointers. static void OnExecutionCompleted( @@ -198,15 +201,14 @@ class ModelHandler : public OptimizationTargetModelObserver { const proto::OptimizationTarget optimization_target_; - // The owned background executor. - std::unique_ptr<ModelExecutor<OutputType, InputTypes...>> - background_executor_; + // The owned model executor. + std::unique_ptr<ModelExecutor<OutputType, InputTypes...>> model_executor_; - // The background task runner. Note that whenever a task is posted here, the - // task takes a reference to the TaskRunner (in a cyclic dependency) so + // The model executor task runner. Note that whenever a task is posted here, + // the task takes a reference to the TaskRunner (in a cyclic dependency) so // |base::Unretained| is not safe anywhere in this class or the - // |background_executor_|. - scoped_refptr<base::SequencedTaskRunner> background_task_runner_; + // |model_executor_|. + scoped_refptr<base::SequencedTaskRunner> model_executor_task_runner_; // Set in |OnModelUpdated|. absl::optional<ModelInfo> model_info_ GUARDED_BY_CONTEXT(sequence_checker_); diff --git a/chromium/components/optimization_guide/core/model_info.cc b/chromium/components/optimization_guide/core/model_info.cc index 077b95dc9b2..db16275f297 100644 --- a/chromium/components/optimization_guide/core/model_info.cc +++ b/chromium/components/optimization_guide/core/model_info.cc @@ -6,7 +6,9 @@ #include "base/memory/ptr_util.h" #include "base/notreached.h" -#include "components/optimization_guide/core/optimization_guide_util.h" +#include "base/strings/utf_string_conversions.h" +#include "build/build_config.h" +#include "components/optimization_guide/core/model_util.h" namespace optimization_guide { diff --git a/chromium/components/optimization_guide/core/model_util.cc b/chromium/components/optimization_guide/core/model_util.cc new file mode 100644 index 00000000000..936d0e4c4cb --- /dev/null +++ b/chromium/components/optimization_guide/core/model_util.cc @@ -0,0 +1,86 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "components/optimization_guide/core/model_util.h" + +#include "base/base64.h" +#include "base/containers/flat_set.h" +#include "base/notreached.h" +#include "base/strings/utf_string_conversions.h" +#include "build/build_config.h" +#include "net/base/url_util.h" +#include "url/url_canon.h" + +namespace optimization_guide { + +// These names are persisted to histograms, so don't change them. +std::string GetStringNameForOptimizationTarget( + optimization_guide::proto::OptimizationTarget optimization_target) { + switch (optimization_target) { + case proto::OPTIMIZATION_TARGET_UNKNOWN: + return "Unknown"; + case proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD: + return "PainfulPageLoad"; + case proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION: + return "LanguageDetection"; + case proto::OPTIMIZATION_TARGET_PAGE_TOPICS: + return "PageTopics"; + case proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB: + return "SegmentationNewTab"; + case proto::OPTIMIZATION_TARGET_SEGMENTATION_SHARE: + return "SegmentationShare"; + case proto::OPTIMIZATION_TARGET_SEGMENTATION_VOICE: + return "SegmentationVoice"; + case proto::OPTIMIZATION_TARGET_MODEL_VALIDATION: + return "ModelValidation"; + case proto::OPTIMIZATION_TARGET_PAGE_ENTITIES: + return "PageEntities"; + case proto::OPTIMIZATION_TARGET_NOTIFICATION_PERMISSION_PREDICTIONS: + return "NotificationPermissions"; + case proto::OPTIMIZATION_TARGET_SEGMENTATION_DUMMY: + return "SegmentationDummyFeature"; + case proto::OPTIMIZATION_TARGET_SEGMENTATION_CHROME_START_ANDROID: + return "SegmentationChromeStartAndroid"; + case proto::OPTIMIZATION_TARGET_SEGMENTATION_QUERY_TILES: + return "SegmentationQueryTiles"; + case proto::OPTIMIZATION_TARGET_PAGE_VISIBILITY: + return "PageVisibility"; + case proto::OPTIMIZATION_TARGET_AUTOFILL_ASSISTANT: + return "AutofillAssistant"; + case proto::OPTIMIZATION_TARGET_PAGE_TOPICS_V2: + return "PageTopicsV2"; + case proto::OPTIMIZATION_TARGET_SEGMENTATION_CHROME_LOW_USER_ENGAGEMENT: + return "SegmentationChromeLowUserEngagement"; + // Whenever a new value is added, make sure to add it to the OptTarget + // variant list in + // //tools/metrics/histograms/metadata/optimization/histograms.xml. + } + NOTREACHED(); + return std::string(); +} + +absl::optional<base::FilePath> StringToFilePath(const std::string& str_path) { + if (str_path.empty()) + return absl::nullopt; + +#if BUILDFLAG(IS_WIN) + return base::FilePath(base::UTF8ToWide(str_path)); +#else + return base::FilePath(str_path); +#endif +} + +std::string FilePathToString(const base::FilePath& file_path) { +#if BUILDFLAG(IS_WIN) + return base::WideToUTF8(file_path.value()); +#else + return file_path.value(); +#endif +} + +base::FilePath GetBaseFileNameForModels() { + return base::FilePath(FILE_PATH_LITERAL("model.tflite")); +} + +} // namespace optimization_guide diff --git a/chromium/components/optimization_guide/core/model_util.h b/chromium/components/optimization_guide/core/model_util.h new file mode 100644 index 00000000000..dbcf1d476db --- /dev/null +++ b/chromium/components/optimization_guide/core/model_util.h @@ -0,0 +1,37 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef COMPONENTS_OPTIMIZATION_GUIDE_CORE_MODEL_UTIL_H_ +#define COMPONENTS_OPTIMIZATION_GUIDE_CORE_MODEL_UTIL_H_ + +#include <string> + +#include "base/files/file_path.h" +#include "components/optimization_guide/proto/models.pb.h" +#include "third_party/abseil-cpp/absl/types/optional.h" + +namespace optimization_guide { + +// Returns the string than can be used to record histograms for the optimization +// target. If adding a histogram to use the string or adding an optimization +// target, update the OptimizationGuide.OptimizationTargets histogram suffixes +// in histograms.xml. +std::string GetStringNameForOptimizationTarget( + proto::OptimizationTarget optimization_target); + +// Returns the file path represented by the given string, handling platform +// differences in the conversion. nullopt is only returned iff the passed string +// is empty. +absl::optional<base::FilePath> StringToFilePath(const std::string& str_path); + +// Returns a string representation of the given |file_path|, handling platform +// differences in the conversion. +std::string FilePathToString(const base::FilePath& file_path); + +// Returns the base file name to use for storing all prediction models. +base::FilePath GetBaseFileNameForModels(); + +} // namespace optimization_guide + +#endif // COMPONENTS_OPTIMIZATION_GUIDE_CORE_MODEL_UTIL_H_ diff --git a/chromium/components/optimization_guide/core/model_validator.cc b/chromium/components/optimization_guide/core/model_validator.cc index b936f7470f5..bb09c0c8e80 100644 --- a/chromium/components/optimization_guide/core/model_validator.cc +++ b/chromium/components/optimization_guide/core/model_validator.cc @@ -53,18 +53,22 @@ ModelValidatorExecutor::ModelValidatorExecutor() = default; ModelValidatorExecutor::~ModelValidatorExecutor() = default; -absl::Status ModelValidatorExecutor::Preprocess( +bool ModelValidatorExecutor::Preprocess( const std::vector<TfLiteTensor*>& input_tensors, const std::vector<float>& input) { // Return error so that actual model execution does not happen. - return absl::Status(absl::StatusCode::kUnimplemented, - "Model execution not supported"); + return false; } -float ModelValidatorExecutor::Postprocess( +absl::optional<float> ModelValidatorExecutor::Postprocess( const std::vector<const TfLiteTensor*>& output_tensors) { std::vector<float> data; - tflite::task::core::PopulateVector<float>(output_tensors[0], &data); + absl::Status status = + tflite::task::core::PopulateVector<float>(output_tensors[0], &data); + if (!status.ok()) { + NOTREACHED(); + return absl::nullopt; + } return data[0]; } diff --git a/chromium/components/optimization_guide/core/model_validator.h b/chromium/components/optimization_guide/core/model_validator.h index 4e48f4f4a2f..fc25cf6bd94 100644 --- a/chromium/components/optimization_guide/core/model_validator.h +++ b/chromium/components/optimization_guide/core/model_validator.h @@ -53,9 +53,9 @@ class ModelValidatorExecutor protected: // BaseModelExecutor: - absl::Status Preprocess(const std::vector<TfLiteTensor*>& input_tensors, - const std::vector<float>& input) override; - float Postprocess( + bool Preprocess(const std::vector<TfLiteTensor*>& input_tensors, + const std::vector<float>& input) override; + absl::optional<float> Postprocess( const std::vector<const TfLiteTensor*>& output_tensors) override; }; diff --git a/chromium/components/optimization_guide/core/model_validator_unittest.cc b/chromium/components/optimization_guide/core/model_validator_unittest.cc index 48a66a623eb..8792f0a9798 100644 --- a/chromium/components/optimization_guide/core/model_validator_unittest.cc +++ b/chromium/components/optimization_guide/core/model_validator_unittest.cc @@ -15,6 +15,7 @@ #include "base/test/metrics/histogram_tester.h" #include "base/test/task_environment.h" #include "build/build_config.h" +#include "components/optimization_guide/core/model_util.h" #include "components/optimization_guide/core/model_validator.h" #include "components/optimization_guide/core/optimization_guide_switches.h" #include "components/optimization_guide/core/optimization_guide_util.h" @@ -125,8 +126,14 @@ TEST_F(ModelValidatorExecutorTest, ValidModel) { proto::OptimizationTarget::OPTIMIZATION_TARGET_MODEL_VALIDATION), ExecutionStatus::kErrorUnknown, 1); + histogram_tester().ExpectUniqueSample( + "OptimizationGuide.ModelExecutor.ModelLoadedSuccessfully." + + GetStringNameForOptimizationTarget( + proto::OptimizationTarget::OPTIMIZATION_TARGET_MODEL_VALIDATION), + true, 1); + histogram_tester().ExpectTotalCount( - "OptimizationGuide.ModelExecutor.ModelLoadingDuration." + + "OptimizationGuide.ModelExecutor.ModelLoadingDuration2." + GetStringNameForOptimizationTarget( proto::OptimizationTarget::OPTIMIZATION_TARGET_MODEL_VALIDATION), 1); @@ -147,8 +154,15 @@ TEST_F(ModelValidatorExecutorTest, DISABLED_InvalidModel) { GetStringNameForOptimizationTarget( proto::OptimizationTarget::OPTIMIZATION_TARGET_MODEL_VALIDATION), ExecutionStatus::kErrorModelFileNotValid, 1); + + histogram_tester().ExpectUniqueSample( + "OptimizationGuide.ModelExecutor.ModelLoadedSuccessfully." + + GetStringNameForOptimizationTarget( + proto::OptimizationTarget::OPTIMIZATION_TARGET_MODEL_VALIDATION), + false, 1); + histogram_tester().ExpectTotalCount( - "OptimizationGuide.ModelExecutor.ModelLoadingDuration." + + "OptimizationGuide.ModelExecutor.ModelLoadingDuration2." + GetStringNameForOptimizationTarget( proto::OptimizationTarget::OPTIMIZATION_TARGET_MODEL_VALIDATION), 1); diff --git a/chromium/components/optimization_guide/core/optimization_guide_constants.cc b/chromium/components/optimization_guide/core/optimization_guide_constants.cc index a8102e2dcac..c6b1318da25 100644 --- a/chromium/components/optimization_guide/core/optimization_guide_constants.cc +++ b/chromium/components/optimization_guide/core/optimization_guide_constants.cc @@ -27,4 +27,7 @@ const base::FilePath::CharType kOptimizationGuidePredictionModelAndFeaturesStore[] = FILE_PATH_LITERAL("optimization_guide_model_and_features_store"); +const base::FilePath::CharType kPageEntitiesMetadataStore[] = + FILE_PATH_LITERAL("page_content_annotations_page_entities_metadata_store"); + } // namespace optimization_guide diff --git a/chromium/components/optimization_guide/core/optimization_guide_constants.h b/chromium/components/optimization_guide/core/optimization_guide_constants.h index 8b94c027625..d095761b879 100644 --- a/chromium/components/optimization_guide/core/optimization_guide_constants.h +++ b/chromium/components/optimization_guide/core/optimization_guide_constants.h @@ -33,6 +33,9 @@ extern const base::FilePath::CharType kOptimizationGuideHintStore[]; extern const base::FilePath::CharType kOptimizationGuidePredictionModelAndFeaturesStore[]; +// The folder where the page entities metadata store will be stored on disk. +extern const base::FilePath::CharType kPageEntitiesMetadataStore[]; + } // namespace optimization_guide #endif // COMPONENTS_OPTIMIZATION_GUIDE_CORE_OPTIMIZATION_GUIDE_CONSTANTS_H_ diff --git a/chromium/components/optimization_guide/core/optimization_guide_enums.h b/chromium/components/optimization_guide/core/optimization_guide_enums.h index 88d24bcf537..a7e817c2b5e 100644 --- a/chromium/components/optimization_guide/core/optimization_guide_enums.h +++ b/chromium/components/optimization_guide/core/optimization_guide_enums.h @@ -49,29 +49,6 @@ enum class OptimizationTypeDecision { kMaxValue = kHintFetchStartedButNotAvailableInTime, }; -// The types of decisions that can be made for an optimization target. -// -// Keep in sync with OptimizationGuideOptimizationTargetDecision in enums.xml. -enum class OptimizationTargetDecision { - kUnknown, - // The page load does not match the optimization target. - kPageLoadDoesNotMatch, - // The page load matches the optimization target. - kPageLoadMatches, - // The model needed to make the target decision was not available on the - // client. - kModelNotAvailableOnClient, - // The page load is part of a model prediction holdback where all decisions - // will return |OptimizationGuideDecision::kFalse| in an attempt to not taint - // the data for understanding the production recall of the model. - kModelPredictionHoldback, - // The OptimizationGuideDecider was not initialized yet. - kDeciderNotInitialized, - - // Add new values above this line. - kMaxValue = kDeciderNotInitialized, -}; - // The statuses for racing a hints fetch with the current navigation based // on the availability of hints for both the current host and URL. // @@ -101,28 +78,6 @@ enum class RaceNavigationFetchAttemptStatus { kDeprecatedRaceNavigationFetchNotAttemptedTooManyConcurrentFetches, }; -// The statuses for a prediction model in the prediction manager when requested -// to be evaluated. -// -// Keep in sync with OptimizationGuidePredictionManagerModelStatus in enums.xml. -enum class PredictionManagerModelStatus { - kUnknown, - // The model is loaded and available for use. - kModelAvailable, - // The store is initialized but does not contain a model for the optimization - // target. - kStoreAvailableNoModelForTarget, - // The store is initialized and contains a model for the optimization target - // but it is not loaded in memory. - kStoreAvailableModelNotLoaded, - // The store is not initialized and it is unknown if it contains a model for - // the optimization target. - kStoreUnavailableModelUnknown, - - // Add new values above this line. - kMaxValue = kStoreUnavailableModelUnknown, -}; - // The statuses for a download file containing a prediction model when verified // and processed. // diff --git a/chromium/components/optimization_guide/core/optimization_guide_features.cc b/chromium/components/optimization_guide/core/optimization_guide_features.cc index 8de77e545fd..6c3a200263f 100644 --- a/chromium/components/optimization_guide/core/optimization_guide_features.cc +++ b/chromium/components/optimization_guide/core/optimization_guide_features.cc @@ -25,15 +25,49 @@ namespace optimization_guide { namespace features { +namespace { + +// Returns whether |locale| is a supported locale for |feature|. +// +// This matches |locale| with the "supported_locales" feature param value in +// |feature|, which is expected to be a comma-separated list of locales. A +// feature param containing "en,es-ES,zh-TW" restricts the feature to English +// language users from any locale and Spanish language users from the Spain +// es-ES locale. A feature param containing "" is unrestricted by locale and any +// user may load it. +bool IsSupportedLocaleForFeature(const std::string locale, + const base::Feature& feature) { + if (!base::FeatureList::IsEnabled(feature)) { + return false; + } + + std::string value = + base::GetFieldTrialParamValueByFeature(feature, "supported_locales"); + std::vector<std::string> supported_locales = base::SplitString( + value, ",", base::TRIM_WHITESPACE, base::SPLIT_WANT_NONEMPTY); + // An empty allowlist admits any locale. + if (supported_locales.empty()) { + return true; + } + + // Otherwise, the locale or the + // primary language subtag must match an element of the allowlist. + std::string locale_language = l10n_util::GetLanguage(locale); + return base::Contains(supported_locales, locale) || + base::Contains(supported_locales, locale_language); +} + +} // namespace + // Enables the syncing of the Optimization Hints component, which provides // hints for what optimizations can be applied on a page load. const base::Feature kOptimizationHints { "OptimizationHints", -#if defined(OS_IOS) +#if BUILDFLAG(IS_IOS) base::FEATURE_DISABLED_BY_DEFAULT -#else // !defined(OS_IOS) +#else // !BUILDFLAG(IS_IOS) base::FEATURE_ENABLED_BY_DEFAULT -#endif // defined(OS_IOS) +#endif // BUILDFLAG(IS_IOS) }; // Feature flag that contains a feature param that specifies the field trials @@ -47,11 +81,11 @@ const base::Feature kRemoteOptimizationGuideFetching{ const base::Feature kRemoteOptimizationGuideFetchingAnonymousDataConsent { "OptimizationHintsFetchingAnonymousDataConsent", -#if defined(OS_ANDROID) +#if BUILDFLAG(IS_ANDROID) base::FEATURE_ENABLED_BY_DEFAULT -#else // !defined(OS_ANDROID) +#else // !BUILDFLAG(IS_ANDROID) base::FEATURE_DISABLED_BY_DEFAULT -#endif // defined(OS_ANDROID) +#endif // BUILDFLAG(IS_ANDROID) }; // Enables performance info in the context menu and fetching from a remote @@ -78,6 +112,17 @@ const base::Feature kOptimizationGuideModelDownloading { const base::Feature kPageContentAnnotations{"PageContentAnnotations", base::FEATURE_DISABLED_BY_DEFAULT}; +// Enables the page entities model to be annotated on every page load. +const base::Feature kPageEntitiesPageContentAnnotations{ + "PageEntitiesPageContentAnnotations", base::FEATURE_DISABLED_BY_DEFAULT}; +// Enables the page visibility model to be annotated on every page load. +const base::Feature kPageVisibilityPageContentAnnotations{ + "PageVisibilityPageContentAnnotations", base::FEATURE_DISABLED_BY_DEFAULT}; + +// This feature flag enables resetting the entities model on shutdown. +const base::Feature kPageEntitiesModelResetOnShutdown{ + "PageEntitiesModelResetOnShutdown", base::FEATURE_DISABLED_BY_DEFAULT}; + // Enables push notification of hints. const base::Feature kPushNotifications{"OptimizationGuidePushNotifications", base::FEATURE_DISABLED_BY_DEFAULT}; @@ -96,6 +141,12 @@ const base::Feature kPageTopicsBatchAnnotations{ const base::Feature kPageVisibilityBatchAnnotations{ "PageVisibilityBatchAnnotations", base::FEATURE_ENABLED_BY_DEFAULT}; +const base::Feature kUseLocalPageEntitiesMetadataProvider{ + "UseLocalPageEntitiesMetadataProvider", base::FEATURE_DISABLED_BY_DEFAULT}; + +const base::Feature kBatchAnnotationsValidation{ + "BatchAnnotationsValidation", base::FEATURE_DISABLED_BY_DEFAULT}; + // The default value here is a bit of a guess. // TODO(crbug/1163244): This should be tuned once metrics are available. base::TimeDelta PageTextExtractionOutstandingRequestsGracePeriod() { @@ -264,14 +315,15 @@ base::TimeDelta StoredHostModelFeaturesFreshnessDuration() { "max_store_duration_for_host_model_features_in_days", 7)); } -base::TimeDelta StoredModelsInactiveDuration() { +base::TimeDelta StoredModelsValidDuration() { // TODO(crbug.com/1234054) This field should not be changed without VERY - // careful consideration. Any model that is on device and expires will be - // removed and triggered to refetch so any feature relying on the model could - // have a period of time without a valid model. + // careful consideration. This is the default duration for models that do not + // specify retention, so changing this can cause models to be removed and + // refetch would only apply to newer models. Any feature relying on the model + // would have a period of time without a valid model, and would need to push a + // new version. return base::Days(GetFieldTrialParamByFeatureAsInt( - kOptimizationTargetPrediction, "inactive_duration_for_models_in_days", - 30)); + kOptimizationTargetPrediction, "valid_duration_for_models_in_days", 30)); } base::TimeDelta URLKeyedHintValidCacheDuration() { @@ -334,6 +386,11 @@ base::TimeDelta PredictionModelFetchRetryDelay() { kOptimizationTargetPrediction, "fetch_retry_minutes", 2)); } +base::TimeDelta PredictionModelFetchStartupDelay() { + return base::Milliseconds(GetFieldTrialParamByFeatureAsInt( + kOptimizationTargetPrediction, "fetch_startup_delay_ms", 2000)); +} + base::TimeDelta PredictionModelFetchInterval() { return base::Hours(GetFieldTrialParamByFeatureAsInt( kOptimizationTargetPrediction, "fetch_interval_hours", 24)); @@ -396,75 +453,20 @@ bool ShouldExtractRelatedSearches() { return kContentAnnotationsExtractRelatedSearchesParam.Get(); } -std::vector<optimization_guide::proto::OptimizationTarget> -GetPageContentModelsToExecute(const std::string& locale) { - if (!IsPageContentAnnotationEnabled()) - return {}; - - // Use an updated parameter name that supports locale filtering. That way, - // older clients that don't know how to interpret locale filtering ignore the - // new parameter name and keep looking for the old one. - std::string value = base::GetFieldTrialParamValueByFeature( - kPageContentAnnotations, "models_to_execute_v2"); - if (value.empty()) { - // If the updated parameter is empty, try getting the older parameter name - // that doesn't support locale-specific models. That way, older parameter - // configurations still work. We don't do a union because that's confusing. - value = base::GetFieldTrialParamValueByFeature(kPageContentAnnotations, - "models_to_execute"); - } - if (value.empty()) { - // If neither the newer or older parameter is set, run the page topics model - // by default. - return {optimization_guide::proto::OPTIMIZATION_TARGET_PAGE_TOPICS}; - } - - // The parameter value delimits models by commas, and per-model locale - // restrictions by colon. For example: - // FOO_MODEL:en:es-ES,BAR_MODEL,BAZ_MODEL:zh-TW - // - FOO_MODEL is restricted to English language users from any locale, and - // Spanish language users from the Spain es-ES locale. - // - BAR_MODEL is unrestricted by locale, and any user may load it. - // - BAZ_MODEL is restricted to zh-TW only, so zh-CN users won't load it. - // - // First split by comma to handle one model at a time. - std::vector<std::string> model_target_strings = base::SplitString( - value, ",", base::TRIM_WHITESPACE, base::SPLIT_WANT_NONEMPTY); - - std::string locale_language = l10n_util::GetLanguage(locale); +bool ShouldExecutePageEntitiesModelOnPageContent(const std::string& locale) { + return base::FeatureList::IsEnabled(kPageEntitiesPageContentAnnotations) && + IsSupportedLocaleForFeature(locale, + kPageEntitiesPageContentAnnotations); +} - optimization_guide::InsertionOrderedSet< - optimization_guide::proto::OptimizationTarget> - model_targets; - for (const auto& model_target_string : model_target_strings) { - // Split by colon to extract the model name and allowlist, early continuing - // for invalid values. - std::vector<std::string> model_name_and_allowed_locales = - base::SplitString(model_target_string, ":", base::TRIM_WHITESPACE, - base::SPLIT_WANT_NONEMPTY); - if (model_name_and_allowed_locales.empty()) - continue; - std::string model_name = model_name_and_allowed_locales[0]; - std::vector<std::string> allowlist; - for (size_t i = 1; i < model_name_and_allowed_locales.size(); ++i) { - allowlist.push_back(model_name_and_allowed_locales[i]); - } - - optimization_guide::proto::OptimizationTarget model_target; - if (!optimization_guide::proto::OptimizationTarget_Parse(model_name, - &model_target)) { - continue; - } - - // An empty allowlist admits any locale. Otherwise, the locale or the - // primary language subtag must match an element of the allowlist. - if (allowlist.empty() || base::Contains(allowlist, locale) || - base::Contains(allowlist, locale_language)) { - model_targets.insert(model_target); - } - } +bool ShouldResetPageEntitiesModelOnShutdown() { + return base::FeatureList::IsEnabled(kPageEntitiesModelResetOnShutdown); +} - return model_targets.vector(); +bool ShouldExecutePageVisibilityModelOnPageContent(const std::string& locale) { + return base::FeatureList::IsEnabled(kPageVisibilityPageContentAnnotations) && + IsSupportedLocaleForFeature(locale, + kPageVisibilityPageContentAnnotations); } bool RemotePageEntitiesEnabled() { @@ -501,7 +503,7 @@ bool ShouldMetadataValidationFetchHostKeyed() { bool ShouldDeferStartupActiveTabsHintsFetch() { return GetFieldTrialParamByFeatureAsBool( kOptimizationHints, "defer_startup_active_tabs_hints_fetch", -#if defined(OS_ANDROID) +#if BUILDFLAG(IS_ANDROID) true #else false @@ -517,5 +519,37 @@ bool PageVisibilityBatchAnnotationsEnabled() { return base::FeatureList::IsEnabled(kPageVisibilityBatchAnnotations); } +bool UseLocalPageEntitiesMetadataProvider() { + return base::FeatureList::IsEnabled(kUseLocalPageEntitiesMetadataProvider); +} + +size_t AnnotateVisitBatchSize() { + return std::max( + 1, GetFieldTrialParamByFeatureAsInt(kPageContentAnnotations, + "annotate_visit_batch_size", 1)); +} + +bool BatchAnnotationsValidationEnabled() { + return base::FeatureList::IsEnabled(kBatchAnnotationsValidation); +} + +base::TimeDelta BatchAnnotationValidationStartupDelay() { + return base::Seconds( + std::max(1, GetFieldTrialParamByFeatureAsInt(kBatchAnnotationsValidation, + "startup_delay", 30))); +} + +size_t BatchAnnotationsValidationBatchSize() { + int batch_size = GetFieldTrialParamByFeatureAsInt(kBatchAnnotationsValidation, + "batch_size", 25); + return std::max(1, batch_size); +} + +size_t MaxVisitAnnotationCacheSize() { + int batch_size = GetFieldTrialParamByFeatureAsInt( + kPageContentAnnotations, "max_visit_annotation_cache_size", 50); + return std::max(1, batch_size); +} + } // namespace features } // namespace optimization_guide diff --git a/chromium/components/optimization_guide/core/optimization_guide_features.h b/chromium/components/optimization_guide/core/optimization_guide_features.h index 9789dbc0b42..e69bd80e439 100644 --- a/chromium/components/optimization_guide/core/optimization_guide_features.h +++ b/chromium/components/optimization_guide/core/optimization_guide_features.h @@ -29,11 +29,15 @@ extern const base::Feature kContextMenuPerformanceInfoAndRemoteHintFetching; extern const base::Feature kOptimizationTargetPrediction; extern const base::Feature kOptimizationGuideModelDownloading; extern const base::Feature kPageContentAnnotations; +extern const base::Feature kPageEntitiesPageContentAnnotations; +extern const base::Feature kPageVisibilityPageContentAnnotations; extern const base::Feature kPageTextExtraction; extern const base::Feature kPushNotifications; extern const base::Feature kOptimizationGuideMetadataValidation; extern const base::Feature kPageTopicsBatchAnnotations; extern const base::Feature kPageVisibilityBatchAnnotations; +extern const base::Feature kUseLocalPageEntitiesMetadataProvider; +extern const base::Feature kBatchAnnotationsValidation; // The grace period duration for how long to give outstanding page text dump // requests to respond after DidFinishLoad. @@ -133,7 +137,7 @@ base::TimeDelta StoredHostModelFeaturesFreshnessDuration(); // The maximum duration for which models can remain in the // OptimizationGuideStore without being loaded. -base::TimeDelta StoredModelsInactiveDuration(); +base::TimeDelta StoredModelsValidDuration(); // The amount of time URL-keyed hints within the hint cache will be // allowed to be used and not be purged. @@ -177,6 +181,10 @@ int PredictionModelFetchRandomMaxDelaySecs(); // models. base::TimeDelta PredictionModelFetchRetryDelay(); +// Returns the time to wait after browser start before fetching prediciton +// models. +base::TimeDelta PredictionModelFetchStartupDelay(); + // Returns the time to wait after a successful fetch of prediction models to // refresh models. base::TimeDelta PredictionModelFetchInterval(); @@ -214,15 +222,16 @@ size_t MaxContentAnnotationRequestsCached(); // as part of page content annotations. bool ShouldExtractRelatedSearches(); -// Returns an ordered vector of models to execute on the page content for each -// page load. It is guaranteed that an optimization target will only be present -// at most once in the returned vector. However, it is not guaranteed that it -// will only contain models that the current PageContentAnnotationsService -// supports, so it is up to the caller to ensure that it can execute the -// specified models. `locale` is used for implement client-side locale filtering -// for models that only work for some locales. -std::vector<optimization_guide::proto::OptimizationTarget> -GetPageContentModelsToExecute(const std::string& locale); +// Returns whether the page entities model should be executed on page content +// for a user using |locale| as their browser language. +bool ShouldExecutePageEntitiesModelOnPageContent(const std::string& locale); + +// Returns whether the page entities model should be reset on shutdown. +bool ShouldResetPageEntitiesModelOnShutdown(); + +// Returns whether the page visibility model should be executed on page content +// for a user using |locale| as their browser language. +bool ShouldExecutePageVisibilityModelOnPageContent(const std::string& locale); // Returns whether page entities should be retrieved from the remote // Optimization Guide service. @@ -249,6 +258,27 @@ bool PageTopicsBatchAnnotationsEnabled(); // Returns if Page Visibility Batch Annotations are enabled. bool PageVisibilityBatchAnnotationsEnabled(); +// Whether to use the leveldb-based page entities metadata provider. +bool UseLocalPageEntitiesMetadataProvider(); + +// The number of visits batch before running the page content annotation +// models. A size of 1 is equivalent to annotating one page load at time +// immediately after requested. +size_t AnnotateVisitBatchSize(); + +// Whether the batch annotation validation feature is enabled. +bool BatchAnnotationsValidationEnabled(); + +// The time period between browser start and running a running batch annotation +// validation. +base::TimeDelta BatchAnnotationValidationStartupDelay(); + +// The size of batches to run for validation. +size_t BatchAnnotationsValidationBatchSize(); + +// The maximum size of the visit annotation cache. +size_t MaxVisitAnnotationCacheSize(); + } // namespace features } // namespace optimization_guide diff --git a/chromium/components/optimization_guide/core/optimization_guide_features_unittest.cc b/chromium/components/optimization_guide/core/optimization_guide_features_unittest.cc index 152c621bfe3..48433bff52b 100644 --- a/chromium/components/optimization_guide/core/optimization_guide_features_unittest.cc +++ b/chromium/components/optimization_guide/core/optimization_guide_features_unittest.cc @@ -68,73 +68,72 @@ TEST(OptimizationGuideFeaturesTest, ValidPageContentRAPPORMetrics) { EXPECT_EQ(.2, features::NoiseProbabilityForRAPPORMetrics()); } -TEST(OptimizationGuideFeaturesTest, GetPageContentModelsToExecute) { +TEST(OptimizationGuideFeaturesTest, + ShouldExecutePageEntitiesModelOnPageContentDisabled) { base::test::ScopedFeatureList scoped_feature_list; - scoped_feature_list.InitAndEnableFeatureWithParameters( - features::kPageContentAnnotations, - {{"models_to_execute_v2", - "OPTIMIZATION_TARGET_PAGE_TOPICS,OPTIMIZATION_TARGET_PAGE_ENTITIES"}}); + scoped_feature_list.InitAndDisableFeature( + features::kPageEntitiesPageContentAnnotations); - auto models = features::GetPageContentModelsToExecute("en-US"); - ASSERT_EQ(2U, models.size()); - ASSERT_EQ(proto::OPTIMIZATION_TARGET_PAGE_TOPICS, models[0]); - ASSERT_EQ(proto::OPTIMIZATION_TARGET_PAGE_ENTITIES, models[1]); + EXPECT_FALSE(features::ShouldExecutePageEntitiesModelOnPageContent("en-US")); } TEST(OptimizationGuideFeaturesTest, - GetPageContentModelsToExecuteOldParameterName) { + ShouldExecutePageEntitiesModelOnPageContentEmptyAllowlist) { + base::test::ScopedFeatureList scoped_feature_list; + + scoped_feature_list.InitAndEnableFeature( + features::kPageEntitiesPageContentAnnotations); + + EXPECT_TRUE(features::ShouldExecutePageEntitiesModelOnPageContent("en-US")); +} + +TEST(OptimizationGuideFeaturesTest, + ShouldExecutePageEntitiesModelOnPageContentWithAllowlist) { base::test::ScopedFeatureList scoped_feature_list; scoped_feature_list.InitAndEnableFeatureWithParameters( - features::kPageContentAnnotations, - {{"models_to_execute", - "OPTIMIZATION_TARGET_PAGE_TOPICS,OPTIMIZATION_TARGET_PAGE_ENTITIES"}}); + features::kPageEntitiesPageContentAnnotations, + {{"supported_locales", "en,zh-TW"}}); - auto models = features::GetPageContentModelsToExecute("en-US"); - ASSERT_EQ(2U, models.size()); - ASSERT_EQ(proto::OPTIMIZATION_TARGET_PAGE_TOPICS, models[0]); - ASSERT_EQ(proto::OPTIMIZATION_TARGET_PAGE_ENTITIES, models[1]); + EXPECT_TRUE(features::ShouldExecutePageEntitiesModelOnPageContent("en-US")); + EXPECT_FALSE(features::ShouldExecutePageEntitiesModelOnPageContent("")); + EXPECT_FALSE(features::ShouldExecutePageEntitiesModelOnPageContent("zh-CN")); } -TEST(OptimizationGuideFeaturesTest, GetPageContentModelsToExecuteLocales) { +TEST(OptimizationGuideFeaturesTest, + ShouldExecutePageVisibilityModelOnPageContentDisabled) { + base::test::ScopedFeatureList scoped_feature_list; + + scoped_feature_list.InitAndDisableFeature( + features::kPageVisibilityPageContentAnnotations); + + EXPECT_FALSE( + features::ShouldExecutePageVisibilityModelOnPageContent("en-US")); +} + +TEST(OptimizationGuideFeaturesTest, + ShouldExecutePageVisibilityModelOnPageContentEmptyAllowlist) { + base::test::ScopedFeatureList scoped_feature_list; + + scoped_feature_list.InitAndEnableFeature( + features::kPageVisibilityPageContentAnnotations); + + EXPECT_TRUE(features::ShouldExecutePageVisibilityModelOnPageContent("en-US")); +} + +TEST(OptimizationGuideFeaturesTest, + ShouldExecutePageVisibilityModelOnPageContentWithAllowlist) { base::test::ScopedFeatureList scoped_feature_list; scoped_feature_list.InitAndEnableFeatureWithParameters( - features::kPageContentAnnotations, - {{ - "models_to_execute_v2", - // This string is meant to test language filtering, locale filtering, - // and tolerance of whitespaces, as well as extra delimiters. - "OPTIMIZATION_TARGET_PAGE_TOPICS:en:es-ES , OPTIMIZATION_TARGET_PAGE_" - "ENTITIES,,OPTIMIZATION_TARGET_PAGE_VISIBILITY:zh-TW:", - }}); - - { - auto models = features::GetPageContentModelsToExecute("en-US"); - ASSERT_EQ(2U, models.size()); - ASSERT_EQ(proto::OPTIMIZATION_TARGET_PAGE_TOPICS, models[0]); - ASSERT_EQ(proto::OPTIMIZATION_TARGET_PAGE_ENTITIES, models[1]); - } - - { - auto models = features::GetPageContentModelsToExecute(""); - ASSERT_EQ(1U, models.size()); - ASSERT_EQ(proto::OPTIMIZATION_TARGET_PAGE_ENTITIES, models[0]); - } - - { - auto models = features::GetPageContentModelsToExecute("zh-CN"); - ASSERT_EQ(1U, models.size()); - ASSERT_EQ(proto::OPTIMIZATION_TARGET_PAGE_ENTITIES, models[0]); - } - - { - auto models = features::GetPageContentModelsToExecute("zh-TW"); - ASSERT_EQ(2U, models.size()); - ASSERT_EQ(proto::OPTIMIZATION_TARGET_PAGE_ENTITIES, models[0]); - ASSERT_EQ(proto::OPTIMIZATION_TARGET_PAGE_VISIBILITY, models[1]); - } + features::kPageVisibilityPageContentAnnotations, + {{"supported_locales", "en,zh-TW"}}); + + EXPECT_TRUE(features::ShouldExecutePageVisibilityModelOnPageContent("en-US")); + EXPECT_FALSE(features::ShouldExecutePageVisibilityModelOnPageContent("")); + EXPECT_FALSE( + features::ShouldExecutePageVisibilityModelOnPageContent("zh-CN")); } } // namespace diff --git a/chromium/components/optimization_guide/core/optimization_guide_logger.cc b/chromium/components/optimization_guide/core/optimization_guide_logger.cc new file mode 100644 index 00000000000..20cdbaf4d1a --- /dev/null +++ b/chromium/components/optimization_guide/core/optimization_guide_logger.cc @@ -0,0 +1,32 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "components/optimization_guide/core/optimization_guide_logger.h" + +OptimizationGuideLogger::OptimizationGuideLogger() = default; + +OptimizationGuideLogger::~OptimizationGuideLogger() = default; + +void OptimizationGuideLogger::AddObserver( + OptimizationGuideLogger::Observer* observer) { + observers_.AddObserver(observer); +} + +void OptimizationGuideLogger::RemoveObserver( + OptimizationGuideLogger::Observer* observer) { + observers_.RemoveObserver(observer); +} + +void OptimizationGuideLogger::OnLogMessageAdded(base::Time event_time, + const std::string& source_file, + int source_line, + const std::string& message) { + DCHECK(!observers_.empty()); + for (Observer& obs : observers_) + obs.OnLogMessageAdded(event_time, source_file, source_line, message); +} + +bool OptimizationGuideLogger::ShouldEnableDebugLogs() const { + return !observers_.empty(); +} diff --git a/chromium/components/optimization_guide/core/optimization_guide_logger.h b/chromium/components/optimization_guide/core/optimization_guide_logger.h new file mode 100644 index 00000000000..d9a2f169ba1 --- /dev/null +++ b/chromium/components/optimization_guide/core/optimization_guide_logger.h @@ -0,0 +1,45 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef COMPONENTS_OPTIMIZATION_GUIDE_CORE_OPTIMIZATION_GUIDE_LOGGER_H_ +#define COMPONENTS_OPTIMIZATION_GUIDE_CORE_OPTIMIZATION_GUIDE_LOGGER_H_ + +#include <string> + +#include "base/observer_list.h" +#include "base/observer_list_types.h" +#include "base/time/time.h" + +// Interface to record the debug logs and send it to be shown in the +// optimization guide internals page. +class OptimizationGuideLogger { + public: + class Observer : public base::CheckedObserver { + public: + virtual void OnLogMessageAdded(base::Time event_time, + const std::string& source_file, + int source_line, + const std::string& message) = 0; + }; + OptimizationGuideLogger(); + ~OptimizationGuideLogger(); + + OptimizationGuideLogger(const OptimizationGuideLogger&) = delete; + OptimizationGuideLogger& operator=(const OptimizationGuideLogger&) = delete; + + void AddObserver(OptimizationGuideLogger::Observer* observer); + void RemoveObserver(OptimizationGuideLogger::Observer* observer); + void OnLogMessageAdded(base::Time event_time, + const std::string& source_file, + int source_line, + const std::string& message); + + // Whether debug logs should allowed to be recorded. + bool ShouldEnableDebugLogs() const; + + private: + base::ObserverList<OptimizationGuideLogger::Observer> observers_; +}; + +#endif // COMPONENTS_OPTIMIZATION_GUIDE_CORE_OPTIMIZATION_GUIDE_LOGGER_H_ diff --git a/chromium/components/optimization_guide/core/optimization_guide_permissions_util.cc b/chromium/components/optimization_guide/core/optimization_guide_permissions_util.cc index d7a71e2785e..fec08802020 100644 --- a/chromium/components/optimization_guide/core/optimization_guide_permissions_util.cc +++ b/chromium/components/optimization_guide/core/optimization_guide_permissions_util.cc @@ -7,21 +7,12 @@ #include <memory> #include "base/feature_list.h" -#include "components/data_reduction_proxy/core/browser/data_reduction_proxy_settings.h" #include "components/optimization_guide/core/optimization_guide_features.h" #include "components/optimization_guide/core/optimization_guide_switches.h" #include "components/unified_consent/url_keyed_data_collection_consent_helper.h" namespace { -bool IsUserDataSaverEnabledAndAllowedToFetchFromRemoteService( - bool is_off_the_record, - PrefService* pref_service) { - // Check if they are a data saver user. - return data_reduction_proxy::DataReductionProxySettings:: - IsDataSaverEnabledByUser(is_off_the_record, pref_service); -} - bool IsUserConsentedToAnonymousDataCollectionAndAllowedToFetchFromRemoteService( PrefService* pref_service) { if (!optimization_guide::features:: @@ -55,10 +46,6 @@ bool IsUserPermittedToFetchFromRemoteOptimizationGuide( if (features::IsRemoteFetchingExplicitlyAllowedForPerformanceInfo()) return true; - if (IsUserDataSaverEnabledAndAllowedToFetchFromRemoteService( - is_off_the_record, pref_service)) - return true; - return IsUserConsentedToAnonymousDataCollectionAndAllowedToFetchFromRemoteService( pref_service); } diff --git a/chromium/components/optimization_guide/core/optimization_guide_permissions_util_unittest.cc b/chromium/components/optimization_guide/core/optimization_guide_permissions_util_unittest.cc index 9e879748e84..015955ccc26 100644 --- a/chromium/components/optimization_guide/core/optimization_guide_permissions_util_unittest.cc +++ b/chromium/components/optimization_guide/core/optimization_guide_permissions_util_unittest.cc @@ -7,8 +7,6 @@ #include "base/command_line.h" #include "base/test/scoped_feature_list.h" #include "base/test/task_environment.h" -#include "components/data_reduction_proxy/core/browser/data_reduction_proxy_settings.h" -#include "components/data_reduction_proxy/core/common/data_reduction_proxy_pref_names.h" #include "components/optimization_guide/core/optimization_guide_features.h" #include "components/sync_preferences/testing_pref_service_syncable.h" #include "components/unified_consent/pref_names.h" @@ -22,13 +20,6 @@ class OptimizationGuidePermissionsUtilTest : public testing::Test { void SetUp() override { unified_consent::UnifiedConsentService::RegisterPrefs( pref_service_.registry()); - pref_service_.registry()->RegisterBooleanPref( - data_reduction_proxy::prefs::kDataSaverEnabled, false); - } - - void SetDataSaverEnabled(bool enabled) { - data_reduction_proxy::DataReductionProxySettings:: - SetDataSaverEnabledForTesting(&pref_service_, enabled); } void SetUrlKeyedAnonymizedDataCollectionEnabled(bool enabled) { @@ -45,53 +36,38 @@ class OptimizationGuidePermissionsUtilTest : public testing::Test { }; TEST_F(OptimizationGuidePermissionsUtilTest, - IsUserPermittedToFetchHintsNonDataSaverUser) { + IsUserPermittedToFetchHintsDefaultUser) { base::test::ScopedFeatureList scoped_feature_list; scoped_feature_list.InitAndEnableFeature( {optimization_guide::features::kRemoteOptimizationGuideFetching}); - SetDataSaverEnabled(false); EXPECT_FALSE(IsUserPermittedToFetchFromRemoteOptimizationGuide( /*is_off_the_record=*/false, pref_service())); } -TEST_F(OptimizationGuidePermissionsUtilTest, - IsUserPermittedToFetchHintsDataSaverUser) { - base::test::ScopedFeatureList scoped_feature_list; - scoped_feature_list.InitAndEnableFeature( - {optimization_guide::features::kRemoteOptimizationGuideFetching}); - SetDataSaverEnabled(true); - - EXPECT_TRUE(IsUserPermittedToFetchFromRemoteOptimizationGuide( - /*is_off_the_record=*/false, pref_service())); -} - TEST_F( OptimizationGuidePermissionsUtilTest, - IsUserPermittedToFetchHintsNonDataSaverUserAnonymousDataCollectionEnabledFeatureEnabled) { + IsUserPermittedToFetchHintsDefaultUserAnonymousDataCollectionEnabledFeatureEnabled) { base::test::ScopedFeatureList scoped_feature_list; scoped_feature_list.InitWithFeatures( {optimization_guide::features::kRemoteOptimizationGuideFetching, optimization_guide::features:: kRemoteOptimizationGuideFetchingAnonymousDataConsent}, {}); - SetDataSaverEnabled(false); SetUrlKeyedAnonymizedDataCollectionEnabled(true); EXPECT_TRUE(IsUserPermittedToFetchFromRemoteOptimizationGuide( /*is_off_the_record=*/false, pref_service())); } -TEST_F( - OptimizationGuidePermissionsUtilTest, - IsUserPermittedToFetchHintsNonDataSaverUserAnonymousDataCollectionDisabled) { +TEST_F(OptimizationGuidePermissionsUtilTest, + IsUserPermittedToFetchHintsDefaultUserAnonymousDataCollectionDisabled) { base::test::ScopedFeatureList scoped_feature_list; scoped_feature_list.InitWithFeatures( {optimization_guide::features::kRemoteOptimizationGuideFetching, optimization_guide::features:: kRemoteOptimizationGuideFetchingAnonymousDataConsent}, {}); - SetDataSaverEnabled(false); SetUrlKeyedAnonymizedDataCollectionEnabled(false); EXPECT_FALSE(IsUserPermittedToFetchFromRemoteOptimizationGuide( @@ -100,13 +76,12 @@ TEST_F( TEST_F( OptimizationGuidePermissionsUtilTest, - IsUserPermittedToFetchHintsNonDataSaverUserAnonymousDataCollectionEnabledFeatureNotEnabled) { + IsUserPermittedToFetchHintsDefaultUserAnonymousDataCollectionEnabledFeatureNotEnabled) { base::test::ScopedFeatureList scoped_feature_list; scoped_feature_list.InitWithFeatures( {optimization_guide::features::kRemoteOptimizationGuideFetching}, {optimization_guide::features:: kRemoteOptimizationGuideFetchingAnonymousDataConsent}); - SetDataSaverEnabled(false); SetUrlKeyedAnonymizedDataCollectionEnabled(true); EXPECT_FALSE(IsUserPermittedToFetchFromRemoteOptimizationGuide( @@ -118,7 +93,6 @@ TEST_F(OptimizationGuidePermissionsUtilTest, base::test::ScopedFeatureList scoped_feature_list; scoped_feature_list.InitWithFeatures( {}, {optimization_guide::features::kRemoteOptimizationGuideFetching}); - SetDataSaverEnabled(true); SetUrlKeyedAnonymizedDataCollectionEnabled(true); EXPECT_FALSE(IsUserPermittedToFetchFromRemoteOptimizationGuide( @@ -133,7 +107,6 @@ TEST_F(OptimizationGuidePermissionsUtilTest, optimization_guide::features:: kContextMenuPerformanceInfoAndRemoteHintFetching}, {}); - SetDataSaverEnabled(false); SetUrlKeyedAnonymizedDataCollectionEnabled(false); EXPECT_TRUE(IsUserPermittedToFetchFromRemoteOptimizationGuide( @@ -150,7 +123,6 @@ TEST_F(OptimizationGuidePermissionsUtilTest, optimization_guide::features:: kContextMenuPerformanceInfoAndRemoteHintFetching}, {}); - SetDataSaverEnabled(true); SetUrlKeyedAnonymizedDataCollectionEnabled(true); EXPECT_FALSE(IsUserPermittedToFetchFromRemoteOptimizationGuide( diff --git a/chromium/components/optimization_guide/core/optimization_guide_store.cc b/chromium/components/optimization_guide/core/optimization_guide_store.cc index 97a50f7812a..7b99acaa5ad 100644 --- a/chromium/components/optimization_guide/core/optimization_guide_store.cc +++ b/chromium/components/optimization_guide/core/optimization_guide_store.cc @@ -3,10 +3,13 @@ // found in the LICENSE file. #include "components/optimization_guide/core/optimization_guide_store.h" +#include <memory> +#include <string> #include "base/bind.h" #include "base/callback_helpers.h" #include "base/files/file_util.h" +#include "base/logging.h" #include "base/metrics/histogram_functions.h" #include "base/metrics/histogram_macros.h" #include "base/sequence_checker.h" @@ -17,9 +20,11 @@ #include "components/leveldb_proto/public/proto_database_provider.h" #include "components/leveldb_proto/public/shared_proto_database_client_list.h" #include "components/optimization_guide/core/memory_hint.h" +#include "components/optimization_guide/core/model_util.h" #include "components/optimization_guide/core/optimization_guide_features.h" #include "components/optimization_guide/core/optimization_guide_util.h" #include "components/optimization_guide/proto/hint_cache.pb.h" +#include "third_party/abseil-cpp/absl/types/optional.h" namespace optimization_guide { @@ -62,9 +67,7 @@ enum class OptimizationGuideHintCacheLevelDBStoreLoadMetadataResult { // recorded when it goes out of scope and its destructor is called. class ScopedLoadMetadataResultRecorder { public: - ScopedLoadMetadataResultRecorder() - : result_(OptimizationGuideHintCacheLevelDBStoreLoadMetadataResult:: - kSuccess) {} + ScopedLoadMetadataResultRecorder() = default; ~ScopedLoadMetadataResultRecorder() { UMA_HISTOGRAM_ENUMERATION( "OptimizationGuide.HintCacheLevelDBStore.LoadMetadataResult", result_); @@ -76,7 +79,8 @@ class ScopedLoadMetadataResultRecorder { } private: - OptimizationGuideHintCacheLevelDBStoreLoadMetadataResult result_; + OptimizationGuideHintCacheLevelDBStoreLoadMetadataResult result_ = + OptimizationGuideHintCacheLevelDBStoreLoadMetadataResult::kSuccess; }; void RecordStatusChange(OptimizationGuideStore::Status status) { @@ -289,20 +293,6 @@ void OptimizationGuideStore::PurgeExpiredFetchedHints() { weak_ptr_factory_.GetWeakPtr())); } -void OptimizationGuideStore::PurgeExpiredHostModelFeatures() { - DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); - - if (!IsAvailable()) - return; - - // Load all the host model features to check their expiry times. - database_->LoadKeysAndEntriesWithFilter( - base::BindRepeating(&DatabasePrefixFilter, - GetHostModelFeaturesEntryKeyPrefix()), - base::BindOnce(&OptimizationGuideStore::OnLoadEntriesToPurgeExpired, - weak_ptr_factory_.GetWeakPtr())); -} - void OptimizationGuideStore::PurgeInactiveModels() { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); @@ -453,14 +443,6 @@ base::Time OptimizationGuideStore::GetFetchedHintsUpdateTime() const { return fetched_update_time_; } -base::Time OptimizationGuideStore::GetHostModelFeaturesUpdateTime() const { - // If the store is not available, the metadata entries have not been loaded - // so there are no host model features. - if (!IsAvailable()) - return base::Time(); - return host_model_features_update_time_; -} - // static const char OptimizationGuideStore::kStoreSchemaVersion[] = "1"; @@ -530,14 +512,6 @@ OptimizationGuideStore::GetOptimizationTargetFromPredictionModelEntryKey( return static_cast<proto::OptimizationTarget>(optimization_target_number); } -// static -OptimizationGuideStore::EntryKeyPrefix -OptimizationGuideStore::GetHostModelFeaturesEntryKeyPrefix() { - return base::NumberToString(static_cast<int>( - OptimizationGuideStore::StoreEntryType::kHostModelFeatures)) + - kKeySectionDelimiter; -} - void OptimizationGuideStore::UpdateStatus(Status new_status) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); @@ -780,24 +754,6 @@ void OptimizationGuideStore::OnLoadMetadata( fetched_update_time_ = base::Time(); } - auto host_model_features_entry = metadata_entries->find( - GetMetadataTypeEntryKey(MetadataType::kHostModelFeatures)); - bool host_model_features_metadata_loaded = false; - host_model_features_update_time_ = base::Time(); - if (host_model_features_entry != metadata_entries->end()) { - DCHECK(host_model_features_entry->second.has_update_time_secs()); - host_model_features_update_time_ = base::Time::FromDeltaSinceWindowsEpoch( - base::Seconds(host_model_features_entry->second.update_time_secs())); - host_model_features_metadata_loaded = true; - } - // TODO(crbug/1001194): Metrics should be separated so that stores maintaining - // different information types only record metrics for the types of entries - // they store. - UMA_HISTOGRAM_BOOLEAN( - "OptimizationGuide.PredictionModelStore." - "HostModelFeaturesLoadMetadataResult", - host_model_features_metadata_loaded); - UpdateStatus(Status::kAvailable); MaybeLoadEntryKeys(std::move(callback)); } @@ -963,14 +919,30 @@ void OptimizationGuideStore::OnLoadModelsToBeUpdated( bool had_entries_to_update_or_remove = !update_vector->empty() || !remove_vector->empty(); for (const auto& entry : *entries) { - bool should_delete_download_file = had_entries_to_update_or_remove; + absl::optional<std::string> delete_download_file; + if (had_entries_to_update_or_remove && + entry.second.has_prediction_model() && + !entry.second.prediction_model().model().download_url().empty()) { + delete_download_file = + entry.second.prediction_model().model().download_url(); + } + // Only check expiry if we weren't explicitly passed in entries to update or // remove. if (!had_entries_to_update_or_remove) { + if (entry.second.keep_beyond_valid_duration()) { + continue; + } if (entry.second.has_expiry_time_secs()) { if (entry.second.expiry_time_secs() <= now_since_epoch) { + // Update the entry to remove the model. + if (entry.second.has_prediction_model() && + !entry.second.prediction_model().model().download_url().empty()) { + delete_download_file = + entry.second.prediction_model().model().download_url(); + } + remove_vector->push_back(entry.first); - should_delete_download_file = true; proto::OptimizationTarget optimization_target = GetOptimizationTargetFromPredictionModelEntryKey(entry.first); base::UmaHistogramBoolean( @@ -984,20 +956,17 @@ void OptimizationGuideStore::OnLoadModelsToBeUpdated( update_vector->push_back(entry); update_vector->back().second.set_expiry_time_secs( now_since_epoch + - features::StoredModelsInactiveDuration().InSeconds()); + features::StoredModelsValidDuration().InSeconds()); } } // Delete files (the model itself and any additional files) that are // provided by the model in its directory. - if (should_delete_download_file && entry.second.has_prediction_model() && - !entry.second.prediction_model().model().download_url().empty()) { - // |GetFilePathFromPredictionModel| only returns nullopt when - // |model().download_url()| is empty. + if (delete_download_file) { + // |StringToFilePath| only returns nullopt when + // |delete_download_file| is empty. base::FilePath model_file_path = - StringToFilePath( - entry.second.prediction_model().model().download_url()) - .value(); + StringToFilePath(*delete_download_file).value(); base::FilePath path_to_delete; // Backwards compatibility: Once upon a time (<M93), model files were @@ -1043,9 +1012,8 @@ bool OptimizationGuideStore::FindPredictionModelEntryKey( *out_prediction_model_entry_key = GetPredictionModelEntryKeyPrefix() + base::NumberToString(static_cast<int>(optimization_target)); - if (entry_keys_->find(*out_prediction_model_entry_key) != entry_keys_->end()) - return true; - return false; + return entry_keys_->find(*out_prediction_model_entry_key) != + entry_keys_->end(); } bool OptimizationGuideStore::RemovePredictionModelFromEntryKey( @@ -1107,6 +1075,17 @@ void OptimizationGuideStore::OnLoadPredictionModel( std::move(callback).Run(std::move(loaded_prediction_model)); return; } + // Also ensure that nothing is returned if the model is expired. + int64_t now_since_epoch = + base::Time::Now().ToDeltaSinceWindowsEpoch().InSeconds(); + if (!entry->keep_beyond_valid_duration() && + entry->expiry_time_secs() <= now_since_epoch) { + // Leave the actual model deletions to |PurgeInactiveModels| and return + // early. + std::unique_ptr<proto::PredictionModel> loaded_prediction_model(nullptr); + std::move(callback).Run(std::move(loaded_prediction_model)); + return; + } std::unique_ptr<proto::PredictionModel> loaded_prediction_model( entry->release_prediction_model()); @@ -1166,167 +1145,4 @@ void OptimizationGuideStore::OnModelFilePathVerified( std::move(callback).Run(nullptr); } -std::unique_ptr<StoreUpdateData> -OptimizationGuideStore::CreateUpdateDataForHostModelFeatures( - base::Time host_model_features_update_time, - base::Time expiry_time) const { - // Create and returns a StoreUpdateData object. This object has host model - // features from the GetModelsResponse moved into and organizes them in a - // format usable by the store. The object will be stored with - // UpdateHostModelFeatures(). - return StoreUpdateData::CreateHostModelFeaturesStoreUpdateData( - host_model_features_update_time, expiry_time); -} - -void OptimizationGuideStore::UpdateHostModelFeatures( - std::unique_ptr<StoreUpdateData> host_model_features_update_data, - base::OnceClosure callback) { - DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); - DCHECK(host_model_features_update_data->update_time()); - - if (!IsAvailable()) { - std::move(callback).Run(); - return; - } - - host_model_features_update_time_ = - *host_model_features_update_data->update_time(); - - entry_keys_.reset(); - - // This will remove the host model features metadata entry and insert all the - // entries currently in |host_model_features_update_data|. - database_->UpdateEntriesWithRemoveFilter( - host_model_features_update_data->TakeUpdateEntries(), - base::BindRepeating( - &DatabasePrefixFilter, - GetMetadataTypeEntryKey(MetadataType::kHostModelFeatures)), - base::BindOnce(&OptimizationGuideStore::OnUpdateStore, - weak_ptr_factory_.GetWeakPtr(), std::move(callback))); -} - -bool OptimizationGuideStore::FindHostModelFeaturesEntryKey( - const std::string& host, - OptimizationGuideStore::EntryKey* out_host_model_features_entry_key) const { - DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); - - if (!entry_keys_) - return false; - - return FindEntryKeyForHostWithPrefix(host, out_host_model_features_entry_key, - GetHostModelFeaturesEntryKeyPrefix()); -} - -void OptimizationGuideStore::LoadAllHostModelFeatures( - AllHostModelFeaturesLoadedCallback callback) { - DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); - - if (!IsAvailable()) { - std::move(callback).Run({}); - return; - } - - // Load all the host model features within the store. - database_->LoadEntriesWithFilter( - base::BindRepeating(&DatabasePrefixFilter, - GetHostModelFeaturesEntryKeyPrefix()), - base::BindOnce(&OptimizationGuideStore::OnLoadAllHostModelFeatures, - weak_ptr_factory_.GetWeakPtr(), std::move(callback))); -} - -void OptimizationGuideStore::LoadHostModelFeatures( - const EntryKey& host_model_features_entry_key, - HostModelFeaturesLoadedCallback callback) { - DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); - - if (!IsAvailable()) { - std::move(callback).Run({}); - return; - } - - // Load all the host model features within the store. - database_->GetEntry( - host_model_features_entry_key, - base::BindOnce(&OptimizationGuideStore::OnLoadHostModelFeatures, - weak_ptr_factory_.GetWeakPtr(), std::move(callback))); -} - -void OptimizationGuideStore::OnLoadHostModelFeatures( - HostModelFeaturesLoadedCallback callback, - bool success, - std::unique_ptr<proto::StoreEntry> entry) { - DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); - - // If either the request failed or the store was set to unavailable after the - // request was started, then the loaded host model features should not - // be considered valid. Reset the entry so that nothing is returned to the - // requester. - if (!success || !IsAvailable()) { - entry.reset(); - } - if (!entry || !entry->has_host_model_features()) { - std::move(callback).Run(nullptr); - return; - } - - std::unique_ptr<proto::HostModelFeatures> loaded_host_model_features( - entry->release_host_model_features()); - std::move(callback).Run(std::move(loaded_host_model_features)); -} - -void OptimizationGuideStore::OnLoadAllHostModelFeatures( - AllHostModelFeaturesLoadedCallback callback, - bool success, - std::unique_ptr<std::vector<proto::StoreEntry>> entries) { - DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); - - // If either the request failed or the store was set to unavailable after the - // request was started, then the loaded host model features should not - // be considered valid. Reset the entry so that nothing is returned to the - // requester. - if (!success || !IsAvailable()) { - entries.reset(); - } - - if (!entries || entries->size() == 0) { - std::unique_ptr<std::vector<proto::HostModelFeatures>> - loaded_host_model_features(nullptr); - std::move(callback).Run(std::move(loaded_host_model_features)); - return; - } - - std::unique_ptr<std::vector<proto::HostModelFeatures>> - loaded_host_model_features = - std::make_unique<std::vector<proto::HostModelFeatures>>(); - for (auto& entry : *entries.get()) { - if (!entry.has_host_model_features()) - continue; - loaded_host_model_features->emplace_back(entry.host_model_features()); - } - std::move(callback).Run(std::move(loaded_host_model_features)); -} - -void OptimizationGuideStore::ClearHostModelFeaturesFromDatabase() { - DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); - - base::UmaHistogramBoolean( - "OptimizationGuide.ClearHostModelFeatures.StoreAvailable", IsAvailable()); - if (!IsAvailable()) - return; - - auto entries_to_save = std::make_unique<EntryVector>(); - - entry_keys_.reset(); - - // Removes all |kHostModelFeatures| store entries. OnUpdateStore will handle - // updating status and re-filling entry_keys with the entries still in the - // store. - database_->UpdateEntriesWithRemoveFilter( - std::move(entries_to_save), // this should be empty. - base::BindRepeating(&DatabasePrefixFilter, - GetHostModelFeaturesEntryKeyPrefix()), - base::BindOnce(&OptimizationGuideStore::OnUpdateStore, - weak_ptr_factory_.GetWeakPtr(), base::DoNothing())); -} - } // namespace optimization_guide diff --git a/chromium/components/optimization_guide/core/optimization_guide_store.h b/chromium/components/optimization_guide/core/optimization_guide_store.h index 2b7a51f85af..38b70f16814 100644 --- a/chromium/components/optimization_guide/core/optimization_guide_store.h +++ b/chromium/components/optimization_guide/core/optimization_guide_store.h @@ -40,10 +40,6 @@ class OptimizationGuideStore { base::OnceCallback<void(const std::string&, std::unique_ptr<MemoryHint>)>; using PredictionModelLoadedCallback = base::OnceCallback<void(std::unique_ptr<proto::PredictionModel>)>; - using HostModelFeaturesLoadedCallback = - base::OnceCallback<void(std::unique_ptr<proto::HostModelFeatures>)>; - using AllHostModelFeaturesLoadedCallback = base::OnceCallback<void( - std::unique_ptr<std::vector<proto::HostModelFeatures>>)>; using EntryKey = std::string; using StoreEntryProtoDatabase = leveldb_proto::ProtoDatabase<proto::StoreEntry>; @@ -80,16 +76,14 @@ class OptimizationGuideStore { // cannot be changed, but new types can be added to the end. // StoreEntryType should remain synchronized with the // HintCacheStoreEntryType in enums.xml. - // Also ensure to add to the OptimizationGuide.StoreEntryTypes histogram - // suffixes if adding a new one. enum class StoreEntryType { kEmpty = 0, kMetadata = 1, kComponentHint = 2, kFetchedHint = 3, kPredictionModel = 4, - kHostModelFeatures = 5, - kMaxValue = kHostModelFeatures, + kDeprecatedHostModelFeatures = 5, // deprecated. + kMaxValue = kDeprecatedHostModelFeatures, }; OptimizationGuideStore( @@ -172,13 +166,9 @@ class OptimizationGuideStore { // removed. void PurgeExpiredFetchedHints(); - // Removes all host model features that have expired from the store. - // |entry_keys_| is updated after the expired host model features are - // removed. - void PurgeExpiredHostModelFeatures(); - // Removes all models that have not been loaded in the max inactive duration // configured. |entry_keys| is updated after the inactive models are removed. + // Respects models' |keep_beyond_valid_duration| setting. void PurgeInactiveModels(); // Creates and returns a StoreUpdateData object for Prediction Models. This @@ -196,9 +186,9 @@ class OptimizationGuideStore { std::unique_ptr<StoreUpdateData> prediction_models_update_data, base::OnceClosure callback); - // Finds the entry key for the prediction model if it is known to the store. - // Returns true if an entry key is found and |out_prediction_model_entry_key| - // is populated with the matching key. + // Finds the entry key for the prediction model if it is still valid in the + // store. Returns true if an entry key is valid and + // |out_prediction_model_entry_key| is populated with any matching key. // Virtualized for testing. virtual bool FindPredictionModelEntryKey( proto::OptimizationTarget optimization_target, @@ -218,60 +208,11 @@ class OptimizationGuideStore { // false otherwise. bool RemovePredictionModelFromEntryKey(const EntryKey& entry_key); - // Creates and returns a StoreUpdateData object for host model features. This - // object is used to collect a batch of host model features in a format that - // is usable to update the store on a background thread. This is always - // created when host model features have been successfully fetched from the - // remote Optimization Guide Service so the store can update old host model - // features. - std::unique_ptr<StoreUpdateData> CreateUpdateDataForHostModelFeatures( - base::Time host_model_features_update_time, - base::Time expiry_time) const; - - // Updates the host model features contained in the store. The callback is run - // asynchronously after the database stores the host model features. - // Virtualized for testing. - virtual void UpdateHostModelFeatures( - std::unique_ptr<StoreUpdateData> host_model_features_update_data, - base::OnceClosure callback); - - // Finds the entry key for the host model features for |host| if it is known - // to the store. Returns true if an entry key is found and - // |out_host_model_features_entry_key| is populated with the matching key. - bool FindHostModelFeaturesEntryKey( - const std::string& host, - OptimizationGuideStore::EntryKey* out_host_model_features_entry_key) - const; - - // Loads the host model features specified by |host_model_features_entry_key|. - // After the load finishes, the host model features data is passed to - // |callback|. In the case where the host model features cannot be loaded, the - // callback is run with a nullptr. Depending on the load result, the callback - // may be synchronous or asynchronous. - void LoadHostModelFeatures(const EntryKey& host_model_features_entry_key, - HostModelFeaturesLoadedCallback callback); - - // Loads all the host model features known to the store. After the load - // finishes, the host model features data is passed back to |callback|. In the - // case where the host model features cannot be loaded, the callback is run - // with a nullptr. Depending on the load result, the callback may be - // synchronous or asynchronous. - // Virtualized for testing. - virtual void LoadAllHostModelFeatures( - AllHostModelFeaturesLoadedCallback callback); - - // Returns the time that the host model features in the store can be updated. - // If |this| is not available, base::Time() is returned. - base::Time GetHostModelFeaturesUpdateTime() const; - // Removes fetched hints whose keys are in |hint_keys| and runs |on_success| // if successful, otherwise the callback is not run. void RemoveFetchedHintsByKey(base::OnceClosure on_success, const base::flat_set<std::string>& hint_keys); - // Clears all host model features from the database and resets the entry keys. - void ClearHostModelFeaturesFromDatabase(); - // Returns true if the current status is Status::kAvailable. bool IsAvailable() const; @@ -304,7 +245,7 @@ class OptimizationGuideStore { kSchema = 1, kComponent = 2, kFetched = 3, - kHostModelFeatures = 4, + kDeprecatedHostModelFeatures = 4, // deprecated. }; // Current schema version of the hint cache store. When this is changed, @@ -331,9 +272,6 @@ class OptimizationGuideStore { // Returns prefix of the key of every prediction model entry: "4_". static EntryKeyPrefix GetPredictionModelEntryKeyPrefix(); - // Returns prefix of the key of every host model features entry: "5_". - static EntryKeyPrefix GetHostModelFeaturesEntryKeyPrefix(); - // Returns the OptimizationTarget from |prediction_model_entry_key|. static proto::OptimizationTarget GetOptimizationTargetFromPredictionModelEntryKey( @@ -466,29 +404,6 @@ class OptimizationGuideStore { PredictionModelLoadedCallback callback, bool success); - // Callback that runs after a host model features entry is loaded from the - // database. If there's currently an in-flight update, then the data could be - // invalidated, so loaded host model features data is discarded. Otherwise, - // the host model features are released into the callback, allowing the caller - // to own the host model features without copying it. Regardless of the - // success or failure of retrieving the key, the callback always runs (it - // simply runs with a nullptr on failure). - void OnLoadHostModelFeatures(HostModelFeaturesLoadedCallback callback, - bool success, - std::unique_ptr<proto::StoreEntry> entry); - - // Callback that runs after all the host model features entries are loaded - // from the database. If there's currently an in-flight update, then the data - // could be invalidated, so loaded host model features data is discarded. - // Otherwise, the host model features are released into the callback, allowing - // the caller to own the host model features without copying it. Regardless of - // the success or failure of retrieving the key, the callback always runs (it - // simply runs with a nullptr on failure). - void OnLoadAllHostModelFeatures( - AllHostModelFeaturesLoadedCallback callback, - bool success, - std::unique_ptr<std::vector<proto::StoreEntry>> entry); - // Proto database used by the store. std::unique_ptr<StoreEntryProtoDatabase> database_; diff --git a/chromium/components/optimization_guide/core/optimization_guide_store_unittest.cc b/chromium/components/optimization_guide/core/optimization_guide_store_unittest.cc index bf60d8a865d..6baa370ec39 100644 --- a/chromium/components/optimization_guide/core/optimization_guide_store_unittest.cc +++ b/chromium/components/optimization_guide/core/optimization_guide_store_unittest.cc @@ -15,6 +15,7 @@ #include "base/test/metrics/histogram_tester.h" #include "base/test/task_environment.h" #include "components/leveldb_proto/testing/fake_db.h" +#include "components/optimization_guide/core/model_util.h" #include "components/optimization_guide/core/optimization_guide_features.h" #include "components/optimization_guide/core/optimization_guide_util.h" #include "components/optimization_guide/core/store_update_data.h" @@ -60,8 +61,8 @@ std::unique_ptr<proto::PredictionModel> CreatePredictionModel() { model_info->set_version(1); model_info->set_optimization_target( proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD); - model_info->add_supported_model_types( - proto::ModelType::MODEL_TYPE_DECISION_TREE); + model_info->add_supported_model_engine_versions( + proto::ModelEngineVersion::MODEL_ENGINE_VERSION_DECISION_TREE); return prediction_model; } @@ -87,8 +88,6 @@ class OptimizationGuideStoreTest : public testing::Test { MetadataSchemaState state, absl::optional<size_t> component_hint_count = absl::optional<size_t>(), absl::optional<base::Time> fetched_hints_update = - absl::optional<base::Time>(), - absl::optional<base::Time> host_model_features_update = absl::optional<base::Time>()) { db_store_.clear(); @@ -135,13 +134,6 @@ class OptimizationGuideStoreTest : public testing::Test { .set_update_time_secs( fetched_hints_update->ToDeltaSinceWindowsEpoch().InSeconds()); } - if (host_model_features_update) { - db_store_[OptimizationGuideStore::GetMetadataTypeEntryKey( - OptimizationGuideStore::MetadataType::kHostModelFeatures)] - .set_update_time_secs( - host_model_features_update->ToDeltaSinceWindowsEpoch() - .InSeconds()); - } } // Moves the specified number of component hints into the update data. @@ -176,38 +168,27 @@ class OptimizationGuideStoreTest : public testing::Test { StoreUpdateData* update_data, optimization_guide::proto::OptimizationTarget optimization_target, absl::optional<base::FilePath> model_file_path = absl::nullopt, - base::flat_set<base::FilePath> additional_file_paths = {}) { + absl::optional<proto::ModelInfo> info = absl::nullopt, + absl::optional<base::Time> expiry_time = absl::nullopt) { std::unique_ptr<optimization_guide::proto::PredictionModel> prediction_model = CreatePredictionModel(); + if (info) + prediction_model->mutable_model_info()->MergeFrom(*info); prediction_model->mutable_model_info()->set_optimization_target( optimization_target); + if (expiry_time) { + auto diff = expiry_time.value() - base::Time::Now(); + prediction_model->mutable_model_info() + ->mutable_valid_duration() + ->set_seconds(diff.InSeconds()); + } if (model_file_path) { prediction_model->mutable_model()->set_download_url( FilePathToString(*model_file_path)); } - for (const base::FilePath& additional_file : additional_file_paths) { - prediction_model->mutable_model_info() - ->add_additional_files() - ->set_file_path(FilePathToString(additional_file)); - } update_data->CopyPredictionModelIntoUpdateData(*prediction_model); } - // Moves |host_model_features_count| into |update_data|. - void SeedHostModelFeaturesUpdateData(StoreUpdateData* update_data, - size_t host_model_features_count) { - for (size_t i = 0; i < host_model_features_count; i++) { - std::string host = GetHost(i); - proto::HostModelFeatures host_model_features; - proto::ModelFeature* model_feature = - host_model_features.add_model_features(); - model_feature->set_feature_name("host_feat1"); - model_feature->set_double_value(2.0); - host_model_features.set_host(host); - update_data->CopyHostModelFeaturesIntoUpdateData(host_model_features); - } - } - void CreateDatabase() { // Reset everything. db_ = nullptr; @@ -304,35 +285,12 @@ class OptimizationGuideStoreTest : public testing::Test { } } - void UpdateHostModelFeatures( - std::unique_ptr<StoreUpdateData> host_model_features_data, - bool update_success = true, - bool load_host_model_features_entry_keys_success = true) { - EXPECT_CALL(*this, OnUpdateStore()); - guide_store()->UpdateHostModelFeatures( - std::move(host_model_features_data), - base::BindOnce(&OptimizationGuideStoreTest::OnUpdateStore, - base::Unretained(this))); - // OnUpdateStore callback - db()->UpdateCallback(update_success); - if (update_success) { - // OnLoadEntryKeys callback - db()->LoadCallback(load_host_model_features_entry_keys_success); - } - } - void ClearFetchedHintsFromDatabase() { guide_store()->ClearFetchedHintsFromDatabase(); db()->UpdateCallback(true); db()->LoadCallback(true); } - void ClearHostModelFeaturesFromDatabase() { - guide_store()->ClearHostModelFeaturesFromDatabase(); - db()->UpdateCallback(true); - db()->LoadCallback(true); - } - void PurgeExpiredFetchedHints() { guide_store()->PurgeExpiredFetchedHints(); @@ -344,17 +302,6 @@ class OptimizationGuideStoreTest : public testing::Test { db()->LoadCallback(true); } - void PurgeExpiredHostModelFeatures() { - guide_store()->PurgeExpiredHostModelFeatures(); - - // OnLoadExpiredEntriesToPurge - db()->LoadCallback(true); - // OnUpdateStore - db()->UpdateCallback(true); - // OnLoadEntryKeys callback - db()->LoadCallback(true); - } - void PurgeInactiveModels() { guide_store()->PurgeInactiveModels(); @@ -388,23 +335,6 @@ class OptimizationGuideStoreTest : public testing::Test { } } - // Verifies that the host model features metadata has the expected next update - // time. - void ExpectHostModelFeaturesMetadata(base::Time update_time) const { - const auto& metadata_entry = - db_store_.find(OptimizationGuideStore::GetMetadataTypeEntryKey( - OptimizationGuideStore::MetadataType::kHostModelFeatures)); - if (metadata_entry != db_store_.end()) { - // The next update time should have same time up to the second as the - // metadata entry is stored in seconds. - EXPECT_TRUE(base::Time::FromDeltaSinceWindowsEpoch(base::Seconds( - metadata_entry->second.update_time_secs())) - - update_time < - base::Seconds(1)); - } else { - FAIL() << "No host model features metadata found"; - } - } // Verifies that the component metadata has the expected version and all // expected component hints are present. void ExpectComponentHintsPresent(const std::string& version, @@ -458,14 +388,6 @@ class OptimizationGuideStoreTest : public testing::Test { MemoryHint* last_loaded_hint() { return last_loaded_hint_.get(); } - proto::HostModelFeatures* last_loaded_host_model_features() { - return last_loaded_host_model_features_.get(); - } - - std::vector<proto::HostModelFeatures>* last_loaded_all_host_model_features() { - return last_loaded_all_host_model_features_.get(); - } - proto::PredictionModel* last_loaded_prediction_model() { return last_loaded_prediction_model_.get(); } @@ -476,18 +398,6 @@ class OptimizationGuideStoreTest : public testing::Test { last_loaded_hint_ = std::move(loaded_hint); } - void OnHostModelFeaturesLoaded( - std::unique_ptr<proto::HostModelFeatures> loaded_host_model_features) { - last_loaded_host_model_features_ = std::move(loaded_host_model_features); - } - - void OnAllHostModelFeaturesLoaded( - std::unique_ptr<std::vector<proto::HostModelFeatures>> - loaded_all_host_model_features) { - last_loaded_all_host_model_features_ = - std::move(loaded_all_host_model_features); - } - void OnPredictionModelLoaded( std::unique_ptr<proto::PredictionModel> loaded_prediction_model) { last_loaded_prediction_model_ = std::move(loaded_prediction_model); @@ -507,9 +417,6 @@ class OptimizationGuideStoreTest : public testing::Test { OptimizationGuideStore::EntryKey last_loaded_entry_key_; std::unique_ptr<MemoryHint> last_loaded_hint_; - std::unique_ptr<proto::HostModelFeatures> last_loaded_host_model_features_; - std::unique_ptr<std::vector<proto::HostModelFeatures>> - last_loaded_all_host_model_features_; std::unique_ptr<proto::PredictionModel> last_loaded_prediction_model_; }; @@ -737,10 +644,6 @@ TEST_F(OptimizationGuideStoreTest, histogram_tester.ExpectBucketCount( "OptimizationGuide.HintCacheLevelDBStore.LoadMetadataResult", 0 /* kSuccess */, 1); - histogram_tester.ExpectBucketCount( - "OptimizationGuide.PredictionModelStore." - "HostModelFeaturesLoadMetadataResult", - false, 1); histogram_tester.ExpectBucketCount( "OptimizationGuide.HintCacheLevelDBStore.Status", 0 /* kUninitialized */, @@ -839,11 +742,6 @@ TEST_F(OptimizationGuideStoreTest, InitializeSucceededWithValidSchemaEntry) { 6 /* kComponentAndFetchedMetadataMissing*/, 1); histogram_tester.ExpectBucketCount( - "OptimizationGuide.PredictionModelStore." - "HostModelFeaturesLoadMetadataResult", - false, 1); - - histogram_tester.ExpectBucketCount( "OptimizationGuide.HintCacheLevelDBStore.Status", 0 /* kUninitialized */, 1); histogram_tester.ExpectBucketCount( @@ -901,9 +799,6 @@ TEST_F(OptimizationGuideStoreTest, InitializeSucceededWithPurgeExistingData) { EXPECT_TRUE(IsMetadataSchemaEntryKeyPresent()); - histogram_tester.ExpectTotalCount( - "OptimizationGuide.HintCacheLevelDBStore.LoadMetadataResult", 0); - histogram_tester.ExpectBucketCount( "OptimizationGuide.HintCacheLevelDBStore.Status", 0 /* kUninitialized */, 1); @@ -923,15 +818,14 @@ TEST_F(OptimizationGuideStoreTest, MetadataSchemaState schema_state = MetadataSchemaState::kValid; size_t component_hint_count = 10; SeedInitialData(schema_state, component_hint_count, - base::Time().Now(), /* fetch_update_time */ - base::Time().Now() /* host_model_features_update_time */); + base::Time().Now() /* fetch_update_time */); CreateDatabase(); InitializeStore(schema_state); // The store should contain the schema metadata entry, the component metadata // entry, and all of the initial component hints. EXPECT_EQ(GetDBStoreEntryCount(), - static_cast<size_t>(component_hint_count + 4)); + static_cast<size_t>(component_hint_count + 3)); EXPECT_EQ(GetStoreEntryKeyCount(), component_hint_count); EXPECT_TRUE(IsMetadataSchemaEntryKeyPresent()); @@ -942,11 +836,6 @@ TEST_F(OptimizationGuideStoreTest, 0 /* kSuccess */, 1); histogram_tester.ExpectBucketCount( - "OptimizationGuide.PredictionModelStore." - "HostModelFeaturesLoadMetadataResult", - true, 1); - - histogram_tester.ExpectBucketCount( "OptimizationGuide.HintCacheLevelDBStore.Status", 0 /* kUninitialized */, 1); histogram_tester.ExpectBucketCount( @@ -988,11 +877,6 @@ TEST_F(OptimizationGuideStoreTest, 6 /* kComponentAndFetchedMetadataMissing*/, 0); histogram_tester.ExpectBucketCount( - "OptimizationGuide.PredictionModelStore." - "HostModelFeaturesLoadMetadataResult", - false, 1); - - histogram_tester.ExpectBucketCount( "OptimizationGuide.HintCacheLevelDBStore.Status", 0 /* kUninitialized */, 1); histogram_tester.ExpectBucketCount( @@ -1027,11 +911,6 @@ TEST_F(OptimizationGuideStoreTest, 4 /* kComponentMetadataMissing*/, 1); histogram_tester.ExpectBucketCount( - "OptimizationGuide.PredictionModelStore." - "HostModelFeaturesLoadMetadataResult", - false, 1); - - histogram_tester.ExpectBucketCount( "OptimizationGuide.HintCacheLevelDBStore.Status", 0 /* kUninitialized */, 1); histogram_tester.ExpectBucketCount( @@ -1945,7 +1824,7 @@ TEST_F(OptimizationGuideStoreTest, FindPredictionModelEntryKey) { std::unique_ptr<StoreUpdateData> update_data = guide_store()->CreateUpdateDataForPredictionModels( update_time + - optimization_guide::features::StoredModelsInactiveDuration()); + optimization_guide::features::StoredModelsValidDuration()); ASSERT_TRUE(update_data); SeedPredictionModelUpdateData(update_data.get(), proto::OPTIMIZATION_TARGET_UNKNOWN); @@ -1971,7 +1850,7 @@ TEST_F(OptimizationGuideStoreTest, std::unique_ptr<StoreUpdateData> update_data = guide_store()->CreateUpdateDataForPredictionModels( update_time + - optimization_guide::features::StoredModelsInactiveDuration()); + optimization_guide::features::StoredModelsValidDuration()); ASSERT_TRUE(update_data); SeedPredictionModelUpdateData(update_data.get(), proto::OPTIMIZATION_TARGET_UNKNOWN); @@ -1994,7 +1873,7 @@ TEST_F(OptimizationGuideStoreTest, FindAndRemovePredictionModelEntryKey) { std::unique_ptr<StoreUpdateData> update_data = guide_store()->CreateUpdateDataForPredictionModels( update_time + - optimization_guide::features::StoredModelsInactiveDuration()); + optimization_guide::features::StoredModelsValidDuration()); ASSERT_TRUE(update_data); SeedPredictionModelUpdateData(update_data.get(), proto::OPTIMIZATION_TARGET_UNKNOWN); @@ -2034,7 +1913,7 @@ TEST_F(OptimizationGuideStoreTest, LoadPredictionModel) { std::unique_ptr<StoreUpdateData> update_data = guide_store()->CreateUpdateDataForPredictionModels( update_time + - optimization_guide::features::StoredModelsInactiveDuration()); + optimization_guide::features::StoredModelsValidDuration()); ASSERT_TRUE(update_data); SeedPredictionModelUpdateData(update_data.get(), proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD); @@ -2076,6 +1955,46 @@ TEST_F(OptimizationGuideStoreTest, LoadPredictionModelOnUnavailableStore) { EXPECT_FALSE(last_loaded_prediction_model()); } +TEST_F(OptimizationGuideStoreTest, LoadPredictionModelOnExpiredModel) { + base::HistogramTester histogram_tester; + size_t initial_hint_count = 10; + MetadataSchemaState schema_state = MetadataSchemaState::kValid; + SeedInitialData(schema_state, initial_hint_count); + CreateDatabase(); + InitializeStore(schema_state); + + // Add an update with models that are "inactive". + base::Time update_time = base::Time().Now(); + std::unique_ptr<StoreUpdateData> update_data = + guide_store()->CreateUpdateDataForPredictionModels( + update_time - + optimization_guide::features::StoredModelsValidDuration()); + ASSERT_TRUE(update_data); + SeedPredictionModelUpdateData( + update_data.get(), proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, + /*model_file_path=*/absl::nullopt, + /*info=*/{}, + update_time - optimization_guide::features::StoredModelsValidDuration()); + UpdatePredictionModels(std::move(update_data)); + // Since models haven't been purged yet it will initially show up as valid. + OptimizationGuideStore::EntryKey entry_key; + EXPECT_TRUE(guide_store()->FindPredictionModelEntryKey( + proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, &entry_key)); + guide_store()->LoadPredictionModel( + entry_key, + base::BindOnce(&OptimizationGuideStoreTest::OnPredictionModelLoaded, + base::Unretained(this))); + // OnPredictionModelLoaded callback + db()->GetCallback(true); + // After load completes, the key will still exist until after purge. + EXPECT_TRUE(guide_store()->FindPredictionModelEntryKey( + proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, &entry_key)); + + // Verify that the OnPredictionModelLoaded callback runs when the store is + // unavailable and that the prediction model was correctly not available. + EXPECT_FALSE(last_loaded_prediction_model()); +} + TEST_F(OptimizationGuideStoreTest, LoadPredictionModelWithUpdateInFlight) { base::HistogramTester histogram_tester; MetadataSchemaState schema_state = MetadataSchemaState::kValid; @@ -2087,7 +2006,7 @@ TEST_F(OptimizationGuideStoreTest, LoadPredictionModelWithUpdateInFlight) { std::unique_ptr<StoreUpdateData> update_data = guide_store()->CreateUpdateDataForPredictionModels( update_time + - optimization_guide::features::StoredModelsInactiveDuration()); + optimization_guide::features::StoredModelsValidDuration()); ASSERT_TRUE(update_data); SeedPredictionModelUpdateData(update_data.get(), proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD); @@ -2117,7 +2036,7 @@ TEST_F(OptimizationGuideStoreTest, std::unique_ptr<StoreUpdateData> update_data = guide_store()->CreateUpdateDataForPredictionModels( update_time + - optimization_guide::features::StoredModelsInactiveDuration()); + optimization_guide::features::StoredModelsValidDuration()); ASSERT_TRUE(update_data); SeedPredictionModelUpdateData(update_data.get(), proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, @@ -2164,7 +2083,7 @@ TEST_F(OptimizationGuideStoreTest, std::unique_ptr<StoreUpdateData> update_data = guide_store()->CreateUpdateDataForPredictionModels( update_time + - optimization_guide::features::StoredModelsInactiveDuration()); + optimization_guide::features::StoredModelsValidDuration()); ASSERT_TRUE(update_data); base::FilePath model_file_path = temp_dir().AppendASCII("model.tflite"); @@ -2173,9 +2092,12 @@ TEST_F(OptimizationGuideStoreTest, base::FilePath additional_file_path = temp_dir().AppendASCII("doesntexist"); - SeedPredictionModelUpdateData( - update_data.get(), proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, - temp_dir().AppendASCII("doesntexist"), {additional_file_path}); + proto::ModelInfo info; + info.add_additional_files()->set_file_path( + FilePathToString(additional_file_path)); + SeedPredictionModelUpdateData(update_data.get(), + proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, + temp_dir().AppendASCII("doesntexist"), info); UpdatePredictionModels(std::move(update_data)); OptimizationGuideStore::EntryKey entry_key; @@ -2221,7 +2143,7 @@ TEST_F(OptimizationGuideStoreTest, std::unique_ptr<StoreUpdateData> update_data = guide_store()->CreateUpdateDataForPredictionModels( update_time + - optimization_guide::features::StoredModelsInactiveDuration()); + optimization_guide::features::StoredModelsValidDuration()); ASSERT_TRUE(update_data); base::FilePath model_file_path = temp_dir().AppendASCII("model.tflite"); @@ -2232,10 +2154,12 @@ TEST_F(OptimizationGuideStoreTest, temp_dir().AppendASCII("additional_file.txt"); ASSERT_EQ(static_cast<int32_t>(3), base::WriteFile(additional_file_path, "ah!", 3)); - + proto::ModelInfo info; + info.add_additional_files()->set_file_path( + FilePathToString(additional_file_path)); SeedPredictionModelUpdateData(update_data.get(), proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, - model_file_path, {additional_file_path}); + model_file_path, info); UpdatePredictionModels(std::move(update_data)); OptimizationGuideStore::EntryKey entry_key; @@ -2277,7 +2201,7 @@ TEST_F(OptimizationGuideStoreTest, std::unique_ptr<StoreUpdateData> update_data = guide_store()->CreateUpdateDataForPredictionModels( update_time + - optimization_guide::features::StoredModelsInactiveDuration()); + optimization_guide::features::StoredModelsValidDuration()); ASSERT_TRUE(update_data); base::FilePath file_path = temp_dir().AppendASCII("file"); ASSERT_EQ(static_cast<int32_t>(3), base::WriteFile(file_path, "boo", 3)); @@ -2324,7 +2248,7 @@ TEST_F(OptimizationGuideStoreTest, UpdatePredictionModelsDeletesOldFile) { std::unique_ptr<StoreUpdateData> update_data = guide_store()->CreateUpdateDataForPredictionModels( update_time + - optimization_guide::features::StoredModelsInactiveDuration()); + optimization_guide::features::StoredModelsValidDuration()); ASSERT_TRUE(update_data); base::FilePath old_file_path = old_dir.AppendASCII("model.tflite"); ASSERT_EQ(static_cast<int32_t>(3), base::WriteFile(old_file_path, "boo", 3)); @@ -2341,7 +2265,7 @@ TEST_F(OptimizationGuideStoreTest, UpdatePredictionModelsDeletesOldFile) { std::unique_ptr<StoreUpdateData> update_data2 = guide_store()->CreateUpdateDataForPredictionModels( update_time + - optimization_guide::features::StoredModelsInactiveDuration()); + optimization_guide::features::StoredModelsValidDuration()); ASSERT_TRUE(update_data2); base::FilePath new_file_path = new_dir.AppendASCII("model.tflite"); ASSERT_EQ(static_cast<int32_t>(3), base::WriteFile(new_file_path, "boo", 3)); @@ -2380,7 +2304,7 @@ TEST_F(OptimizationGuideStoreTest, std::unique_ptr<StoreUpdateData> update_data = guide_store()->CreateUpdateDataForPredictionModels( update_time + - optimization_guide::features::StoredModelsInactiveDuration()); + optimization_guide::features::StoredModelsValidDuration()); ASSERT_TRUE(update_data); base::FilePath old_file_path = old_dir.AppendASCII("model_v1.tflite"); ASSERT_EQ(static_cast<int32_t>(3), base::WriteFile(old_file_path, "boo", 3)); @@ -2397,7 +2321,7 @@ TEST_F(OptimizationGuideStoreTest, std::unique_ptr<StoreUpdateData> update_data2 = guide_store()->CreateUpdateDataForPredictionModels( update_time + - optimization_guide::features::StoredModelsInactiveDuration()); + optimization_guide::features::StoredModelsValidDuration()); ASSERT_TRUE(update_data2); base::FilePath new_file_path = new_dir.Append(GetBaseFileNameForModels()); ASSERT_EQ(static_cast<int32_t>(3), base::WriteFile(new_file_path, "boo", 3)); @@ -2429,7 +2353,7 @@ TEST_F(OptimizationGuideStoreTest, RemovePredictionModelEntryKeyDeletesFile) { std::unique_ptr<StoreUpdateData> update_data = guide_store()->CreateUpdateDataForPredictionModels( update_time + - optimization_guide::features::StoredModelsInactiveDuration()); + optimization_guide::features::StoredModelsValidDuration()); ASSERT_TRUE(update_data); base::FilePath file_path = temp_dir().AppendASCII("file"); ASSERT_EQ(static_cast<int32_t>(3), base::WriteFile(file_path, "boo", 3)); @@ -2460,203 +2384,79 @@ TEST_F(OptimizationGuideStoreTest, RemovePredictionModelEntryKeyDeletesFile) { EXPECT_FALSE(base::PathExists(file_path)); } -TEST_F(OptimizationGuideStoreTest, HostModelFeaturesMetadataStored) { - MetadataSchemaState schema_state = MetadataSchemaState::kValid; - base::Time update_time = base::Time().Now(); - SeedInitialData(schema_state, 10, update_time, - base::Time().Now() /* host_model_features_update_time */); - CreateDatabase(); - InitializeStore(schema_state); - - ExpectHostModelFeaturesMetadata(update_time); -} - -TEST_F(OptimizationGuideStoreTest, FindEntryKeyForHostModelFeatures) { - MetadataSchemaState schema_state = MetadataSchemaState::kValid; - size_t update_host_model_features_count = 5; - base::Time update_time = base::Time().Now(); - SeedInitialData(schema_state, 0, - base::Time().Now() /* host_model_features_update_time */); - CreateDatabase(); - InitializeStore(schema_state); - - std::unique_ptr<StoreUpdateData> update_data = - guide_store()->CreateUpdateDataForHostModelFeatures( - update_time, update_time + - optimization_guide::features:: - StoredHostModelFeaturesFreshnessDuration()); - ASSERT_TRUE(update_data); - SeedHostModelFeaturesUpdateData(update_data.get(), - update_host_model_features_count); - UpdateHostModelFeatures(std::move(update_data)); - - for (size_t i = 0; i < update_host_model_features_count; ++i) { - std::string host = GetHost(i); - OptimizationGuideStore::EntryKey entry_key; - bool success = - guide_store()->FindHostModelFeaturesEntryKey(host, &entry_key); - EXPECT_EQ(success, i < update_host_model_features_count); - } -} - -TEST_F(OptimizationGuideStoreTest, LoadHostModelFeaturesForHost) { +TEST_F(OptimizationGuideStoreTest, PurgeInactiveModels) { base::HistogramTester histogram_tester; - size_t update_host_model_features_count = 5; - MetadataSchemaState schema_state = MetadataSchemaState::kValid; - base::Time update_time = base::Time().Now(); - SeedInitialData(schema_state, 0, base::Time().Now()); - CreateDatabase(); - InitializeStore(schema_state); - - std::unique_ptr<StoreUpdateData> update_data = - guide_store()->CreateUpdateDataForHostModelFeatures( - update_time, update_time + - optimization_guide::features:: - StoredHostModelFeaturesFreshnessDuration()); - ASSERT_TRUE(update_data); - SeedHostModelFeaturesUpdateData(update_data.get(), - update_host_model_features_count); - UpdateHostModelFeatures(std::move(update_data)); - - for (size_t i = 0; i < update_host_model_features_count; ++i) { - std::string host = GetHost(i); - OptimizationGuideStore::EntryKey entry_key; - bool success = - guide_store()->FindHostModelFeaturesEntryKey(host, &entry_key); - EXPECT_TRUE(success); - - guide_store()->LoadHostModelFeatures( - entry_key, - base::BindOnce(&OptimizationGuideStoreTest::OnHostModelFeaturesLoaded, - base::Unretained(this))); - // OnPredictionModelLoaded callback - db()->GetCallback(true); - - if (!last_loaded_host_model_features()) { - FAIL() << "Loaded host model features was null for entry key: " - << entry_key; - } - - EXPECT_EQ(last_loaded_host_model_features()->host(), host); - } -} - -TEST_F(OptimizationGuideStoreTest, LoadAllHostModelFeatures) { - base::HistogramTester histogram_tester; - size_t update_host_model_features_count = 5; MetadataSchemaState schema_state = MetadataSchemaState::kValid; - base::Time update_time = base::Time().Now(); - SeedInitialData(schema_state, 0, base::Time().Now()); + SeedInitialData(schema_state, 0); CreateDatabase(); InitializeStore(schema_state); - std::unique_ptr<StoreUpdateData> update_data = - guide_store()->CreateUpdateDataForHostModelFeatures( - update_time, update_time + - optimization_guide::features:: - StoredHostModelFeaturesFreshnessDuration()); - ASSERT_TRUE(update_data); - SeedHostModelFeaturesUpdateData(update_data.get(), - update_host_model_features_count); - UpdateHostModelFeatures(std::move(update_data)); - guide_store()->LoadAllHostModelFeatures( - base::BindOnce(&OptimizationGuideStoreTest::OnAllHostModelFeaturesLoaded, - base::Unretained(this))); - - // OnAllHostModelFeaturesLoaded callback - db()->LoadCallback(true); - - std::vector<proto::HostModelFeatures>* all_host_model_features = - last_loaded_all_host_model_features(); - EXPECT_TRUE(all_host_model_features); - EXPECT_EQ(update_host_model_features_count, all_host_model_features->size()); - - // Build a list of the hosts that are stored in the store. - base::flat_set<std::string> hosts = {}; - for (size_t i = 0; i < update_host_model_features_count; i++) - hosts.insert(GetHost(i)); - - // Make sure all of the hosts of the host model features are returned. - for (const auto& host_model_features : *all_host_model_features) - EXPECT_NE(hosts.find(host_model_features.host()), hosts.end()); -} - -TEST_F(OptimizationGuideStoreTest, ClearHostModelFeatures) { - base::HistogramTester histogram_tester; - size_t update_host_model_features_count = 5; - MetadataSchemaState schema_state = MetadataSchemaState::kValid; + // Add an update with models that are "inactive". base::Time update_time = base::Time().Now(); - SeedInitialData(schema_state, 0, base::Time().Now()); - CreateDatabase(); - InitializeStore(schema_state); - std::unique_ptr<StoreUpdateData> update_data = - guide_store()->CreateUpdateDataForHostModelFeatures( - update_time, update_time + - optimization_guide::features:: - StoredHostModelFeaturesFreshnessDuration()); + guide_store()->CreateUpdateDataForPredictionModels( + update_time - + optimization_guide::features::StoredModelsValidDuration()); ASSERT_TRUE(update_data); - SeedHostModelFeaturesUpdateData(update_data.get(), - update_host_model_features_count); - UpdateHostModelFeatures(std::move(update_data)); - - for (size_t i = 0; i < update_host_model_features_count; ++i) { - std::string host = GetHost(i); - OptimizationGuideStore::EntryKey entry_key; - EXPECT_TRUE(guide_store()->FindHostModelFeaturesEntryKey(host, &entry_key)); - } + base::FilePath old_file_path = temp_dir().AppendASCII("model_v1.tflite"); + ASSERT_EQ(static_cast<int32_t>(3), base::WriteFile(old_file_path, "boo", 3)); + SeedPredictionModelUpdateData( + update_data.get(), proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, + old_file_path, + /*info=*/{}, + update_time - optimization_guide::features::StoredModelsValidDuration()); + UpdatePredictionModels(std::move(update_data)); - // Remove the host model features from the OptimizationGuideStore. - ClearHostModelFeaturesFromDatabase(); - histogram_tester.ExpectBucketCount( - "OptimizationGuide.ClearHostModelFeatures.StoreAvailable", true, 1); + // Add an update with models that are "active". + std::unique_ptr<StoreUpdateData> update_data2 = + guide_store()->CreateUpdateDataForPredictionModels( + update_time + + optimization_guide::features::StoredModelsValidDuration()); + ASSERT_TRUE(update_data2); + SeedPredictionModelUpdateData(update_data2.get(), + proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION); + UpdatePredictionModels(std::move(update_data2)); - for (size_t i = 0; i < update_host_model_features_count; ++i) { - std::string host = GetHost(i); - OptimizationGuideStore::EntryKey entry_key; - EXPECT_FALSE( - guide_store()->FindHostModelFeaturesEntryKey(host, &entry_key)); - } -} + // Make sure both models are in the store. + OptimizationGuideStore::EntryKey entry_key; + bool success = guide_store()->FindPredictionModelEntryKey( + proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, &entry_key); + ASSERT_TRUE(success); + success = guide_store()->FindPredictionModelEntryKey( + proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION, &entry_key); + ASSERT_TRUE(success); -TEST_F(OptimizationGuideStoreTest, PurgeExpiredHostModelFeatures) { - base::HistogramTester histogram_tester; - size_t update_host_model_features_count = 5; - MetadataSchemaState schema_state = MetadataSchemaState::kValid; - base::Time update_time = base::Time().Now(); - SeedInitialData(schema_state, 0, base::Time().Now()); - CreateDatabase(); - InitializeStore(schema_state); + PurgeInactiveModels(); + RunUntilIdle(); + // The expired model should be removed. + EXPECT_FALSE(guide_store()->FindPredictionModelEntryKey( + proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, &entry_key)); + EXPECT_FALSE(base::PathExists(old_file_path)); - std::unique_ptr<StoreUpdateData> update_data = - guide_store()->CreateUpdateDataForHostModelFeatures( - update_time, update_time - - optimization_guide::features:: - StoredHostModelFeaturesFreshnessDuration()); - ASSERT_TRUE(update_data); - SeedHostModelFeaturesUpdateData(update_data.get(), - update_host_model_features_count); - UpdateHostModelFeatures(std::move(update_data)); + // Should not purge models that are still active. + EXPECT_TRUE(guide_store()->FindPredictionModelEntryKey( + proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION, &entry_key)); - for (size_t i = 0; i < update_host_model_features_count; ++i) { - std::string host = GetHost(i); - OptimizationGuideStore::EntryKey entry_key; - EXPECT_TRUE(guide_store()->FindHostModelFeaturesEntryKey(host, &entry_key)); - } + histogram_tester.ExpectUniqueSample( + "OptimizationGuide.PredictionModelExpired.PainfulPageLoad", true, 1); + histogram_tester.ExpectTotalCount( + "OptimizationGuide.PredictionModelExpired.LanguageDetection", 0); +} - // Remove expired host model features from the opt. guide store. - PurgeExpiredHostModelFeatures(); +struct ValidityTestCase { + std::string test_name; + bool keep_beyond_valid_duration; + bool initially_expired; + bool expect_kept; +}; - for (size_t i = 0; i < update_host_model_features_count; ++i) { - std::string host = GetHost(i); - OptimizationGuideStore::EntryKey entry_key; - EXPECT_FALSE( - guide_store()->FindHostModelFeaturesEntryKey(host, &entry_key)); - } -} +class OptimizationGuideStoreValidityTest + : public OptimizationGuideStoreTest, + public ::testing::WithParamInterface<ValidityTestCase> {}; -TEST_F(OptimizationGuideStoreTest, PurgeInactiveModels) { +TEST_P(OptimizationGuideStoreValidityTest, PurgeInactiveModels) { + const ValidityTestCase& test_case = GetParam(); base::HistogramTester histogram_tester; MetadataSchemaState schema_state = MetadataSchemaState::kValid; @@ -2664,22 +2464,30 @@ TEST_F(OptimizationGuideStoreTest, PurgeInactiveModels) { CreateDatabase(); InitializeStore(schema_state); - // Add an update with models that are "inactive". + // Add an update with one model according to ValidityTestCase settings. base::Time update_time = base::Time().Now(); + if (test_case.initially_expired) { + update_time -= optimization_guide::features::StoredModelsValidDuration(); + } else { + update_time += optimization_guide::features::StoredModelsValidDuration(); + } std::unique_ptr<StoreUpdateData> update_data = - guide_store()->CreateUpdateDataForPredictionModels( - update_time - - optimization_guide::features::StoredModelsInactiveDuration()); + guide_store()->CreateUpdateDataForPredictionModels(update_time); ASSERT_TRUE(update_data); + proto::ModelInfo info; + info.set_keep_beyond_valid_duration(test_case.keep_beyond_valid_duration); + base::FilePath old_file_path = temp_dir().AppendASCII("model_v1.tflite"); + ASSERT_EQ(static_cast<int32_t>(3), base::WriteFile(old_file_path, "boo", 3)); SeedPredictionModelUpdateData(update_data.get(), - proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD); + proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, + old_file_path, info, update_time); UpdatePredictionModels(std::move(update_data)); - // Add an update with models that are "active". + // Add an update with models that are "active" and should be unaffected. std::unique_ptr<StoreUpdateData> update_data2 = guide_store()->CreateUpdateDataForPredictionModels( - update_time + - optimization_guide::features::StoredModelsInactiveDuration()); + base::Time().Now() + + optimization_guide::features::StoredModelsValidDuration()); ASSERT_TRUE(update_data2); SeedPredictionModelUpdateData(update_data2.get(), proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION); @@ -2690,22 +2498,67 @@ TEST_F(OptimizationGuideStoreTest, PurgeInactiveModels) { bool success = guide_store()->FindPredictionModelEntryKey( proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, &entry_key); ASSERT_TRUE(success); + EXPECT_TRUE(base::PathExists(old_file_path)); + success = guide_store()->FindPredictionModelEntryKey( proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION, &entry_key); ASSERT_TRUE(success); PurgeInactiveModels(); - - EXPECT_FALSE(guide_store()->FindPredictionModelEntryKey( - proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, &entry_key)); - // Should not purge models that are still active. + RunUntilIdle(); + // Verify that the model file, entry key and histogram match expectations for + // PageLoad. + EXPECT_EQ(test_case.expect_kept, + guide_store()->FindPredictionModelEntryKey( + proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, &entry_key)); + EXPECT_EQ(test_case.expect_kept, base::PathExists(old_file_path)); + + if (test_case.expect_kept) { + histogram_tester.ExpectTotalCount( + "OptimizationGuide.PredictionModelExpired.PainfulPageLoad", 0); + } else { + histogram_tester.ExpectTotalCount( + "OptimizationGuide.PredictionModelExpired.PainfulPageLoad", 1); + } + // Verify that the other model is not deleted. EXPECT_TRUE(guide_store()->FindPredictionModelEntryKey( proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION, &entry_key)); - - histogram_tester.ExpectUniqueSample( - "OptimizationGuide.PredictionModelExpired.PainfulPageLoad", true, 1); histogram_tester.ExpectTotalCount( "OptimizationGuide.PredictionModelExpired.LanguageDetection", 0); } +INSTANTIATE_TEST_SUITE_P( + OptimizationGuideStoreValidityTests, + OptimizationGuideStoreValidityTest, + testing::ValuesIn<ValidityTestCase>({ + { + "KeepDespiteInvalidModel", + /*keep_beyond_valid_duration=*/true, + /*initially_expired=*/true, + /*expect_kept=*/true, + }, + { + "KeepAndInitiallyValid", + /*keep_beyond_valid_duration=*/true, + /*initially_expired=*/false, + /*expect_kept=*/true, + }, + { + "DeleteAndInitiallyValid", + /*keep_beyond_valid_duration=*/false, + /*initially_expired=*/false, + /*expect_kept=*/true, + }, + // Only in this case should the model be removed. + { + "DeleteAndInvalidModel", + /*keep_beyond_valid_duration=*/false, + /*initially_expired=*/true, + /*expect_kept=*/false, + }, + }), + [](const testing::TestParamInfo< + OptimizationGuideStoreValidityTest::ParamType>& info) { + return info.param.test_name; + }); } // namespace optimization_guide diff --git a/chromium/components/optimization_guide/core/optimization_guide_switches.cc b/chromium/components/optimization_guide/core/optimization_guide_switches.cc index 8c55155bf31..e2026ddc1cf 100644 --- a/chromium/components/optimization_guide/core/optimization_guide_switches.cc +++ b/chromium/components/optimization_guide/core/optimization_guide_switches.cc @@ -27,12 +27,6 @@ const char kHintsProtoOverride[] = "optimization_guide_hints_override"; // hosts. const char kFetchHintsOverride[] = "optimization-guide-fetch-hints-override"; -// Overrides scheduling and time delays for fetching prediction models and host -// model features. This causes a prediction model and host model features fetch -// immediately on start up. -const char kFetchModelsAndHostModelFeaturesOverrideTimer[] = - "optimization-guide-fetch-models-and-features-override"; - // Overrides the hints fetch scheduling and delay, causing a hints fetch // immediately on start up using the TopHostProvider. This is meant for testing. const char kFetchHintsOverrideTimer[] = @@ -87,6 +81,14 @@ const char kModelOverride[] = "optimization-guide-model-override"; // Triggers validation of the model. Used for manual testing. const char kModelValidate[] = "optimization-guide-model-validate"; +// Prevents any models from being executing when in annotating a batch +// of visits. This is used for testing only. +const char kStopHistoryVisitBatchAnnotateForTesting[] = + "stop-history-visit-batch-annotate"; + +const char kPageContentAnnotationsLoggingEnabled[] = + "enable-page-content-annotations-logging"; + bool IsHintComponentProcessingDisabled() { return base::CommandLine::ForCurrentProcess()->HasSwitch(kHintsProtoOverride); } @@ -134,11 +136,6 @@ bool ShouldOverrideFetchHintsTimer() { kFetchHintsOverrideTimer); } -bool ShouldOverrideFetchModelsAndFeaturesTimer() { - return base::CommandLine::ForCurrentProcess()->HasSwitch( - kFetchModelsAndHostModelFeaturesOverrideTimer); -} - std::unique_ptr<optimization_guide::proto::Configuration> ParseComponentConfigFromCommandLine() { base::CommandLine* cmd_line = base::CommandLine::ForCurrentProcess(); @@ -191,7 +188,7 @@ bool ShouldValidateModel() { } absl::optional<std::string> GetModelOverride() { -#if defined(OS_WIN) +#if BUILDFLAG(IS_WIN) // TODO(crbug/1227996): The parsing below is not supported on Windows because // ':' is used as a delimiter, but this must be used in the absolute file path // on Windows. @@ -206,5 +203,17 @@ absl::optional<std::string> GetModelOverride() { #endif } +bool StopHistoryVisitBatchAnnotateForTesting() { + base::CommandLine* command_line = base::CommandLine::ForCurrentProcess(); + if (command_line->HasSwitch(kStopHistoryVisitBatchAnnotateForTesting)) + return true; + return false; +} + +bool ShouldLogPageContentAnnotationsInput() { + return base::CommandLine::ForCurrentProcess()->HasSwitch( + kPageContentAnnotationsLoggingEnabled); +} + } // namespace switches } // namespace optimization_guide diff --git a/chromium/components/optimization_guide/core/optimization_guide_switches.h b/chromium/components/optimization_guide/core/optimization_guide_switches.h index 188d06c574b..c7e80ceeabb 100644 --- a/chromium/components/optimization_guide/core/optimization_guide_switches.h +++ b/chromium/components/optimization_guide/core/optimization_guide_switches.h @@ -22,7 +22,6 @@ namespace switches { extern const char kHintsProtoOverride[]; extern const char kFetchHintsOverride[]; extern const char kFetchHintsOverrideTimer[]; -extern const char kFetchModelsAndHostModelFeaturesOverrideTimer[]; extern const char kOptimizationGuideServiceGetHintsURL[]; extern const char kOptimizationGuideServiceGetModelsURL[]; extern const char kOptimizationGuideServiceAPIKey[]; @@ -34,6 +33,8 @@ extern const char kDisableModelDownloadVerificationForTesting[]; extern const char kModelOverride[]; extern const char kDebugLoggingEnabled[]; extern const char kModelValidate[]; +extern const char kStopHistoryVisitBatchAnnotateForTesting[]; +extern const char kPageContentAnnotationsLoggingEnabled[]; // Returns whether the hint component should be processed. // Available hint components are only processed if a proto override isn't being @@ -58,10 +59,6 @@ ParseHintsFetchOverrideFromCommandLine(); // Whether the hints fetcher timer should be overridden. bool ShouldOverrideFetchHintsTimer(); -// Whether the prediction model and host model features fetcher timer should be -// overridden. -bool ShouldOverrideFetchModelsAndFeaturesTimer(); - // Attempts to parse a base64 encoded Optimization Guide Configuration proto // from the command line. If no proto is given or if it is encoded incorrectly, // nullptr is returned. @@ -92,6 +89,13 @@ absl::optional<std::string> GetModelOverride(); // Returns true if debug logs are enabled for the optimization guide. bool IsDebugLogsEnabled(); +// Whether to prevent annotations from happening when in a batch. For testing +// purposes only. +bool StopHistoryVisitBatchAnnotateForTesting(); + +// Returns true if page content annotations input should be logged. +bool ShouldLogPageContentAnnotationsInput(); + } // namespace switches } // namespace optimization_guide diff --git a/chromium/components/optimization_guide/core/optimization_guide_switches_unittest.cc b/chromium/components/optimization_guide/core/optimization_guide_switches_unittest.cc index ab401708455..267633cb400 100644 --- a/chromium/components/optimization_guide/core/optimization_guide_switches_unittest.cc +++ b/chromium/components/optimization_guide/core/optimization_guide_switches_unittest.cc @@ -14,7 +14,7 @@ namespace optimization_guide { namespace switches { -#if !defined(OS_WIN) +#if !BUILDFLAG(IS_WIN) TEST(OptimizationGuideSwitchesTest, ParseHintsFetchOverrideFromCommandLine) { base::CommandLine::ForCurrentProcess()->AppendSwitchASCII(kFetchHintsOverride, diff --git a/chromium/components/optimization_guide/core/optimization_guide_test_util.cc b/chromium/components/optimization_guide/core/optimization_guide_test_util.cc index f1e88b67197..e349570a811 100644 --- a/chromium/components/optimization_guide/core/optimization_guide_test_util.cc +++ b/chromium/components/optimization_guide/core/optimization_guide_test_util.cc @@ -9,7 +9,7 @@ namespace optimization_guide { -#if defined(OS_WIN) +#if BUILDFLAG(IS_WIN) const char kTestAbsoluteFilePath[] = "C:\\absolute/file/path"; const char kTestRelativeFilePath[] = "relative/file/path"; #else diff --git a/chromium/components/optimization_guide/core/optimization_guide_util.cc b/chromium/components/optimization_guide/core/optimization_guide_util.cc index 392f806ac5a..15acd17ad35 100644 --- a/chromium/components/optimization_guide/core/optimization_guide_util.cc +++ b/chromium/components/optimization_guide/core/optimization_guide_util.cc @@ -19,50 +19,6 @@ namespace optimization_guide { -// These names are persisted to histograms, so don't change them. -std::string GetStringNameForOptimizationTarget( - optimization_guide::proto::OptimizationTarget optimization_target) { - switch (optimization_target) { - case proto::OPTIMIZATION_TARGET_UNKNOWN: - return "Unknown"; - case proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD: - return "PainfulPageLoad"; - case proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION: - return "LanguageDetection"; - case proto::OPTIMIZATION_TARGET_PAGE_TOPICS: - return "PageTopics"; - case proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB: - return "SegmentationNewTab"; - case proto::OPTIMIZATION_TARGET_SEGMENTATION_SHARE: - return "SegmentationShare"; - case proto::OPTIMIZATION_TARGET_SEGMENTATION_VOICE: - return "SegmentationVoice"; - case proto::OPTIMIZATION_TARGET_MODEL_VALIDATION: - return "ModelValidation"; - case proto::OPTIMIZATION_TARGET_PAGE_ENTITIES: - return "PageEntities"; - case proto::OPTIMIZATION_TARGET_NOTIFICATION_PERMISSION_PREDICTIONS: - return "NotificationPermissions"; - case proto::OPTIMIZATION_TARGET_SEGMENTATION_DUMMY: - return "SegmentationDummyFeature"; - case proto::OPTIMIZATION_TARGET_SEGMENTATION_CHROME_START_ANDROID: - return "SegmentationChromeStartAndroid"; - case proto::OPTIMIZATION_TARGET_SEGMENTATION_QUERY_TILES: - return "SegmentationQueryTiles"; - case proto::OPTIMIZATION_TARGET_PAGE_VISIBILITY: - return "PageVisibility"; - case proto::OPTIMIZATION_TARGET_AUTOFILL_ASSISTANT: - return "AutofillAssistant"; - case proto::OPTIMIZATION_TARGET_PAGE_TOPICS_V2: - return "PageTopicsV2"; - // Whenever a new value is added, make sure to add it to the OptTarget - // variant list in - // //tools/metrics/histograms/metadata/optimization/histograms.xml. - } - NOTREACHED(); - return std::string(); -} - bool IsHostValidToFetchFromRemoteOptimizationGuide(const std::string& host) { if (net::HostStringIsLocalhost(host)) return false; @@ -108,29 +64,6 @@ GetActiveFieldTrialsAllowedForFetch() { return filtered_active_field_trials; } -absl::optional<base::FilePath> StringToFilePath(const std::string& str_path) { - if (str_path.empty()) - return absl::nullopt; - -#if defined(OS_WIN) - return base::FilePath(base::UTF8ToWide(str_path)); -#else - return base::FilePath(str_path); -#endif -} - -std::string FilePathToString(const base::FilePath& file_path) { -#if defined(OS_WIN) - return base::WideToUTF8(file_path.value()); -#else - return file_path.value(); -#endif -} - -base::FilePath GetBaseFileNameForModels() { - return base::FilePath(FILE_PATH_LITERAL("model.tflite")); -} - std::string GetStringForOptimizationGuideDecision( OptimizationGuideDecision decision) { switch (decision) { @@ -149,7 +82,7 @@ absl::optional< std::pair<std::string, absl::optional<optimization_guide::proto::Any>>> GetModelOverrideForOptimizationTarget( optimization_guide::proto::OptimizationTarget optimization_target) { -#if defined(OS_WIN) +#if BUILDFLAG(IS_WIN) // TODO(crbug/1227996): The parsing below is not supported on Windows because // ':' is used as a delimiter, but this must be used in the absolute file path // on Windows. diff --git a/chromium/components/optimization_guide/core/optimization_guide_util.h b/chromium/components/optimization_guide/core/optimization_guide_util.h index f425dde56a1..90b70f34bc1 100644 --- a/chromium/components/optimization_guide/core/optimization_guide_util.h +++ b/chromium/components/optimization_guide/core/optimization_guide_util.h @@ -7,24 +7,28 @@ #include <string> -#include "base/files/file_path.h" #include "base/strings/string_split.h" +#include "base/time/time.h" #include "components/optimization_guide/core/optimization_guide_enums.h" #include "components/optimization_guide/proto/common_types.pb.h" #include "components/optimization_guide/proto/models.pb.h" #include "third_party/abseil-cpp/absl/types/optional.h" +#define OPTIMIZATION_GUIDE_LOG(optimization_guide_logger, message) \ + do { \ + if (optimization_guide_logger && \ + optimization_guide_logger->ShouldEnableDebugLogs()) { \ + optimization_guide_logger->OnLogMessageAdded( \ + base::Time::Now(), __FILE__, __LINE__, message); \ + } \ + if (optimization_guide::switches::IsDebugLogsEnabled()) \ + DVLOG(0) << message; \ + } while (0) + namespace optimization_guide { enum class OptimizationGuideDecision; -// Returns the string than can be used to record histograms for the optimization -// target. If adding a histogram to use the string or adding an optimization -// target, update the OptimizationGuide.OptimizationTargets histogram suffixes -// in histograms.xml. -std::string GetStringNameForOptimizationTarget( - proto::OptimizationTarget optimization_target); - // Returns false if the host is an IP address, localhosts, or an invalid // host that is not supported by the remote optimization guide. bool IsHostValidToFetchFromRemoteOptimizationGuide(const std::string& host); @@ -34,18 +38,6 @@ bool IsHostValidToFetchFromRemoteOptimizationGuide(const std::string& host); google::protobuf::RepeatedPtrField<proto::FieldTrial> GetActiveFieldTrialsAllowedForFetch(); -// Returns the file path represented by the given string, handling platform -// differences in the conversion. nullopt is only returned iff the passed string -// is empty. -absl::optional<base::FilePath> StringToFilePath(const std::string& str_path); - -// Returns a string representation of the given |file_path|, handling platform -// differences in the conversion. -std::string FilePathToString(const base::FilePath& file_path); - -// Returns the base file name to use for storing all prediction models. -base::FilePath GetBaseFileNameForModels(); - // Validates that the metadata stored in |any_metadata_| is of the same type // and is parseable as |T|. Will return metadata if all checks pass. template <class T, diff --git a/chromium/components/optimization_guide/core/optimization_guide_util_unittest.cc b/chromium/components/optimization_guide/core/optimization_guide_util_unittest.cc index 1d400833462..f3a7fee11ae 100644 --- a/chromium/components/optimization_guide/core/optimization_guide_util_unittest.cc +++ b/chromium/components/optimization_guide/core/optimization_guide_util_unittest.cc @@ -63,7 +63,7 @@ TEST(OptimizationGuideUtilTest, ParsedAnyMetadataTest) { EXPECT_TRUE(parsed_subresource.preconnect_only()); } -#if !defined(OS_WIN) +#if !BUILDFLAG(IS_WIN) TEST(OptimizationGuideUtilTest, GetModelOverrideForOptimizationTargetSwitchNotSet) { diff --git a/chromium/components/optimization_guide/core/optimization_hints_component_update_listener.cc b/chromium/components/optimization_guide/core/optimization_hints_component_update_listener.cc index 4ea36ecd4cb..019c0d464ef 100644 --- a/chromium/components/optimization_guide/core/optimization_hints_component_update_listener.cc +++ b/chromium/components/optimization_guide/core/optimization_hints_component_update_listener.cc @@ -5,6 +5,7 @@ #include "components/optimization_guide/core/optimization_hints_component_update_listener.h" #include "base/metrics/histogram_functions.h" +#include "base/no_destructor.h" namespace optimization_guide { diff --git a/chromium/components/optimization_guide/core/page_content_annotation_job.cc b/chromium/components/optimization_guide/core/page_content_annotation_job.cc index 4ac77c0091b..45104edb920 100644 --- a/chromium/components/optimization_guide/core/page_content_annotation_job.cc +++ b/chromium/components/optimization_guide/core/page_content_annotation_job.cc @@ -5,6 +5,7 @@ #include "components/optimization_guide/core/page_content_annotation_job.h" #include "base/check_op.h" +#include "base/metrics/histogram_functions.h" namespace optimization_guide { @@ -14,11 +15,35 @@ PageContentAnnotationJob::PageContentAnnotationJob( AnnotationType type) : on_complete_callback_(std::move(on_complete_callback)), type_(type), - inputs_(inputs.begin(), inputs.end()) { + inputs_(inputs.begin(), inputs.end()), + job_creation_time_(base::TimeTicks::Now()) { DCHECK(!inputs_.empty()); } -PageContentAnnotationJob::~PageContentAnnotationJob() = default; +PageContentAnnotationJob::~PageContentAnnotationJob() { + if (!job_execution_start_time_) + return; + + base::TimeDelta job_scheduling_wait_time = + *job_execution_start_time_ - job_creation_time_; + base::TimeDelta job_exec_time = + base::TimeTicks::Now() - *job_execution_start_time_; + + base::UmaHistogramMediumTimes( + "OptimizationGuide.PageContentAnnotations.JobExecutionTime." + + AnnotationTypeToString(type()), + job_exec_time); + + base::UmaHistogramMediumTimes( + "OptimizationGuide.PageContentAnnotations.JobScheduleTime." + + AnnotationTypeToString(type()), + job_scheduling_wait_time); + + base::UmaHistogramBoolean( + "OptimizationGuide.PageContentAnnotations.BatchSuccess." + + AnnotationTypeToString(type()), + HadAnySuccess()); +} void PageContentAnnotationJob::FillWithNullOutputs() { while (auto input = GetNextInput()) { @@ -59,6 +84,10 @@ size_t PageContentAnnotationJob::CountOfRemainingNonNullInputs() const { } absl::optional<std::string> PageContentAnnotationJob::GetNextInput() { + if (!job_execution_start_time_) { + job_execution_start_time_ = base::TimeTicks::Now(); + } + if (inputs_.empty()) { return absl::nullopt; } @@ -72,4 +101,20 @@ void PageContentAnnotationJob::PostNewResult( results_.push_back(result); } +bool PageContentAnnotationJob::HadAnySuccess() const { + for (const BatchAnnotationResult& result : results_) { + if (result.type() == AnnotationType::kPageTopics && result.topics()) { + return true; + } + if (result.type() == AnnotationType::kPageEntities && result.entities()) { + return true; + } + if (result.type() == AnnotationType::kContentVisibility && + result.visibility_score()) { + return true; + } + } + return false; +} + } // namespace optimization_guide diff --git a/chromium/components/optimization_guide/core/page_content_annotation_job.h b/chromium/components/optimization_guide/core/page_content_annotation_job.h index 989736df1e1..9ba68090c00 100644 --- a/chromium/components/optimization_guide/core/page_content_annotation_job.h +++ b/chromium/components/optimization_guide/core/page_content_annotation_job.h @@ -10,6 +10,7 @@ #include <vector> #include "base/callback.h" +#include "base/time/time.h" #include "components/optimization_guide/core/page_content_annotations_common.h" #include "third_party/abseil-cpp/absl/types/optional.h" @@ -21,8 +22,6 @@ namespace optimization_guide { // container that matches the I/O of a single call to the PCA Service. class PageContentAnnotationJob { public: - using WeightedCategories = std::vector<WeightedString>; - PageContentAnnotationJob(BatchAnnotationCallback on_complete_callback, const std::vector<std::string>& inputs, AnnotationType type); @@ -46,6 +45,10 @@ class PageContentAnnotationJob { // Posts a new result after an execution has completed. void PostNewResult(const BatchAnnotationResult& result); + // Returns true if any element of |results_| was a successful execution. We + // expect that if one result is successful, many more will be as well. + bool HadAnySuccess() const; + AnnotationType type() const { return type_; } PageContentAnnotationJob(const PageContentAnnotationJob&) = delete; @@ -65,6 +68,12 @@ class PageContentAnnotationJob { // Filled by |PostNewResult| with the complete annotations, specified by // |type_|. std::vector<BatchAnnotationResult> results_; + + // The time the job was constructed. + const base::TimeTicks job_creation_time_; + + // Set when |GetNextInput| is called for the first time. + absl::optional<base::TimeTicks> job_execution_start_time_; }; } // namespace optimization_guide diff --git a/chromium/components/optimization_guide/core/page_content_annotation_job_executor.cc b/chromium/components/optimization_guide/core/page_content_annotation_job_executor.cc index 38f617cef94..779524c94d2 100644 --- a/chromium/components/optimization_guide/core/page_content_annotation_job_executor.cc +++ b/chromium/components/optimization_guide/core/page_content_annotation_job_executor.cc @@ -73,6 +73,8 @@ void PageContentAnnotationJobExecutor::OnJobExecutionComplete( std::unique_ptr<PageContentAnnotationJob> job) { job->OnComplete(); // Intentionally reset |job| here to make lifetime clearer and less bug-prone. + // Note that the job dtor also records some timing metrics which is better to + // do now rather than after the following callback. job.reset(); std::move(on_job_complete_callback_from_caller).Run(); diff --git a/chromium/components/optimization_guide/core/page_content_annotation_job_executor_unittest.cc b/chromium/components/optimization_guide/core/page_content_annotation_job_executor_unittest.cc index d16ffb6859f..5fe0a6a48a6 100644 --- a/chromium/components/optimization_guide/core/page_content_annotation_job_executor_unittest.cc +++ b/chromium/components/optimization_guide/core/page_content_annotation_job_executor_unittest.cc @@ -15,7 +15,7 @@ namespace optimization_guide { namespace { -const std::vector<WeightedString> kOutput{WeightedString("output", 1.0)}; +const std::vector<WeightedIdentifier> kOutput{WeightedIdentifier(1337, 1.0)}; } class TestJobExecutor : public PageContentAnnotationJobExecutor { diff --git a/chromium/components/optimization_guide/core/page_content_annotations_common.cc b/chromium/components/optimization_guide/core/page_content_annotations_common.cc index 08bd362e404..62f637d7443 100644 --- a/chromium/components/optimization_guide/core/page_content_annotations_common.cc +++ b/chromium/components/optimization_guide/core/page_content_annotations_common.cc @@ -13,6 +13,9 @@ namespace optimization_guide { +// Each of these string values is used in UMA histograms so please update the +// variants there when any changes are made. +// //tools/metrics/histograms/metadata/optimization/histograms.xml std::string AnnotationTypeToString(AnnotationType type) { switch (type) { case AnnotationType::kUnknown: @@ -26,26 +29,25 @@ std::string AnnotationTypeToString(AnnotationType type) { } } -WeightedString::WeightedString(const std::string& value, double weight) +WeightedIdentifier::WeightedIdentifier(int32_t value, double weight) : value_(value), weight_(weight) { DCHECK_GE(weight_, 0.0); DCHECK_LE(weight_, 1.0); } -WeightedString::WeightedString(const WeightedString&) = default; -WeightedString::~WeightedString() = default; +WeightedIdentifier::WeightedIdentifier(const WeightedIdentifier&) = default; +WeightedIdentifier::~WeightedIdentifier() = default; -bool WeightedString::operator==(const WeightedString& other) const { +bool WeightedIdentifier::operator==(const WeightedIdentifier& other) const { constexpr double kWeightTolerance = 1e-6; return this->value_ == other.value_ && abs(this->weight_ - other.weight_) <= kWeightTolerance; } -std::string WeightedString::ToString() const { - return base::StringPrintf("WeightedString{\"%s\",%f}", value().c_str(), - weight()); +std::string WeightedIdentifier::ToString() const { + return base::StringPrintf("WeightedIdentifier{%d,%f}", value(), weight()); } -std::ostream& operator<<(std::ostream& stream, const WeightedString& ws) { +std::ostream& operator<<(std::ostream& stream, const WeightedIdentifier& ws) { stream << ws.ToString(); return stream; } @@ -58,11 +60,11 @@ BatchAnnotationResult::~BatchAnnotationResult() = default; std::string BatchAnnotationResult::ToString() const { std::string output = "nullopt"; if (topics_) { - std::vector<std::string> all_weighted_strings; - for (const WeightedString& ws : *topics_) { - all_weighted_strings.push_back(ws.ToString()); + std::vector<std::string> all_weighted_ids; + for (const WeightedIdentifier& wi : *topics_) { + all_weighted_ids.push_back(wi.ToString()); } - output = "{" + base::JoinString(all_weighted_strings, ",") + "}"; + output = "{" + base::JoinString(all_weighted_ids, ",") + "}"; } else if (entities_) { std::vector<std::string> all_entities; for (const ScoredEntityMetadata& md : *entities_) { @@ -89,7 +91,7 @@ std::ostream& operator<<(std::ostream& stream, // static BatchAnnotationResult BatchAnnotationResult::CreatePageTopicsResult( const std::string& input, - absl::optional<std::vector<WeightedString>> topics) { + absl::optional<std::vector<WeightedIdentifier>> topics) { BatchAnnotationResult result; result.input_ = input; result.topics_ = topics; @@ -98,7 +100,7 @@ BatchAnnotationResult BatchAnnotationResult::CreatePageTopicsResult( // Always sort the result (if present) by the given score. if (result.topics_) { std::sort(result.topics_->begin(), result.topics_->end(), - [](const WeightedString& a, const WeightedString& b) { + [](const WeightedIdentifier& a, const WeightedIdentifier& b) { return a.weight() < b.weight(); }); } diff --git a/chromium/components/optimization_guide/core/page_content_annotations_common.h b/chromium/components/optimization_guide/core/page_content_annotations_common.h index 5679064b8d0..4f77386485c 100644 --- a/chromium/components/optimization_guide/core/page_content_annotations_common.h +++ b/chromium/components/optimization_guide/core/page_content_annotations_common.h @@ -15,6 +15,10 @@ namespace optimization_guide { // The type of annotation that is being done on the given input. +// +// Each of these is used in UMA histograms so please update the variants there +// when any changes are made. +// //tools/metrics/histograms/metadata/optimization/histograms.xml enum class AnnotationType { kUnknown, @@ -33,25 +37,25 @@ enum class AnnotationType { std::string AnnotationTypeToString(AnnotationType type); -// A weighted string value. -class WeightedString { +// A weighted ID value. +class WeightedIdentifier { public: - WeightedString(const std::string& value, double weight); - WeightedString(const WeightedString&); - ~WeightedString(); + WeightedIdentifier(int32_t value, double weight); + WeightedIdentifier(const WeightedIdentifier&); + ~WeightedIdentifier(); - std::string value() const { return value_; } + int32_t value() const { return value_; } double weight() const { return weight_; } std::string ToString() const; - bool operator==(const WeightedString& other) const; + bool operator==(const WeightedIdentifier& other) const; friend std::ostream& operator<<(std::ostream& stream, - const WeightedString& ws); + const WeightedIdentifier& ws); private: - std::string value_; + int32_t value_; // In the range of [0.0, 1.0]. double weight_ = 0; @@ -63,7 +67,7 @@ class BatchAnnotationResult { // Creates a result for a page topics annotation. static BatchAnnotationResult CreatePageTopicsResult( const std::string& input, - absl::optional<std::vector<WeightedString>> topics); + absl::optional<std::vector<WeightedIdentifier>> topics); // Creates a result for a page entities annotation. static BatchAnnotationResult CreatePageEntitiesResult( @@ -84,7 +88,9 @@ class BatchAnnotationResult { std::string input() const { return input_; } AnnotationType type() const { return type_; } - absl::optional<std::vector<WeightedString>> topics() const { return topics_; } + absl::optional<std::vector<WeightedIdentifier>> topics() const { + return topics_; + } absl::optional<std::vector<ScoredEntityMetadata>> entities() const { return entities_; } @@ -105,7 +111,7 @@ class BatchAnnotationResult { // Output for page topics annotations, set only if the |type_| matches and the // execution was successful. - absl::optional<std::vector<WeightedString>> topics_; + absl::optional<std::vector<WeightedIdentifier>> topics_; // Output for page entities annotations, set only if the |type_| matches and // the execution was successful. diff --git a/chromium/components/optimization_guide/core/page_entities_model_executor.h b/chromium/components/optimization_guide/core/page_entities_model_executor.h index 919e7b76f76..5ea55c6e243 100644 --- a/chromium/components/optimization_guide/core/page_entities_model_executor.h +++ b/chromium/components/optimization_guide/core/page_entities_model_executor.h @@ -15,7 +15,7 @@ namespace optimization_guide { -// TODO(crbug/1249632): Remove this entirely. +// TODO(crbug/1278828): Remove this entirely. class HumanReadablePageEntitiesModelExecutor { public: virtual ~HumanReadablePageEntitiesModelExecutor() = default; diff --git a/chromium/components/optimization_guide/core/page_entities_model_executor_impl.cc b/chromium/components/optimization_guide/core/page_entities_model_executor_impl.cc new file mode 100644 index 00000000000..d3a34832f6b --- /dev/null +++ b/chromium/components/optimization_guide/core/page_entities_model_executor_impl.cc @@ -0,0 +1,230 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "components/optimization_guide/core/page_entities_model_executor_impl.h" + +#include "base/metrics/histogram_functions.h" +#include "base/threading/sequenced_task_runner_handle.h" +#include "base/timer/elapsed_timer.h" +#include "components/optimization_guide/core/entity_annotator_native_library.h" +#include "components/optimization_guide/core/optimization_guide_features.h" +#include "components/optimization_guide/core/optimization_guide_model_provider.h" +#include "components/optimization_guide/proto/page_entities_model_metadata.pb.h" + +namespace optimization_guide { + +namespace { + +const char kPageEntitiesModelMetadataTypeUrl[] = + "type.googleapis.com/" + "google.internal.chrome.optimizationguide.v1.PageEntitiesModelMetadata"; + +} // namespace + +EntityAnnotatorHolder::EntityAnnotatorHolder( + scoped_refptr<base::SequencedTaskRunner> background_task_runner, + scoped_refptr<base::SequencedTaskRunner> reply_task_runner) + : background_task_runner_(background_task_runner), + reply_task_runner_(reply_task_runner) {} + +EntityAnnotatorHolder::~EntityAnnotatorHolder() { + DCHECK(background_task_runner_->RunsTasksInCurrentSequence()); + + if (features::ShouldResetPageEntitiesModelOnShutdown()) { + ResetEntityAnnotator(); + } +} + +void EntityAnnotatorHolder:: + InitializeEntityAnnotatorNativeLibraryOnBackgroundThread( + base::OnceCallback<void(int32_t)> init_callback) { + DCHECK(background_task_runner_->RunsTasksInCurrentSequence()); + + DCHECK(!entity_annotator_native_library_); + if (entity_annotator_native_library_) { + // We should only be initialized once but in case someone does something + // wrong in a non-debug build, we invoke the callback anyway. + reply_task_runner_->PostTask( + FROM_HERE, + base::BindOnce( + std::move(init_callback), + entity_annotator_native_library_->GetMaxSupportedFeatureFlag())); + return; + } + + entity_annotator_native_library_ = EntityAnnotatorNativeLibrary::Create(); + if (!entity_annotator_native_library_) { + reply_task_runner_->PostTask(FROM_HERE, + base::BindOnce(std::move(init_callback), -1)); + return; + } + + int32_t max_supported_feature_flag = + entity_annotator_native_library_->GetMaxSupportedFeatureFlag(); + reply_task_runner_->PostTask( + FROM_HERE, + base::BindOnce(std::move(init_callback), max_supported_feature_flag)); +} + +void EntityAnnotatorHolder::ResetEntityAnnotator() { + DCHECK(background_task_runner_->RunsTasksInCurrentSequence()); + + if (entity_annotator_) { + DCHECK(entity_annotator_native_library_); + entity_annotator_native_library_->DeleteEntityAnnotator(entity_annotator_); + + entity_annotator_ = nullptr; + } +} + +void EntityAnnotatorHolder::CreateAndSetEntityAnnotatorOnBackgroundThread( + const ModelInfo& model_info) { + DCHECK(background_task_runner_->RunsTasksInCurrentSequence()); + + if (!entity_annotator_native_library_) { + return; + } + + ResetEntityAnnotator(); + + entity_annotator_ = + entity_annotator_native_library_->CreateEntityAnnotator(model_info); +} + +void EntityAnnotatorHolder::AnnotateEntitiesMetadataModelOnBackgroundThread( + const std::string& text, + PageEntitiesMetadataModelExecutedCallback callback) { + DCHECK(background_task_runner_->RunsTasksInCurrentSequence()); + base::ElapsedThreadTimer annotate_timer; + + absl::optional<std::vector<ScoredEntityMetadata>> scored_md; + if (entity_annotator_) { + DCHECK(entity_annotator_native_library_); + base::TimeTicks start_time = base::TimeTicks::Now(); + scored_md = + entity_annotator_native_library_->AnnotateText(entity_annotator_, text); + // The max of the below histograms is 1 hour because we want to understand + // tail behavior and catch long running model executions. + base::UmaHistogramLongTimes( + "OptimizationGuide.PageContentAnnotationsService.ModelExecutionLatency." + "PageEntities", + base::TimeTicks::Now() - start_time); + + base::UmaHistogramLongTimes( + "OptimizationGuide.PageContentAnnotationsService." + "ModelThreadExecutionLatency.PageEntities", + annotate_timer.Elapsed()); + } + reply_task_runner_->PostTask(FROM_HERE, + base::BindOnce(std::move(callback), scored_md)); +} + +void EntityAnnotatorHolder::GetMetadataForEntityIdOnBackgroundThread( + const std::string& entity_id, + PageEntitiesModelExecutor::PageEntitiesModelEntityMetadataRetrievedCallback + callback) { + DCHECK(background_task_runner_->RunsTasksInCurrentSequence()); + + absl::optional<EntityMetadata> entity_metadata; + if (entity_annotator_) { + DCHECK(entity_annotator_native_library_); + entity_metadata = + entity_annotator_native_library_->GetEntityMetadataForEntityId( + entity_annotator_, entity_id); + } + reply_task_runner_->PostTask( + FROM_HERE, + base::BindOnce(std::move(callback), std::move(entity_metadata))); +} + +base::WeakPtr<EntityAnnotatorHolder> +EntityAnnotatorHolder::GetBackgroundWeakPtr() { + return background_weak_ptr_factory_.GetWeakPtr(); +} + +PageEntitiesModelExecutorImpl::PageEntitiesModelExecutorImpl( + OptimizationGuideModelProvider* optimization_guide_model_provider, + scoped_refptr<base::SequencedTaskRunner> background_task_runner) + : background_task_runner_(background_task_runner), + entity_annotator_holder_(std::make_unique<EntityAnnotatorHolder>( + background_task_runner_, + base::SequencedTaskRunnerHandle::Get())) { + background_task_runner_->PostTask( + FROM_HERE, + base::BindOnce( + &EntityAnnotatorHolder:: + InitializeEntityAnnotatorNativeLibraryOnBackgroundThread, + entity_annotator_holder_->GetBackgroundWeakPtr(), + base::BindOnce(&PageEntitiesModelExecutorImpl:: + OnEntityAnnotatorLibraryInitialized, + weak_ptr_factory_.GetWeakPtr(), + optimization_guide_model_provider))); +} + +void PageEntitiesModelExecutorImpl::OnEntityAnnotatorLibraryInitialized( + OptimizationGuideModelProvider* optimization_guide_model_provider, + int32_t max_model_format_feature_flag) { + if (max_model_format_feature_flag <= 0) { + return; + } + + proto::Any any_metadata; + any_metadata.set_type_url(kPageEntitiesModelMetadataTypeUrl); + proto::PageEntitiesModelMetadata model_metadata; + model_metadata.set_max_model_format_feature_flag( + max_model_format_feature_flag); + model_metadata.SerializeToString(any_metadata.mutable_value()); + optimization_guide_model_provider->AddObserverForOptimizationTargetModel( + proto::OptimizationTarget::OPTIMIZATION_TARGET_PAGE_ENTITIES, + any_metadata, this); +} + +PageEntitiesModelExecutorImpl::~PageEntitiesModelExecutorImpl() { + // |entity_annotator_holder_|'s WeakPtrs are used on the background thread, + // so that is also where the class must be destroyed. + background_task_runner_->DeleteSoon(FROM_HERE, + std::move(entity_annotator_holder_)); +} + +void PageEntitiesModelExecutorImpl::OnModelUpdated( + proto::OptimizationTarget optimization_target, + const ModelInfo& model_info) { + if (optimization_target != proto::OPTIMIZATION_TARGET_PAGE_ENTITIES) + return; + + background_task_runner_->PostTask( + FROM_HERE, + base::BindOnce( + &EntityAnnotatorHolder::CreateAndSetEntityAnnotatorOnBackgroundThread, + entity_annotator_holder_->GetBackgroundWeakPtr(), model_info)); +} + +void PageEntitiesModelExecutorImpl::HumanReadableExecuteModelWithInput( + const std::string& text, + PageEntitiesMetadataModelExecutedCallback callback) { + if (text.empty()) { + std::move(callback).Run(absl::nullopt); + return; + } + + background_task_runner_->PostTask( + FROM_HERE, + base::BindOnce(&EntityAnnotatorHolder:: + AnnotateEntitiesMetadataModelOnBackgroundThread, + entity_annotator_holder_->GetBackgroundWeakPtr(), text, + std::move(callback))); +} + +void PageEntitiesModelExecutorImpl::GetMetadataForEntityId( + const std::string& entity_id, + PageEntitiesModelEntityMetadataRetrievedCallback callback) { + background_task_runner_->PostTask( + FROM_HERE, + base::BindOnce( + &EntityAnnotatorHolder::GetMetadataForEntityIdOnBackgroundThread, + entity_annotator_holder_->GetBackgroundWeakPtr(), entity_id, + std::move(callback))); +} + +} // namespace optimization_guide diff --git a/chromium/components/optimization_guide/core/page_entities_model_executor_impl.h b/chromium/components/optimization_guide/core/page_entities_model_executor_impl.h new file mode 100644 index 00000000000..d39944fc261 --- /dev/null +++ b/chromium/components/optimization_guide/core/page_entities_model_executor_impl.h @@ -0,0 +1,115 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef COMPONENTS_OPTIMIZATION_GUIDE_CORE_PAGE_ENTITIES_MODEL_EXECUTOR_IMPL_H_ +#define COMPONENTS_OPTIMIZATION_GUIDE_CORE_PAGE_ENTITIES_MODEL_EXECUTOR_IMPL_H_ + +#include "base/task/sequenced_task_runner.h" +#include "base/task/task_traits.h" +#include "base/task/thread_pool.h" +#include "components/optimization_guide/core/entity_metadata.h" +#include "components/optimization_guide/core/optimization_target_model_observer.h" +#include "components/optimization_guide/core/page_entities_model_executor.h" + +namespace optimization_guide { + +class EntityAnnotatorNativeLibrary; +class OptimizationGuideModelProvider; + +// An object used to hold an entity annotator on a background thread. +class EntityAnnotatorHolder { + public: + EntityAnnotatorHolder( + scoped_refptr<base::SequencedTaskRunner> background_task_runner, + scoped_refptr<base::SequencedTaskRunner> reply_task_runner); + ~EntityAnnotatorHolder(); + + // Initializes the native library on a background thread. Will invoke + // |init_callback| on |reply_task_runner_| with the max version supported for + // the entity annotator on success. Otherwise, -1. + void InitializeEntityAnnotatorNativeLibraryOnBackgroundThread( + base::OnceCallback<void(int32_t)> init_callback); + + // Creates an entity annotator on the background thread and sets it to + // |entity_annotator_|. Should be invoked on |background_task_runner_|. + void CreateAndSetEntityAnnotatorOnBackgroundThread( + const ModelInfo& model_info); + + // Requests for |entity_annotator_| to execute its model for |text| and map + // the entities back to their metadata. Should be invoked on + // |background_task_runner_|. + using PageEntitiesMetadataModelExecutedCallback = base::OnceCallback<void( + const absl::optional<std::vector<ScoredEntityMetadata>>&)>; + void AnnotateEntitiesMetadataModelOnBackgroundThread( + const std::string& text, + PageEntitiesMetadataModelExecutedCallback callback); + + // Returns entity metadata from |entity_annotator_| for |entity_id|. + // Should be invoked on |background_task_runner_|. + void GetMetadataForEntityIdOnBackgroundThread( + const std::string& entity_id, + PageEntitiesModelExecutor:: + PageEntitiesModelEntityMetadataRetrievedCallback callback); + + // Gets the weak ptr to |this| on the background thread. + base::WeakPtr<EntityAnnotatorHolder> GetBackgroundWeakPtr(); + + private: + void ResetEntityAnnotator(); + + scoped_refptr<base::SequencedTaskRunner> background_task_runner_; + scoped_refptr<base::SequencedTaskRunner> reply_task_runner_; + + std::unique_ptr<EntityAnnotatorNativeLibrary> + entity_annotator_native_library_; + void* entity_annotator_ = nullptr; + + base::WeakPtrFactory<EntityAnnotatorHolder> background_weak_ptr_factory_{ + this}; +}; + +// Manages the loading and execution of the page entities model. +class PageEntitiesModelExecutorImpl : public OptimizationTargetModelObserver, + public PageEntitiesModelExecutor { + public: + PageEntitiesModelExecutorImpl( + OptimizationGuideModelProvider* model_provider, + scoped_refptr<base::SequencedTaskRunner> background_task_runner = + base::ThreadPool::CreateSequencedTaskRunner( + {base::MayBlock(), base::TaskPriority::BEST_EFFORT})); + ~PageEntitiesModelExecutorImpl() override; + PageEntitiesModelExecutorImpl(const PageEntitiesModelExecutorImpl&) = delete; + PageEntitiesModelExecutorImpl& operator=( + const PageEntitiesModelExecutorImpl&) = delete; + + // PageEntitiesModelExecutor: + void GetMetadataForEntityId( + const std::string& entity_id, + PageEntitiesModelEntityMetadataRetrievedCallback callback) override; + void HumanReadableExecuteModelWithInput( + const std::string& text, + PageEntitiesMetadataModelExecutedCallback callback) override; + + // OptimizationTargetModelObserver: + void OnModelUpdated(proto::OptimizationTarget optimization_target, + const ModelInfo& model_info) override; + + private: + // Invoked on the UI thread when entity annotator library has been + // initialized. + void OnEntityAnnotatorLibraryInitialized( + OptimizationGuideModelProvider* model_provider, + int32_t max_model_format_feature_flag); + + scoped_refptr<base::SequencedTaskRunner> background_task_runner_; + + // The holder used to hold the annotator used to annotate entities. + std::unique_ptr<EntityAnnotatorHolder> entity_annotator_holder_; + + base::WeakPtrFactory<PageEntitiesModelExecutorImpl> weak_ptr_factory_{this}; +}; + +} // namespace optimization_guide + +#endif // COMPONENTS_OPTIMIZATION_GUIDE_CORE_PAGE_ENTITIES_MODEL_EXECUTOR_IMPL_H_ diff --git a/chromium/components/optimization_guide/core/page_entities_model_executor_impl_unittest.cc b/chromium/components/optimization_guide/core/page_entities_model_executor_impl_unittest.cc new file mode 100644 index 00000000000..e1954454829 --- /dev/null +++ b/chromium/components/optimization_guide/core/page_entities_model_executor_impl_unittest.cc @@ -0,0 +1,268 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "components/optimization_guide/core/page_entities_model_executor_impl.h" + +#include "base/observer_list.h" +#include "base/path_service.h" +#include "base/run_loop.h" +#include "base/test/task_environment.h" +#include "components/optimization_guide/core/model_util.h" +#include "components/optimization_guide/core/optimization_guide_util.h" +#include "components/optimization_guide/core/test_model_info_builder.h" +#include "components/optimization_guide/core/test_optimization_guide_model_provider.h" +#include "components/optimization_guide/proto/page_entities_model_metadata.pb.h" +#include "testing/gmock/include/gmock/gmock.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace optimization_guide { +namespace { + +using ::testing::ElementsAre; + +class ModelObserverTracker : public TestOptimizationGuideModelProvider { + public: + void AddObserverForOptimizationTargetModel( + proto::OptimizationTarget target, + const absl::optional<proto::Any>& model_metadata, + OptimizationTargetModelObserver* observer) override { + registered_model_metadata_.insert_or_assign(target, model_metadata); + registered_observers_.AddObserver(observer); + } + + void RemoveObserverForOptimizationTargetModel( + proto::OptimizationTarget target, + OptimizationTargetModelObserver* observer) override { + registered_observers_.RemoveObserver(observer); + } + + bool DidRegisterForTarget( + proto::OptimizationTarget target, + absl::optional<proto::Any>* out_model_metadata) const { + auto it = registered_model_metadata_.find(target); + if (it == registered_model_metadata_.end()) + return false; + *out_model_metadata = registered_model_metadata_.at(target); + return true; + } + + void PushModelInfoToObservers(const ModelInfo& model_info) { + for (auto& observer : registered_observers_) { + observer.OnModelUpdated(proto::OPTIMIZATION_TARGET_PAGE_ENTITIES, + model_info); + } + } + + private: + base::flat_map<proto::OptimizationTarget, absl::optional<proto::Any>> + registered_model_metadata_; + base::ObserverList<OptimizationTargetModelObserver> registered_observers_; +}; + +class PageEntitiesModelExecutorImplTest : public testing::Test { + public: + void SetUp() override { + model_observer_tracker_ = std::make_unique<ModelObserverTracker>(); + model_executor_ = std::make_unique<PageEntitiesModelExecutorImpl>( + model_observer_tracker_.get()); + + // Wait for PageEntitiesModelExecutor to set everything up. + task_environment_.RunUntilIdle(); + } + + void TearDown() override { + model_executor_.reset(); + model_observer_tracker_.reset(); + + // Wait for PageEntitiesModelExecutor to clean everything up. + task_environment_.RunUntilIdle(); + } + + absl::optional<std::vector<ScoredEntityMetadata>> ExecuteHumanReadableModel( + const std::string& text) { + absl::optional<std::vector<ScoredEntityMetadata>> entity_metadata; + + base::RunLoop run_loop; + model_executor_->HumanReadableExecuteModelWithInput( + text, base::BindOnce( + [](base::RunLoop* run_loop, + absl::optional<std::vector<ScoredEntityMetadata>>* + out_entity_metadata, + const absl::optional<std::vector<ScoredEntityMetadata>>& + entity_metadata) { + *out_entity_metadata = entity_metadata; + run_loop->Quit(); + }, + &run_loop, &entity_metadata)); + run_loop.Run(); + + // Sort the result by score to make validating the output easier. + if (entity_metadata) { + std::sort( + entity_metadata->begin(), entity_metadata->end(), + [](const ScoredEntityMetadata& a, const ScoredEntityMetadata& b) { + return a.score > b.score; + }); + } + return entity_metadata; + } + + absl::optional<EntityMetadata> GetMetadataForEntityId( + const std::string& entity_id) { + absl::optional<EntityMetadata> entity_metadata; + + base::RunLoop run_loop; + model_executor_->GetMetadataForEntityId( + entity_id, + base::BindOnce( + [](base::RunLoop* run_loop, + absl::optional<EntityMetadata>* out_entity_metadata, + const absl::optional<EntityMetadata>& entity_metadata) { + *out_entity_metadata = entity_metadata; + run_loop->Quit(); + }, + &run_loop, &entity_metadata)); + run_loop.Run(); + + return entity_metadata; + } + + ModelObserverTracker* model_observer_tracker() const { + return model_observer_tracker_.get(); + } + + base::FilePath GetModelTestDataDir() { + base::FilePath source_root_dir; + base::PathService::Get(base::DIR_SOURCE_ROOT, &source_root_dir); + return source_root_dir.AppendASCII("components") + .AppendASCII("optimization_guide") + .AppendASCII("internal") + .AppendASCII("testdata"); + } + + void PushModelInfoToObservers(const ModelInfo& model_info) { + model_observer_tracker_->PushModelInfoToObservers(model_info); + task_environment_.RunUntilIdle(); + } + + private: + base::test::TaskEnvironment task_environment_; + std::unique_ptr<ModelObserverTracker> model_observer_tracker_; + std::unique_ptr<PageEntitiesModelExecutorImpl> model_executor_; +}; + +TEST_F(PageEntitiesModelExecutorImplTest, CreateNoMetadata) { + std::unique_ptr<ModelInfo> model_info = TestModelInfoBuilder().Build(); + ASSERT_TRUE(model_info); + PushModelInfoToObservers(*model_info); + + // We expect that there will be no model to evaluate even for this input that + // has output in the test model. + EXPECT_EQ(ExecuteHumanReadableModel("Taylor Swift singer"), absl::nullopt); +} + +TEST_F(PageEntitiesModelExecutorImplTest, CreateMetadataWrongType) { + proto::Any any; + any.set_type_url(any.GetTypeName()); + proto::FieldTrial garbage; + garbage.SerializeToString(any.mutable_value()); + + proto::PredictionModel model; + model.mutable_model()->set_download_url( + FilePathToString(GetModelTestDataDir().AppendASCII("model.tflite"))); + model.mutable_model_info()->set_version(123); + *model.mutable_model_info()->mutable_model_metadata() = any; + std::unique_ptr<ModelInfo> model_info = ModelInfo::Create(model); + ASSERT_TRUE(model_info); + PushModelInfoToObservers(*model_info); + + // We expect that there will be no model to evaluate even for this input that + // has output in the test model. + EXPECT_EQ(ExecuteHumanReadableModel("Taylor Swift singer"), absl::nullopt); +} + +TEST_F(PageEntitiesModelExecutorImplTest, CreateNoSlices) { + proto::Any any; + proto::PageEntitiesModelMetadata metadata; + any.set_type_url(metadata.GetTypeName()); + metadata.SerializeToString(any.mutable_value()); + + proto::PredictionModel model; + model.mutable_model()->set_download_url( + FilePathToString(GetModelTestDataDir().AppendASCII("model.tflite"))); + model.mutable_model_info()->set_version(123); + *model.mutable_model_info()->mutable_model_metadata() = any; + std::unique_ptr<ModelInfo> model_info = ModelInfo::Create(model); + ASSERT_TRUE(model_info); + PushModelInfoToObservers(*model_info); + + // We expect that there will be no model to evaluate even for this input that + // has output in the test model. + EXPECT_EQ(ExecuteHumanReadableModel("Taylor Swift singer"), absl::nullopt); +} + +TEST_F(PageEntitiesModelExecutorImplTest, CreateMissingFiles) { + proto::Any any; + proto::PageEntitiesModelMetadata metadata; + metadata.add_slice("global"); + any.set_type_url(metadata.GetTypeName()); + metadata.SerializeToString(any.mutable_value()); + + base::FilePath dir_path = GetModelTestDataDir(); + base::flat_set<std::string> expected_additional_files = { + FilePathToString(dir_path.AppendASCII("model_metadata.pb")), + FilePathToString(dir_path.AppendASCII("word_embeddings")), + FilePathToString(dir_path.AppendASCII("global-entities_names")), + FilePathToString(dir_path.AppendASCII("global-entities_metadata")), + FilePathToString(dir_path.AppendASCII("global-entities_names_filter")), + FilePathToString(dir_path.AppendASCII("global-entities_prefixes_filter")), + }; + // Remove one file for each iteration and make sure it fails. + for (const auto& missing_file_name : expected_additional_files) { + // Make a copy of the expected files and remove the one file from the set. + base::flat_set<std::string> additional_files = expected_additional_files; + additional_files.erase(missing_file_name); + + proto::PredictionModel model; + model.mutable_model()->set_download_url( + FilePathToString(dir_path.AppendASCII("model.tflite"))); + model.mutable_model_info()->set_version(123); + *model.mutable_model_info()->mutable_model_metadata() = any; + for (const auto& additional_file : additional_files) { + model.mutable_model_info()->add_additional_files()->set_file_path( + additional_file); + } + std::unique_ptr<ModelInfo> model_info = ModelInfo::Create(model); + ASSERT_TRUE(model_info); + PushModelInfoToObservers(*model_info); + + // We expect that there will be no model to evaluate even for this input + // that has output in the test model. + EXPECT_EQ(ExecuteHumanReadableModel("Taylor Swift singer"), absl::nullopt); + } +} + +TEST_F(PageEntitiesModelExecutorImplTest, GetMetadataForEntityIdNoModel) { + EXPECT_EQ(GetMetadataForEntityId("/m/0dl567"), absl::nullopt); +} + +TEST_F(PageEntitiesModelExecutorImplTest, ExecuteHumanReadableModelNoModel) { + EXPECT_EQ(ExecuteHumanReadableModel("Taylor Swift singer"), absl::nullopt); +} + +TEST_F(PageEntitiesModelExecutorImplTest, + SetsUpModelCorrectlyBasedOnFeatureParams) { + absl::optional<proto::Any> registered_model_metadata; + EXPECT_TRUE(model_observer_tracker()->DidRegisterForTarget( + proto::OPTIMIZATION_TARGET_PAGE_ENTITIES, ®istered_model_metadata)); + EXPECT_TRUE(registered_model_metadata.has_value()); + absl::optional<proto::PageEntitiesModelMetadata> + page_entities_model_metadata = + ParsedAnyMetadata<proto::PageEntitiesModelMetadata>( + *registered_model_metadata); + EXPECT_TRUE(page_entities_model_metadata.has_value()); +} + +} // namespace +} // namespace optimization_guide diff --git a/chromium/components/optimization_guide/core/page_topics_model_executor.cc b/chromium/components/optimization_guide/core/page_topics_model_executor.cc index b1c1707bc8e..dbc9a425889 100644 --- a/chromium/components/optimization_guide/core/page_topics_model_executor.cc +++ b/chromium/components/optimization_guide/core/page_topics_model_executor.cc @@ -17,7 +17,7 @@ namespace { // The ID of the NONE category in the taxonomy. This node always exists. // Semantically, the none category is attached to data for which we can say // with certainty that no single label in the taxonomy is appropriate. -const char kNoneCategoryId[] = "-2"; +const int32_t kNoneCategoryId = -2; } // namespace @@ -52,7 +52,7 @@ void PageTopicsModelExecutor::PostprocessCategoriesToBatchAnnotationResult( const absl::optional<std::vector<tflite::task::core::Category>>& output) { DCHECK_EQ(annotation_type, AnnotationType::kPageTopics); - absl::optional<std::vector<WeightedString>> categories; + absl::optional<std::vector<WeightedIdentifier>> categories; if (output) { categories = ExtractCategoriesFromModelOutput(*output); } @@ -60,7 +60,7 @@ void PageTopicsModelExecutor::PostprocessCategoriesToBatchAnnotationResult( BatchAnnotationResult::CreatePageTopicsResult(input, categories)); } -absl::optional<std::vector<WeightedString>> +absl::optional<std::vector<WeightedIdentifier>> PageTopicsModelExecutor::ExtractCategoriesFromModelOutput( const std::vector<tflite::task::core::Category>& model_output) const { absl::optional<proto::PageTopicsModelMetadata> model_metadata = @@ -79,7 +79,7 @@ PageTopicsModelExecutor::ExtractCategoriesFromModelOutput( .category_name()) : absl::nullopt; - std::vector<std::pair<std::string, float>> category_candidates; + std::vector<std::pair<int32_t, float>> category_candidates; for (const auto& category : model_output) { if (visibility_category_name && @@ -89,8 +89,8 @@ PageTopicsModelExecutor::ExtractCategoriesFromModelOutput( // Assume everything else is for categories. int category_id; if (base::StringToInt(category.class_name, &category_id)) { - category_candidates.emplace_back(std::make_pair( - category.class_name, static_cast<float>(category.score))); + category_candidates.emplace_back( + std::make_pair(category_id, static_cast<float>(category.score))); } } @@ -103,20 +103,19 @@ PageTopicsModelExecutor::ExtractCategoriesFromModelOutput( model_metadata->output_postprocessing_params().category_params(); // Determine the categories with the highest weights. - std::sort(category_candidates.begin(), category_candidates.end(), - [](const std::pair<std::string, float>& a, - const std::pair<std::string, float>& b) { - return a.second > b.second; - }); + std::sort( + category_candidates.begin(), category_candidates.end(), + [](const std::pair<int32_t, float>& a, + const std::pair<int32_t, float>& b) { return a.second > b.second; }); size_t max_categories = static_cast<size_t>(category_params.max_categories()); float total_weight = 0.0; float sum_positive_scores = 0.0; absl::optional<std::pair<size_t, float>> none_idx_and_weight; - std::vector<std::pair<std::string, float>> categories; + std::vector<std::pair<int32_t, float>> categories; categories.reserve(max_categories); for (size_t i = 0; i < category_candidates.size() && i < max_categories; i++) { - std::pair<std::string, float> candidate = category_candidates[i]; + std::pair<int32_t, float> candidate = category_candidates[i]; categories.push_back(candidate); total_weight += candidate.second; @@ -132,7 +131,7 @@ PageTopicsModelExecutor::ExtractCategoriesFromModelOutput( if (category_params.min_category_weight() > 0) { categories.erase( std::remove_if(categories.begin(), categories.end(), - [&](const std::pair<std::string, float>& category) { + [&](const std::pair<int32_t, float>& category) { return category.second < category_params.min_category_weight(); }), @@ -159,19 +158,19 @@ PageTopicsModelExecutor::ExtractCategoriesFromModelOutput( categories.erase( std::remove_if( categories.begin(), categories.end(), - [&](const std::pair<std::string, float>& category) { + [&](const std::pair<int32_t, float>& category) { return (category.second / normalization_factor) < category_params.min_normalized_weight_within_top_n(); }), categories.end()); - std::vector<WeightedString> final_categories; + std::vector<WeightedIdentifier> final_categories; final_categories.reserve(categories.size()); for (const auto& category : categories) { // We expect the weight to be between 0 and 1. DCHECK(category.second >= 0.0 && category.second <= 1.0); final_categories.emplace_back( - WeightedString(category.first, category.second)); + WeightedIdentifier(category.first, category.second)); } DCHECK_LE(final_categories.size(), max_categories); diff --git a/chromium/components/optimization_guide/core/page_topics_model_executor.h b/chromium/components/optimization_guide/core/page_topics_model_executor.h index fa396fafc4e..30e263872e5 100644 --- a/chromium/components/optimization_guide/core/page_topics_model_executor.h +++ b/chromium/components/optimization_guide/core/page_topics_model_executor.h @@ -42,7 +42,8 @@ class PageTopicsModelExecutor : public PageContentAnnotationJobExecutor, // Extracts the scored categories from the output of the model. // Public for testing. - absl::optional<std::vector<WeightedString>> ExtractCategoriesFromModelOutput( + absl::optional<std::vector<WeightedIdentifier>> + ExtractCategoriesFromModelOutput( const std::vector<tflite::task::core::Category>& model_output) const; private: diff --git a/chromium/components/optimization_guide/core/page_topics_model_executor_unittest.cc b/chromium/components/optimization_guide/core/page_topics_model_executor_unittest.cc index 7fefca3bdc0..6ad49e17067 100644 --- a/chromium/components/optimization_guide/core/page_topics_model_executor_unittest.cc +++ b/chromium/components/optimization_guide/core/page_topics_model_executor_unittest.cc @@ -125,13 +125,13 @@ TEST_F( {"0", 0.0001}, {"1", 0.1}, {"not an int", 0.9}, {"2", 0.2}, {"3", 0.3}, }; - absl::optional<std::vector<WeightedString>> categories = + absl::optional<std::vector<WeightedIdentifier>> categories = model_executor()->ExtractCategoriesFromModelOutput(model_output); ASSERT_TRUE(categories); EXPECT_THAT(*categories, - testing::UnorderedElementsAre(WeightedString("1", 0.1), - WeightedString("2", 0.2), - WeightedString("3", 0.3))); + testing::UnorderedElementsAre(WeightedIdentifier(1, 0.1), + WeightedIdentifier(2, 0.2), + WeightedIdentifier(3, 0.3))); } TEST_F(PageTopicsModelExecutorTest, @@ -157,7 +157,7 @@ TEST_F(PageTopicsModelExecutorTest, {"1", 0.2}, }; - absl::optional<std::vector<WeightedString>> categories = + absl::optional<std::vector<WeightedIdentifier>> categories = model_executor()->ExtractCategoriesFromModelOutput(model_output); EXPECT_FALSE(categories); } @@ -183,13 +183,13 @@ TEST_F(PageTopicsModelExecutorTest, {"-2", 0.1}, {"0", 0.3}, {"1", 0.2}, {"2", 0.4}, {"3", 0.05}, }; - absl::optional<std::vector<WeightedString>> categories = + absl::optional<std::vector<WeightedIdentifier>> categories = model_executor()->ExtractCategoriesFromModelOutput(model_output); ASSERT_TRUE(categories); EXPECT_THAT(*categories, - testing::UnorderedElementsAre(WeightedString("0", 0.3), - WeightedString("1", 0.2), - WeightedString("2", 0.4))); + testing::UnorderedElementsAre(WeightedIdentifier(0, 0.3), + WeightedIdentifier(1, 0.2), + WeightedIdentifier(2, 0.4))); } TEST_F(PageTopicsModelExecutorTest, @@ -216,13 +216,13 @@ TEST_F(PageTopicsModelExecutorTest, {"3", 0.05}, }; - absl::optional<std::vector<WeightedString>> categories = + absl::optional<std::vector<WeightedIdentifier>> categories = model_executor()->ExtractCategoriesFromModelOutput(model_output); ASSERT_TRUE(categories); EXPECT_THAT(*categories, - testing::UnorderedElementsAre(WeightedString("0", 0.3), - WeightedString("1", 0.25), - WeightedString("2", 0.4))); + testing::UnorderedElementsAre(WeightedIdentifier(0, 0.3), + WeightedIdentifier(1, 0.25), + WeightedIdentifier(2, 0.4))); } TEST_F(PageTopicsModelExecutorTest, @@ -260,10 +260,10 @@ TEST_F(PageTopicsModelExecutorTest, &topics_result), AnnotationType::kPageTopics, "input", model_output); EXPECT_EQ(topics_result, BatchAnnotationResult::CreatePageTopicsResult( - "input", std::vector<WeightedString>{ - WeightedString("0", 0.3), - WeightedString("1", 0.25), - WeightedString("2", 0.4), + "input", std::vector<WeightedIdentifier>{ + WeightedIdentifier(0, 0.3), + WeightedIdentifier(1, 0.25), + WeightedIdentifier(2, 0.4), })); } diff --git a/chromium/components/optimization_guide/core/prediction_model.cc b/chromium/components/optimization_guide/core/prediction_model.cc deleted file mode 100644 index 1cf077bd71a..00000000000 --- a/chromium/components/optimization_guide/core/prediction_model.cc +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2020 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "components/optimization_guide/core/prediction_model.h" - -#include <utility> - -#include "components/optimization_guide/core/decision_tree_prediction_model.h" - -namespace optimization_guide { - -// static -std::unique_ptr<PredictionModel> PredictionModel::Create( - const proto::PredictionModel& prediction_model) { - // TODO(crbug/1009123): Add a histogram to record if the provided model is - // constructed successfully or not. - // TODO(crbug/1009123): Adding timing metrics around initialization due to - // potential validation overhead. - if (!prediction_model.has_model()) - return nullptr; - - if (!prediction_model.has_model_info()) - return nullptr; - - if (!prediction_model.model_info().has_version()) - return nullptr; - - // Enforce that only one ModelType is specified for the PredictionModel. - if (prediction_model.model_info().supported_model_types_size() != 1) { - return nullptr; - } - - // Check that the client supports this type of model and is not an unknown - // type. - if (!proto::ModelType_IsValid( - prediction_model.model_info().supported_model_types(0)) || - prediction_model.model_info().supported_model_types(0) == - proto::ModelType::MODEL_TYPE_UNKNOWN) { - return nullptr; - } - - std::unique_ptr<PredictionModel> model; - // The Decision Tree model type is currently the only supported model type. - if (prediction_model.model_info().supported_model_types(0) != - proto::ModelType::MODEL_TYPE_DECISION_TREE) { - return nullptr; - } - model = std::make_unique<DecisionTreePredictionModel>(prediction_model); - - // Any constructed model must be validated for correctness according to its - // model type before being returned. - if (!model->ValidatePredictionModel()) - return nullptr; - - return model; -} - -namespace { - -std::vector<std::string> ComputeModelFeatures( - const proto::ModelInfo& model_info) { - std::vector<std::string> features; - features.reserve(model_info.supported_host_model_features_size()); - // Insert all the host model features for the owned |model_|. - for (const auto& host_model_feature : - model_info.supported_host_model_features()) { - features.push_back(host_model_feature); - } - return features; -} - -} // namespace - -PredictionModel::PredictionModel(const proto::PredictionModel& prediction_model) - : model_(prediction_model.model()), - model_features_(ComputeModelFeatures(prediction_model.model_info())), - version_(prediction_model.model_info().version()) {} - -PredictionModel::~PredictionModel() = default; - -} // namespace optimization_guide diff --git a/chromium/components/optimization_guide/core/prediction_model.h b/chromium/components/optimization_guide/core/prediction_model.h deleted file mode 100644 index c4c3c024d0b..00000000000 --- a/chromium/components/optimization_guide/core/prediction_model.h +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright 2020 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef COMPONENTS_OPTIMIZATION_GUIDE_CORE_PREDICTION_MODEL_H_ -#define COMPONENTS_OPTIMIZATION_GUIDE_CORE_PREDICTION_MODEL_H_ - -#include <stdint.h> -#include <memory> -#include <string> - -#include "base/containers/flat_map.h" -#include "base/containers/flat_set.h" -#include "components/optimization_guide/core/optimization_guide_enums.h" -#include "components/optimization_guide/proto/models.pb.h" - -namespace optimization_guide { - -// A PredictionModel supported by the optimization guide that makes an -// OptimizationTargetDecision by evaluating a prediction model. -class PredictionModel { - public: - PredictionModel(const PredictionModel&) = delete; - PredictionModel& operator=(const PredictionModel&) = delete; - - virtual ~PredictionModel(); - - // Creates an Prediction model of the correct ModelType specified in - // |prediction_model|. The validation overhead of this factory can be high and - // should should be called in the background. - static std::unique_ptr<PredictionModel> Create( - const proto::PredictionModel& prediction_model); - - // Returns the OptimizationTargetDecision by evaluating the |model_| - // using the provided |model_features|. |prediction_score| will be populated - // with the score output by the model. - virtual OptimizationTargetDecision Predict( - const base::flat_map<std::string, float>& model_features, - double* prediction_score) = 0; - - // Provide the version of the |model_| by |this|. - int64_t GetVersion() const { return version_; } - - // Provide the model features required for evaluation of the |model_| by - // |this|. - const base::flat_set<std::string>& GetModelFeatures() const { - return model_features_; - } - - protected: - explicit PredictionModel(const proto::PredictionModel& prediction_model); - - // The in-memory model used for prediction. - const proto::Model model_; - - private: - // Determines if the |model_| is complete and can be successfully evaluated by - // |this|. - virtual bool ValidatePredictionModel() const = 0; - - // The set of features required by the |model_| to be evaluated. - const base::flat_set<std::string> model_features_; - - // The version of the |model_|. - const int64_t version_; -}; - -} // namespace optimization_guide - -#endif // COMPONENTS_OPTIMIZATION_GUIDE_CORE_PREDICTION_MODEL_H_ diff --git a/chromium/components/optimization_guide/core/prediction_model_fetcher_impl.cc b/chromium/components/optimization_guide/core/prediction_model_fetcher_impl.cc index 79282f12248..1aba12cfa54 100644 --- a/chromium/components/optimization_guide/core/prediction_model_fetcher_impl.cc +++ b/chromium/components/optimization_guide/core/prediction_model_fetcher_impl.cc @@ -22,7 +22,6 @@ #include "net/http/http_response_headers.h" #include "net/http/http_status_code.h" #include "net/traffic_annotation/network_traffic_annotation.h" -#include "services/network/public/cpp/network_connection_tracker.h" #include "services/network/public/cpp/resource_request.h" #include "services/network/public/cpp/shared_url_loader_factory.h" #include "services/network/public/cpp/simple_url_loader.h" @@ -32,16 +31,14 @@ namespace optimization_guide { PredictionModelFetcherImpl::PredictionModelFetcherImpl( scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory, - const GURL& optimization_guide_service_get_models_url, - network::NetworkConnectionTracker* network_connection_tracker) + const GURL& optimization_guide_service_get_models_url) : optimization_guide_service_get_models_url_( net::AppendOrReplaceQueryParameter( optimization_guide_service_get_models_url, "key", optimization_guide::features:: GetOptimizationGuideServiceAPIKey())), - url_loader_factory_(url_loader_factory), - network_connection_tracker_(network_connection_tracker) { + url_loader_factory_(url_loader_factory) { CHECK(optimization_guide_service_get_models_url_.SchemeIs(url::kHttpsScheme)); } @@ -55,11 +52,6 @@ bool PredictionModelFetcherImpl::FetchOptimizationGuideServiceModels( ModelsFetchedCallback models_fetched_callback) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); - if (network_connection_tracker_->IsOffline()) { - std::move(models_fetched_callback).Run(absl::nullopt); - return false; - } - if (url_loader_) return false; diff --git a/chromium/components/optimization_guide/core/prediction_model_fetcher_impl.h b/chromium/components/optimization_guide/core/prediction_model_fetcher_impl.h index 2d2557e8c15..969af65685a 100644 --- a/chromium/components/optimization_guide/core/prediction_model_fetcher_impl.h +++ b/chromium/components/optimization_guide/core/prediction_model_fetcher_impl.h @@ -19,7 +19,6 @@ #include "url/gurl.h" namespace network { -class NetworkConnectionTracker; class SharedURLLoaderFactory; class SimpleURLLoader; } // namespace network @@ -34,8 +33,7 @@ class PredictionModelFetcherImpl : public PredictionModelFetcher { public: PredictionModelFetcherImpl( scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory, - const GURL& optimization_guide_service_get_models_url, - network::NetworkConnectionTracker* network_connection_tracker); + const GURL& optimization_guide_service_get_models_url); PredictionModelFetcherImpl(const PredictionModelFetcherImpl&) = delete; PredictionModelFetcherImpl& operator=(const PredictionModelFetcherImpl&) = @@ -82,10 +80,6 @@ class PredictionModelFetcherImpl : public PredictionModelFetcher { // Used for creating a |url_loader_| when needed for request hints. scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory_; - // Listens to changes around the network connection. Not owned. Guaranteed to - // outlive |this|. - raw_ptr<network::NetworkConnectionTracker> network_connection_tracker_; - SEQUENCE_CHECKER(sequence_checker_); }; diff --git a/chromium/components/optimization_guide/core/prediction_model_fetcher_unittest.cc b/chromium/components/optimization_guide/core/prediction_model_fetcher_unittest.cc index 1b59d314c01..a4ca4b85255 100644 --- a/chromium/components/optimization_guide/core/prediction_model_fetcher_unittest.cc +++ b/chromium/components/optimization_guide/core/prediction_model_fetcher_unittest.cc @@ -21,7 +21,6 @@ #include "net/base/url_util.h" #include "services/network/public/cpp/shared_url_loader_factory.h" #include "services/network/public/cpp/weak_wrapper_shared_url_loader_factory.h" -#include "services/network/test/test_network_connection_tracker.h" #include "services/network/test/test_url_loader_factory.h" #include "testing/gtest/include/gtest/gtest.h" #include "third_party/abseil-cpp/absl/types/optional.h" @@ -37,11 +36,9 @@ class PredictionModelFetcherTest : public testing::Test { : task_environment_(base::test::TaskEnvironment::MainThreadType::UI), shared_url_loader_factory_( base::MakeRefCounted<network::WeakWrapperSharedURLLoaderFactory>( - &test_url_loader_factory_)), - network_tracker_(network::TestNetworkConnectionTracker::GetInstance()) { + &test_url_loader_factory_)) { prediction_model_fetcher_ = std::make_unique<PredictionModelFetcherImpl>( - shared_url_loader_factory_, GURL(optimization_guide_service_url), - network_tracker_); + shared_url_loader_factory_, GURL(optimization_guide_service_url)); } PredictionModelFetcherTest(const PredictionModelFetcherTest&) = delete; @@ -58,16 +55,6 @@ class PredictionModelFetcherTest : public testing::Test { bool models_fetched() { return models_fetched_; } - void SetConnectionOffline() { - network_tracker_->SetConnectionType( - network::mojom::ConnectionType::CONNECTION_NONE); - } - - void SetConnectionOnline() { - network_tracker_->SetConnectionType( - network::mojom::ConnectionType::CONNECTION_4G); - } - protected: bool FetchModels(const std::vector<proto::ModelInfo> models_request_info, const std::vector<proto::FieldTrial>& active_field_trials, @@ -116,7 +103,6 @@ class PredictionModelFetcherTest : public testing::Test { scoped_refptr<network::SharedURLLoaderFactory> shared_url_loader_factory_; network::TestURLLoaderFactory test_url_loader_factory_; - raw_ptr<network::TestNetworkConnectionTracker> network_tracker_; }; TEST_F(PredictionModelFetcherTest, FetchOptimizationGuideServiceModels) { @@ -173,26 +159,6 @@ TEST_F(PredictionModelFetcherTest, FetchReturnBadResponse) { EXPECT_FALSE(models_fetched()); } -TEST_F(PredictionModelFetcherTest, FetchAttemptWhenNetworkOffline) { - SetConnectionOffline(); - std::string response_content; - proto::ModelInfo model_info; - model_info.set_optimization_target( - proto::OptimizationTarget::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD); - EXPECT_FALSE(FetchModels({model_info}, /*active_field_trials=*/{}, - proto::RequestContext::CONTEXT_BATCH_UPDATE_MODELS, - "en-US")); - EXPECT_FALSE(models_fetched()); - - SetConnectionOnline(); - EXPECT_TRUE(FetchModels({model_info}, /*active_field_trials=*/{}, - proto::RequestContext::CONTEXT_BATCH_UPDATE_MODELS, - "en-US")); - VerifyHasPendingFetchRequests(); - EXPECT_TRUE(SimulateResponse(response_content, net::HTTP_OK)); - EXPECT_TRUE(models_fetched()); -} - TEST_F(PredictionModelFetcherTest, EmptyModelInfo) { base::HistogramTester histogram_tester; std::string response_content; diff --git a/chromium/components/optimization_guide/core/prediction_model_unittest.cc b/chromium/components/optimization_guide/core/prediction_model_unittest.cc deleted file mode 100644 index c928abdd5cf..00000000000 --- a/chromium/components/optimization_guide/core/prediction_model_unittest.cc +++ /dev/null @@ -1,134 +0,0 @@ -// Copyright 2020 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "components/optimization_guide/core/prediction_model.h" - -#include <utility> - -#include "components/optimization_guide/proto/models.pb.h" -#include "testing/gtest/include/gtest/gtest.h" - -namespace optimization_guide { - -TEST(PredictionModelTest, ValidPredictionModel) { - proto::PredictionModel prediction_model; - prediction_model.mutable_model()->mutable_threshold()->set_value(5.0); - - proto::DecisionTree decision_tree_model = proto::DecisionTree(); - decision_tree_model.set_weight(2.0); - - proto::TreeNode* tree_node = decision_tree_model.add_nodes(); - tree_node->mutable_node_id()->set_value(0); - tree_node->mutable_binary_node()->mutable_left_child_id()->set_value(1); - tree_node->mutable_binary_node()->mutable_right_child_id()->set_value(2); - tree_node->mutable_binary_node() - ->mutable_inequality_left_child_test() - ->mutable_feature_id() - ->mutable_id() - ->set_value("agg1"); - tree_node->mutable_binary_node() - ->mutable_inequality_left_child_test() - ->set_type(proto::InequalityTest::LESS_OR_EQUAL); - tree_node->mutable_binary_node() - ->mutable_inequality_left_child_test() - ->mutable_threshold() - ->set_float_value(1.0); - - tree_node = decision_tree_model.add_nodes(); - tree_node->mutable_node_id()->set_value(1); - tree_node->mutable_leaf()->mutable_vector()->add_value()->set_double_value( - 2.); - - tree_node = decision_tree_model.add_nodes(); - tree_node->mutable_node_id()->set_value(2); - tree_node->mutable_leaf()->mutable_vector()->add_value()->set_double_value( - 4.); - - *prediction_model.mutable_model()->mutable_decision_tree() = - decision_tree_model; - - proto::ModelInfo* model_info = prediction_model.mutable_model_info(); - model_info->set_version(1); - model_info->add_supported_model_types( - proto::ModelType::MODEL_TYPE_DECISION_TREE); - model_info->add_supported_host_model_features("agg1"); - - std::unique_ptr<PredictionModel> model = - PredictionModel::Create(prediction_model); - - EXPECT_EQ(1, model->GetVersion()); - EXPECT_EQ(1u, model->GetModelFeatures().size()); - EXPECT_TRUE(model->GetModelFeatures().count("agg1")); -} - -TEST(PredictionModelTest, NoModel) { - proto::PredictionModel prediction_model; - - std::unique_ptr<PredictionModel> model = - PredictionModel::Create(prediction_model); - EXPECT_FALSE(model); -} - -TEST(PredictionModelTest, NoModelVersion) { - proto::PredictionModel prediction_model; - - proto::DecisionTree* decision_tree_model = - prediction_model.mutable_model()->mutable_decision_tree(); - decision_tree_model->set_weight(2.0); - - std::unique_ptr<PredictionModel> model = - PredictionModel::Create(prediction_model); - EXPECT_FALSE(model); -} - -TEST(PredictionModelTest, NoModelType) { - proto::PredictionModel prediction_model; - - proto::DecisionTree* decision_tree_model = - prediction_model.mutable_model()->mutable_decision_tree(); - decision_tree_model->set_weight(2.0); - - proto::ModelInfo* model_info = prediction_model.mutable_model_info(); - model_info->set_version(1); - - std::unique_ptr<PredictionModel> model = - PredictionModel::Create(std::move(prediction_model)); - EXPECT_FALSE(model); -} - -TEST(PredictionModelTest, UnknownModelType) { - proto::PredictionModel prediction_model; - - proto::DecisionTree* decision_tree_model = - prediction_model.mutable_model()->mutable_decision_tree(); - decision_tree_model->set_weight(2.0); - - proto::ModelInfo* model_info = prediction_model.mutable_model_info(); - model_info->set_version(1); - model_info->add_supported_model_types(proto::ModelType::MODEL_TYPE_UNKNOWN); - - std::unique_ptr<PredictionModel> model = - PredictionModel::Create(prediction_model); - EXPECT_FALSE(model); -} - -TEST(PredictionModelTest, MultipleModelTypes) { - proto::PredictionModel prediction_model; - - proto::DecisionTree* decision_tree_model = - prediction_model.mutable_model()->mutable_decision_tree(); - decision_tree_model->set_weight(2.0); - - proto::ModelInfo* model_info = prediction_model.mutable_model_info(); - model_info->set_version(1); - model_info->add_supported_model_types( - proto::ModelType::MODEL_TYPE_DECISION_TREE); - model_info->add_supported_model_types(proto::ModelType::MODEL_TYPE_UNKNOWN); - - std::unique_ptr<PredictionModel> model = - PredictionModel::Create(prediction_model); - EXPECT_FALSE(model); -} - -} // namespace optimization_guide diff --git a/chromium/components/optimization_guide/core/store_update_data.cc b/chromium/components/optimization_guide/core/store_update_data.cc index 0afd12d3e80..d5c636ef16b 100644 --- a/chromium/components/optimization_guide/core/store_update_data.cc +++ b/chromium/components/optimization_guide/core/store_update_data.cc @@ -38,40 +38,6 @@ StoreUpdateData::CreatePredictionModelStoreUpdateData(base::Time expiry_time) { return base::WrapUnique<StoreUpdateData>(new StoreUpdateData(expiry_time)); } -// static -std::unique_ptr<StoreUpdateData> -StoreUpdateData::CreateHostModelFeaturesStoreUpdateData( - base::Time host_model_features_update_time, - base::Time expiry_time) { - std::unique_ptr<StoreUpdateData> host_model_features_update_data( - new StoreUpdateData(host_model_features_update_time, expiry_time)); - return host_model_features_update_data; -} - -StoreUpdateData::StoreUpdateData(base::Time host_model_features_update_time, - base::Time expiry_time) - : update_time_(host_model_features_update_time), - expiry_time_(expiry_time), - entries_to_save_(std::make_unique<EntryVector>()) { - entry_key_prefix_ = - OptimizationGuideStore::GetHostModelFeaturesEntryKeyPrefix(); - proto::StoreEntry metadata_host_model_features_entry; - metadata_host_model_features_entry.set_entry_type( - static_cast<proto::StoreEntryType>( - OptimizationGuideStore::StoreEntryType::kMetadata)); - metadata_host_model_features_entry.set_update_time_secs( - host_model_features_update_time.ToDeltaSinceWindowsEpoch().InSeconds()); - entries_to_save_->emplace_back( - OptimizationGuideStore::GetMetadataTypeEntryKey( - OptimizationGuideStore::MetadataType::kHostModelFeatures), - std::move(metadata_host_model_features_entry)); - - // |this| may be modified on another thread after construction but all - // future modifications, from that call forward, must be made on the same - // thread. - DETACH_FROM_SEQUENCE(sequence_checker_); -} - StoreUpdateData::StoreUpdateData(base::Time expiry_time) : expiry_time_(expiry_time), entries_to_save_(std::make_unique<EntryVector>()) { @@ -161,28 +127,6 @@ void StoreUpdateData::MoveHintIntoUpdateData(proto::Hint&& hint) { std::move(entry_proto)); } -void StoreUpdateData::CopyHostModelFeaturesIntoUpdateData( - const proto::HostModelFeatures& host_model_features) { - // All future modifications must be made by the same thread. Note, |this| may - // have been constructed on another thread. - DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); - DCHECK(!entry_key_prefix_.empty()); - DCHECK(expiry_time_); - - // To avoid any unnecessary copying, the host model feature data is moved into - // proto::StoreEntry. - OptimizationGuideStore::EntryKey host_model_features_entry_key = - entry_key_prefix_ + host_model_features.host(); - proto::StoreEntry entry_proto; - entry_proto.set_entry_type(static_cast<proto::StoreEntryType>( - OptimizationGuideStore::StoreEntryType::kHostModelFeatures)); - entry_proto.set_expiry_time_secs( - expiry_time_->ToDeltaSinceWindowsEpoch().InSeconds()); - entry_proto.mutable_host_model_features()->CopyFrom(host_model_features); - entries_to_save_->emplace_back(std::move(host_model_features_entry_key), - std::move(entry_proto)); -} - void StoreUpdateData::CopyPredictionModelIntoUpdateData( const proto::PredictionModel& prediction_model) { // All future modifications must be made by the same thread. Note, |this| may @@ -200,8 +144,19 @@ void StoreUpdateData::CopyPredictionModelIntoUpdateData( proto::StoreEntry entry_proto; entry_proto.set_entry_type(static_cast<proto::StoreEntryType>( OptimizationGuideStore::StoreEntryType::kPredictionModel)); + + base::TimeDelta expiry_duration; + if (prediction_model.model_info().has_valid_duration()) { + expiry_duration = + base::Seconds(prediction_model.model_info().valid_duration().seconds()); + } else { + expiry_duration = features::StoredFetchedHintsFreshnessDuration(); + } + expiry_time_ = base::Time::Now() + expiry_duration; entry_proto.set_expiry_time_secs( - expiry_time_->ToDeltaSinceWindowsEpoch().InSeconds()); + expiry_time_.value().ToDeltaSinceWindowsEpoch().InSeconds()); + entry_proto.set_keep_beyond_valid_duration( + prediction_model.model_info().keep_beyond_valid_duration()); entry_proto.mutable_prediction_model()->CopyFrom(prediction_model); entries_to_save_->emplace_back(std::move(prediction_model_entry_key), std::move(entry_proto)); diff --git a/chromium/components/optimization_guide/core/store_update_data.h b/chromium/components/optimization_guide/core/store_update_data.h index 68a098a8ec3..05dbc9a872e 100644 --- a/chromium/components/optimization_guide/core/store_update_data.h +++ b/chromium/components/optimization_guide/core/store_update_data.h @@ -16,7 +16,6 @@ namespace optimization_guide { namespace proto { class Hint; -class HostModelFeatures; class PredictionModel; class StoreEntry; } // namespace proto @@ -24,8 +23,7 @@ class StoreEntry; using EntryVector = leveldb_proto::ProtoDatabase<proto::StoreEntry>::KeyEntryVector; -// Holds hint, prediction model, or host model features data for updating the -// OptimizationGuideStore. +// Holds hint or prediction model data for updating the OptimizationGuideStore. class StoreUpdateData { public: StoreUpdateData(const StoreUpdateData&) = delete; @@ -45,12 +43,6 @@ class StoreUpdateData { static std::unique_ptr<StoreUpdateData> CreatePredictionModelStoreUpdateData( base::Time expiry_time); - // Creates an update data object for a host model features update. - static std::unique_ptr<StoreUpdateData> - CreateHostModelFeaturesStoreUpdateData( - base::Time host_model_features_update_time, - base::Time expiry_time); - // Returns the component version of a component hint update. const absl::optional<base::Version> component_version() const { return component_version_; @@ -66,10 +58,6 @@ class StoreUpdateData { // called, |hint| is no longer valid. void MoveHintIntoUpdateData(proto::Hint&& hint); - // Copies |host_model_features| into this update data. - void CopyHostModelFeaturesIntoUpdateData( - const proto::HostModelFeatures& host_model_features); - // Copies |prediction_model| into this update data. void CopyPredictionModelIntoUpdateData( const proto::PredictionModel& prediction_model); @@ -81,8 +69,6 @@ class StoreUpdateData { StoreUpdateData(absl::optional<base::Version> component_version, absl::optional<base::Time> fetch_update_time, absl::optional<base::Time> expiry_time); - StoreUpdateData(base::Time host_model_features_update_time, - base::Time expiry_time); explicit StoreUpdateData(base::Time expiry_time); // The component version of the update data for a component update. diff --git a/chromium/components/optimization_guide/core/store_update_data_unittest.cc b/chromium/components/optimization_guide/core/store_update_data_unittest.cc index 785033692bd..e36087ea00c 100644 --- a/chromium/components/optimization_guide/core/store_update_data_unittest.cc +++ b/chromium/components/optimization_guide/core/store_update_data_unittest.cc @@ -120,11 +120,13 @@ TEST(StoreUpdateDataTest, BuildPredictionModelUpdateData) { model_info->set_version(1); model_info->set_optimization_target( proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD); - model_info->add_supported_model_types( - proto::ModelType::MODEL_TYPE_DECISION_TREE); + model_info->add_supported_model_engine_versions( + proto::ModelEngineVersion::MODEL_ENGINE_VERSION_DECISION_TREE); + model_info->set_keep_beyond_valid_duration(false); - base::Time expected_expiry_time = - base::Time::Now() + features::StoredModelsInactiveDuration(); + model_info->mutable_valid_duration()->set_seconds(3); + + base::Time expected_expiry_time = base::Time::Now() + base::Seconds(3); std::unique_ptr<StoreUpdateData> prediction_model_update = StoreUpdateData::CreatePredictionModelStoreUpdateData( expected_expiry_time); @@ -143,39 +145,14 @@ TEST(StoreUpdateDataTest, BuildPredictionModelUpdateData) { found_prediction_model_entry = true; EXPECT_EQ(expected_expiry_time.ToDeltaSinceWindowsEpoch().InSeconds(), store_entry.expiry_time_secs()); + EXPECT_EQ(store_entry.keep_beyond_valid_duration(), + model_info->keep_beyond_valid_duration()); break; } } EXPECT_TRUE(found_prediction_model_entry); } -TEST(StoreUpdateDataTest, BuildHostModelFeaturesUpdateData) { - // Verify creating a Prediction Model update data. - base::Time host_model_features_update_time = base::Time::Now(); - - proto::HostModelFeatures host_model_features; - host_model_features.set_host("foo.com"); - proto::ModelFeature* model_feature = host_model_features.add_model_features(); - model_feature->set_feature_name("host_feat1"); - model_feature->set_double_value(2.0); - - std::unique_ptr<StoreUpdateData> host_model_features_update = - StoreUpdateData::CreateHostModelFeaturesStoreUpdateData( - host_model_features_update_time, - host_model_features_update_time + - optimization_guide::features:: - StoredHostModelFeaturesFreshnessDuration()); - host_model_features_update->CopyHostModelFeaturesIntoUpdateData( - std::move(host_model_features)); - EXPECT_FALSE(host_model_features_update->component_version().has_value()); - EXPECT_TRUE(host_model_features_update->update_time().has_value()); - EXPECT_EQ(host_model_features_update_time, - *host_model_features_update->update_time()); - // Verify there are 2 store entries, 1 for the metadata entry and 1 for the - // added host model features entry. - EXPECT_EQ(2ul, host_model_features_update->TakeUpdateEntries()->size()); -} - } // namespace } // namespace optimization_guide diff --git a/chromium/components/optimization_guide/core/test_model_executor.cc b/chromium/components/optimization_guide/core/test_model_executor.cc index dbd8c4072b5..07c75aafcd7 100644 --- a/chromium/components/optimization_guide/core/test_model_executor.cc +++ b/chromium/components/optimization_guide/core/test_model_executor.cc @@ -7,13 +7,13 @@ namespace optimization_guide { void TestModelExecutor::SendForExecution( - ExecutionCallback ui_callback_on_complete, + ExecutionCallback callback_on_complete, base::TimeTicks start_time, const std::vector<float>& args) { std::vector<float> results = std::vector<float>(); for (auto& arg : args) results.push_back(arg); - std::move(ui_callback_on_complete).Run(std::move(results)); + std::move(callback_on_complete).Run(std::move(results)); } } // namespace optimization_guide diff --git a/chromium/components/optimization_guide/core/test_model_executor.h b/chromium/components/optimization_guide/core/test_model_executor.h index 42dd3cfcff1..987942810e8 100644 --- a/chromium/components/optimization_guide/core/test_model_executor.h +++ b/chromium/components/optimization_guide/core/test_model_executor.h @@ -6,7 +6,6 @@ #define COMPONENTS_OPTIMIZATION_GUIDE_CORE_TEST_MODEL_EXECUTOR_H_ #include "components/optimization_guide/core/model_executor.h" -#include "third_party/abseil-cpp/absl/status/status.h" namespace optimization_guide { @@ -16,7 +15,7 @@ class TestModelExecutor TestModelExecutor() = default; ~TestModelExecutor() override = default; - void InitializeAndMoveToBackgroundThread( + void InitializeAndMoveToExecutionThread( proto::OptimizationTarget, scoped_refptr<base::SequencedTaskRunner>, scoped_refptr<base::SequencedTaskRunner>) override {} @@ -29,7 +28,7 @@ class TestModelExecutor using ExecutionCallback = base::OnceCallback<void(const absl::optional<std::vector<float>>&)>; - void SendForExecution(ExecutionCallback ui_callback_on_complete, + void SendForExecution(ExecutionCallback callback_on_complete, base::TimeTicks start_time, const std::vector<float>& args) override; }; diff --git a/chromium/components/optimization_guide/core/test_model_info_builder.cc b/chromium/components/optimization_guide/core/test_model_info_builder.cc index 9efa9d4892d..517462239da 100644 --- a/chromium/components/optimization_guide/core/test_model_info_builder.cc +++ b/chromium/components/optimization_guide/core/test_model_info_builder.cc @@ -4,8 +4,8 @@ #include "components/optimization_guide/core/test_model_info_builder.h" +#include "components/optimization_guide/core/model_util.h" #include "components/optimization_guide/core/optimization_guide_test_util.h" -#include "components/optimization_guide/core/optimization_guide_util.h" namespace optimization_guide { diff --git a/chromium/components/optimization_guide/core/test_tflite_model_executor.cc b/chromium/components/optimization_guide/core/test_tflite_model_executor.cc index 9b7ea803949..a310697de19 100644 --- a/chromium/components/optimization_guide/core/test_tflite_model_executor.cc +++ b/chromium/components/optimization_guide/core/test_tflite_model_executor.cc @@ -8,17 +8,19 @@ namespace optimization_guide { -absl::Status TestTFLiteModelExecutor::Preprocess( +bool TestTFLiteModelExecutor::Preprocess( const std::vector<TfLiteTensor*>& input_tensors, const std::vector<float>& input) { - tflite::task::core::PopulateTensor<float>(input, input_tensors[0]); - return absl::OkStatus(); + return tflite::task::core::PopulateTensor<float>(input, input_tensors[0]) + .ok(); } -std::vector<float> TestTFLiteModelExecutor::Postprocess( +absl::optional<std::vector<float>> TestTFLiteModelExecutor::Postprocess( const std::vector<const TfLiteTensor*>& output_tensors) { std::vector<float> data; - tflite::task::core::PopulateVector<float>(output_tensors[0], &data); + absl::Status status = + tflite::task::core::PopulateVector<float>(output_tensors[0], &data); + DCHECK(status.ok()); return data; } diff --git a/chromium/components/optimization_guide/core/test_tflite_model_executor.h b/chromium/components/optimization_guide/core/test_tflite_model_executor.h index ac33fd33603..39c99e4582a 100644 --- a/chromium/components/optimization_guide/core/test_tflite_model_executor.h +++ b/chromium/components/optimization_guide/core/test_tflite_model_executor.h @@ -16,10 +16,10 @@ class TestTFLiteModelExecutor ~TestTFLiteModelExecutor() override = default; protected: - absl::Status Preprocess(const std::vector<TfLiteTensor*>& input_tensors, - const std::vector<float>& input) override; + bool Preprocess(const std::vector<TfLiteTensor*>& input_tensors, + const std::vector<float>& input) override; - std::vector<float> Postprocess( + absl::optional<std::vector<float>> Postprocess( const std::vector<const TfLiteTensor*>& output_tensors) override; }; diff --git a/chromium/components/optimization_guide/core/tflite_model_executor.h b/chromium/components/optimization_guide/core/tflite_model_executor.h index fe615dc6ee9..7b733101746 100644 --- a/chromium/components/optimization_guide/core/tflite_model_executor.h +++ b/chromium/components/optimization_guide/core/tflite_model_executor.h @@ -17,9 +17,9 @@ #include "base/time/time.h" #include "base/trace_event/trace_event.h" #include "components/optimization_guide/core/execution_status.h" +#include "components/optimization_guide/core/model_enums.h" #include "components/optimization_guide/core/model_executor.h" -#include "components/optimization_guide/core/optimization_guide_enums.h" -#include "components/optimization_guide/core/optimization_guide_util.h" +#include "components/optimization_guide/core/model_util.h" #include "third_party/abseil-cpp/absl/types/optional.h" #include "third_party/tflite/src/tensorflow/lite/c/common.h" #include "third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/base_task_api.h" @@ -34,8 +34,7 @@ class ScopedExecutionStatusResultRecorder { public: explicit ScopedExecutionStatusResultRecorder( proto::OptimizationTarget optimization_target) - : optimization_target_(optimization_target), - start_time_(base::TimeTicks::Now()) {} + : optimization_target_(optimization_target) {} ~ScopedExecutionStatusResultRecorder() { base::UmaHistogramEnumeration( @@ -43,12 +42,6 @@ class ScopedExecutionStatusResultRecorder { optimization_guide::GetStringNameForOptimizationTarget( optimization_target_), status_); - - base::UmaHistogramTimes( - "OptimizationGuide.ModelExecutor.ModelLoadingDuration." + - optimization_guide::GetStringNameForOptimizationTarget( - optimization_target_), - base::TimeTicks::Now() - start_time_); } ExecutionStatus* mutable_status() { return &status_; } @@ -61,9 +54,6 @@ class ScopedExecutionStatusResultRecorder { // The OptimizationTarget of the model being executed. const proto::OptimizationTarget optimization_target_; - // The time at which this instance was constructed. - const base::TimeTicks start_time_; - ExecutionStatus status_ = ExecutionStatus::kUnknown; }; @@ -88,26 +78,26 @@ class TFLiteModelExecutor : public ModelExecutor<OutputType, InputTypes...> { } // Should be called on the same sequence as the ctor, but once called |this| - // must only be used from a background thread/sequence. - void InitializeAndMoveToBackgroundThread( + // must only be used from the |execution_task_runner| thread/sequence. + void InitializeAndMoveToExecutionThread( proto::OptimizationTarget optimization_target, - scoped_refptr<base::SequencedTaskRunner> background_task_runner, + scoped_refptr<base::SequencedTaskRunner> execution_task_runner, scoped_refptr<base::SequencedTaskRunner> reply_task_runner) override { - DCHECK(!background_task_runner_); + DCHECK(!execution_task_runner_); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK_NE(optimization_target, proto::OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN); DETACH_FROM_SEQUENCE(sequence_checker_); optimization_target_ = optimization_target; - background_task_runner_ = background_task_runner; + execution_task_runner_ = execution_task_runner; reply_task_runner_ = reply_task_runner; } // Called when a model file is available to load. Depending on feature flags, // the model may or may not be immediately loaded. void UpdateModelFile(const base::FilePath& file_path) override { - DCHECK(background_task_runner_->RunsTasksInCurrentSequence()); + DCHECK(execution_task_runner_->RunsTasksInCurrentSequence()); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); UnloadModel(); @@ -130,7 +120,7 @@ class TFLiteModelExecutor : public ModelExecutor<OutputType, InputTypes...> { // called. False is the default behavior (see class comment). void SetShouldUnloadModelOnComplete( bool should_unload_model_on_complete) override { - DCHECK(background_task_runner_->RunsTasksInCurrentSequence()); + DCHECK(execution_task_runner_->RunsTasksInCurrentSequence()); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); should_unload_model_on_complete_ = should_unload_model_on_complete; } @@ -142,21 +132,21 @@ class TFLiteModelExecutor : public ModelExecutor<OutputType, InputTypes...> { "OptimizationTarget", optimization_guide::GetStringNameForOptimizationTarget( optimization_target_)); - DCHECK(background_task_runner_->RunsTasksInCurrentSequence()); + DCHECK(execution_task_runner_->RunsTasksInCurrentSequence()); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); loaded_model_.reset(); model_fb_.reset(); } - // Starts the execution of the model. When complete, |ui_callback_on_complete| - // will be run on the UI thread with the output of the model. + // Starts the execution of the model. When complete, |callback_on_complete| + // will be run via |reply_task_runner_| with the output of the model. using ExecutionCallback = base::OnceCallback<void(const absl::optional<OutputType>&)>; - void SendForExecution(ExecutionCallback ui_callback_on_complete, + void SendForExecution(ExecutionCallback callback_on_complete, base::TimeTicks start_time, InputTypes... args) override { - DCHECK(background_task_runner_->RunsTasksInCurrentSequence()); + DCHECK(execution_task_runner_->RunsTasksInCurrentSequence()); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK(reply_task_runner_); @@ -175,7 +165,7 @@ class TFLiteModelExecutor : public ModelExecutor<OutputType, InputTypes...> { if (!loaded_model_ && !LoadModelFile(status_recorder.mutable_status())) { reply_task_runner_->PostTask( FROM_HERE, - base::BindOnce(std::move(ui_callback_on_complete), absl::nullopt)); + base::BindOnce(std::move(callback_on_complete), absl::nullopt)); // Some error status is expected, and derived classes should have set the // status. DCHECK_NE(status_recorder.status(), ExecutionStatus::kUnknown); @@ -213,19 +203,13 @@ class TFLiteModelExecutor : public ModelExecutor<OutputType, InputTypes...> { base::TimeTicks::Now() - execute_start_time); } - DCHECK(ui_callback_on_complete); + DCHECK(callback_on_complete); reply_task_runner_->PostTask( - FROM_HERE, base::BindOnce(std::move(ui_callback_on_complete), output)); + FROM_HERE, base::BindOnce(std::move(callback_on_complete), output)); OnExecutionComplete(); } - // IMPORTANT: These WeakPointers must only be dereferenced on the background - // thread. - base::WeakPtr<TFLiteModelExecutor> GetBackgroundWeakPtr() { - return background_weak_ptr_factory_.GetWeakPtr(); - } - TFLiteModelExecutor(const TFLiteModelExecutor&) = delete; TFLiteModelExecutor& operator=(const TFLiteModelExecutor&) = delete; @@ -252,7 +236,7 @@ class TFLiteModelExecutor : public ModelExecutor<OutputType, InputTypes...> { "OptimizationTarget", optimization_guide::GetStringNameForOptimizationTarget( optimization_target_)); - DCHECK(background_task_runner_->RunsTasksInCurrentSequence()); + DCHECK(execution_task_runner_->RunsTasksInCurrentSequence()); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); UnloadModel(); @@ -267,6 +251,8 @@ class TFLiteModelExecutor : public ModelExecutor<OutputType, InputTypes...> { return false; } + base::TimeTicks loading_start_time = base::TimeTicks::Now(); + std::unique_ptr<base::MemoryMappedFile> model_fb = std::make_unique<base::MemoryMappedFile>(); if (!model_fb->Initialize(*model_file_path_)) { @@ -277,11 +263,28 @@ class TFLiteModelExecutor : public ModelExecutor<OutputType, InputTypes...> { loaded_model_ = BuildModelExecutionTask(model_fb_.get(), out_status); + if (!!loaded_model_) { + // We only want to record successful loading times. + base::UmaHistogramTimes( + "OptimizationGuide.ModelExecutor.ModelLoadingDuration2." + + optimization_guide::GetStringNameForOptimizationTarget( + optimization_target_), + base::TimeTicks::Now() - loading_start_time); + } + + // Local histogram used in integration testing. + base::BooleanHistogram::FactoryGet( + "OptimizationGuide.ModelExecutor.ModelLoadedSuccessfully." + + optimization_guide::GetStringNameForOptimizationTarget( + optimization_target_), + base::Histogram::kNoFlags) + ->Add(!!loaded_model_); + return !!loaded_model_; } void OnExecutionComplete() { - DCHECK(background_task_runner_->RunsTasksInCurrentSequence()); + DCHECK(execution_task_runner_->RunsTasksInCurrentSequence()); DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (should_unload_model_on_complete_) { UnloadModel(); @@ -293,7 +296,7 @@ class TFLiteModelExecutor : public ModelExecutor<OutputType, InputTypes...> { bool should_unload_model_on_complete_ = true; - scoped_refptr<base::SequencedTaskRunner> background_task_runner_; + scoped_refptr<base::SequencedTaskRunner> execution_task_runner_; scoped_refptr<base::SequencedTaskRunner> reply_task_runner_; @@ -320,8 +323,6 @@ class TFLiteModelExecutor : public ModelExecutor<OutputType, InputTypes...> { GUARDED_BY_CONTEXT(sequence_checker_); SEQUENCE_CHECKER(sequence_checker_); - - base::WeakPtrFactory<TFLiteModelExecutor> background_weak_ptr_factory_{this}; }; } // namespace optimization_guide diff --git a/chromium/components/optimization_guide/core/tflite_model_executor_unittest.cc b/chromium/components/optimization_guide/core/tflite_model_executor_unittest.cc index 1c519e51639..c0f5f76f91e 100644 --- a/chromium/components/optimization_guide/core/tflite_model_executor_unittest.cc +++ b/chromium/components/optimization_guide/core/tflite_model_executor_unittest.cc @@ -187,6 +187,11 @@ TEST_F(TFLiteModelExecutorTest, ExecuteWithLoadedModel) { optimization_guide::GetStringNameForOptimizationTarget( proto::OptimizationTarget::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD), ExecutionStatus::kSuccess, 1); + histogram_tester.ExpectUniqueSample( + "OptimizationGuide.ModelExecutor.ModelLoadedSuccessfully." + + optimization_guide::GetStringNameForOptimizationTarget( + proto::OptimizationTarget::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD), + true, 1); } TEST_F(TFLiteModelExecutorTest, ExecuteTwiceWithLoadedModel) { @@ -228,6 +233,11 @@ TEST_F(TFLiteModelExecutorTest, ExecuteTwiceWithLoadedModel) { optimization_guide::GetStringNameForOptimizationTarget( proto::OptimizationTarget::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD), ExecutionStatus::kSuccess, 1); + histogram_tester.ExpectUniqueSample( + "OptimizationGuide.ModelExecutor.ModelLoadedSuccessfully." + + optimization_guide::GetStringNameForOptimizationTarget( + proto::OptimizationTarget::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD), + true, 1); // Second run. run_loop = std::make_unique<base::RunLoop>(); @@ -253,6 +263,11 @@ TEST_F(TFLiteModelExecutorTest, ExecuteTwiceWithLoadedModel) { optimization_guide::GetStringNameForOptimizationTarget( proto::OptimizationTarget::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD), ExecutionStatus::kSuccess, 2); + histogram_tester.ExpectUniqueSample( + "OptimizationGuide.ModelExecutor.ModelLoadedSuccessfully." + + optimization_guide::GetStringNameForOptimizationTarget( + proto::OptimizationTarget::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD), + true, 2); histogram_tester.ExpectTotalCount( "OptimizationGuide.ModelExecutor.TaskExecutionLatency." + diff --git a/chromium/components/optimization_guide/core/tflite_op_resolver.cc b/chromium/components/optimization_guide/core/tflite_op_resolver.cc index 1cd89433031..8674eb8d700 100644 --- a/chromium/components/optimization_guide/core/tflite_op_resolver.cc +++ b/chromium/components/optimization_guide/core/tflite_op_resolver.cc @@ -371,6 +371,10 @@ TFLiteOpResolver::TFLiteOpResolver() { tflite::ops::builtin::Register_BATCH_MATMUL(), /* min_version = */ 1, /* max_version = */ 4); + AddBuiltin(tflite::BuiltinOperator_GELU, + tflite::ops::builtin::Register_GELU(), + /* min_version = */ 1, + /* max_version = */ 2); } } // namespace optimization_guide diff --git a/chromium/components/optimization_guide/features.gni b/chromium/components/optimization_guide/features.gni index 26d60df3f0d..2ed17e3165e 100644 --- a/chromium/components/optimization_guide/features.gni +++ b/chromium/components/optimization_guide/features.gni @@ -13,6 +13,18 @@ declare_args() { # You can set the variable 'build_with_internal_optimization_guide' to true # even in a developer build in args.gn. Setting this variable explicitly to true will # cause your build to fail if the internal files are missing. + # + # If changing the value of this, you MUST also update the following files depending on the + # platform: + # ChromeOS: //lib/chrome_util.py in the Chromite repo (ex: https://crrev.com/c/3437291) + # Linux: Internal archive files. //chrome/installer/linux/common/installer.include handles the + # relevant files not being present. + # Mac: //chrome/installer/mac/signing/parts.py + # Windows: //chrome/installer/mini_installer/chrome.release and internal archive files + # + # The library this pulls in depends on open-source LevelDB which is not supported for Fuchsia. + # Android and iOS should just work but are not included in the set we release for, so we do + # not needlessly increase the binary. build_with_internal_optimization_guide = - is_chrome_branded && !is_android && !is_ios + is_chrome_branded && !is_android && !is_ios && !is_fuchsia } diff --git a/chromium/components/optimization_guide/optimization_guide_internals/resources/BUILD.gn b/chromium/components/optimization_guide/optimization_guide_internals/resources/BUILD.gn new file mode 100644 index 00000000000..22359bed596 --- /dev/null +++ b/chromium/components/optimization_guide/optimization_guide_internals/resources/BUILD.gn @@ -0,0 +1,71 @@ +# Copyright 2022 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +import("//tools/grit/grit_rule.gni") +import("//tools/polymer/html_to_js.gni") +import("//tools/typescript/ts_library.gni") +import("//ui/webui/resources/tools/generate_grd.gni") + +tsc_folder = "tsc" + +grit("resources") { + # These arguments are needed since the grd is generated at build time. + enable_input_discovery_for_gn_analyze = false + source = "$target_gen_dir/resources.grd" + deps = [ ":build_grd" ] + + outputs = [ + "grit/optimization_guide_internals_resources.h", + "grit/optimization_guide_internals_resources_map.cc", + "grit/optimization_guide_internals_resources_map.h", + "optimization_guide_internals_resources.pak", + ] + + output_dir = "$root_gen_dir/components" +} + +generate_grd("build_grd") { + grd_prefix = "optimization_guide_internals" + out_grd = "$target_gen_dir/resources.grd" + deps = [ ":build_ts" ] + manifest_files = [ "$target_gen_dir/tsconfig.manifest" ] + input_files = [ "optimization_guide_internals.html" ] + input_files_base_dir = rebase_path(".", "//") +} + +html_to_js("web_components") { + js_files = [ "optimization_guide_internals.ts" ] +} + +copy("copy_proxy") { + sources = [ "optimization_guide_internals_browser_proxy.ts" ] + outputs = [ "$target_gen_dir/{{source_file_part}}" ] +} + +copy("copy_mojo") { + deps = [ "//components/optimization_guide/optimization_guide_internals/webui:mojo_bindings_webui_js" ] + mojom_folder = "$root_gen_dir/mojom-webui/components/optimization_guide/optimization_guide_internals/webui/" + sources = [ "$mojom_folder/optimization_guide_internals.mojom-webui.js" ] + outputs = [ "$target_gen_dir/{{source_file_part}}" ] +} + +ts_library("build_ts") { + root_dir = "$target_gen_dir" + out_dir = "$target_gen_dir/$tsc_folder" + tsconfig_base = "tsconfig_base.json" + in_files = [ + "optimization_guide_internals.ts", + "optimization_guide_internals_browser_proxy.ts", + "optimization_guide_internals.mojom-webui.js", + ] + deps = [ + "//ui/webui/resources:library", + "//ui/webui/resources/js/browser_command:build_ts", + ] + extra_deps = [ + ":copy_mojo", + ":copy_proxy", + ":web_components", + ] +} diff --git a/chromium/components/optimization_guide/optimization_guide_internals/resources/OWNERS b/chromium/components/optimization_guide/optimization_guide_internals/resources/OWNERS new file mode 100644 index 00000000000..1ee724d06f4 --- /dev/null +++ b/chromium/components/optimization_guide/optimization_guide_internals/resources/OWNERS @@ -0,0 +1 @@ +file://components/optimization_guide/OWNERS diff --git a/chromium/components/optimization_guide/optimization_guide_internals/resources/optimization_guide_internals.html b/chromium/components/optimization_guide/optimization_guide_internals/resources/optimization_guide_internals.html new file mode 100644 index 00000000000..8237d8a8f29 --- /dev/null +++ b/chromium/components/optimization_guide/optimization_guide_internals/resources/optimization_guide_internals.html @@ -0,0 +1,29 @@ +<!doctype html> +<html lang="en" dir="ltr"> + <head> + <style> + .segment { + border: 1px outset black; + margin: 2px 2px 2px 2px; + } + </style> + <meta charset="utf-8"> + <title>Optimization Guide Internals</title> + <meta name="viewport" content="width=device-width"> + <link rel="stylesheet" href="chrome://resources/css/text_defaults.css"> + </head> + <body> + <h1>Optimization Guide Internals - Debug Logs</h1> + <button id="log-messages-dump">Dump</button> + <table id="log-message-container"> + <thead> + <tr> + <th>Time</th> + <th>Source Location</th> + <th>Log Message</th> + </tr> + </thead> + </table> + <script type="module" src="optimization_guide_internals.js"></script> + </body> +</html>
\ No newline at end of file diff --git a/chromium/components/optimization_guide/optimization_guide_internals/resources/optimization_guide_internals.ts b/chromium/components/optimization_guide/optimization_guide_internals/resources/optimization_guide_internals.ts new file mode 100644 index 00000000000..df91b5bcf79 --- /dev/null +++ b/chromium/components/optimization_guide/optimization_guide_internals/resources/optimization_guide_internals.ts @@ -0,0 +1,80 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import {$} from 'chrome://resources/js/util.m.js'; +import {Time} from 'chrome://resources/mojo/mojo/public/mojom/base/time.mojom-webui.js'; + +import {OptimizationGuideInternalsBrowserProxy} from './optimization_guide_internals_browser_proxy.js'; + +// Contains all the log events received when the internals page is open. +const logMessages: + {eventTime: string, sourceLocation: string, message: string}[] = []; + +/** + * Converts a mojo time to a JS time. + * @param {!mojoBase.mojom.Time} mojoTime + * @return {!Date} + */ +function convertMojoTimeToJS(mojoTime: Time) { + // The JS Date() is based off of the number of milliseconds since the + // UNIX epoch (1970-01-01 00::00:00 UTC), while |internalValue| of the + // base::Time (represented in mojom.Time) represents the number of + // microseconds since the Windows FILETIME epoch (1601-01-01 00:00:00 UTC). + // This computes the final JS time by computing the epoch delta and the + // conversion from microseconds to milliseconds. + const windowsEpoch = Date.UTC(1601, 0, 1, 0, 0, 0, 0); + const unixEpoch = Date.UTC(1970, 0, 1, 0, 0, 0, 0); + // |epochDeltaInMs| equals to base::Time::kTimeTToMicrosecondsOffset. + const epochDeltaInMs = unixEpoch - windowsEpoch; + const timeInMs = Number(mojoTime.internalValue) / 1000; + + return new Date(timeInMs - epochDeltaInMs); +} + +/** + * The callback to button#log-messages-dump to save the logs to a file. + */ +function onLogMessagesDump() { + const data = JSON.stringify(logMessages); + const blob = new Blob([data], {'type': 'text/json'}); + const url = URL.createObjectURL(blob); + const filename = 'optimization_guide_internals_logs_dump.json'; + + const a = document.createElement('a'); + a.setAttribute('href', url); + a.setAttribute('download', filename); + + const event = document.createEvent('MouseEvent'); + event.initMouseEvent( + 'click', true, true, window, 0, 0, 0, 0, 0, false, false, false, false, 0, + null); + a.dispatchEvent(event); +} + +function getProxy(): OptimizationGuideInternalsBrowserProxy { + return OptimizationGuideInternalsBrowserProxy.getInstance(); +} + + +function initialize() { + const logMessageContainer = $('log-message-container') as HTMLTableElement; + + $('log-messages-dump').addEventListener('click', onLogMessagesDump); + + getProxy().getCallbackRouter().onLogMessageAdded.addListener( + (eventTime: Time, sourceFile: string, sourceLine: number, + message: string) => { + const eventTimeStr = convertMojoTimeToJS(eventTime).toISOString(); + const sourceLocation = `${sourceFile}(${sourceLine})`; + logMessages.push({eventTime: eventTimeStr, sourceLocation, message}); + if (logMessageContainer) { + const logmessage = logMessageContainer.insertRow(); + logmessage.insertCell().innerHTML = eventTimeStr; + logmessage.insertCell().innerHTML = sourceLocation; + logmessage.insertCell().innerHTML = message; + } + }); +} + +document.addEventListener('DOMContentLoaded', initialize);
\ No newline at end of file diff --git a/chromium/components/optimization_guide/optimization_guide_internals/resources/optimization_guide_internals_browser_proxy.ts b/chromium/components/optimization_guide/optimization_guide_internals/resources/optimization_guide_internals_browser_proxy.ts new file mode 100644 index 00000000000..b7636abbe61 --- /dev/null +++ b/chromium/components/optimization_guide/optimization_guide_internals/resources/optimization_guide_internals_browser_proxy.ts @@ -0,0 +1,26 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +import {PageCallbackRouter, PageHandlerFactory} from './optimization_guide_internals.mojom-webui.js'; + +export class OptimizationGuideInternalsBrowserProxy { + private callbackRouter: PageCallbackRouter; + + constructor() { + this.callbackRouter = new PageCallbackRouter(); + const factory = PageHandlerFactory.getRemote(); + factory.createPageHandler(this.callbackRouter.$.bindNewPipeAndPassRemote()); + } + + static getInstance(): OptimizationGuideInternalsBrowserProxy { + return instance || + (instance = new OptimizationGuideInternalsBrowserProxy()); + } + + getCallbackRouter(): PageCallbackRouter { + return this.callbackRouter; + } +} + +let instance: OptimizationGuideInternalsBrowserProxy|null = null;
\ No newline at end of file diff --git a/chromium/components/optimization_guide/optimization_guide_internals/resources/tsconfig_base.json b/chromium/components/optimization_guide/optimization_guide_internals/resources/tsconfig_base.json new file mode 100644 index 00000000000..cbf406ef813 --- /dev/null +++ b/chromium/components/optimization_guide/optimization_guide_internals/resources/tsconfig_base.json @@ -0,0 +1,6 @@ +{ + "extends": "../../../../tools/typescript/tsconfig_base.json", + "compilerOptions": { + "allowJs": true + } +}
\ No newline at end of file diff --git a/chromium/components/optimization_guide/optimization_guide_internals/webui/BUILD.gn b/chromium/components/optimization_guide/optimization_guide_internals/webui/BUILD.gn new file mode 100644 index 00000000000..f5952e8ff2f --- /dev/null +++ b/chromium/components/optimization_guide/optimization_guide_internals/webui/BUILD.gn @@ -0,0 +1,30 @@ +# Copyright 2022 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +import("//mojo/public/tools/bindings/mojom.gni") + +static_library("webui") { + sources = [ + "optimization_guide_internals_page_handler_impl.cc", + "optimization_guide_internals_page_handler_impl.h", + "optimization_guide_internals_ui.cc", + "optimization_guide_internals_ui.h", + "url_constants.cc", + "url_constants.h", + ] + deps = [ + "//base", + "//components/optimization_guide/core", + "//components/optimization_guide/optimization_guide_internals/resources:resources", + "//components/optimization_guide/optimization_guide_internals/webui:mojo_bindings", + "//third_party/abseil-cpp:absl", + "//ui/base", + "//ui/webui", + ] +} +mojom("mojo_bindings") { + sources = [ "optimization_guide_internals.mojom" ] + webui_module_path = "/" + public_deps = [ "//mojo/public/mojom/base" ] +} diff --git a/chromium/components/optimization_guide/optimization_guide_internals/webui/DEPS b/chromium/components/optimization_guide/optimization_guide_internals/webui/DEPS new file mode 100644 index 00000000000..930de395723 --- /dev/null +++ b/chromium/components/optimization_guide/optimization_guide_internals/webui/DEPS @@ -0,0 +1,7 @@ +include_rules = [ + "+mojo/public/cpp/bindings", + "+ui/base/webui/resource_path.h", + "+ui/webui/mojo_web_ui_controller.h", + "+components/grit/optimization_guide_internals_resources.h", + "+components/grit/optimization_guide_internals_resources_map.h", +] diff --git a/chromium/components/optimization_guide/optimization_guide_internals/webui/OWNERS b/chromium/components/optimization_guide/optimization_guide_internals/webui/OWNERS new file mode 100644 index 00000000000..5245dc478f1 --- /dev/null +++ b/chromium/components/optimization_guide/optimization_guide_internals/webui/OWNERS @@ -0,0 +1,4 @@ +file://components/optimization_guide/OWNERS + +per-file *.mojom=set noparent +per-file *.mojom=file://ipc/SECURITY_OWNERS
\ No newline at end of file diff --git a/chromium/components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals.mojom b/chromium/components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals.mojom new file mode 100644 index 00000000000..b4eec080fa0 --- /dev/null +++ b/chromium/components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals.mojom @@ -0,0 +1,24 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + + +module optimization_guide_internals.mojom; + +import "mojo/public/mojom/base/time.mojom"; + +// Used by the WebUI page to bootstrap bidirectional communication. +interface PageHandlerFactory { + // The WebUI calls this method when the page is first initialized. + CreatePageHandler(pending_remote<Page> page); +}; + +// Renderer-side handler for internal page to process the updates from +// the OptimizationGuide service. +interface Page { + // Notifies the page of a log event from the OptimizationGuide service. + OnLogMessageAdded(mojo_base.mojom.Time event_time, + string source_file, + int64 source_line, + string message); +}; diff --git a/chromium/components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals_page_handler_impl.cc b/chromium/components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals_page_handler_impl.cc new file mode 100644 index 00000000000..40ebfe372f0 --- /dev/null +++ b/chromium/components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals_page_handler_impl.cc @@ -0,0 +1,31 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals_page_handler_impl.h" + +#include "base/time/time.h" + +OptimizationGuideInternalsPageHandlerImpl:: + OptimizationGuideInternalsPageHandlerImpl( + mojo::PendingRemote<optimization_guide_internals::mojom::Page> page, + OptimizationGuideLogger* optimization_guide_logger) + : page_(std::move(page)), + optimization_guide_logger_(optimization_guide_logger) { + if (optimization_guide_logger_) + optimization_guide_logger_->AddObserver(this); +} + +OptimizationGuideInternalsPageHandlerImpl:: + ~OptimizationGuideInternalsPageHandlerImpl() { + if (optimization_guide_logger_) + optimization_guide_logger_->RemoveObserver(this); +} + +void OptimizationGuideInternalsPageHandlerImpl::OnLogMessageAdded( + base::Time event_time, + const std::string& source_file, + int source_line, + const std::string& message) { + page_->OnLogMessageAdded(event_time, source_file, source_line, message); +} diff --git a/chromium/components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals_page_handler_impl.h b/chromium/components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals_page_handler_impl.h new file mode 100644 index 00000000000..6816e3b68ec --- /dev/null +++ b/chromium/components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals_page_handler_impl.h @@ -0,0 +1,44 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef COMPONENTS_OPTIMIZATION_GUIDE_OPTIMIZATION_GUIDE_INTERNALS_WEBUI_OPTIMIZATION_GUIDE_INTERNALS_PAGE_HANDLER_IMPL_H_ +#define COMPONENTS_OPTIMIZATION_GUIDE_OPTIMIZATION_GUIDE_INTERNALS_WEBUI_OPTIMIZATION_GUIDE_INTERNALS_PAGE_HANDLER_IMPL_H_ + +#include <string> + +#include "components/optimization_guide/core/optimization_guide_logger.h" +#include "components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals.mojom.h" +#include "mojo/public/cpp/bindings/remote.h" + +// Handler for the internals page to receive and forward the log messages. +class OptimizationGuideInternalsPageHandlerImpl + : public OptimizationGuideLogger::Observer { + public: + OptimizationGuideInternalsPageHandlerImpl( + mojo::PendingRemote<optimization_guide_internals::mojom::Page> page, + OptimizationGuideLogger* optimization_guide_logger); + ~OptimizationGuideInternalsPageHandlerImpl() override; + + OptimizationGuideInternalsPageHandlerImpl( + const OptimizationGuideInternalsPageHandlerImpl&) = delete; + OptimizationGuideInternalsPageHandlerImpl& operator=( + const OptimizationGuideInternalsPageHandlerImpl&) = delete; + + private: + // optimization_guide::OptimizationGuideLogger::Observer overrides. + void OnLogMessageAdded(base::Time event_time, + const std::string& source_file, + int source_line, + const std::string& message) override; + + mojo::Remote<optimization_guide_internals::mojom::Page> page_; + + // Logger to receive the debug logs from the optimization guide service. Not + // owned. Guaranteed to outlive |this|, since the logger is owned by the + // optimization guide keyed service, while |this| is part of + // RenderFrameHostImpl::WebUIImpl. + raw_ptr<OptimizationGuideLogger> optimization_guide_logger_; +}; + +#endif // COMPONENTS_OPTIMIZATION_GUIDE_OPTIMIZATION_GUIDE_INTERNALS_WEBUI_OPTIMIZATION_GUIDE_INTERNALS_PAGE_HANDLER_IMPL_H_ diff --git a/chromium/components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals_ui.cc b/chromium/components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals_ui.cc new file mode 100644 index 00000000000..e69b01edbff --- /dev/null +++ b/chromium/components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals_ui.cc @@ -0,0 +1,38 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals_ui.h" + +#include "components/grit/optimization_guide_internals_resources.h" +#include "components/grit/optimization_guide_internals_resources_map.h" +#include "components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals_page_handler_impl.h" + +OptimizationGuideInternalsUI::OptimizationGuideInternalsUI( + content::WebUI* web_ui, + OptimizationGuideLogger* optimization_guide_logger, + SetupWebUIDataSourceCallback set_up_data_source_callback) + : MojoWebUIController(web_ui, /*enable_chrome_send=*/true), + optimization_guide_logger_(optimization_guide_logger) { + std::move(set_up_data_source_callback) + .Run(base::make_span(kOptimizationGuideInternalsResources, + kOptimizationGuideInternalsResourcesSize), + IDR_OPTIMIZATION_GUIDE_INTERNALS_OPTIMIZATION_GUIDE_INTERNALS_HTML); +} + +OptimizationGuideInternalsUI::~OptimizationGuideInternalsUI() = default; + +void OptimizationGuideInternalsUI::BindInterface( + mojo::PendingReceiver< + optimization_guide_internals::mojom::PageHandlerFactory> receiver) { + optimization_guide_internals_page_factory_receiver_.Bind(std::move(receiver)); +} + +void OptimizationGuideInternalsUI::CreatePageHandler( + mojo::PendingRemote<optimization_guide_internals::mojom::Page> page) { + optimization_guide_internals_page_handler_ = + std::make_unique<OptimizationGuideInternalsPageHandlerImpl>( + std::move(page), optimization_guide_logger_); +} + +WEB_UI_CONTROLLER_TYPE_IMPL(OptimizationGuideInternalsUI) diff --git a/chromium/components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals_ui.h b/chromium/components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals_ui.h new file mode 100644 index 00000000000..4421de745c8 --- /dev/null +++ b/chromium/components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals_ui.h @@ -0,0 +1,59 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef COMPONENTS_OPTIMIZATION_GUIDE_OPTIMIZATION_GUIDE_INTERNALS_WEBUI_OPTIMIZATION_GUIDE_INTERNALS_UI_H_ +#define COMPONENTS_OPTIMIZATION_GUIDE_OPTIMIZATION_GUIDE_INTERNALS_WEBUI_OPTIMIZATION_GUIDE_INTERNALS_UI_H_ + +#include "components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals.mojom.h" +#include "mojo/public/cpp/bindings/pending_receiver.h" +#include "ui/base/webui/resource_path.h" +#include "ui/webui/mojo_web_ui_controller.h" + +class OptimizationGuideLogger; +class OptimizationGuideInternalsPageHandlerImpl; + +// The WebUI controller for chrome://optimization-guide-internals. +class OptimizationGuideInternalsUI + : public ui::MojoWebUIController, + public optimization_guide_internals::mojom::PageHandlerFactory { + public: + using SetupWebUIDataSourceCallback = + base::OnceCallback<void(base::span<const webui::ResourcePath> resources, + int default_resource)>; + + explicit OptimizationGuideInternalsUI( + content::WebUI* web_ui, + OptimizationGuideLogger* optimization_guide_logger, + SetupWebUIDataSourceCallback set_up_data_source_callback); + ~OptimizationGuideInternalsUI() override; + + OptimizationGuideInternalsUI(const OptimizationGuideInternalsUI&) = delete; + OptimizationGuideInternalsUI& operator=(const OptimizationGuideInternalsUI&) = + delete; + + void BindInterface( + mojo::PendingReceiver< + optimization_guide_internals::mojom::PageHandlerFactory> receiver); + + private: + // optimization_guide_internals::mojom::PageHandlerFactory impls. + void CreatePageHandler( + mojo::PendingRemote<optimization_guide_internals::mojom::Page> page) + override; + + std::unique_ptr<OptimizationGuideInternalsPageHandlerImpl> + optimization_guide_internals_page_handler_; + mojo::Receiver<optimization_guide_internals::mojom::PageHandlerFactory> + optimization_guide_internals_page_factory_receiver_{this}; + + // Logger to receive the debug logs from the optimization guide service. Not + // owned. Guaranteed to outlive |this|, since the logger is owned by the + // optimization guide keyed service, while |this| is part of + // RenderFrameHostImpl::WebUIImpl. + raw_ptr<OptimizationGuideLogger> optimization_guide_logger_; + + WEB_UI_CONTROLLER_TYPE_DECL(); +}; + +#endif // COMPONENTS_OPTIMIZATION_GUIDE_OPTIMIZATION_GUIDE_INTERNALS_WEBUI_OPTIMIZATION_GUIDE_INTERNALS_UI_H_ diff --git a/chromium/components/optimization_guide/optimization_guide_internals/webui/url_constants.cc b/chromium/components/optimization_guide/optimization_guide_internals/webui/url_constants.cc new file mode 100644 index 00000000000..3fc092ccddb --- /dev/null +++ b/chromium/components/optimization_guide/optimization_guide_internals/webui/url_constants.cc @@ -0,0 +1,12 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "components/optimization_guide/optimization_guide_internals/webui/url_constants.h" + +namespace optimization_guide_internals { + +const char kChromeUIOptimizationGuideInternalsHost[] = + "optimization-guide-internals"; + +} diff --git a/chromium/components/optimization_guide/optimization_guide_internals/webui/url_constants.h b/chromium/components/optimization_guide/optimization_guide_internals/webui/url_constants.h new file mode 100644 index 00000000000..14a21360f89 --- /dev/null +++ b/chromium/components/optimization_guide/optimization_guide_internals/webui/url_constants.h @@ -0,0 +1,15 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef COMPONENTS_OPTIMIZATION_GUIDE_OPTIMIZATION_GUIDE_INTERNALS_WEBUI_URL_CONSTANTS_H_ +#define COMPONENTS_OPTIMIZATION_GUIDE_OPTIMIZATION_GUIDE_INTERNALS_WEBUI_URL_CONSTANTS_H_ + +namespace optimization_guide_internals { + +// The host of the optimization guide internals page URL. +extern const char kChromeUIOptimizationGuideInternalsHost[]; + +} // namespace optimization_guide_internals + +#endif // COMPONENTS_OPTIMIZATION_GUIDE_OPTIMIZATION_GUIDE_INTERNALS_WEBUI_URL_CONSTANTS_H_ diff --git a/chromium/components/optimization_guide/proto/BUILD.gn b/chromium/components/optimization_guide/proto/BUILD.gn index 288e9fd2c25..2253201609c 100644 --- a/chromium/components/optimization_guide/proto/BUILD.gn +++ b/chromium/components/optimization_guide/proto/BUILD.gn @@ -17,6 +17,7 @@ proto_library("optimization_guide_proto") { "loading_predictor_metadata.proto", "models.proto", "page_entities_metadata.proto", + "page_entities_model_metadata.proto", "page_topics_model_metadata.proto", "performance_hints_metadata.proto", "public_image_metadata.proto", diff --git a/chromium/components/optimization_guide/proto/hint_cache.proto b/chromium/components/optimization_guide/proto/hint_cache.proto index a27430d3e53..e89dec22aa3 100644 --- a/chromium/components/optimization_guide/proto/hint_cache.proto +++ b/chromium/components/optimization_guide/proto/hint_cache.proto @@ -59,4 +59,6 @@ message StoreEntry { optional PredictionModel prediction_model = 6; // The actual HostModelFeature data. optional HostModelFeatures host_model_features = 7; + // Whether to delete a model once expiry_time_secs is past. + optional bool keep_beyond_valid_duration = 8; } diff --git a/chromium/components/optimization_guide/proto/models.proto b/chromium/components/optimization_guide/proto/models.proto index c69cb127eec..a916421b570 100644 --- a/chromium/components/optimization_guide/proto/models.proto +++ b/chromium/components/optimization_guide/proto/models.proto @@ -200,7 +200,7 @@ message AdditionalModelFile { // Metadata for a prediction model for a specific optimization target. // -// Next ID: 8 +// Next ID: 10 message ModelInfo { reserved 3; @@ -208,8 +208,9 @@ message ModelInfo { optional OptimizationTarget optimization_target = 1; // The version of the model, which is specific to the optimization target. optional int64 version = 2; - // The set of model types the requesting client can use to make predictions. - repeated ModelType supported_model_types = 4; + // The set of model engine versions the requesting client can use to do model + // inference. + repeated ModelEngineVersion supported_model_engine_versions = 4; // The set of host model features that are referenced by the model. // // Note that this should only be populated if part of the response. @@ -222,6 +223,11 @@ message ModelInfo { // This does not need to be sent to the server in the request for an update to // this model. The server will ignore this if sent. repeated AdditionalModelFile additional_files = 7; + // How long the model will remain valid in client storage. If + // |keep_beyond_valid_duration| is true, will be ignored. + optional Duration valid_duration = 8; + // Whether to delete the model once valid_duration has passed. + optional bool keep_beyond_valid_duration = 9; // Mechanism used for model owners to attach metadata to the request or // response. // @@ -267,31 +273,38 @@ enum OptimizationTarget { // Target for determining topics present on a page. // TODO(crbug/1266504): Remove PAGE_TOPICS in favor of this target. OPTIMIZATION_TARGET_PAGE_TOPICS_V2 = 15; + // Target for segmentation: Determine users with low engagement with chrome. + OPTIMIZATION_TARGET_SEGMENTATION_CHROME_LOW_USER_ENGAGEMENT = 16; } -// The types of models that can be evaluated. +// The model engine versions that can be used to do model inference. // // Please only update these enums when a new major version of TFLite rolls. // // For example: v1.2.3 // ^ // Change when this number increments. -enum ModelType { - MODEL_TYPE_UNKNOWN = 0; +enum ModelEngineVersion { + MODEL_ENGINE_VERSION_UNKNOWN = 0; // A decision tree. - MODEL_TYPE_DECISION_TREE = 1; + MODEL_ENGINE_VERSION_DECISION_TREE = 1; // A model using only operations that are supported by TensorflowLite 2.3.0. - MODEL_TYPE_TFLITE_2_3_0 = 2; + MODEL_ENGINE_VERSION_TFLITE_2_3_0 = 2; // A model using only operations that are supported by TensorflowLite 2.3.0 // with updated FULLY_CONNECTED and BATCH_MUL versions for quantized models. - MODEL_TYPE_TFLITE_2_3_0_1 = 3; + MODEL_ENGINE_VERSION_TFLITE_2_3_0_1 = 3; // TensorflowLite version 2.4.2, and a bit more up to internal rev number // 381280669. - MODEL_TYPE_TFLITE_2_4 = 4; + MODEL_ENGINE_VERSION_TFLITE_2_4 = 4; // TensorflowLite version 2.7.*. This is where regular ~HEAD rolls started. - MODEL_TYPE_TFLITE_2_7 = 5; + MODEL_ENGINE_VERSION_TFLITE_2_7 = 5; // A model using only operations that are supported by TensorflowLite 2.8.0. - MODEL_TYPE_TFLITE_2_8 = 6; + MODEL_ENGINE_VERSION_TFLITE_2_8 = 6; + // A model using only operations that are supported by TensorflowLite 2.9.0. + MODEL_ENGINE_VERSION_TFLITE_2_9 = 7; + // A model using only operations that are supported by TensorflowLite 2.9.0. + // This adds GELU to the supported ops in Optimiziation Guide. + MODEL_ENGINE_VERSION_TFLITE_2_9_0_1 = 8; } // A set of model features and the host that it applies to. diff --git a/chromium/components/optimization_guide/proto/page_entities_metadata.proto b/chromium/components/optimization_guide/proto/page_entities_metadata.proto index 96a8479f179..1b5fa334a18 100644 --- a/chromium/components/optimization_guide/proto/page_entities_metadata.proto +++ b/chromium/components/optimization_guide/proto/page_entities_metadata.proto @@ -24,7 +24,18 @@ message Entity { // entities on the page. // // It is only populated for the PAGE_ENTITIES optimization type. +// +// Note that the meaning of metadata here is in relation to a page load. message PageEntitiesMetadata { // A set of entities that are expected to be present on the page. repeated Entity entities = 1; } + +// The metadata associated with an |Entity|. +// +// Each |Entity| has some attached metadata about it which may be stored on +// device for later lookup. Notably, this includes it's human-readable name as +// opposed to the opaque entity_id. +message EntityMetadataStorage { + optional string entity_name = 1; +}
\ No newline at end of file diff --git a/chromium/components/optimization_guide/proto/page_entities_model_metadata.proto b/chromium/components/optimization_guide/proto/page_entities_model_metadata.proto new file mode 100644 index 00000000000..cf239e37352 --- /dev/null +++ b/chromium/components/optimization_guide/proto/page_entities_model_metadata.proto @@ -0,0 +1,25 @@ +// Copyright 2022 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +syntax = "proto2"; + +option optimize_for = LITE_RUNTIME; +option java_package = "org.chromium.components.optimization_guide.proto"; +option java_outer_classname = "PageEntitiesModelMetadataProto"; + +package optimization_guide.proto; + +message PageEntitiesModelMetadata { + // The maximum model format feature flag that is supported. + // + // If sent from the server, this is the maximum model format feature flag the + // returned model supports. If sent from the client, this is the maximum + // model format feature flag the client knows how to evaluate. + optional int32 max_model_format_feature_flag = 1; + + // The slices to load into the entity annotator. + // + // Will only be populated by the server. + repeated string slice = 2; +}
\ No newline at end of file diff --git a/chromium/components/optimization_guide/proto/page_topics_model_metadata.proto b/chromium/components/optimization_guide/proto/page_topics_model_metadata.proto index 59e4a0441be..42516c9358a 100644 --- a/chromium/components/optimization_guide/proto/page_topics_model_metadata.proto +++ b/chromium/components/optimization_guide/proto/page_topics_model_metadata.proto @@ -50,7 +50,24 @@ message PageTopicsOutputPostprocessingParams { optional PageTopicsCategoryPostprocessingParams category_params = 2; } +message Topic { + // The user-visible string of the taxonomy topic. + optional string topic_name = 1; + // The id of the topic. + optional int64 topic_id = 2; +} + +message TopicTaxonomy { + // The version of this specific taxonomy, which is separate from the model + // version. + optional int64 version = 1; + // The topics supported by this taxonomy. + repeated Topic topics = 2; +} + message PageTopicsModelMetadata { + reserved 4; + // The version of the model sent by the server, and thus, will only be // populated by the server. optional int64 version = 1; @@ -64,4 +81,6 @@ message PageTopicsModelMetadata { // populated by the server. optional PageTopicsOutputPostprocessingParams output_postprocessing_params = 3; + // The taxonomy used by this model. + optional TopicTaxonomy topic_taxonomy = 5; } |