summaryrefslogtreecommitdiff
path: root/chromium/media/learning/impl/random_tree_trainer.h
blob: 2d4ecfcff3ccc206b70ba0b1b4c8e1db8516e18c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef MEDIA_LEARNING_IMPL_RANDOM_TREE_TRAINER_H_
#define MEDIA_LEARNING_IMPL_RANDOM_TREE_TRAINER_H_

#include <limits>
#include <map>
#include <memory>
#include <set>

#include "base/component_export.h"
#include "media/learning/common/learning_task.h"
#include "media/learning/impl/random_number_generator.h"
#include "media/learning/impl/training_algorithm.h"

namespace media {
namespace learning {

// Trains RandomTree decision tree classifier / regressor.
//
// Decision trees, including RandomTree, classify instances as follows.  Each
// non-leaf node is marked with a feature number |i|.  The value of the |i|-th
// feature of the instance is then used to select which outgoing edge is
// traversed.  This repeats until arriving at a leaf, which has a distribution
// over target values that is the prediction.  The tree structure, including the
// feature index at each node and distribution at each leaf, is chosen once when
// the tree is trained.
//
// Training involves starting with a set of training examples, each of which has
// features and a target value.  The tree is constructed recursively, starting
// with the root.  For the node being constructed, the training algorithm is
// given the portion of the training set that would reach the node, if it were
// sent down the tree in a similar fashion as described above.  It then
// considers assigning each (unused) feature index as the index to split the
// training examples at this node.  For each index |t|, it groups the training
// set into subsets, each of which consists of those examples with the same
// of the |i|-th feature.  It then computes a score for the split using the
// target values that ended up in each group.  The index with the best score is
// chosen for the split.
//
// For nominal features, we split the feature into all of its nominal values.
// This is somewhat nonstandard; one would normally convert to one-hot numeric
// features first.  See OneHotConverter if you'd like to do this.
//
// For numeric features, we choose a split point uniformly at random between its
// min and max values in the training data.  We do this because it's suitable
// for extra trees.  RandomForest trees want to select the best split point for
// each feature, rather than uniformly.  Either way, of course, we choose the
// best split among the (feature, split point) pairs we're considering.
//
// Also note that for one-hot features, these are the same thing.  So, this
// implementation is suitable for extra trees with numeric (possibly one hot)
// features, or RF with one-hot nominal features.  Note that non-one-hot nominal
// features probably work fine with RF too.  Numeric, non-binary features don't
// work with RF, unless one changes the split point selection.
//
// The training algorithm then recurses to build child nodes.  One child node is
// created for each observed value of the |i|-th feature in the training set.
// The child node is trained using the subset of the training set that shares
// that node's value for feature |i|.
//
// The above is a generic decision tree training algorithm.  A RandomTree
// differs from that mostly in how it selects the feature to split at each node
// during training.  Rather than computing a score for each feature, a
// RandomTree chooses a random subset of the features and only compares those.
//
// See https://en.wikipedia.org/wiki/Random_forest for information.  Note that
// this is just a single tree, not the whole forest.
//
// Note that this variant chooses split points randomly, as described by the
// ExtraTrees algorithm.  This is slightly different than RandomForest, which
// chooses split points to improve the split's score.
class COMPONENT_EXPORT(LEARNING_IMPL) RandomTreeTrainer
    : public TrainingAlgorithm,
      public HasRandomNumberGenerator {
 public:
  explicit RandomTreeTrainer(RandomNumberGenerator* rng = nullptr);

  RandomTreeTrainer(const RandomTreeTrainer&) = delete;
  RandomTreeTrainer& operator=(const RandomTreeTrainer&) = delete;

  ~RandomTreeTrainer() override;

  // Train on all examples.  Calls |model_cb| with the trained model, which
  // won't happen before this returns.
  void Train(const LearningTask& task,
             const TrainingData& examples,
             TrainedModelCB model_cb) override;

 private:
  // Train on the subset |training_idx|.
  std::unique_ptr<Model> Train(const LearningTask& task,
                               const TrainingData& examples,
                               const std::vector<size_t>& training_idx);

  // Set of feature indices.
  using FeatureSet = std::set<int>;

  // Information about a proposed split, and the training sets that would result
  // from that split.
  struct Split {
    Split();
    explicit Split(int index);

    Split(const Split&) = delete;
    Split& operator=(const Split&) = delete;

    Split(Split&& rhs);

    ~Split();

    Split& operator=(Split&& rhs);

    // Feature index to split on.
    size_t split_index = 0;

    // For numeric splits, branch 0 is <= |split_point|, and 1 is > .
    FeatureValue split_point;

    // Expected nats needed to compute the class, given that we're at this
    // node in the tree.
    // "nat" == entropy measured with natural log rather than base-2.
    double nats_remaining = std::numeric_limits<double>::infinity();

    // Per-branch (i.e. per-child node) information about this split.
    struct BranchInfo {
      explicit BranchInfo();
      BranchInfo(const BranchInfo& rhs) = delete;
      BranchInfo(BranchInfo&& rhs);
      ~BranchInfo();

      BranchInfo& operator=(const BranchInfo& rhs) = delete;
      BranchInfo& operator=(BranchInfo&& rhs) = delete;

      // Training set for this branch of the split.  |training_idx| holds the
      // indices that we're using out of our training data.
      std::vector<size_t> training_idx;

      // Number of occurances of each target value in |training_data| along this
      // branch of the split.
      // This is a flat_map since we're likely to have a very small (e.g.,
      // "true / "false") number of targets.
      TargetHistogram target_histogram;
    };

    // [feature value at this split] = info about which examples take this
    // branch of the split.
    std::map<FeatureValue, BranchInfo> branch_infos;
  };

  // Build this node from |training_data|.  |used_set| is the set of features
  // that we already used higher in the tree.
  std::unique_ptr<Model> Build(const LearningTask& task,
                               const TrainingData& training_data,
                               const std::vector<size_t>& training_idx,
                               const FeatureSet& used_set);

  // Compute and return a split of |training_data| on the |index|-th feature.
  Split ConstructSplit(const LearningTask& task,
                       const TrainingData& training_data,
                       const std::vector<size_t>& training_idx,
                       int index);

  // Fill in |nats_remaining| for |split| for a nominal target.
  // |total_incoming_weight| is the total weight of all instances coming into
  // the node that we're splitting.
  void ComputeSplitScore_Nominal(Split* split, double total_incoming_weight);

  // Fill in |nats_remaining| for |split| for a numeric target.
  void ComputeSplitScore_Numeric(Split* split, double total_incoming_weight);

  // Compute the split point for |training_data| for a nominal feature.
  FeatureValue FindSplitPoint_Nominal(size_t index,
                                      const TrainingData& training_data,
                                      const std::vector<size_t>& training_idx);

  // Compute the split point for |training_data| for a numeric feature.
  FeatureValue FindSplitPoint_Numeric(size_t index,
                                      const TrainingData& training_data,
                                      const std::vector<size_t>& training_idx);
};

}  // namespace learning
}  // namespace media

#endif  // MEDIA_LEARNING_IMPL_RANDOM_TREE_TRAINER_H_