summaryrefslogtreecommitdiff
path: root/chromium/components/segmentation_platform/internal/execution/model_execution_manager_impl.cc
diff options
context:
space:
mode:
Diffstat (limited to 'chromium/components/segmentation_platform/internal/execution/model_execution_manager_impl.cc')
-rw-r--r--chromium/components/segmentation_platform/internal/execution/model_execution_manager_impl.cc181
1 files changed, 32 insertions, 149 deletions
diff --git a/chromium/components/segmentation_platform/internal/execution/model_execution_manager_impl.cc b/chromium/components/segmentation_platform/internal/execution/model_execution_manager_impl.cc
index ab3fbf38619..04f06f24498 100644
--- a/chromium/components/segmentation_platform/internal/execution/model_execution_manager_impl.cc
+++ b/chromium/components/segmentation_platform/internal/execution/model_execution_manager_impl.cc
@@ -21,7 +21,7 @@
#include "components/optimization_guide/proto/models.pb.h"
#include "components/segmentation_platform/internal/database/metadata_utils.h"
#include "components/segmentation_platform/internal/database/signal_database.h"
-#include "components/segmentation_platform/internal/execution/feature_aggregator.h"
+#include "components/segmentation_platform/internal/execution/feature_list_query_processor.h"
#include "components/segmentation_platform/internal/execution/model_execution_manager.h"
#include "components/segmentation_platform/internal/execution/model_execution_status.h"
#include "components/segmentation_platform/internal/execution/segmentation_model_handler.h"
@@ -29,6 +29,7 @@
#include "components/segmentation_platform/internal/proto/model_metadata.pb.h"
#include "components/segmentation_platform/internal/proto/model_prediction.pb.h"
#include "components/segmentation_platform/internal/proto/types.pb.h"
+#include "components/segmentation_platform/internal/segmentation_ukm_helper.h"
#include "components/segmentation_platform/internal/stats.h"
#include "third_party/abseil-cpp/absl/types/optional.h"
#include "third_party/perfetto/include/perfetto/tracing/track.h"
@@ -75,10 +76,7 @@ struct ModelExecutionManagerImpl::ExecutionState {
OptimizationTarget segment_id;
raw_ptr<SegmentationModelHandler> model_handler = nullptr;
ModelExecutionCallback callback;
- base::TimeDelta bucket_duration;
- std::deque<proto::Feature> features;
std::vector<float> input_tensor;
- base::Time end_time;
base::Time total_execution_start_time;
base::Time model_execution_start_time;
};
@@ -97,34 +95,19 @@ ModelExecutionManagerImpl::ModelExecutionTraceEvent::
perfetto::Track::FromPointer(&state));
}
-struct ModelExecutionManagerImpl::FeatureState {
- FeatureState() = default;
- ~FeatureState() = default;
-
- // Disallow copy/assign.
- FeatureState(const FeatureState&) = delete;
- FeatureState& operator=(const FeatureState&) = delete;
-
- proto::SignalType signal_type;
- proto::Aggregation aggregation;
- absl::optional<std::vector<int32_t>> accepted_enum_ids;
- uint64_t bucket_count;
- uint64_t tensor_length;
-};
-
ModelExecutionManagerImpl::ModelExecutionManagerImpl(
const base::flat_set<OptimizationTarget>& segment_ids,
ModelHandlerCreator model_handler_creator,
base::Clock* clock,
SegmentInfoDatabase* segment_database,
SignalDatabase* signal_database,
- std::unique_ptr<FeatureAggregator> feature_aggregator,
+ FeatureListQueryProcessor* feature_list_query_processor,
const SegmentationModelUpdatedCallback& model_updated_callback)
: clock_(clock),
segment_database_(segment_database),
signal_database_(signal_database),
- feature_aggregator_(std::move(feature_aggregator)),
model_updated_callback_(model_updated_callback) {
+ feature_list_query_processor_ = feature_list_query_processor;
for (OptimizationTarget segment_id : segment_ids) {
model_handlers_.emplace(std::make_pair(
segment_id,
@@ -177,135 +160,28 @@ void ModelExecutionManagerImpl::OnSegmentInfoFetchedForExecution(
return;
}
- // The total bucket duration is defined by product of the bucket_duration
- // value and the length of related time_unit field, e.g. 28 * length(DAY).
- const auto& model_metadata = segment_info->model_metadata();
- uint64_t bucket_duration = model_metadata.bucket_duration();
- base::TimeDelta time_unit_len = metadata_utils::GetTimeUnit(model_metadata);
- state->bucket_duration = bucket_duration * time_unit_len;
-
- // Now that we have just fetched the metadata, set the end_time to be shared
- // across all features, so we get a consistent picture.
- state->end_time = clock_->Now();
-
- // Grab the metadata for all the features, which will be processed one at a
- // time, before executing the model.
- for (int i = 0; i < model_metadata.features_size(); ++i)
- state->features.emplace_back(model_metadata.features(i));
-
- // Process all the features in-order, starting with the first feature.
- ProcessFeatures(std::move(state));
-}
-
-void ModelExecutionManagerImpl::ProcessFeatures(
- std::unique_ptr<ExecutionState> state) {
- ModelExecutionTraceEvent trace_event(
- "ModelExecutionManagerImpl::ProcessFeatures", *state);
- // When there are no more features to process, we are done, so we execute the
- // model.
- if (state->features.empty()) {
- ExecuteModel(std::move(state));
- return;
- }
-
- proto::Feature feature;
- do {
- // Copy and pop the next feature.
- feature = state->features.front();
- state->features.pop_front();
-
- // Validate the proto::Feature metadata.
- if (metadata_utils::ValidateMetadataFeature(feature) !=
- metadata_utils::ValidationResult::kValidationSuccess) {
- RunModelExecutionCallback(std::move(state), 0,
- ModelExecutionStatus::kInvalidMetadata);
- return;
- }
- } while (feature.bucket_count() == 0); // Skip collection-only features.
-
- // Capture all relevant metadata for the current proto::Feature into the
- // FeatureState.
- auto feature_state = std::make_unique<FeatureState>();
- feature_state->signal_type = feature.type();
- feature_state->aggregation = feature.aggregation();
- feature_state->bucket_count = feature.bucket_count();
- feature_state->tensor_length = feature.tensor_length();
-
- auto name_hash = feature.name_hash();
-
- // Enum histograms can optionally only accept some of the enum values.
- // While the proto::Feature is available, capture a vector of the accepted
- // enum values. An empty vector is ignored (all values are considered
- // accepted).
- if (feature_state->signal_type == proto::SignalType::HISTOGRAM_ENUM) {
- std::vector<int32_t> accepted_enum_ids{};
- for (int i = 0; i < feature.enum_ids_size(); ++i)
- accepted_enum_ids.emplace_back(feature.enum_ids(i));
-
- feature_state->accepted_enum_ids = absl::make_optional(accepted_enum_ids);
- }
-
- // Only fetch data that is relevant for the current proto::Feature, since
- // the FeatureAggregator assumes that only relevant data is given to it.
- base::TimeDelta duration =
- state->bucket_duration * feature_state->bucket_count;
- base::Time start_time = state->end_time - duration;
-
- // Fetch the relevant samples for the current proto::Feature. Once the result
- // has come back, it will be processed and inserted into the
- // ExecutorState::input_tensor and will then invoke ProcessFeatures(...)
- // again to ensure we continue until all features have been processed.
- // Note: All parameters from the ExecutorState need to be captured locally
- // before invoking GetSamples, because the state is moved with the callback,
- // and the order of the move and accessing the members while invoking
- // GetSamples is not guaranteed.
- auto signal_type = feature_state->signal_type;
- auto end_time = state->end_time;
- signal_database_->GetSamples(
- signal_type, name_hash, start_time, end_time,
- base::BindOnce(&ModelExecutionManagerImpl::OnGetSamplesForFeature,
- weak_ptr_factory_.GetWeakPtr(), std::move(state),
- std::move(feature_state)));
+ OptimizationTarget segment_id = state->segment_id;
+ feature_list_query_processor_->ProcessFeatureList(
+ segment_info->model_metadata(), segment_id, clock_->Now(),
+ base::BindOnce(
+ &ModelExecutionManagerImpl::OnProcessingFeatureListComplete,
+ weak_ptr_factory_.GetWeakPtr(), std::move(state)));
}
-void ModelExecutionManagerImpl::OnGetSamplesForFeature(
+void ModelExecutionManagerImpl::OnProcessingFeatureListComplete(
std::unique_ptr<ExecutionState> state,
- std::unique_ptr<FeatureState> feature_state,
- std::vector<SignalDatabase::Sample> samples) {
- ModelExecutionTraceEvent trace_event(
- "ModelExecutionManagerImpl::OnGetSamplesForFeature", *state);
- base::Time process_start_time = clock_->Now();
- // HISTOGRAM_ENUM features might require us to filter out the result to only
- // keep enum values that match the accepted list. If the accepted list is'
- // empty, all histogram enum values are kept.
- // The SignalDatabase does not currently support this type of data filter,
- // so instead we are doing this here.
- if (feature_state->signal_type == proto::SignalType::HISTOGRAM_ENUM) {
- DCHECK(feature_state->accepted_enum_ids.has_value());
- feature_aggregator_->FilterEnumSamples(*feature_state->accepted_enum_ids,
- samples);
+ bool error,
+ const std::vector<float>& input_tensor) {
+ if (error) {
+ // Validation error occurred on model's metadata.
+ RunModelExecutionCallback(std::move(state), 0,
+ ModelExecutionStatus::kInvalidMetadata);
+ return;
}
+ state->input_tensor.insert(state->input_tensor.end(), input_tensor.begin(),
+ input_tensor.end());
- // We now have all the data required to process a single feature, so we can
- // process it synchronously, and insert it into the
- // ExecutorState::input_tensor so we can later pass it to the ML model
- // executor.
- std::vector<float> feature_data = feature_aggregator_->Process(
- feature_state->signal_type, feature_state->aggregation,
- feature_state->bucket_count, state->end_time, state->bucket_duration,
- samples);
- DCHECK_EQ(feature_state->tensor_length, feature_data.size());
- state->input_tensor.insert(state->input_tensor.end(), feature_data.begin(),
- feature_data.end());
-
- stats::RecordModelExecutionDurationFeatureProcessing(
- state->segment_id, clock_->Now() - process_start_time);
-
- // Continue with the rest of the features.
- base::SequencedTaskRunnerHandle::Get()->PostTask(
- FROM_HERE,
- base::BindOnce(&ModelExecutionManagerImpl::ProcessFeatures,
- weak_ptr_factory_.GetWeakPtr(), std::move(state)));
+ ExecuteModel(std::move(state));
}
void ModelExecutionManagerImpl::ExecuteModel(
@@ -349,6 +225,11 @@ void ModelExecutionManagerImpl::OnModelExecutionComplete(
if (result.has_value()) {
VLOG(1) << "Segmentation model result: " << *result;
stats::RecordModelExecutionResult(state->segment_id, result.value());
+ if (state->model_handler->GetModelInfo()) {
+ SegmentationUkmHelper::GetInstance()->RecordModelExecutionResult(
+ state->segment_id, state->model_handler->GetModelInfo()->GetVersion(),
+ state->input_tensor, result.value());
+ }
RunModelExecutionCallback(std::move(state), *result,
ModelExecutionStatus::kSuccess);
} else {
@@ -380,7 +261,8 @@ void ModelExecutionManagerImpl::OnSegmentationModelUpdated(
return;
}
- // Set or overwrite name hashes for metadata features based on the name field.
+ // Set or overwrite name hashes for metadata features based on the name
+ // field.
metadata_utils::SetFeatureNameHashesFromName(&metadata);
auto validation = metadata_utils::ValidateMetadataAndFeatures(metadata);
@@ -442,8 +324,8 @@ void ModelExecutionManagerImpl::OnSegmentInfoFetchedForModelUpdate(
stats::RecordModelDeliveryMetadataFeatureCount(
segment_id, new_segment_info.model_metadata().features_size());
- // Now that we've merged the old and the new SegmentInfo, we want to store the
- // new version in the database.
+ // Now that we've merged the old and the new SegmentInfo, we want to store
+ // the new version in the database.
segment_database_->UpdateSegment(
segment_id, absl::make_optional(new_segment_info),
base::BindOnce(&ModelExecutionManagerImpl::OnUpdatedSegmentInfoStored,
@@ -459,7 +341,8 @@ void ModelExecutionManagerImpl::OnUpdatedSegmentInfoStored(
if (!success)
return;
- // We are now ready to receive requests for execution, so invoke the callback.
+ // We are now ready to receive requests for execution, so invoke the
+ // callback.
model_updated_callback_.Run(std::move(segment_info));
}