// Copyright 2019 The Chromium Authors // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include #include "base/test/task_environment.h" #include "media/learning/impl/learning_task_controller_impl.h" using media::learning::FeatureValue; using media::learning::FeatureVector; using media::learning::LearningTask; using ValueDescription = media::learning::LearningTask::ValueDescription; using media::learning::LearningTaskControllerImpl; using media::learning::ObservationCompletion; using media::learning::TargetValue; ValueDescription ConsumeValueDescription(FuzzedDataProvider* provider) { ValueDescription desc; desc.name = provider->ConsumeRandomLengthString(100); desc.ordering = provider->ConsumeEnum(); desc.privacy_mode = provider->ConsumeEnum(); return desc; } double ConsumeDouble(FuzzedDataProvider* provider) { std::vector v = provider->ConsumeBytes(sizeof(double)); if (v.size() == sizeof(double)) return reinterpret_cast(v.data())[0]; return 0; } FeatureVector ConsumeFeatureVector(FuzzedDataProvider* provider) { FeatureVector features; int n = provider->ConsumeIntegralInRange(0, 100); while (n-- > 0) features.push_back(FeatureValue(ConsumeDouble(provider))); return features; } extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { base::test::TaskEnvironment task_environment; FuzzedDataProvider provider(data, size); LearningTask task; task.name = provider.ConsumeRandomLengthString(100); task.model = provider.ConsumeEnum(); task.use_one_hot_conversion = provider.ConsumeBool(); task.uma_hacky_aggregate_confusion_matrix = provider.ConsumeBool(); task.uma_hacky_by_training_weight_confusion_matrix = provider.ConsumeBool(); task.uma_hacky_by_feature_subset_confusion_matrix = provider.ConsumeBool(); int n_features = provider.ConsumeIntegralInRange(0, 100); int subset_size = provider.ConsumeIntegralInRange(0, n_features); if (subset_size) task.feature_subset_size = subset_size; for (int i = 0; i < n_features; i++) task.feature_descriptions.push_back(ConsumeValueDescription(&provider)); task.target_description = ConsumeValueDescription(&provider); LearningTaskControllerImpl controller(task); // Build random examples. while (provider.remaining_bytes() > 0) { base::UnguessableToken id = base::UnguessableToken::Create(); absl::optional default_target; if (provider.ConsumeBool()) default_target = TargetValue(ConsumeDouble(&provider)); controller.BeginObservation(id, ConsumeFeatureVector(&provider), default_target, absl::nullopt); controller.CompleteObservation( id, ObservationCompletion(TargetValue(ConsumeDouble(&provider)), ConsumeDouble(&provider))); task_environment.RunUntilIdle(); } return 0; }