summaryrefslogtreecommitdiff
path: root/chromium/components/segmentation_platform/internal/segmentation_platform_service_impl_unittest.cc
diff options
context:
space:
mode:
Diffstat (limited to 'chromium/components/segmentation_platform/internal/segmentation_platform_service_impl_unittest.cc')
-rw-r--r--chromium/components/segmentation_platform/internal/segmentation_platform_service_impl_unittest.cc75
1 files changed, 71 insertions, 4 deletions
diff --git a/chromium/components/segmentation_platform/internal/segmentation_platform_service_impl_unittest.cc b/chromium/components/segmentation_platform/internal/segmentation_platform_service_impl_unittest.cc
index 6e0f7f98ca6..33f089c9b51 100644
--- a/chromium/components/segmentation_platform/internal/segmentation_platform_service_impl_unittest.cc
+++ b/chromium/components/segmentation_platform/internal/segmentation_platform_service_impl_unittest.cc
@@ -5,6 +5,7 @@
#include "components/segmentation_platform/internal/segmentation_platform_service_impl.h"
#include <string>
+#include <utility>
#include "base/bind.h"
#include "base/files/file_path.h"
@@ -41,13 +42,17 @@
#include "components/segmentation_platform/internal/signals/histogram_signal_handler.h"
#include "components/segmentation_platform/internal/signals/signal_filter_processor.h"
#include "components/segmentation_platform/internal/signals/user_action_signal_handler.h"
+#include "components/segmentation_platform/internal/ukm_data_manager.h"
#include "components/segmentation_platform/public/config.h"
+#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
#if BUILDFLAG(BUILD_WITH_TFLITE_LIB)
#include "components/segmentation_platform/internal/execution/model_execution_manager_impl.h"
#endif // BUILDFLAG(BUILD_WITH_TFLITE_LIB
+using ::testing::_;
+
namespace segmentation_platform {
namespace {
@@ -94,6 +99,19 @@ std::vector<std::unique_ptr<Config>> CreateTestConfigs() {
} // namespace
+// A mock of the ServiceProxy::Observer.
+class MockServiceProxyObserver : public ServiceProxy::Observer {
+ public:
+ MockServiceProxyObserver() = default;
+ ~MockServiceProxyObserver() override = default;
+
+ MOCK_METHOD(void, OnServiceStatusChanged, (bool, int), (override));
+ MOCK_METHOD(void,
+ OnClientInfoAvailable,
+ (const std::vector<ServiceProxy::ClientInfo>& client_info),
+ (override));
+};
+
class SegmentationPlatformServiceImplTest : public testing::Test {
public:
SegmentationPlatformServiceImplTest() = default;
@@ -124,8 +142,11 @@ class SegmentationPlatformServiceImplTest : public testing::Test {
segmentation_platform_service_impl_ =
std::make_unique<SegmentationPlatformServiceImpl>(
std::move(segment_db), std::move(signal_db),
- std::move(segment_storage_config_db), &model_provider_,
- &pref_service_, task_runner_, &test_clock_, std::move(configs));
+ std::move(segment_storage_config_db), &ukm_data_manager_,
+ &model_provider_, &pref_service_, task_runner_, &test_clock_,
+ std::move(configs));
+ segmentation_platform_service_impl_->GetServiceProxy()->AddObserver(
+ &observer_);
}
void TearDown() override {
@@ -137,7 +158,7 @@ class SegmentationPlatformServiceImplTest : public testing::Test {
virtual void SetUpPrefs() {
DictionaryPrefUpdate update(&pref_service_, kSegmentationResultPref);
- base::DictionaryValue* dictionary = update.Get();
+ base::Value* dictionary = update.Get();
base::Value segmentation_result(base::Value::Type::DICTIONARY);
segmentation_result.SetIntKey(
@@ -175,6 +196,20 @@ class SegmentationPlatformServiceImplTest : public testing::Test {
loop.Run();
}
+ void AssertCachedSegment(
+ const std::string& segmentation_key,
+ bool is_ready,
+ OptimizationTarget expected =
+ OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN) {
+ SegmentSelectionResult result;
+ result.is_ready = is_ready;
+ if (is_ready)
+ result.segment = expected;
+ ASSERT_EQ(result,
+ segmentation_platform_service_impl_->GetCachedSegmentResult(
+ segmentation_key));
+ }
+
protected:
base::test::TaskEnvironment task_environment_{
base::test::TaskEnvironment::TimeSource::MOCK_TIME};
@@ -190,12 +225,15 @@ class SegmentationPlatformServiceImplTest : public testing::Test {
optimization_guide::TestOptimizationGuideModelProvider model_provider_;
TestingPrefServiceSimple pref_service_;
base::SimpleTestClock test_clock_;
+ UkmDataManager ukm_data_manager_;
std::unique_ptr<SegmentationPlatformServiceImpl>
segmentation_platform_service_impl_;
+ MockServiceProxyObserver observer_;
};
TEST_F(SegmentationPlatformServiceImplTest, InitializationFlow) {
// Let the DB loading complete successfully.
+ EXPECT_CALL(observer_, OnServiceStatusChanged(true, 7));
segment_db_->InitStatusCallback(leveldb_proto::Enums::InitStatus::kOK);
signal_db_->InitStatusCallback(leveldb_proto::Enums::InitStatus::kOK);
segment_storage_config_db_->InitStatusCallback(
@@ -253,6 +291,16 @@ TEST_F(SegmentationPlatformServiceImplTest, InitializationFlow) {
OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
AssertSelectedSegment(kTestSegmentationKey2, false);
AssertSelectedSegment(kTestSegmentationKey3, false);
+ AssertCachedSegment(
+ kTestSegmentationKey1, true,
+ OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
+ AssertCachedSegment(kTestSegmentationKey2, false);
+ AssertCachedSegment(kTestSegmentationKey3, false);
+
+ // ServiceProxy will load new segment info from the DB.
+ EXPECT_CALL(observer_, OnClientInfoAvailable(_));
+ task_environment_.RunUntilIdle();
+ segment_db_->LoadCallback(true);
mem_impl->OnSegmentationModelUpdated(
OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_VOICE, metadata);
@@ -266,6 +314,11 @@ TEST_F(SegmentationPlatformServiceImplTest, InitializationFlow) {
EXPECT_EQ(
2, histogram_tester.GetBucketCount(
"SegmentationPlatform.Signals.ListeningCount.HistogramValue", 1));
+
+ // ServiceProxy will load new segment info from the DB.
+ EXPECT_CALL(observer_, OnClientInfoAvailable(_));
+ task_environment_.RunUntilIdle();
+ segment_db_->LoadCallback(true);
#endif // BUILDFLAG(BUILD_WITH_TFLITE_LIB)
// Database maintenance tasks should try to cleanup the signals after a short
@@ -278,6 +331,11 @@ TEST_F(SegmentationPlatformServiceImplTest, InitializationFlow) {
OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
AssertSelectedSegment(kTestSegmentationKey2, false);
AssertSelectedSegment(kTestSegmentationKey3, false);
+ AssertCachedSegment(
+ kTestSegmentationKey1, true,
+ OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
+ AssertCachedSegment(kTestSegmentationKey2, false);
+ AssertCachedSegment(kTestSegmentationKey3, false);
}
TEST_F(SegmentationPlatformServiceImplTest,
@@ -302,6 +360,7 @@ class SegmentationPlatformServiceImplEmptyConfigTest
TEST_F(SegmentationPlatformServiceImplEmptyConfigTest, InitializationFlow) {
// Let the DB loading complete successfully.
+ EXPECT_CALL(observer_, OnServiceStatusChanged(true, 7));
segment_db_->InitStatusCallback(leveldb_proto::Enums::InitStatus::kOK);
signal_db_->InitStatusCallback(leveldb_proto::Enums::InitStatus::kOK);
segment_storage_config_db_->InitStatusCallback(
@@ -317,7 +376,7 @@ class SegmentationPlatformServiceImplMultiClientTest
: public SegmentationPlatformServiceImplTest {
void SetUpPrefs() override {
DictionaryPrefUpdate update(&pref_service_, kSegmentationResultPref);
- base::DictionaryValue* dictionary = update.Get();
+ base::Value* dictionary = update.Get();
base::Value segmentation_result(base::Value::Type::DICTIONARY);
segmentation_result.SetIntKey(
@@ -335,6 +394,7 @@ class SegmentationPlatformServiceImplMultiClientTest
TEST_F(SegmentationPlatformServiceImplMultiClientTest, InitializationFlow) {
// Let the DB loading complete successfully.
+ EXPECT_CALL(observer_, OnServiceStatusChanged(true, 7));
segment_db_->InitStatusCallback(leveldb_proto::Enums::InitStatus::kOK);
signal_db_->InitStatusCallback(leveldb_proto::Enums::InitStatus::kOK);
segment_storage_config_db_->InitStatusCallback(
@@ -352,6 +412,13 @@ TEST_F(SegmentationPlatformServiceImplMultiClientTest, InitializationFlow) {
kTestSegmentationKey2, true,
OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_VOICE);
AssertSelectedSegment(kTestSegmentationKey3, false);
+ AssertCachedSegment(
+ kTestSegmentationKey1, true,
+ OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
+ AssertCachedSegment(
+ kTestSegmentationKey2, true,
+ OptimizationTarget::OPTIMIZATION_TARGET_SEGMENTATION_VOICE);
+ AssertCachedSegment(kTestSegmentationKey3, false);
}
} // namespace segmentation_platform