summaryrefslogtreecommitdiff
path: root/chromium/components/assist_ranker/assist_ranker_service_impl.cc
diff options
context:
space:
mode:
Diffstat (limited to 'chromium/components/assist_ranker/assist_ranker_service_impl.cc')
-rw-r--r--chromium/components/assist_ranker/assist_ranker_service_impl.cc28
1 files changed, 20 insertions, 8 deletions
diff --git a/chromium/components/assist_ranker/assist_ranker_service_impl.cc b/chromium/components/assist_ranker/assist_ranker_service_impl.cc
index 255183a6e8c..5b870b8f578 100644
--- a/chromium/components/assist_ranker/assist_ranker_service_impl.cc
+++ b/chromium/components/assist_ranker/assist_ranker_service_impl.cc
@@ -3,7 +3,7 @@
// found in the LICENSE file.
#include "components/assist_ranker/assist_ranker_service_impl.h"
-
+#include "base/memory/weak_ptr.h"
#include "components/assist_ranker/binary_classifier_predictor.h"
#include "components/assist_ranker/ranker_model_loader_impl.h"
#include "net/url_request/url_request_context_getter.h"
@@ -19,15 +19,27 @@ AssistRankerServiceImpl::AssistRankerServiceImpl(
AssistRankerServiceImpl::~AssistRankerServiceImpl() {}
-std::unique_ptr<BinaryClassifierPredictor>
+base::WeakPtr<BinaryClassifierPredictor>
AssistRankerServiceImpl::FetchBinaryClassifierPredictor(
- GURL model_url,
- const std::string& model_filename,
- const std::string& uma_prefix) {
+ const PredictorConfig& config) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
- return BinaryClassifierPredictor::Create(url_request_context_getter_.get(),
- GetModelPath(model_filename),
- model_url, uma_prefix);
+ const std::string& model_name = config.model_name;
+ auto predictor_it = predictor_map_.find(model_name);
+ if (predictor_it != predictor_map_.end()) {
+ DVLOG(1) << "Predictor " << model_name << " already initialized.";
+ return base::AsWeakPtr(
+ static_cast<BinaryClassifierPredictor*>(predictor_it->second.get()));
+ }
+
+ // The predictor does not exist yet, so we create one.
+ DVLOG(1) << "Initializing predictor: " << model_name;
+ std::unique_ptr<BinaryClassifierPredictor> predictor =
+ BinaryClassifierPredictor::Create(config, GetModelPath(model_name),
+ url_request_context_getter_.get());
+ base::WeakPtr<BinaryClassifierPredictor> weak_ptr =
+ base::AsWeakPtr(predictor.get());
+ predictor_map_[model_name] = std::move(predictor);
+ return weak_ptr;
}
base::FilePath AssistRankerServiceImpl::GetModelPath(