diff options
Diffstat (limited to 'chromium/components/assist_ranker/assist_ranker_service_impl.cc')
-rw-r--r-- | chromium/components/assist_ranker/assist_ranker_service_impl.cc | 28 |
1 files changed, 20 insertions, 8 deletions
diff --git a/chromium/components/assist_ranker/assist_ranker_service_impl.cc b/chromium/components/assist_ranker/assist_ranker_service_impl.cc index 255183a6e8c..5b870b8f578 100644 --- a/chromium/components/assist_ranker/assist_ranker_service_impl.cc +++ b/chromium/components/assist_ranker/assist_ranker_service_impl.cc @@ -3,7 +3,7 @@ // found in the LICENSE file. #include "components/assist_ranker/assist_ranker_service_impl.h" - +#include "base/memory/weak_ptr.h" #include "components/assist_ranker/binary_classifier_predictor.h" #include "components/assist_ranker/ranker_model_loader_impl.h" #include "net/url_request/url_request_context_getter.h" @@ -19,15 +19,27 @@ AssistRankerServiceImpl::AssistRankerServiceImpl( AssistRankerServiceImpl::~AssistRankerServiceImpl() {} -std::unique_ptr<BinaryClassifierPredictor> +base::WeakPtr<BinaryClassifierPredictor> AssistRankerServiceImpl::FetchBinaryClassifierPredictor( - GURL model_url, - const std::string& model_filename, - const std::string& uma_prefix) { + const PredictorConfig& config) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); - return BinaryClassifierPredictor::Create(url_request_context_getter_.get(), - GetModelPath(model_filename), - model_url, uma_prefix); + const std::string& model_name = config.model_name; + auto predictor_it = predictor_map_.find(model_name); + if (predictor_it != predictor_map_.end()) { + DVLOG(1) << "Predictor " << model_name << " already initialized."; + return base::AsWeakPtr( + static_cast<BinaryClassifierPredictor*>(predictor_it->second.get())); + } + + // The predictor does not exist yet, so we create one. + DVLOG(1) << "Initializing predictor: " << model_name; + std::unique_ptr<BinaryClassifierPredictor> predictor = + BinaryClassifierPredictor::Create(config, GetModelPath(model_name), + url_request_context_getter_.get()); + base::WeakPtr<BinaryClassifierPredictor> weak_ptr = + base::AsWeakPtr(predictor.get()); + predictor_map_[model_name] = std::move(predictor); + return weak_ptr; } base::FilePath AssistRankerServiceImpl::GetModelPath( |