summaryrefslogtreecommitdiff
path: root/chromium/components/assist_ranker/binary_classifier_predictor.cc
blob: 402aa5931c6a9f59a12cabebf2c0719cf01a8b05 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
// Copyright 2017 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "components/assist_ranker/binary_classifier_predictor.h"

#include <memory>

#include "base/bind.h"
#include "base/bind_helpers.h"
#include "base/files/file_path.h"
#include "components/assist_ranker/generic_logistic_regression_inference.h"
#include "components/assist_ranker/proto/ranker_model.pb.h"
#include "components/assist_ranker/ranker_model.h"
#include "components/assist_ranker/ranker_model_loader_impl.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"

namespace assist_ranker {

BinaryClassifierPredictor::BinaryClassifierPredictor(
    const PredictorConfig& config)
    : BasePredictor(config){};
BinaryClassifierPredictor::~BinaryClassifierPredictor(){};

// static
std::unique_ptr<BinaryClassifierPredictor> BinaryClassifierPredictor::Create(
    const PredictorConfig& config,
    const base::FilePath& model_path,
    scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory) {
  std::unique_ptr<BinaryClassifierPredictor> predictor(
      new BinaryClassifierPredictor(config));
  if (!predictor->is_query_enabled()) {
    DVLOG(1) << "Query disabled, bypassing model loading.";
    return predictor;
  }
  const GURL& model_url = predictor->GetModelUrl();
  DVLOG(1) << "Creating predictor instance for " << predictor->GetModelName();
  DVLOG(1) << "Model URL: " << model_url;
  DVLOG(1) << "Using predict threshold replacement: "
           << predictor->GetPredictThresholdReplacement();
  auto model_loader = std::make_unique<RankerModelLoaderImpl>(
      base::BindRepeating(&BinaryClassifierPredictor::ValidateModel),
      base::BindRepeating(&BinaryClassifierPredictor::OnModelAvailable,
                          base::Unretained(predictor.get())),
      url_loader_factory, model_path, model_url, config.uma_prefix);
  predictor->LoadModel(std::move(model_loader));
  return predictor;
}

bool BinaryClassifierPredictor::Predict(const RankerExample& example,
                                        bool* prediction) {
  if (!IsReady()) {
    DVLOG(1) << "Predictor " << GetModelName() << " not ready for prediction.";
    return false;
  }

  float predict_threshold_replacement = GetPredictThresholdReplacement();
  if (predict_threshold_replacement != kNoPredictThresholdReplacement) {
    *prediction = inference_module_->PredictScore(PreprocessExample(example)) >=
                  predict_threshold_replacement;
  } else {
    *prediction = inference_module_->Predict(PreprocessExample(example));
  }
  DVLOG(1) << "Predictor " << GetModelName() << " predicted: " << *prediction;
  return true;
}

bool BinaryClassifierPredictor::PredictScore(const RankerExample& example,
                                             float* prediction) {
  if (!IsReady()) {
    DVLOG(1) << "Predictor " << GetModelName() << " not ready for prediction.";
    return false;
  }
  *prediction = inference_module_->PredictScore(PreprocessExample(example));
  DVLOG(1) << "Predictor " << GetModelName() << " predicted: " << prediction;
  return true;
}

// static
RankerModelStatus BinaryClassifierPredictor::ValidateModel(
    const RankerModel& model) {
  if (model.proto().model_case() != RankerModelProto::kLogisticRegression) {
    DVLOG(0) << "Model is incompatible.";
    return RankerModelStatus::INCOMPATIBLE;
  }
  const GenericLogisticRegressionModel& glr =
      model.proto().logistic_regression();
  if (glr.is_preprocessed_model()) {
    if (glr.fullname_weights().empty() || !glr.weights().empty()) {
      DVLOG(0) << "Model is incompatible. Preprocessed model should use "
                  "fullname_weights.";
      return RankerModelStatus::INCOMPATIBLE;
    }
    if (!glr.preprocessor_config().feature_indices().empty()) {
      DVLOG(0) << "Preprocessed model doesn't need feature indices.";
      return RankerModelStatus::INCOMPATIBLE;
    }
  } else {
    if (!glr.fullname_weights().empty() || glr.weights().empty()) {
      DVLOG(0) << "Model is incompatible. Non-preprocessed model should use "
                  "weights.";
      return RankerModelStatus::INCOMPATIBLE;
    }
  }
  return RankerModelStatus::OK;
}

bool BinaryClassifierPredictor::Initialize() {
  if (ranker_model_->proto().model_case() ==
      RankerModelProto::kLogisticRegression) {
    inference_module_ = std::make_unique<GenericLogisticRegressionInference>(
        ranker_model_->proto().logistic_regression());
    return true;
  }

  DVLOG(0) << "Could not initialize inference module.";
  return false;
}

}  // namespace assist_ranker