diff options
Diffstat (limited to 'chromium/components/segmentation_platform/internal/execution/model_execution_manager_impl.cc')
-rw-r--r-- | chromium/components/segmentation_platform/internal/execution/model_execution_manager_impl.cc | 181 |
1 files changed, 32 insertions, 149 deletions
diff --git a/chromium/components/segmentation_platform/internal/execution/model_execution_manager_impl.cc b/chromium/components/segmentation_platform/internal/execution/model_execution_manager_impl.cc index ab3fbf38619..04f06f24498 100644 --- a/chromium/components/segmentation_platform/internal/execution/model_execution_manager_impl.cc +++ b/chromium/components/segmentation_platform/internal/execution/model_execution_manager_impl.cc @@ -21,7 +21,7 @@ #include "components/optimization_guide/proto/models.pb.h" #include "components/segmentation_platform/internal/database/metadata_utils.h" #include "components/segmentation_platform/internal/database/signal_database.h" -#include "components/segmentation_platform/internal/execution/feature_aggregator.h" +#include "components/segmentation_platform/internal/execution/feature_list_query_processor.h" #include "components/segmentation_platform/internal/execution/model_execution_manager.h" #include "components/segmentation_platform/internal/execution/model_execution_status.h" #include "components/segmentation_platform/internal/execution/segmentation_model_handler.h" @@ -29,6 +29,7 @@ #include "components/segmentation_platform/internal/proto/model_metadata.pb.h" #include "components/segmentation_platform/internal/proto/model_prediction.pb.h" #include "components/segmentation_platform/internal/proto/types.pb.h" +#include "components/segmentation_platform/internal/segmentation_ukm_helper.h" #include "components/segmentation_platform/internal/stats.h" #include "third_party/abseil-cpp/absl/types/optional.h" #include "third_party/perfetto/include/perfetto/tracing/track.h" @@ -75,10 +76,7 @@ struct ModelExecutionManagerImpl::ExecutionState { OptimizationTarget segment_id; raw_ptr<SegmentationModelHandler> model_handler = nullptr; ModelExecutionCallback callback; - base::TimeDelta bucket_duration; - std::deque<proto::Feature> features; std::vector<float> input_tensor; - base::Time end_time; base::Time total_execution_start_time; base::Time model_execution_start_time; }; @@ -97,34 +95,19 @@ ModelExecutionManagerImpl::ModelExecutionTraceEvent:: perfetto::Track::FromPointer(&state)); } -struct ModelExecutionManagerImpl::FeatureState { - FeatureState() = default; - ~FeatureState() = default; - - // Disallow copy/assign. - FeatureState(const FeatureState&) = delete; - FeatureState& operator=(const FeatureState&) = delete; - - proto::SignalType signal_type; - proto::Aggregation aggregation; - absl::optional<std::vector<int32_t>> accepted_enum_ids; - uint64_t bucket_count; - uint64_t tensor_length; -}; - ModelExecutionManagerImpl::ModelExecutionManagerImpl( const base::flat_set<OptimizationTarget>& segment_ids, ModelHandlerCreator model_handler_creator, base::Clock* clock, SegmentInfoDatabase* segment_database, SignalDatabase* signal_database, - std::unique_ptr<FeatureAggregator> feature_aggregator, + FeatureListQueryProcessor* feature_list_query_processor, const SegmentationModelUpdatedCallback& model_updated_callback) : clock_(clock), segment_database_(segment_database), signal_database_(signal_database), - feature_aggregator_(std::move(feature_aggregator)), model_updated_callback_(model_updated_callback) { + feature_list_query_processor_ = feature_list_query_processor; for (OptimizationTarget segment_id : segment_ids) { model_handlers_.emplace(std::make_pair( segment_id, @@ -177,135 +160,28 @@ void ModelExecutionManagerImpl::OnSegmentInfoFetchedForExecution( return; } - // The total bucket duration is defined by product of the bucket_duration - // value and the length of related time_unit field, e.g. 28 * length(DAY). - const auto& model_metadata = segment_info->model_metadata(); - uint64_t bucket_duration = model_metadata.bucket_duration(); - base::TimeDelta time_unit_len = metadata_utils::GetTimeUnit(model_metadata); - state->bucket_duration = bucket_duration * time_unit_len; - - // Now that we have just fetched the metadata, set the end_time to be shared - // across all features, so we get a consistent picture. - state->end_time = clock_->Now(); - - // Grab the metadata for all the features, which will be processed one at a - // time, before executing the model. - for (int i = 0; i < model_metadata.features_size(); ++i) - state->features.emplace_back(model_metadata.features(i)); - - // Process all the features in-order, starting with the first feature. - ProcessFeatures(std::move(state)); -} - -void ModelExecutionManagerImpl::ProcessFeatures( - std::unique_ptr<ExecutionState> state) { - ModelExecutionTraceEvent trace_event( - "ModelExecutionManagerImpl::ProcessFeatures", *state); - // When there are no more features to process, we are done, so we execute the - // model. - if (state->features.empty()) { - ExecuteModel(std::move(state)); - return; - } - - proto::Feature feature; - do { - // Copy and pop the next feature. - feature = state->features.front(); - state->features.pop_front(); - - // Validate the proto::Feature metadata. - if (metadata_utils::ValidateMetadataFeature(feature) != - metadata_utils::ValidationResult::kValidationSuccess) { - RunModelExecutionCallback(std::move(state), 0, - ModelExecutionStatus::kInvalidMetadata); - return; - } - } while (feature.bucket_count() == 0); // Skip collection-only features. - - // Capture all relevant metadata for the current proto::Feature into the - // FeatureState. - auto feature_state = std::make_unique<FeatureState>(); - feature_state->signal_type = feature.type(); - feature_state->aggregation = feature.aggregation(); - feature_state->bucket_count = feature.bucket_count(); - feature_state->tensor_length = feature.tensor_length(); - - auto name_hash = feature.name_hash(); - - // Enum histograms can optionally only accept some of the enum values. - // While the proto::Feature is available, capture a vector of the accepted - // enum values. An empty vector is ignored (all values are considered - // accepted). - if (feature_state->signal_type == proto::SignalType::HISTOGRAM_ENUM) { - std::vector<int32_t> accepted_enum_ids{}; - for (int i = 0; i < feature.enum_ids_size(); ++i) - accepted_enum_ids.emplace_back(feature.enum_ids(i)); - - feature_state->accepted_enum_ids = absl::make_optional(accepted_enum_ids); - } - - // Only fetch data that is relevant for the current proto::Feature, since - // the FeatureAggregator assumes that only relevant data is given to it. - base::TimeDelta duration = - state->bucket_duration * feature_state->bucket_count; - base::Time start_time = state->end_time - duration; - - // Fetch the relevant samples for the current proto::Feature. Once the result - // has come back, it will be processed and inserted into the - // ExecutorState::input_tensor and will then invoke ProcessFeatures(...) - // again to ensure we continue until all features have been processed. - // Note: All parameters from the ExecutorState need to be captured locally - // before invoking GetSamples, because the state is moved with the callback, - // and the order of the move and accessing the members while invoking - // GetSamples is not guaranteed. - auto signal_type = feature_state->signal_type; - auto end_time = state->end_time; - signal_database_->GetSamples( - signal_type, name_hash, start_time, end_time, - base::BindOnce(&ModelExecutionManagerImpl::OnGetSamplesForFeature, - weak_ptr_factory_.GetWeakPtr(), std::move(state), - std::move(feature_state))); + OptimizationTarget segment_id = state->segment_id; + feature_list_query_processor_->ProcessFeatureList( + segment_info->model_metadata(), segment_id, clock_->Now(), + base::BindOnce( + &ModelExecutionManagerImpl::OnProcessingFeatureListComplete, + weak_ptr_factory_.GetWeakPtr(), std::move(state))); } -void ModelExecutionManagerImpl::OnGetSamplesForFeature( +void ModelExecutionManagerImpl::OnProcessingFeatureListComplete( std::unique_ptr<ExecutionState> state, - std::unique_ptr<FeatureState> feature_state, - std::vector<SignalDatabase::Sample> samples) { - ModelExecutionTraceEvent trace_event( - "ModelExecutionManagerImpl::OnGetSamplesForFeature", *state); - base::Time process_start_time = clock_->Now(); - // HISTOGRAM_ENUM features might require us to filter out the result to only - // keep enum values that match the accepted list. If the accepted list is' - // empty, all histogram enum values are kept. - // The SignalDatabase does not currently support this type of data filter, - // so instead we are doing this here. - if (feature_state->signal_type == proto::SignalType::HISTOGRAM_ENUM) { - DCHECK(feature_state->accepted_enum_ids.has_value()); - feature_aggregator_->FilterEnumSamples(*feature_state->accepted_enum_ids, - samples); + bool error, + const std::vector<float>& input_tensor) { + if (error) { + // Validation error occurred on model's metadata. + RunModelExecutionCallback(std::move(state), 0, + ModelExecutionStatus::kInvalidMetadata); + return; } + state->input_tensor.insert(state->input_tensor.end(), input_tensor.begin(), + input_tensor.end()); - // We now have all the data required to process a single feature, so we can - // process it synchronously, and insert it into the - // ExecutorState::input_tensor so we can later pass it to the ML model - // executor. - std::vector<float> feature_data = feature_aggregator_->Process( - feature_state->signal_type, feature_state->aggregation, - feature_state->bucket_count, state->end_time, state->bucket_duration, - samples); - DCHECK_EQ(feature_state->tensor_length, feature_data.size()); - state->input_tensor.insert(state->input_tensor.end(), feature_data.begin(), - feature_data.end()); - - stats::RecordModelExecutionDurationFeatureProcessing( - state->segment_id, clock_->Now() - process_start_time); - - // Continue with the rest of the features. - base::SequencedTaskRunnerHandle::Get()->PostTask( - FROM_HERE, - base::BindOnce(&ModelExecutionManagerImpl::ProcessFeatures, - weak_ptr_factory_.GetWeakPtr(), std::move(state))); + ExecuteModel(std::move(state)); } void ModelExecutionManagerImpl::ExecuteModel( @@ -349,6 +225,11 @@ void ModelExecutionManagerImpl::OnModelExecutionComplete( if (result.has_value()) { VLOG(1) << "Segmentation model result: " << *result; stats::RecordModelExecutionResult(state->segment_id, result.value()); + if (state->model_handler->GetModelInfo()) { + SegmentationUkmHelper::GetInstance()->RecordModelExecutionResult( + state->segment_id, state->model_handler->GetModelInfo()->GetVersion(), + state->input_tensor, result.value()); + } RunModelExecutionCallback(std::move(state), *result, ModelExecutionStatus::kSuccess); } else { @@ -380,7 +261,8 @@ void ModelExecutionManagerImpl::OnSegmentationModelUpdated( return; } - // Set or overwrite name hashes for metadata features based on the name field. + // Set or overwrite name hashes for metadata features based on the name + // field. metadata_utils::SetFeatureNameHashesFromName(&metadata); auto validation = metadata_utils::ValidateMetadataAndFeatures(metadata); @@ -442,8 +324,8 @@ void ModelExecutionManagerImpl::OnSegmentInfoFetchedForModelUpdate( stats::RecordModelDeliveryMetadataFeatureCount( segment_id, new_segment_info.model_metadata().features_size()); - // Now that we've merged the old and the new SegmentInfo, we want to store the - // new version in the database. + // Now that we've merged the old and the new SegmentInfo, we want to store + // the new version in the database. segment_database_->UpdateSegment( segment_id, absl::make_optional(new_segment_info), base::BindOnce(&ModelExecutionManagerImpl::OnUpdatedSegmentInfoStored, @@ -459,7 +341,8 @@ void ModelExecutionManagerImpl::OnUpdatedSegmentInfoStored( if (!success) return; - // We are now ready to receive requests for execution, so invoke the callback. + // We are now ready to receive requests for execution, so invoke the + // callback. model_updated_callback_.Run(std::move(segment_info)); } |