diff options
Diffstat (limited to 'chromium/components/segmentation_platform/internal/segmentation_platform_service_impl_unittest.cc')
-rw-r--r-- | chromium/components/segmentation_platform/internal/segmentation_platform_service_impl_unittest.cc | 75 |
1 files changed, 71 insertions, 4 deletions
diff --git a/chromium/components/segmentation_platform/internal/segmentation_platform_service_impl_unittest.cc b/chromium/components/segmentation_platform/internal/segmentation_platform_service_impl_unittest.cc index 6e0f7f98ca6..33f089c9b51 100644 --- a/chromium/components/segmentation_platform/internal/segmentation_platform_service_impl_unittest.cc +++ b/chromium/components/segmentation_platform/internal/segmentation_platform_service_impl_unittest.cc @@ -5,6 +5,7 @@ #include "components/segmentation_platform/internal/segmentation_platform_service_impl.h" #include <string> +#include <utility> #include "base/bind.h" #include "base/files/file_path.h" @@ -41,13 +42,17 @@ #include "components/segmentation_platform/internal/signals/histogram_signal_handler.h" #include "components/segmentation_platform/internal/signals/signal_filter_processor.h" #include "components/segmentation_platform/internal/signals/user_action_signal_handler.h" +#include "components/segmentation_platform/internal/ukm_data_manager.h" #include "components/segmentation_platform/public/config.h" +#include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h" #if BUILDFLAG(BUILD_WITH_TFLITE_LIB) #include "components/segmentation_platform/internal/execution/model_execution_manager_impl.h" #endif // BUILDFLAG(BUILD_WITH_TFLITE_LIB +using ::testing::_; + namespace segmentation_platform { namespace { @@ -94,6 +99,19 @@ std::vector<std::unique_ptr<Config>> CreateTestConfigs() { } // namespace +// A mock of the ServiceProxy::Observer. +class MockServiceProxyObserver : public ServiceProxy::Observer { + public: + MockServiceProxyObserver() = default; + ~MockServiceProxyObserver() override = default; + + MOCK_METHOD(void, OnServiceStatusChanged, (bool, int), (override)); + MOCK_METHOD(void, + OnClientInfoAvailable, + (const std::vector<ServiceProxy::ClientInfo>& client_info), + (override)); +}; + class SegmentationPlatformServiceImplTest : public testing::Test { public: SegmentationPlatformServiceImplTest() = default; @@ -124,8 +142,11 @@ class SegmentationPlatformServiceImplTest : public testing::Test { segmentation_platform_service_impl_ = std::make_unique<SegmentationPlatformServiceImpl>( std::move(segment_db), std::move(signal_db), - std::move(segment_storage_config_db), &model_provider_, - &pref_service_, task_runner_, &test_clock_, std::move(configs)); + std::move(segment_storage_config_db), &ukm_data_manager_, + &model_provider_, &pref_service_, task_runner_, &test_clock_, + std::move(configs)); + segmentation_platform_service_impl_->GetServiceProxy()->AddObserver( + &observer_); } void TearDown() override { @@ -137,7 +158,7 @@ class SegmentationPlatformServiceImplTest : public testing::Test { virtual void SetUpPrefs() { DictionaryPrefUpdate update(&pref_service_, kSegmentationResultPref); - base::DictionaryValue* dictionary = update.Get(); + base::Value* dictionary = update.Get(); base::Value segmentation_result(base::Value::Type::DICTIONARY); segmentation_result.SetIntKey( @@ -175,6 +196,20 @@ class SegmentationPlatformServiceImplTest : public testing::Test { loop.Run(); } + void AssertCachedSegment( + const std::string& segmentation_key, + bool is_ready, + OptimizationTarget expected = + OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN) { + SegmentSelectionResult result; + result.is_ready = is_ready; + if (is_ready) + result.segment = expected; + ASSERT_EQ(result, + segmentation_platform_service_impl_->GetCachedSegmentResult( + segmentation_key)); + } + protected: base::test::TaskEnvironment task_environment_{ base::test::TaskEnvironment::TimeSource::MOCK_TIME}; @@ -190,12 +225,15 @@ class SegmentationPlatformServiceImplTest : public testing::Test { optimization_guide::TestOptimizationGuideModelProvider model_provider_; TestingPrefServiceSimple pref_service_; base::SimpleTestClock test_clock_; + UkmDataManager ukm_data_manager_; std::unique_ptr<SegmentationPlatformServiceImpl> segmentation_platform_service_impl_; + MockServiceProxyObserver observer_; }; TEST_F(SegmentationPlatformServiceImplTest, InitializationFlow) { // Let the DB loading complete successfully. + EXPECT_CALL(observer_, OnServiceStatusChanged(true, 7)); segment_db_->InitStatusCallback(leveldb_proto::Enums::InitStatus::kOK); signal_db_->InitStatusCallback(leveldb_proto::Enums::InitStatus::kOK); segment_storage_config_db_->InitStatusCallback( @@ -253,6 +291,16 @@ TEST_F(SegmentationPlatformServiceImplTest, InitializationFlow) { OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE); AssertSelectedSegment(kTestSegmentationKey2, false); AssertSelectedSegment(kTestSegmentationKey3, false); + AssertCachedSegment( + kTestSegmentationKey1, true, + OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE); + AssertCachedSegment(kTestSegmentationKey2, false); + AssertCachedSegment(kTestSegmentationKey3, false); + + // ServiceProxy will load new segment info from the DB. + EXPECT_CALL(observer_, OnClientInfoAvailable(_)); + task_environment_.RunUntilIdle(); + segment_db_->LoadCallback(true); mem_impl->OnSegmentationModelUpdated( OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_VOICE, metadata); @@ -266,6 +314,11 @@ TEST_F(SegmentationPlatformServiceImplTest, InitializationFlow) { EXPECT_EQ( 2, histogram_tester.GetBucketCount( "SegmentationPlatform.Signals.ListeningCount.HistogramValue", 1)); + + // ServiceProxy will load new segment info from the DB. + EXPECT_CALL(observer_, OnClientInfoAvailable(_)); + task_environment_.RunUntilIdle(); + segment_db_->LoadCallback(true); #endif // BUILDFLAG(BUILD_WITH_TFLITE_LIB) // Database maintenance tasks should try to cleanup the signals after a short @@ -278,6 +331,11 @@ TEST_F(SegmentationPlatformServiceImplTest, InitializationFlow) { OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE); AssertSelectedSegment(kTestSegmentationKey2, false); AssertSelectedSegment(kTestSegmentationKey3, false); + AssertCachedSegment( + kTestSegmentationKey1, true, + OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE); + AssertCachedSegment(kTestSegmentationKey2, false); + AssertCachedSegment(kTestSegmentationKey3, false); } TEST_F(SegmentationPlatformServiceImplTest, @@ -302,6 +360,7 @@ class SegmentationPlatformServiceImplEmptyConfigTest TEST_F(SegmentationPlatformServiceImplEmptyConfigTest, InitializationFlow) { // Let the DB loading complete successfully. + EXPECT_CALL(observer_, OnServiceStatusChanged(true, 7)); segment_db_->InitStatusCallback(leveldb_proto::Enums::InitStatus::kOK); signal_db_->InitStatusCallback(leveldb_proto::Enums::InitStatus::kOK); segment_storage_config_db_->InitStatusCallback( @@ -317,7 +376,7 @@ class SegmentationPlatformServiceImplMultiClientTest : public SegmentationPlatformServiceImplTest { void SetUpPrefs() override { DictionaryPrefUpdate update(&pref_service_, kSegmentationResultPref); - base::DictionaryValue* dictionary = update.Get(); + base::Value* dictionary = update.Get(); base::Value segmentation_result(base::Value::Type::DICTIONARY); segmentation_result.SetIntKey( @@ -335,6 +394,7 @@ class SegmentationPlatformServiceImplMultiClientTest TEST_F(SegmentationPlatformServiceImplMultiClientTest, InitializationFlow) { // Let the DB loading complete successfully. + EXPECT_CALL(observer_, OnServiceStatusChanged(true, 7)); segment_db_->InitStatusCallback(leveldb_proto::Enums::InitStatus::kOK); signal_db_->InitStatusCallback(leveldb_proto::Enums::InitStatus::kOK); segment_storage_config_db_->InitStatusCallback( @@ -352,6 +412,13 @@ TEST_F(SegmentationPlatformServiceImplMultiClientTest, InitializationFlow) { kTestSegmentationKey2, true, OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_VOICE); AssertSelectedSegment(kTestSegmentationKey3, false); + AssertCachedSegment( + kTestSegmentationKey1, true, + OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE); + AssertCachedSegment( + kTestSegmentationKey2, true, + OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_VOICE); + AssertCachedSegment(kTestSegmentationKey3, false); } } // namespace segmentation_platform |