summaryrefslogtreecommitdiff
path: root/chromium/third_party/blink/renderer/platform/graphics/dark_mode_generic_classifier.cc
blob: f14ba3721b8c2303774e46208927de20c6f6656f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "third_party/blink/renderer/platform/graphics/dark_mode_generic_classifier.h"

#include "third_party/blink/renderer/platform/graphics/darkmode/darkmode_classifier.h"
#include "third_party/blink/renderer/platform/graphics/image.h"

namespace blink {
namespace {

// Decision tree lower and upper thresholds for grayscale and color images.
const float kLowColorCountThreshold[2] = {0.8125, 0.015137};
const float kHighColorCountThreshold[2] = {1, 0.025635};

DarkModeClassification ClassifyUsingDecisionTree(
    const DarkModeImageClassifier::Features& features) {
  float low_color_count_threshold =
      kLowColorCountThreshold[features.is_colorful];
  float high_color_count_threshold =
      kHighColorCountThreshold[features.is_colorful];

  // Very few colors means it's not a photo, apply the filter.
  if (features.color_buckets_ratio < low_color_count_threshold)
    return DarkModeClassification::kApplyFilter;

  // Too many colors means it's probably photorealistic, do not apply it.
  if (features.color_buckets_ratio > high_color_count_threshold)
    return DarkModeClassification::kDoNotApplyFilter;

  // In-between, decision tree cannot give a precise result.
  return DarkModeClassification::kNotClassified;
}

// The neural network expects these features to be in a specific order within
// the vector. Do not change the order here without also changing the neural
// network code!
Vector<float> ToVector(const DarkModeImageClassifier::Features& features) {
  return {features.is_colorful, features.color_buckets_ratio,
          features.transparency_ratio, features.background_ratio,
          features.is_svg};
}

}  // namespace

DarkModeGenericClassifier::DarkModeGenericClassifier() {}

DarkModeClassification DarkModeGenericClassifier::ClassifyWithFeatures(
    const Features& features) {
  DarkModeClassification result = ClassifyUsingDecisionTree(features);

  // If decision tree cannot decide, we use a neural network to decide whether
  // to filter or not based on all the features.
  if (result == DarkModeClassification::kNotClassified) {
    darkmode_tfnative_model::FixedAllocations nn_temp;
    float nn_out;
    auto feature_vector = ToVector(features);
    darkmode_tfnative_model::Inference(&feature_vector[0], &nn_out, &nn_temp);
    result = nn_out > 0 ? DarkModeClassification::kApplyFilter
                        : DarkModeClassification::kDoNotApplyFilter;
  }

  return result;
}

DarkModeClassification
DarkModeGenericClassifier::ClassifyUsingDecisionTreeForTesting(
    const Features& features) {
  return ClassifyUsingDecisionTree(features);
}

}  // namespace blink