summaryrefslogtreecommitdiff
path: root/chromium/media/learning/impl/one_hot.cc
blob: 2436385dcd18b06946bdeb1707c6b7630ffe2c98 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
// Copyright 2018 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "media/learning/impl/one_hot.h"

#include <set>

namespace media {
namespace learning {

OneHotConverter::OneHotConverter(const LearningTask& task,
                                 const TrainingData& training_data)
    : converted_task_(task) {
  converted_task_.feature_descriptions.clear();

  // store
  converters_.resize(task.feature_descriptions.size());

  for (size_t i = 0; i < task.feature_descriptions.size(); i++) {
    const LearningTask::ValueDescription& feature =
        task.feature_descriptions[i];

    // If this is already a numeric feature, then we will copy it since
    // converters[i] will be unset.
    if (feature.ordering == LearningTask::Ordering::kNumeric) {
      converted_task_.feature_descriptions.push_back(feature);
      continue;
    }

    ProcessOneFeature(i, feature, training_data);
  }
}

OneHotConverter::~OneHotConverter() = default;

TrainingData OneHotConverter::Convert(const TrainingData& training_data) const {
  TrainingData converted_training_data;
  for (auto& example : training_data) {
    LabelledExample converted_example(example);
    converted_example.features = Convert(example.features);
    converted_training_data.push_back(converted_example);
  }

  return converted_training_data;
}

FeatureVector OneHotConverter::Convert(
    const FeatureVector& feature_vector) const {
  FeatureVector converted_feature_vector;
  converted_feature_vector.reserve(converted_task_.feature_descriptions.size());
  for (size_t i = 0; i < converters_.size(); i++) {
    auto& converter = converters_[i];
    if (!converter) {
      // There's no conversion needed for this feature, since it was numeric.
      converted_feature_vector.push_back(feature_vector[i]);
      continue;
    }

    // Convert this feature to a one-hot vector.
    const size_t vector_size = converter->size();

    // Start with a zero-hot vector.  Is that a thing?
    for (size_t v = 0; v < vector_size; v++)
      converted_feature_vector.push_back(FeatureValue(0));

    // Set the appropriate entry to 1, if any.  Otherwise, this is a
    // previously unseen value and all of them should be zero.
    auto iter = converter->find(feature_vector[i]);
    if (iter != converter->end())
      converted_feature_vector[iter->second] = FeatureValue(1);
  }

  return converted_feature_vector;
}

void OneHotConverter::ProcessOneFeature(
    size_t index,
    const LearningTask::ValueDescription& original_description,
    const TrainingData& training_data) {
  // Collect all the distinct values for |index|.
  std::set<Value> values;
  for (auto& example : training_data) {
    DCHECK_GE(example.features.size(), index);
    values.insert(example.features[index]);
  }

  // We let the set's ordering be the one-hot value.  It doesn't really matter
  // as long as we don't change it once we pick it.
  ValueVectorIndexMap value_map;
  // Vector index that should be set to one for each distinct value.  This will
  // start at the next feature in the adjusted task.
  size_t next_vector_index = converted_task_.feature_descriptions.size();

  // Add one feature for each value, and construct a map from value to the
  // feature index that should be 1 when the feature takes that value.
  for (auto& value : values) {
    LearningTask::ValueDescription adjusted_description = original_description;
    adjusted_description.ordering = LearningTask::Ordering::kNumeric;
    converted_task_.feature_descriptions.push_back(adjusted_description);
    // |value| will converted into a 1 in the |next_vector_index|-th feature.
    value_map[value] = next_vector_index++;
  }

  // Record |values| for the |index|-th original feature.
  converters_[index] = std::move(value_map);
}

ConvertingModel::ConvertingModel(std::unique_ptr<OneHotConverter> converter,
                                 std::unique_ptr<Model> model)
    : converter_(std::move(converter)), model_(std::move(model)) {}
ConvertingModel::~ConvertingModel() = default;

TargetHistogram ConvertingModel::PredictDistribution(
    const FeatureVector& instance) {
  FeatureVector converted_instance = converter_->Convert(instance);
  return model_->PredictDistribution(converted_instance);
}

}  // namespace learning
}  // namespace media