diff options
author | Allan Sandfeld Jensen <allan.jensen@qt.io> | 2020-01-20 13:40:20 +0100 |
---|---|---|
committer | Allan Sandfeld Jensen <allan.jensen@qt.io> | 2020-01-22 12:41:23 +0000 |
commit | 7961cea6d1041e3e454dae6a1da660b453efd238 (patch) | |
tree | c0eeb4a9ff9ba32986289c1653d9608e53ccb444 /chromium/third_party/openscreen | |
parent | b7034d0803538058e5c9d904ef03cf5eab34f6ef (diff) | |
download | qtwebengine-chromium-7961cea6d1041e3e454dae6a1da660b453efd238.tar.gz |
BASELINE: Update Chromium to 78.0.3904.130
Change-Id: If185e0c0061b3437531c97c9c8c78f239352a68b
Reviewed-by: Allan Sandfeld Jensen <allan.jensen@qt.io>
Diffstat (limited to 'chromium/third_party/openscreen')
159 files changed, 4807 insertions, 1787 deletions
diff --git a/chromium/third_party/openscreen/src/BUILD.gn b/chromium/third_party/openscreen/src/BUILD.gn index 7fe55161abc..68d5b26293f 100644 --- a/chromium/third_party/openscreen/src/BUILD.gn +++ b/chromium/third_party/openscreen/src/BUILD.gn @@ -39,6 +39,7 @@ executable("openscreen_unittests") { testonly = true deps = [ "cast/common:mdns_unittests", + "cast/common/certificate:unittests", "osp:osp_unittests", "osp/impl/discovery/mdns:mdns_unittests", "osp/msgs:unittests", diff --git a/chromium/third_party/openscreen/src/DEPS b/chromium/third_party/openscreen/src/DEPS index 686d060bdc2..627d67f0308 100644 --- a/chromium/third_party/openscreen/src/DEPS +++ b/chromium/third_party/openscreen/src/DEPS @@ -149,6 +149,10 @@ include_rules = [ '+absl/types/span.h', '+absl/types/variant.h', + # Similar to abseil, don't include boringssl using root path. Instead, + # explicitly allow 'openssl' where needed. + '-third_party/boringssl', + # Test framework includes. "-third_party/googletest", "+gtest", diff --git a/chromium/third_party/openscreen/src/PRESUBMIT.sh b/chromium/third_party/openscreen/src/PRESUBMIT.sh index 790841d1e45..9b911ab331b 100755 --- a/chromium/third_party/openscreen/src/PRESUBMIT.sh +++ b/chromium/third_party/openscreen/src/PRESUBMIT.sh @@ -18,7 +18,7 @@ function check_clang_format() { function check_include_guard() { # Replace all folder slashes with underscores, and add "_" suffix. - guard_name=${1//[\/\.]/_}_ + guard_name=${1//[\/\.\-]/_}_ # This to-uppercase syntax is available in bash 4.0+ guard_name=${guard_name^^} diff --git a/chromium/third_party/openscreen/src/build/config/BUILD.gn b/chromium/third_party/openscreen/src/build/config/BUILD.gn index eb9bf82aefd..a4c7a1f6442 100644 --- a/chromium/third_party/openscreen/src/build/config/BUILD.gn +++ b/chromium/third_party/openscreen/src/build/config/BUILD.gn @@ -8,6 +8,12 @@ declare_args() { # Enable address sanitizer. is_asan = false + + # Enable thread sanitizer. + is_tsan = false + + # Enable trace logging. + enable_trace_logging = false } config("compiler_defaults") { @@ -110,6 +116,9 @@ config("openscreen_code") { if (dcheck_always_on) { defines += [ "DCHECK_ALWAYS_ON" ] } + if (enable_trace_logging) { + defines += [ "ENABLE_TRACE_LOGGING" ] + } } config("default_optimization") { @@ -140,10 +149,17 @@ config("symbol_visibility_default") { } config("default_sanitizers") { + # NOTE: This is not an artificial restriction; clang doesn't allow these to be + # used together. + assert(!is_asan || !is_tsan) + cflags = [] ldflags = [] if (is_asan) { cflags += [ "-fsanitize=address" ] ldflags += [ "-fsanitize=address" ] + } else if (is_tsan) { + cflags += [ "-fsanitize=thread" ] + ldflags += [ "-fsanitize=thread" ] } } diff --git a/chromium/third_party/openscreen/src/build/config/BUILDCONFIG.gn b/chromium/third_party/openscreen/src/build/config/BUILDCONFIG.gn index 64ab0921f89..5a8724da96d 100644 --- a/chromium/third_party/openscreen/src/build/config/BUILDCONFIG.gn +++ b/chromium/third_party/openscreen/src/build/config/BUILDCONFIG.gn @@ -85,14 +85,16 @@ if (is_clang) { # ============================================================================== # # Here we set the default toolchain. Currently only Mac and POSIX are defined. +host_toolchain = "" if (current_os == "chromeos" || current_os == "linux") { - set_default_toolchain("//build/toolchain/linux:linux") + host_toolchain = "//build/toolchain/linux:linux" } else if (current_os == "mac") { - set_default_toolchain("//build/toolchain/mac:clang") + host_toolchain = "//build/toolchain/mac:clang" } else { # TODO(miu): Windows, and others. assert(false, "Toolchain for current_os is not defined.") } +set_default_toolchain(host_toolchain) # ============================================================================= # OS DEFINITIONS diff --git a/chromium/third_party/openscreen/src/cast/common/BUILD.gn b/chromium/third_party/openscreen/src/cast/common/BUILD.gn index e05f97baacc..34a3f159033 100644 --- a/chromium/third_party/openscreen/src/cast/common/BUILD.gn +++ b/chromium/third_party/openscreen/src/cast/common/BUILD.gn @@ -9,6 +9,8 @@ source_set("mdns") { "mdns/mdns_constants.h", "mdns/mdns_reader.cc", "mdns/mdns_reader.h", + "mdns/mdns_receiver.cc", + "mdns/mdns_receiver.h", "mdns/mdns_records.cc", "mdns/mdns_records.h", "mdns/mdns_sender.cc", @@ -31,6 +33,7 @@ source_set("mdns_unittests") { sources = [ "mdns/mdns_reader_unittest.cc", + "mdns/mdns_receiver_unittest.cc", "mdns/mdns_records_unittest.cc", "mdns/mdns_sender_unittest.cc", "mdns/mdns_writer_unittest.cc", diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/BUILD.gn b/chromium/third_party/openscreen/src/cast/common/certificate/BUILD.gn new file mode 100644 index 00000000000..fab645ce18d --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/certificate/BUILD.gn @@ -0,0 +1,27 @@ +# Copyright 2019 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +source_set("certificate") { + sources = [ + "cast_cert_validator.cc", + "cast_cert_validator.h", + "cast_cert_validator_internal.h", + ] + public_deps = [ + "../../../third_party/boringssl", + ] +} + +source_set("unittests") { + testonly = true + sources = [ + "cast_cert_validator_unittest.cc", + ] + + deps = [ + ":certificate", + "../../../third_party/boringssl", + "../../../third_party/googletest:gtest", + ] +} diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/DEPS b/chromium/third_party/openscreen/src/cast/common/certificate/DEPS new file mode 100644 index 00000000000..d6b038c4910 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/certificate/DEPS @@ -0,0 +1,7 @@ +# Copyright 2019 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +include_rules = [ + '+openssl' +] diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator.cc b/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator.cc new file mode 100644 index 00000000000..b5e30d37a78 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator.cc @@ -0,0 +1,718 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/common/certificate/cast_cert_validator.h" + +#include <openssl/digest.h> +#include <openssl/x509.h> +#include <openssl/x509v3.h> +#include <stddef.h> +#include <stdint.h> +#include <string.h> + +#include <algorithm> +#include <memory> +#include <utility> + +#include "cast/common/certificate/cast_cert_validator_internal.h" + +namespace cast { +namespace certificate { +namespace { + +using CastCertError = openscreen::Error::Code; + +// ------------------------------------------------------------------------- +// Cast trust anchors. +// ------------------------------------------------------------------------- + +// There are two trusted roots for Cast certificate chains: +// +// (1) CN=Cast Root CA (kCastRootCaDer) +// (2) CN=Eureka Root CA (kEurekaRootCaDer) +// +// These constants are defined by the files included next: + +#include "cast/common/certificate/cast_root_ca_cert_der-inc.h" +#include "cast/common/certificate/eureka_root_ca_der-inc.h" + +constexpr static int32_t kMinRsaModulusLengthBits = 2048; + +// Adds a trust anchor given a DER-encoded certificate from static +// storage. +template <size_t N> +bssl::UniquePtr<X509> MakeTrustAnchor(const uint8_t (&data)[N]) { + const uint8_t* dptr = data; + return bssl::UniquePtr<X509>{d2i_X509(nullptr, &dptr, N)}; +} + +// Stores intermediate state while attempting to find a valid certificate chain +// from a set of trusted certificates to a target certificate. Together, a +// sequence of these forms a certificate chain to be verified as well as a stack +// that can be unwound for searching more potential paths. +struct CertPathStep { + X509* cert; + + // The next index that can be checked in |trust_store| if the choice |cert| on + // the path needs to be reverted. + uint32_t trust_store_index; + + // The next index that can be checked in |intermediate_certs| if the choice + // |cert| on the path needs to be reverted. + uint32_t intermediate_cert_index; +}; + +// These values are bit positions from RFC 5280 4.2.1.3 and will be passed to +// ASN1_BIT_STRING_get_bit. +enum KeyUsageBits { + kDigitalSignature = 0, + kKeyCertSign = 5, +}; + +bool VerifySignedData(const EVP_MD* digest, + EVP_PKEY* public_key, + const ConstDataSpan& data, + const ConstDataSpan& signature) { + // This code assumes the signature algorithm was RSASSA PKCS#1 v1.5 with + // |digest|. + bssl::ScopedEVP_MD_CTX ctx; + if (!EVP_DigestVerifyInit(ctx.get(), nullptr, digest, nullptr, public_key)) { + return false; + } + return (EVP_DigestVerify(ctx.get(), signature.data, signature.length, + data.data, data.length) == 1); +} + +// Returns the OID for the Audio-Only Cast policy +// (1.3.6.1.4.1.11129.2.5.2) in DER form. +const ConstDataSpan& AudioOnlyPolicyOid() { + static const uint8_t kAudioOnlyPolicy[] = {0x2B, 0x06, 0x01, 0x04, 0x01, + 0xD6, 0x79, 0x02, 0x05, 0x02}; + static ConstDataSpan kPolicySpan{kAudioOnlyPolicy, sizeof(kAudioOnlyPolicy)}; + return kPolicySpan; +} + +class CertVerificationContextImpl final : public CertVerificationContext { + public: + CertVerificationContextImpl(bssl::UniquePtr<EVP_PKEY>&& cert, + std::string&& common_name) + : public_key_{std::move(cert)}, common_name_(std::move(common_name)) {} + + ~CertVerificationContextImpl() override = default; + + bool VerifySignatureOverData( + const ConstDataSpan& signature, + const ConstDataSpan& data, + DigestAlgorithm digest_algorithm) const override { + const EVP_MD* digest; + switch (digest_algorithm) { + case DigestAlgorithm::kSha1: + digest = EVP_sha1(); + break; + case DigestAlgorithm::kSha256: + digest = EVP_sha256(); + break; + case DigestAlgorithm::kSha384: + digest = EVP_sha384(); + break; + case DigestAlgorithm::kSha512: + digest = EVP_sha512(); + break; + default: + return false; + } + + return VerifySignedData(digest, public_key_.get(), data, signature); + } + + const std::string& GetCommonName() const override { return common_name_; } + + private: + const bssl::UniquePtr<EVP_PKEY> public_key_; + const std::string common_name_; +}; + +bool CertInPath(X509_NAME* name, + const std::vector<CertPathStep>& steps, + uint32_t start, + uint32_t stop) { + for (uint32_t i = start; i < stop; ++i) { + if (X509_NAME_cmp(name, X509_get_subject_name(steps[i].cert)) == 0) { + return true; + } + } + return false; +} + +uint8_t ParseAsn1TimeDoubleDigit(ASN1_GENERALIZEDTIME* time, int index) { + return (time->data[index] - '0') * 10 + (time->data[index + 1] - '0'); +} + +// Parses DateTime with additional restrictions laid out by RFC 5280 +// 4.1.2.5.2. +bool ParseAsn1GeneralizedTime(ASN1_GENERALIZEDTIME* time, DateTime* out) { + static constexpr uint8_t kDaysPerMonth[] = { + 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31, + }; + + if (time->length != 15) { + return false; + } + if (time->data[14] != 'Z') { + return false; + } + for (int i = 0; i < 14; ++i) { + if (time->data[i] < '0' || time->data[i] > '9') { + return false; + } + } + out->year = ParseAsn1TimeDoubleDigit(time, 0) * 100 + + ParseAsn1TimeDoubleDigit(time, 2); + out->month = ParseAsn1TimeDoubleDigit(time, 4); + out->day = ParseAsn1TimeDoubleDigit(time, 6); + out->hour = ParseAsn1TimeDoubleDigit(time, 8); + out->minute = ParseAsn1TimeDoubleDigit(time, 10); + out->second = ParseAsn1TimeDoubleDigit(time, 12); + if (out->month == 0 || out->month > 12) { + return false; + } + int days_per_month = kDaysPerMonth[out->month - 1]; + if (out->month == 2) { + if (out->year % 4 == 0 && (out->year % 100 != 0 || out->year % 400 == 0)) { + days_per_month = 29; + } else { + days_per_month = 28; + } + } + if (out->day == 0 || out->day > days_per_month) { + return false; + } + if (out->hour > 23) { + return false; + } + if (out->minute > 59) { + return false; + } + // Leap seconds are allowed. + if (out->second > 60) { + return false; + } + return true; +} + +bool IsDateTimeBefore(const DateTime& a, const DateTime& b) { + if (a.year < b.year) { + return true; + } else if (a.year > b.year) { + return false; + } + if (a.month < b.month) { + return true; + } else if (a.month > b.month) { + return false; + } + if (a.day < b.day) { + return true; + } else if (a.day > b.day) { + return false; + } + if (a.hour < b.hour) { + return true; + } else if (a.hour > b.hour) { + return false; + } + if (a.minute < b.minute) { + return true; + } else if (a.minute > b.minute) { + return false; + } + if (a.second < b.second) { + return true; + } else if (a.second > b.second) { + return false; + } + return false; +} + +CastCertError VerifyCertTime(X509* cert, const DateTime& time) { + ASN1_GENERALIZEDTIME* not_before_asn1 = ASN1_TIME_to_generalizedtime( + cert->cert_info->validity->notBefore, nullptr); + ASN1_GENERALIZEDTIME* not_after_asn1 = ASN1_TIME_to_generalizedtime( + cert->cert_info->validity->notAfter, nullptr); + if (!not_before_asn1 || !not_after_asn1) { + return CastCertError::kErrCertsVerifyGeneric; + } + DateTime not_before; + DateTime not_after; + bool times_valid = ParseAsn1GeneralizedTime(not_before_asn1, ¬_before) && + ParseAsn1GeneralizedTime(not_after_asn1, ¬_after); + ASN1_GENERALIZEDTIME_free(not_before_asn1); + ASN1_GENERALIZEDTIME_free(not_after_asn1); + if (!times_valid) { + return CastCertError::kErrCertsVerifyGeneric; + } + if (IsDateTimeBefore(time, not_before) || IsDateTimeBefore(not_after, time)) { + return CastCertError::kErrCertsDateInvalid; + } + return CastCertError::kNone; +} + +bool VerifyPublicKeyLength(EVP_PKEY* public_key) { + return EVP_PKEY_bits(public_key) >= kMinRsaModulusLengthBits; +} + +bssl::UniquePtr<ASN1_BIT_STRING> GetKeyUsage(X509* cert) { + int pos = X509_get_ext_by_NID(cert, NID_key_usage, -1); + if (pos == -1) { + return nullptr; + } + X509_EXTENSION* key_usage = X509_get_ext(cert, pos); + const uint8_t* value = key_usage->value->data; + ASN1_BIT_STRING* key_usage_bit_string = nullptr; + if (!d2i_ASN1_BIT_STRING(&key_usage_bit_string, &value, + key_usage->value->length)) { + return nullptr; + } + return bssl::UniquePtr<ASN1_BIT_STRING>{key_usage_bit_string}; +} + +CastCertError VerifyCertificateChain(const std::vector<CertPathStep>& path, + uint32_t step_index, + const DateTime& time) { + // Default max path length is the number of intermediate certificates. + int max_pathlen = path.size() - 2; + + std::vector<NAME_CONSTRAINTS*> path_name_constraints; + CastCertError error = CastCertError::kNone; + uint32_t i = step_index; + for (; i < path.size() - 1; ++i) { + X509* subject = path[i + 1].cert; + X509* issuer = path[i].cert; + bool is_root = (i == step_index); + if (!is_root) { + if ((error = VerifyCertTime(issuer, time)) != CastCertError::kNone) { + return error; + } + if (X509_NAME_cmp(X509_get_subject_name(issuer), + X509_get_issuer_name(issuer)) != 0) { + if (max_pathlen == 0) { + return CastCertError::kErrCertsPathlen; + } + --max_pathlen; + } else { + issuer->ex_flags |= EXFLAG_SI; + } + } else { + issuer->ex_flags |= EXFLAG_SI; + } + + bssl::UniquePtr<ASN1_BIT_STRING> key_usage = GetKeyUsage(issuer); + if (key_usage) { + const int bit = + ASN1_BIT_STRING_get_bit(key_usage.get(), KeyUsageBits::kKeyCertSign); + if (bit == 0) { + return CastCertError::kErrCertsVerifyGeneric; + } + } + + // Check that basicConstraints is present, specifies the CA bit, and use + // pathLenConstraint if present. + const int basic_constraints_index = + X509_get_ext_by_NID(issuer, NID_basic_constraints, -1); + if (basic_constraints_index == -1) { + return CastCertError::kErrCertsVerifyGeneric; + } + X509_EXTENSION* const basic_constraints_extension = + X509_get_ext(issuer, basic_constraints_index); + bssl::UniquePtr<BASIC_CONSTRAINTS> basic_constraints{ + reinterpret_cast<BASIC_CONSTRAINTS*>( + X509V3_EXT_d2i(basic_constraints_extension))}; + + if (!basic_constraints || !basic_constraints->ca) { + return CastCertError::kErrCertsVerifyGeneric; + } + + if (basic_constraints->pathlen) { + if (basic_constraints->pathlen->length != 1) { + return CastCertError::kErrCertsVerifyGeneric; + } else { + const int pathlen = *basic_constraints->pathlen->data; + if (pathlen < 0) { + return CastCertError::kErrCertsVerifyGeneric; + } + if (pathlen < max_pathlen) { + max_pathlen = pathlen; + } + } + } + + if (X509_ALGOR_cmp(issuer->sig_alg, issuer->cert_info->signature) != 0) { + return CastCertError::kErrCertsVerifyGeneric; + } + + bssl::UniquePtr<EVP_PKEY> public_key{X509_get_pubkey(issuer)}; + if (!VerifyPublicKeyLength(public_key.get())) { + return CastCertError::kErrCertsVerifyGeneric; + } + + // NOTE: (!self-issued || target) -> verify name constraints. Target case + // is after the loop. + const bool is_self_issued = issuer->ex_flags & EXFLAG_SI; + if (!is_self_issued) { + for (NAME_CONSTRAINTS* name_constraints : path_name_constraints) { + if (NAME_CONSTRAINTS_check(subject, name_constraints) != X509_V_OK) { + return CastCertError::kErrCertsVerifyGeneric; + } + } + } + + if (issuer->nc) { + path_name_constraints.push_back(issuer->nc); + } else { + const int index = X509_get_ext_by_NID(issuer, NID_name_constraints, -1); + if (index != -1) { + X509_EXTENSION* ext = X509_get_ext(issuer, index); + auto* nc = reinterpret_cast<NAME_CONSTRAINTS*>(X509V3_EXT_d2i(ext)); + if (nc) { + issuer->nc = nc; + path_name_constraints.push_back(nc); + } else { + return CastCertError::kErrCertsVerifyGeneric; + } + } + } + + // Check that any policy mappings present are _not_ the anyPolicy OID. Even + // though we don't otherwise handle policies, this is required by RFC 5280 + // 6.1.4(a). + const int policy_mappings_index = + X509_get_ext_by_NID(issuer, NID_policy_mappings, -1); + if (policy_mappings_index != -1) { + X509_EXTENSION* policy_mappings_extension = + X509_get_ext(issuer, policy_mappings_index); + auto* policy_mappings = reinterpret_cast<POLICY_MAPPINGS*>( + X509V3_EXT_d2i(policy_mappings_extension)); + const uint32_t policy_mapping_count = + sk_POLICY_MAPPING_num(policy_mappings); + const ASN1_OBJECT* any_policy = OBJ_nid2obj(NID_any_policy); + for (uint32_t i = 0; i < policy_mapping_count; ++i) { + POLICY_MAPPING* policy_mapping = + sk_POLICY_MAPPING_value(policy_mappings, i); + const bool either_matches = + ((OBJ_cmp(policy_mapping->issuerDomainPolicy, any_policy) == 0) || + (OBJ_cmp(policy_mapping->subjectDomainPolicy, any_policy) == 0)); + if (either_matches) { + error = CastCertError::kErrCertsVerifyGeneric; + break; + } + } + sk_POLICY_MAPPING_free(policy_mappings); + if (error != CastCertError::kNone) { + return error; + } + } + + // Check that we don't have any unhandled extensions marked as critical. + int extension_count = X509_get_ext_count(issuer); + for (int i = 0; i < extension_count; ++i) { + X509_EXTENSION* extension = X509_get_ext(issuer, i); + if (extension->critical > 0) { + const int nid = OBJ_obj2nid(extension->object); + if (nid != NID_name_constraints && nid != NID_basic_constraints && + nid != NID_key_usage) { + return CastCertError::kErrCertsVerifyGeneric; + } + } + } + + int nid = OBJ_obj2nid(subject->sig_alg->algorithm); + const EVP_MD* digest; + switch (nid) { + case NID_sha1WithRSAEncryption: + digest = EVP_sha1(); + break; + case NID_sha256WithRSAEncryption: + digest = EVP_sha256(); + break; + case NID_sha384WithRSAEncryption: + digest = EVP_sha384(); + break; + case NID_sha512WithRSAEncryption: + digest = EVP_sha512(); + break; + default: + return CastCertError::kErrCertsVerifyGeneric; + } + if (!VerifySignedData( + digest, public_key.get(), + {subject->cert_info->enc.enc, + static_cast<uint32_t>(subject->cert_info->enc.len)}, + {subject->signature->data, + static_cast<uint32_t>(subject->signature->length)})) { + return CastCertError::kErrCertsVerifyGeneric; + } + } + // NOTE: Other half of ((!self-issued || target) -> check name constraints). + for (NAME_CONSTRAINTS* name_constraints : path_name_constraints) { + if (NAME_CONSTRAINTS_check(path.back().cert, name_constraints) != + X509_V_OK) { + return CastCertError::kErrCertsVerifyGeneric; + } + } + return error; +} + +X509* ParseX509Der(const std::string& der) { + const uint8_t* data = reinterpret_cast<const uint8_t*>(der.data()); + return d2i_X509(nullptr, &data, der.size()); +} + +CastDeviceCertPolicy GetAudioPolicy(const std::vector<CertPathStep>& path, + uint32_t path_index) { + // Cast device certificates use the policy 1.3.6.1.4.1.11129.2.5.2 to indicate + // it is *restricted* to an audio-only device whereas the absence of a policy + // means it is unrestricted. + // + // This is somewhat different than RFC 5280's notion of policies, so policies + // are checked separately outside of path building. + // + // See the unit-tests VerifyCastDeviceCertTest.Policies* for some + // concrete examples of how this works. + // + // Iterate over all the certificates, including the root certificate. If any + // certificate contains the audio-only policy, the whole chain is considered + // constrained to audio-only device certificates. + // + // Policy mappings are not accounted for. The expectation is that top-level + // intermediates issued with audio-only will have no mappings. If subsequent + // certificates in the chain do, it won't matter as the chain is already + // restricted to being audio-only. + CastDeviceCertPolicy policy = CastDeviceCertPolicy::kUnrestricted; + for (uint32_t i = path_index; + i < path.size() && policy != CastDeviceCertPolicy::kAudioOnly; ++i) { + X509* cert = path[path.size() - 1 - i].cert; + int pos = X509_get_ext_by_NID(cert, NID_certificate_policies, -1); + if (pos != -1) { + X509_EXTENSION* policies_extension = X509_get_ext(cert, pos); + const uint8_t* in = policies_extension->value->data; + CERTIFICATEPOLICIES* policies = d2i_CERTIFICATEPOLICIES( + nullptr, &in, policies_extension->value->length); + + if (policies) { + // Check for |audio_only_policy_oid| in the set of policies. + uint32_t policy_count = sk_POLICYINFO_num(policies); + for (uint32_t i = 0; i < policy_count; ++i) { + POLICYINFO* info = sk_POLICYINFO_value(policies, i); + const ConstDataSpan& audio_only_policy_oid = AudioOnlyPolicyOid(); + if (info->policyid->length == + static_cast<int>(audio_only_policy_oid.length) && + memcmp(info->policyid->data, audio_only_policy_oid.data, + audio_only_policy_oid.length) == 0) { + policy = CastDeviceCertPolicy::kAudioOnly; + break; + } + } + CERTIFICATEPOLICIES_free(policies); + } + } + } + return policy; +} + +} // namespace + +// static +CastTrustStore* CastTrustStore::GetInstance() { + static CastTrustStore* store = new CastTrustStore(); + return store; +} + +CastTrustStore::CastTrustStore() : trust_store_(new TrustStore()) { + trust_store_->certs.emplace_back(MakeTrustAnchor(kCastRootCaDer)); + trust_store_->certs.emplace_back(MakeTrustAnchor(kEurekaRootCaDer)); +} + +CastTrustStore::~CastTrustStore() = default; + +openscreen::Error VerifyDeviceCert( + const std::vector<std::string>& der_certs, + const DateTime& time, + std::unique_ptr<CertVerificationContext>* context, + CastDeviceCertPolicy* policy, + const CastCRL* crl, + CRLPolicy crl_policy, + TrustStore* trust_store) { + if (!trust_store) { + trust_store = CastTrustStore::GetInstance()->trust_store(); + } + + if (der_certs.empty()) { + return CastCertError::kErrCertsMissing; + } + + // Fail early if CRL is required but not provided. + if (!crl && crl_policy == CRLPolicy::kCrlRequired) { + return CastCertError::kErrCrlInvalid; + } + + bssl::UniquePtr<X509> target_cert; + std::vector<bssl::UniquePtr<X509>> intermediate_certs; + target_cert.reset(ParseX509Der(der_certs[0])); + if (!target_cert) { + return CastCertError::kErrCertsParse; + } + for (size_t i = 1; i < der_certs.size(); ++i) { + intermediate_certs.emplace_back(ParseX509Der(der_certs[i])); + if (!intermediate_certs.back()) { + return CastCertError::kErrCertsParse; + } + } + + // Basic checks on the target certificate. + CastCertError error = VerifyCertTime(target_cert.get(), time); + if (error != CastCertError::kNone) { + return error; + } + bssl::UniquePtr<EVP_PKEY> public_key{X509_get_pubkey(target_cert.get())}; + if (!VerifyPublicKeyLength(public_key.get())) { + return CastCertError::kErrCertsVerifyGeneric; + } + if (X509_ALGOR_cmp(target_cert.get()->sig_alg, + target_cert.get()->cert_info->signature) != 0) { + return CastCertError::kErrCertsVerifyGeneric; + } + bssl::UniquePtr<ASN1_BIT_STRING> key_usage = GetKeyUsage(target_cert.get()); + if (!key_usage) { + return CastCertError::kErrCertsRestrictions; + } + int bit = + ASN1_BIT_STRING_get_bit(key_usage.get(), KeyUsageBits::kDigitalSignature); + if (bit == 0) { + return CastCertError::kErrCertsRestrictions; + } + + X509* path_head = target_cert.get(); + std::vector<CertPathStep> path; + + // This vector isn't used as resizable, so instead we allocate the largest + // possible single path up front. This would be a single trusted cert, all + // the intermediate certs used once, and the target cert. + path.resize(1 + intermediate_certs.size() + 1); + + // Additionally, the path is slightly simpler to deal with if the list is + // sorted from trust->target, so the path is actually built starting from the + // end. + uint32_t first_index = path.size() - 1; + path[first_index].cert = path_head; + + // Index into |path| of the current frontier of path construction. + uint32_t path_index = first_index; + + // Whether |path| has reached a certificate in |trust_store| and is ready for + // verification. + bool path_cert_in_trust_store = false; + + // Attempt to build a valid certificate chain from |target_cert| to a + // certificate in |trust_store|. This loop tries all possible paths in a + // depth-first-search fashion. If no valid paths are found, the error + // returned is whatever the last error was from the last path tried. + uint32_t trust_store_index = 0; + uint32_t intermediate_cert_index = 0; + CastCertError last_error = CastCertError::kNone; + for (;;) { + X509_NAME* target_issuer_name = X509_get_issuer_name(path_head); + + // The next issuer certificate to add to the current path. + X509* next_issuer = nullptr; + + for (uint32_t i = trust_store_index; i < trust_store->certs.size(); ++i) { + X509* trust_store_cert = trust_store->certs[i].get(); + X509_NAME* trust_store_cert_name = + X509_get_subject_name(trust_store_cert); + if (X509_NAME_cmp(trust_store_cert_name, target_issuer_name) == 0) { + CertPathStep& next_step = path[--path_index]; + next_step.cert = trust_store_cert; + next_step.trust_store_index = i + 1; + next_step.intermediate_cert_index = 0; + next_issuer = trust_store_cert; + path_cert_in_trust_store = true; + break; + } + } + trust_store_index = 0; + if (!next_issuer) { + for (uint32_t i = intermediate_cert_index; i < intermediate_certs.size(); + ++i) { + X509* intermediate_cert = intermediate_certs[i].get(); + X509_NAME* intermediate_cert_name = + X509_get_subject_name(intermediate_cert); + if (X509_NAME_cmp(intermediate_cert_name, target_issuer_name) == 0 && + !CertInPath(intermediate_cert_name, path, path_index, + first_index)) { + CertPathStep& next_step = path[--path_index]; + next_step.cert = intermediate_cert; + next_step.trust_store_index = trust_store->certs.size(); + next_step.intermediate_cert_index = i + 1; + next_issuer = intermediate_cert; + break; + } + } + } + intermediate_cert_index = 0; + if (!next_issuer) { + if (path_index == first_index) { + // There are no more paths to try. Ensure an error is returned. + if (last_error == CastCertError::kNone) { + return CastCertError::kErrCertsVerifyGeneric; + } + return last_error; + } else { + CertPathStep& last_step = path[path_index++]; + trust_store_index = last_step.trust_store_index; + intermediate_cert_index = last_step.intermediate_cert_index; + continue; + } + } + + // TODO(btolsch): Check against revocation list + if (path_cert_in_trust_store) { + last_error = VerifyCertificateChain(path, path_index, time); + if (last_error != CastCertError::kNone) { + CertPathStep& last_step = path[path_index++]; + trust_store_index = last_step.trust_store_index; + intermediate_cert_index = last_step.intermediate_cert_index; + path_cert_in_trust_store = false; + } else { + break; + } + } + path_head = next_issuer; + } + + if (last_error != CastCertError::kNone) { + return last_error; + } + + *policy = GetAudioPolicy(path, path_index); + + // Finally, make sure there is a common name to give to + // CertVerificationContextImpl. + X509_NAME* target_subject = X509_get_subject_name(target_cert.get()); + std::string common_name(target_subject->canon_enclen, 0); + int len = X509_NAME_get_text_by_NID(target_subject, NID_commonName, + &common_name[0], common_name.size()); + if (len == 0) { + return CastCertError::kErrCertsRestrictions; + } + common_name.resize(len); + + context->reset(new CertVerificationContextImpl( + bssl::UniquePtr<EVP_PKEY>{X509_get_pubkey(target_cert.get())}, + std::move(common_name))); + + return CastCertError::kNone; +} + +} // namespace certificate +} // namespace cast diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator.h b/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator.h new file mode 100644 index 00000000000..7f0da1d0de1 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator.h @@ -0,0 +1,147 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_COMMON_CERTIFICATE_CAST_CERT_VALIDATOR_H_ +#define CAST_COMMON_CERTIFICATE_CAST_CERT_VALIDATOR_H_ + +#include <memory> +#include <string> +#include <vector> + +#include "platform/base/error.h" +#include "platform/base/macros.h" + +namespace cast { +namespace certificate { + +class CastCRL; + +// Describes the policy for a Device certificate. +enum class CastDeviceCertPolicy { + // The device certificate is unrestricted. + kUnrestricted, + + // The device certificate is for an audio-only device. + kAudioOnly, +}; + +enum class CRLPolicy { + // Revocation is only checked if a CRL is provided. + kCrlOptional, + + // Revocation is always checked. A missing CRL results in failure. + kCrlRequired, +}; + +enum class DigestAlgorithm { + kSha1, + kSha256, + kSha384, + kSha512, +}; + +struct ConstDataSpan { + const uint8_t* data; + uint32_t length; +}; + +struct DateTime { + uint16_t year; + uint8_t month; + uint8_t day; + uint8_t hour; + uint8_t minute; + uint8_t second; +}; + +struct TrustStore; + +class CastTrustStore { + public: + // Singleton for the Cast trust store for legacy networkingPrivate use. + static CastTrustStore* GetInstance(); + + CastTrustStore(); + ~CastTrustStore(); + + TrustStore* trust_store() const { return trust_store_.get(); } + + private: + std::unique_ptr<TrustStore> trust_store_; + OSP_DISALLOW_COPY_AND_ASSIGN(CastTrustStore); +}; + +// An object of this type is returned by the VerifyDeviceCert function, and can +// be used for additional certificate-related operations, using the verified +// certificate. +class CertVerificationContext { + public: + CertVerificationContext() = default; + virtual ~CertVerificationContext() = default; + + // Use the public key from the verified certificate to verify a + // |digest_algorithm|WithRSAEncryption |signature| over arbitrary |data|. + // Both |signature| and |data| hold raw binary data. Returns true if the + // signature was correct. + virtual bool VerifySignatureOverData( + const ConstDataSpan& signature, + const ConstDataSpan& data, + DigestAlgorithm digest_algorithm) const = 0; + + // Retrieve the Common Name attribute of the subject's distinguished name from + // the verified certificate, if present. Returns an empty string if no Common + // Name is found. + virtual const std::string& GetCommonName() const = 0; + + private: + OSP_DISALLOW_COPY_AND_ASSIGN(CertVerificationContext); +}; + +// Verifies a cast device certficate given a chain of DER-encoded certificates. +// +// Inputs: +// +// * |der_certs| is a chain of DER-encoded certificates: +// * |der_certs[0]| is the target certificate (i.e. the device certificate). +// * |der_certs[1..n-1]| are intermediates certificates to use in path +// building. Their ordering does not matter. +// +// * |time| is the timestamp to use for determining if the certificate is +// expired. +// +// * |crl| is the CRL to check for certificate revocation status. +// If this is a nullptr, then revocation checking is currently disabled. +// +// * |crl_policy| is for choosing how to handle the absence of a CRL. +// If CRL_REQUIRED is passed, then an empty |crl| input would result +// in a failed verification. Otherwise, |crl| is ignored if it is absent. +// +// * |trust_store| is an optional set of trusted certificates that may act as +// root CAs during chain verification. If this is nullptr, the built-in Cast +// root certificates will be used. +// +// Outputs: +// +// Returns openscreen::Error::Code::kNone on success. Otherwise, the +// corresponding openscreen::Error::Code. On success, the output parameters are +// filled with more details: +// +// * |context| is filled with an object that can be used to verify signatures +// using the device certificate's public key, as well as to extract other +// properties from the device certificate (Common Name). +// * |policy| is filled with an indication of the device certificate's policy +// (i.e. is it for audio-only devices or is it unrestricted?) +MAYBE_NODISCARD openscreen::Error VerifyDeviceCert( + const std::vector<std::string>& der_certs, + const DateTime& time, + std::unique_ptr<CertVerificationContext>* context, + CastDeviceCertPolicy* policy, + const CastCRL* crl, + CRLPolicy crl_policy, + TrustStore* trust_store = nullptr); + +} // namespace certificate +} // namespace cast + +#endif // CAST_COMMON_CERTIFICATE_CAST_CERT_VALIDATOR_H_ diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator_internal.h b/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator_internal.h new file mode 100644 index 00000000000..91779a85243 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator_internal.h @@ -0,0 +1,22 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_COMMON_CERTIFICATE_CAST_CERT_VALIDATOR_INTERNAL_H_ +#define CAST_COMMON_CERTIFICATE_CAST_CERT_VALIDATOR_INTERNAL_H_ + +#include <openssl/x509.h> + +#include <vector> + +namespace cast { +namespace certificate { + +struct TrustStore { + std::vector<bssl::UniquePtr<X509>> certs; +}; + +} // namespace certificate +} // namespace cast + +#endif // CAST_COMMON_CERTIFICATE_CAST_CERT_VALIDATOR_INTERNAL_H_ diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator_unittest.cc b/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator_unittest.cc new file mode 100644 index 00000000000..8c23923943a --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/certificate/cast_cert_validator_unittest.cc @@ -0,0 +1,643 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/common/certificate/cast_cert_validator.h" + +#include <stdio.h> + +#include "cast/common/certificate/cast_cert_validator_internal.h" +#include "gtest/gtest.h" +#include "openssl/pem.h" + +namespace cast { +namespace certificate { +namespace { + +using CastCertError = openscreen::Error::Code; + +enum TrustStoreDependency { + // Uses the built-in trust store for Cast. This is how certificates are + // verified in production. + TRUST_STORE_BUILTIN, + + // Instead of using the built-in trust store, use root certificate in the + // provided test chain as the trust anchor. + // + // This trust anchor is initialized with anchor constraints, similar to how + // TrustAnchors in the built-in store are setup. + TRUST_STORE_FROM_TEST_FILE, +}; + +CastTrustStore g_cast_trust_store; + +// Reads a test chain from |certs_file_name|, and asserts that verifying it as +// a Cast device certificate yields |expected_result|. +// +// RunTest() also checks that the resulting CertVerificationContext does not +// incorrectly verify invalid signatures. +// +// * |expected_policy| - The policy that should have been identified for the +// device certificate. +// * |time| - The timestamp to use when verifying the certificate. +// * |trust_store_dependency| - Which trust store to use when verifying (see +// enum's definition). +// * |optional_signed_data_file_name| - optional path to a PEM file containing +// a valid signature generated by the device certificate. +// +void RunTest(CastCertError expected_result, + const std::string& expected_common_name, + CastDeviceCertPolicy expected_policy, + const std::string& certs_file_name, + const DateTime& time, + TrustStoreDependency trust_store_dependency, + const std::string& optional_signed_data_file_name) { + FILE* fp = fopen(certs_file_name.c_str(), "r"); + ASSERT_TRUE(fp); + std::vector<std::string> certs; +#define STRCMP_LITERAL(s, l) strncmp(s, l, sizeof(l)) + for (;;) { + char* name; + char* header; + unsigned char* data; + long length; + if (PEM_read(fp, &name, &header, &data, &length) == 1) { + if (STRCMP_LITERAL(name, "CERTIFICATE") == 0) { + certs.emplace_back((char*)data, length); + } + OPENSSL_free(name); + OPENSSL_free(header); + OPENSSL_free(data); + } else { + break; + } + } + fclose(fp); + + TrustStore* trust_store; + std::unique_ptr<TrustStore> fake_trust_store; + + switch (trust_store_dependency) { + case TRUST_STORE_BUILTIN: + trust_store = g_cast_trust_store.trust_store(); + break; + + case TRUST_STORE_FROM_TEST_FILE: { + ASSERT_FALSE(certs.empty()); + + // Parse the root certificate of the chain. + const uint8_t* data = (const uint8_t*)certs.back().data(); + X509* fake_root = d2i_X509(nullptr, &data, certs.back().size()); + ASSERT_TRUE(fake_root); + certs.pop_back(); + + // Add a trust anchor and enforce constraints on it (regular mode for + // built-in Cast roots). + fake_trust_store = std::make_unique<TrustStore>(); + fake_trust_store->certs.emplace_back(fake_root); + trust_store = fake_trust_store.get(); + } + } + + std::unique_ptr<CertVerificationContext> context; + CastDeviceCertPolicy policy; + + openscreen::Error result = + VerifyDeviceCert(certs, time, &context, &policy, nullptr, + CRLPolicy::kCrlOptional, trust_store); + + ASSERT_EQ(expected_result, result.code()); + if (expected_result != CastCertError::kNone) + return; + + EXPECT_EQ(expected_policy, policy); + ASSERT_TRUE(context); + + // Test that the context is good. + EXPECT_EQ(expected_common_name, context->GetCommonName()); + +#define DATA_SPAN_FROM_LITERAL(s) ConstDataSpan{(uint8_t*)s, sizeof(s) - 1} + // Test verification of some invalid signatures. + EXPECT_FALSE(context->VerifySignatureOverData( + DATA_SPAN_FROM_LITERAL("bogus signature"), + DATA_SPAN_FROM_LITERAL("bogus data"), DigestAlgorithm::kSha256)); + EXPECT_FALSE(context->VerifySignatureOverData( + DATA_SPAN_FROM_LITERAL(""), DATA_SPAN_FROM_LITERAL("bogus data"), + DigestAlgorithm::kSha256)); + EXPECT_FALSE(context->VerifySignatureOverData(DATA_SPAN_FROM_LITERAL(""), + DATA_SPAN_FROM_LITERAL(""), + DigestAlgorithm::kSha256)); + + // If valid signatures are known for this device certificate, test them. + if (!optional_signed_data_file_name.empty()) { + FILE* fp = fopen(optional_signed_data_file_name.c_str(), "r"); + ASSERT_TRUE(fp); + ConstDataSpan message = {}; + ConstDataSpan sha1 = {}; + ConstDataSpan sha256 = {}; + for (;;) { + char* name; + char* header; + unsigned char* data; + long length; + if (PEM_read(fp, &name, &header, &data, &length) == 1) { + if (STRCMP_LITERAL(name, "MESSAGE") == 0) { + ASSERT_FALSE(message.data); + message.data = data; + message.length = length; + } else if (STRCMP_LITERAL(name, "SIGNATURE SHA1") == 0) { + ASSERT_FALSE(sha1.data); + sha1.data = data; + sha1.length = length; + } else if (STRCMP_LITERAL(name, "SIGNATURE SHA256") == 0) { + ASSERT_FALSE(sha256.data); + sha256.data = data; + sha256.length = length; + } else { + OPENSSL_free(data); + } + OPENSSL_free(name); + OPENSSL_free(header); + } else { + break; + } + } + ASSERT_TRUE(message.data); + ASSERT_TRUE(sha1.data); + ASSERT_TRUE(sha256.data); + + // Test verification of a valid SHA1 signature. + EXPECT_TRUE(context->VerifySignatureOverData(sha1, message, + DigestAlgorithm::kSha1)); + + // Test verification of a valid SHA256 signature. + EXPECT_TRUE(context->VerifySignatureOverData(sha256, message, + DigestAlgorithm::kSha256)); + + OPENSSL_free((uint8_t*)message.data); + OPENSSL_free((uint8_t*)sha1.data); + OPENSSL_free((uint8_t*)sha256.data); + } +} + +// Creates a time in UTC at midnight. +DateTime CreateDate(int year, int month, int day) { + DateTime time = {}; + time.year = year; + time.month = month; + time.day = day; + return time; +} + +// Returns 2016-04-01 00:00:00 UTC. +// +// This is a time when most of the test certificate paths are valid. +DateTime AprilFirst2016() { + return CreateDate(2016, 4, 1); +} + +DateTime AprilFirst2020() { + return CreateDate(2020, 4, 1); +} + +// Returns 2015-01-01 00:00:00 UTC. +DateTime JanuaryFirst2015() { + return CreateDate(2015, 1, 1); +} + +// Returns 2037-03-01 00:00:00 UTC. +// +// This is so far in the future that the test chains in this unit-test should +// all be invalid. +DateTime MarchFirst2037() { + return CreateDate(2037, 3, 1); +} + +#define TEST_DATA_PREFIX "test/data/cast/common/certificate/" + +// Tests verifying a valid certificate chain of length 2: +// +// 0: 2ZZBG9 FA8FCA3EF91A +// 1: Eureka Gen1 ICA +// +// Chains to trust anchor: +// Eureka Root CA (built-in trust store) +TEST(VerifyCastDeviceCertTest, ChromecastGen1) { + RunTest(CastCertError::kNone, "2ZZBG9 FA8FCA3EF91A", + CastDeviceCertPolicy::kUnrestricted, + TEST_DATA_PREFIX "certificates/chromecast_gen1.pem", AprilFirst2016(), + TRUST_STORE_BUILTIN, + TEST_DATA_PREFIX "signeddata/2ZZBG9_FA8FCA3EF91A.pem"); +} + +// Tests verifying a valid certificate chain of length 2: +// +// 0: 2ZZBG9 FA8FCA3EF91A +// 1: Eureka Gen1 ICA +// +// Chains to trust anchor: +// Cast Root CA (built-in trust store) +TEST(VerifyCastDeviceCertTest, ChromecastGen1Reissue) { + RunTest(CastCertError::kNone, "2ZZBG9 FA8FCA3EF91A", + CastDeviceCertPolicy::kUnrestricted, + TEST_DATA_PREFIX "certificates/chromecast_gen1_reissue.pem", + AprilFirst2016(), TRUST_STORE_BUILTIN, + TEST_DATA_PREFIX "signeddata/2ZZBG9_FA8FCA3EF91A.pem"); +} + +// Tests verifying a valid certificate chain of length 2: +// +// 0: 3ZZAK6 FA8FCA3F0D35 +// 1: Chromecast ICA 3 +// +// Chains to trust anchor: +// Cast Root CA (built-in trust store) +TEST(VerifyCastDeviceCertTest, ChromecastGen2) { + RunTest(CastCertError::kNone, "3ZZAK6 FA8FCA3F0D35", + CastDeviceCertPolicy::kUnrestricted, + TEST_DATA_PREFIX "certificates/chromecast_gen2.pem", AprilFirst2016(), + TRUST_STORE_BUILTIN, ""); +} + +// Tests verifying a valid certificate chain of length 3: +// +// 0: -6394818897508095075 +// 1: Asus fugu Cast ICA +// 2: Widevine Cast Subroot +// +// Chains to trust anchor: +// Cast Root CA (built-in trust store) +TEST(VerifyCastDeviceCertTest, Fugu) { + RunTest(CastCertError::kNone, "-6394818897508095075", + CastDeviceCertPolicy::kUnrestricted, + TEST_DATA_PREFIX "certificates/fugu.pem", AprilFirst2016(), + TRUST_STORE_BUILTIN, ""); +} + +// Tests verifying an invalid certificate chain of length 1: +// +// 0: Cast Test Untrusted Device +// +// Chains to: +// Cast Test Untrusted ICA (Not part of trust store) +// +// This is invalid because it does not chain to a trust anchor. +TEST(VerifyCastDeviceCertTest, Unchained) { + RunTest(CastCertError::kErrCertsVerifyGeneric, "", + CastDeviceCertPolicy::kUnrestricted, + TEST_DATA_PREFIX "certificates/unchained.pem", AprilFirst2016(), + TRUST_STORE_BUILTIN, ""); +} + +// Tests verifying one of the self-signed trust anchors (chain of length 1): +// +// 0: Cast Root CA +// +// Chains to trust anchor: +// Cast Root CA (built-in trust store) +// +// Although this is a valid and trusted certificate (it is one of the +// trust anchors after all) it fails the test as it is not a *device +// certificate*. +TEST(VerifyCastDeviceCertTest, CastRootCa) { + RunTest(CastCertError::kErrCertsRestrictions, "", + CastDeviceCertPolicy::kUnrestricted, + TEST_DATA_PREFIX "certificates/cast_root_ca.pem", AprilFirst2016(), + TRUST_STORE_BUILTIN, ""); +} + +// Tests verifying a valid certificate chain of length 2: +// +// 0: 4ZZDZJ FA8FCA7EFE3C +// 1: Chromecast ICA 4 (Audio) +// +// Chains to trust anchor: +// Cast Root CA (built-in trust store) +// +// This device certificate has a policy that means it is valid only for audio +// devices. +TEST(VerifyCastDeviceCertTest, ChromecastAudio) { + RunTest(CastCertError::kNone, "4ZZDZJ FA8FCA7EFE3C", + CastDeviceCertPolicy::kAudioOnly, + TEST_DATA_PREFIX "certificates/chromecast_audio.pem", + AprilFirst2016(), TRUST_STORE_BUILTIN, ""); +} + +// Tests verifying a valid certificate chain of length 3: +// +// 0: MediaTek Audio Dev Test +// 1: MediaTek Audio Dev Model +// 2: Cast Audio Dev Root CA +// +// Chains to trust anchor: +// Cast Root CA (built-in trust store) +// +// This device certificate has a policy that means it is valid only for audio +// devices. +TEST(VerifyCastDeviceCertTest, MtkAudioDev) { + RunTest(CastCertError::kNone, "MediaTek Audio Dev Test", + CastDeviceCertPolicy::kAudioOnly, + TEST_DATA_PREFIX "certificates/mtk_audio_dev.pem", JanuaryFirst2015(), + TRUST_STORE_BUILTIN, ""); +} + +// Tests verifying a valid certificate chain of length 2: +// +// 0: 9V0000VB FA8FCA784D01 +// 1: Cast TV ICA (Vizio) +// +// Chains to trust anchor: +// Cast Root CA (built-in trust store) +TEST(VerifyCastDeviceCertTest, Vizio) { + RunTest(CastCertError::kNone, "9V0000VB FA8FCA784D01", + CastDeviceCertPolicy::kUnrestricted, + TEST_DATA_PREFIX "certificates/vizio.pem", AprilFirst2016(), + TRUST_STORE_BUILTIN, ""); +} + +// Tests verifying a valid certificate chain of length 2 using expired +// time points. +TEST(VerifyCastDeviceCertTest, ChromecastGen2InvalidTime) { + const char* kCertsFile = TEST_DATA_PREFIX "certificates/chromecast_gen2.pem"; + + // Control test - certificate should be valid at some time otherwise + // this test is pointless. + RunTest(CastCertError::kNone, "3ZZAK6 FA8FCA3F0D35", + CastDeviceCertPolicy::kUnrestricted, kCertsFile, AprilFirst2016(), + TRUST_STORE_BUILTIN, ""); + + // Use a time before notBefore. + RunTest(CastCertError::kErrCertsDateInvalid, "", + CastDeviceCertPolicy::kUnrestricted, kCertsFile, JanuaryFirst2015(), + TRUST_STORE_BUILTIN, ""); + + // Use a time after notAfter. + RunTest(CastCertError::kErrCertsDateInvalid, "", + CastDeviceCertPolicy::kUnrestricted, kCertsFile, MarchFirst2037(), + TRUST_STORE_BUILTIN, ""); +} + +// Tests verifying a valid certificate chain of length 3: +// +// 0: Audio Reference Dev Test +// 1: Audio Reference Dev Model +// 2: Cast Audio Dev Root CA +// +// Chains to trust anchor: +// Cast Root CA (built-in trust store) +// +// This device certificate has a policy that means it is valid only for audio +// devices. +TEST(VerifyCastDeviceCertTest, AudioRefDevTestChain3) { + RunTest(CastCertError::kNone, "Audio Reference Dev Test", + CastDeviceCertPolicy::kAudioOnly, + TEST_DATA_PREFIX "certificates/audio_ref_dev_test_chain_3.pem", + AprilFirst2016(), TRUST_STORE_BUILTIN, + TEST_DATA_PREFIX "signeddata/AudioReferenceDevTest.pem"); +} + +// TODO(btolsch): This won't work by default with boringssl, so do we want to +// find a way to work around this or is it safe to enforce 20-octet length now? +// Previous TODO from eroman@ suggested 2017 or even sooner was safe to remove +// this. +#if 0 +// Tests verifying a valid certificate chain of length 3. Note that the first +// intermediate has a serial number that is 21 octets long, which violates RFC +// 5280. However cast verification accepts this certificate for compatibility +// reasons. +// +// 0: 8C579B806FFC8A9DFFFF F8:8F:CA:6B:E6:DA +// 1: Sony so16vic CA +// 2: Cast Audio Sony CA +// +// Chains to trust anchor: +// Cast Root CA (built-in trust store) +// +// This device certificate has a policy that means it is valid only for audio +// devices. +TEST(VerifyCastDeviceCertTest, IntermediateSerialNumberTooLong) { + RunTest(CastCertError::kNone, "8C579B806FFC8A9DFFFF F8:8F:CA:6B:E6:DA", + CastDeviceCertPolicy::AUDIO_ONLY, + "certificates/intermediate_serialnumber_toolong.pem", + AprilFirst2016(), TRUST_STORE_BUILTIN, ""); +} +#endif + +// Tests verifying a valid certificate chain of length 2 when the trust anchor +// is "expired". This is expected to work since expiration is not an enforced +// anchor constraint, even though it may appear in the root certificate. +// +// 0: CastDevice +// 1: CastIntermediate +// +// Chains to trust anchor: +// Expired CastRoot (provided by test data) +TEST(VerifyCastDeviceCertTest, ExpiredTrustAnchor) { + // The root certificate is only valid in 2015, so validating with a time in + // 2016 means it is expired. + RunTest(CastCertError::kNone, "CastDevice", + CastDeviceCertPolicy::kUnrestricted, + TEST_DATA_PREFIX "certificates/expired_root.pem", AprilFirst2016(), + TRUST_STORE_FROM_TEST_FILE, ""); +} + +// Tests verifying a certificate chain where the root certificate has a pathlen +// constraint which is violated by the chain. In this case Root has a pathlen=1 +// constraint, however neither intermediate is constrained. +// +// The expectation is for pathlen constraints on trust anchors to be enforced, +// so this validation must fail. +// +// 0: Target +// 1: Intermediate2 +// 2: Intermediate1 +// +// Chains to trust anchor: +// Root (provided by test data; has pathlen=1 constraint) +TEST(VerifyCastDeviceCertTest, ViolatesPathlenTrustAnchorConstraint) { + // Test that the chain verification fails due to the pathlen constraint. + RunTest(CastCertError::kErrCertsPathlen, "Target", + CastDeviceCertPolicy::kUnrestricted, + TEST_DATA_PREFIX "certificates/violates_root_pathlen_constraint.pem", + AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, ""); +} + +// Tests verifying a certificate chain with the policies: +// +// Root: policies={} +// Intermediate: policies={anyPolicy} +// Leaf: policies={anyPolicy} +TEST(VerifyCastDeviceCertTest, PoliciesIcaAnypolicyLeafAnypolicy) { + RunTest(CastCertError::kNone, "Leaf", CastDeviceCertPolicy::kUnrestricted, + TEST_DATA_PREFIX + "certificates/policies_ica_anypolicy_leaf_anypolicy.pem", + AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, ""); +} + +// Test verifying a certificate chain with the policies: +// +// Root: policies={} +// Intermediate: policies={anyPolicy} +// Leaf: policies={audioOnly} +TEST(VerifyCastDeviceCertTest, PoliciesIcaAnypolicyLeafAudioonly) { + RunTest(CastCertError::kNone, "Leaf", CastDeviceCertPolicy::kAudioOnly, + TEST_DATA_PREFIX + "certificates/policies_ica_anypolicy_leaf_audioonly.pem", + AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, ""); +} + +// Test verifying a certificate chain with the policies: +// +// Root: policies={} +// Intermediate: policies={anyPolicy} +// Leaf: policies={foo} +TEST(VerifyCastDeviceCertTest, PoliciesIcaAnypolicyLeafFoo) { + RunTest(CastCertError::kNone, "Leaf", CastDeviceCertPolicy::kUnrestricted, + TEST_DATA_PREFIX "certificates/policies_ica_anypolicy_leaf_foo.pem", + AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, ""); +} + +// Test verifying a certificate chain with the policies: +// +// Root: policies={} +// Intermediate: policies={anyPolicy} +// Leaf: policies={} +TEST(VerifyCastDeviceCertTest, PoliciesIcaAnypolicyLeafNone) { + RunTest(CastCertError::kNone, "Leaf", CastDeviceCertPolicy::kUnrestricted, + TEST_DATA_PREFIX "certificates/policies_ica_anypolicy_leaf_none.pem", + AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, ""); +} + +// Test verifying a certificate chain with the policies: +// +// Root: policies={} +// Intermediate: policies={audioOnly} +// Leaf: policies={anyPolicy} +TEST(VerifyCastDeviceCertTest, PoliciesIcaAudioonlyLeafAnypolicy) { + RunTest(CastCertError::kNone, "Leaf", CastDeviceCertPolicy::kAudioOnly, + TEST_DATA_PREFIX + "certificates/policies_ica_audioonly_leaf_anypolicy.pem", + AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, ""); +} + +// Test verifying a certificate chain with the policies: +// +// Root: policies={} +// Intermediate: policies={audioOnly} +// Leaf: policies={audioOnly} +TEST(VerifyCastDeviceCertTest, PoliciesIcaAudioonlyLeafAudioonly) { + RunTest(CastCertError::kNone, "Leaf", CastDeviceCertPolicy::kAudioOnly, + TEST_DATA_PREFIX + "certificates/policies_ica_audioonly_leaf_audioonly.pem", + AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, ""); +} + +// Test verifying a certificate chain with the policies: +// +// Root: policies={} +// Intermediate: policies={audioOnly} +// Leaf: policies={foo} +TEST(VerifyCastDeviceCertTest, PoliciesIcaAudioonlyLeafFoo) { + RunTest(CastCertError::kNone, "Leaf", CastDeviceCertPolicy::kAudioOnly, + TEST_DATA_PREFIX "certificates/policies_ica_audioonly_leaf_foo.pem", + AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, ""); +} + +// Test verifying a certificate chain with the policies: +// +// Root: policies={} +// Intermediate: policies={audioOnly} +// Leaf: policies={} +TEST(VerifyCastDeviceCertTest, PoliciesIcaAudioonlyLeafNone) { + RunTest(CastCertError::kNone, "Leaf", CastDeviceCertPolicy::kAudioOnly, + TEST_DATA_PREFIX "certificates/policies_ica_audioonly_leaf_none.pem", + AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, ""); +} + +// Test verifying a certificate chain with the policies: +// +// Root: policies={} +// Intermediate: policies={} +// Leaf: policies={anyPolicy} +TEST(VerifyCastDeviceCertTest, PoliciesIcaNoneLeafAnypolicy) { + RunTest(CastCertError::kNone, "Leaf", CastDeviceCertPolicy::kUnrestricted, + TEST_DATA_PREFIX "certificates/policies_ica_none_leaf_anypolicy.pem", + AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, ""); +} + +// Test verifying a certificate chain with the policies: +// +// Root: policies={} +// Intermediate: policies={} +// Leaf: policies={audioOnly} +TEST(VerifyCastDeviceCertTest, PoliciesIcaNoneLeafAudioonly) { + RunTest(CastCertError::kNone, "Leaf", CastDeviceCertPolicy::kAudioOnly, + TEST_DATA_PREFIX "certificates/policies_ica_none_leaf_audioonly.pem", + AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, ""); +} + +// Test verifying a certificate chain with the policies: +// +// Root: policies={} +// Intermediate: policies={} +// Leaf: policies={foo} +TEST(VerifyCastDeviceCertTest, PoliciesIcaNoneLeafFoo) { + RunTest(CastCertError::kNone, "Leaf", CastDeviceCertPolicy::kUnrestricted, + TEST_DATA_PREFIX "certificates/policies_ica_none_leaf_foo.pem", + AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, ""); +} + +// Test verifying a certificate chain with the policies: +// +// Root: policies={} +// Intermediate: policies={} +// Leaf: policies={} +TEST(VerifyCastDeviceCertTest, PoliciesIcaNoneLeafNone) { + RunTest(CastCertError::kNone, "Leaf", CastDeviceCertPolicy::kUnrestricted, + TEST_DATA_PREFIX "certificates/policies_ica_none_leaf_none.pem", + AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, ""); +} + +// Tests verifying a certificate chain where the leaf certificate has a +// 1024-bit RSA key. Verification should fail since the target's key is +// too weak. +TEST(VerifyCastDeviceCertTest, DeviceCertHas1024BitRsaKey) { + RunTest(CastCertError::kErrCertsVerifyGeneric, "RSA 1024 Device Cert", + CastDeviceCertPolicy::kUnrestricted, + TEST_DATA_PREFIX "certificates/rsa1024_device_cert.pem", + AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, ""); +} + +// Tests verifying a certificate chain where the leaf certificate has a +// 2048-bit RSA key, and then verifying signed data (both SHA1 and SHA256) +// for it. +TEST(VerifyCastDeviceCertTest, DeviceCertHas2048BitRsaKey) { + RunTest(CastCertError::kNone, "RSA 2048 Device Cert", + CastDeviceCertPolicy::kUnrestricted, + TEST_DATA_PREFIX "certificates/rsa2048_device_cert.pem", + AprilFirst2016(), TRUST_STORE_FROM_TEST_FILE, + TEST_DATA_PREFIX "signeddata/rsa2048_device_cert_data.pem"); +} + +// Tests verifying a certificate chain where an intermediate certificate has a +// nameConstraints extension but the leaf certificate is still permitted under +// these constraints. +TEST(VerifyCastDeviceCertTest, NameConstraintsObeyed) { + RunTest(CastCertError::kNone, "Device", CastDeviceCertPolicy::kUnrestricted, + TEST_DATA_PREFIX "certificates/nc.pem", AprilFirst2020(), + TRUST_STORE_FROM_TEST_FILE, ""); +} + +// Tests verifying a certificate chain where an intermediate certificate has a +// nameConstraints extension and the leaf certificate is not permitted under +// these constraints. +TEST(VerifyCastDeviceCertTest, NameConstraintsViolated) { + RunTest(CastCertError::kErrCertsVerifyGeneric, "Device", + CastDeviceCertPolicy::kUnrestricted, + TEST_DATA_PREFIX "certificates/nc_fail.pem", AprilFirst2020(), + TRUST_STORE_FROM_TEST_FILE, ""); +} + +} // namespace +} // namespace certificate +} // namespace cast diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/cast_root_ca_cert_der-inc.h b/chromium/third_party/openscreen/src/cast/common/certificate/cast_root_ca_cert_der-inc.h new file mode 100644 index 00000000000..67ed0b43abe --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/certificate/cast_root_ca_cert_der-inc.h @@ -0,0 +1,153 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_COMMON_CERTIFICATE_CAST_ROOT_CA_CERT_DER_INC_H_ +#define CAST_COMMON_CERTIFICATE_CAST_ROOT_CA_CERT_DER_INC_H_ + +// Certificate: +// Data: +// Version: 3 (0x2) +// Serial Number: 2 (0x2) +// Signature Algorithm: sha1WithRSAEncryption +// Issuer: C=US, ST=California, L=Mountain View, O=Google Inc, +// OU=Cast, CN=Cast Root CA +// Validity +// Not Before: Apr 2 17:34:26 2014 GMT +// Not After : Mar 28 17:34:26 2034 GMT +// Subject: C=US, ST=California, L=Mountain View, O=Google Inc, OU=Cast, +// CN=Cast Root CA +// Subject Public Key Info: +// Public Key Algorithm: rsaEncryption +// Public-Key: (2048 bit) +// Modulus: +// 00:ba:d9:65:9d:da:39:d3:c1:77:f6:d4:d0:ae:8f: +// 58:08:68:39:4a:95:ed:70:cf:fd:79:08:a9:aa:e5: +// e9:b8:a7:2d:a0:67:47:8a:9e:c9:cf:70:b3:05:87: +// 69:11:ec:70:98:97:c3:e6:c3:c3:eb:bd:c6:b0:3d: +// fc:4f:c1:5e:38:9f:da:cf:73:30:06:5b:79:37:c1: +// 5e:8c:87:47:94:9a:41:92:2a:d6:95:c4:71:5c:27: +// 5d:08:b1:80:c6:92:bd:1b:e3:41:97:a1:ec:75:9f: +// 55:9e:3e:9f:8f:1c:c7:65:64:07:d3:b3:96:a1:04: +// 9f:91:c4:de:0a:7b:6c:d9:c8:c0:78:31:a0:19:42: +// a9:e8:83:e3:ce:fc:f1:ce:c2:2e:24:46:95:09:19: +// ca:c0:46:b2:e5:01:ba:d7:4f:f3:bf:f6:69:ad:99: +// 04:fa:a0:07:39:0e:e6:df:51:47:07:c0:e4:a9:5c: +// 4b:94:c5:2f:b3:a0:30:7f:e7:95:6b:b2:af:32:0d: +// f1:8c:d5:6d:cb:7b:47:a7:08:ab:cb:27:a3:4d:cf: +// 4a:5a:f1:05:d1:f8:62:c5:10:2a:74:69:aa:e6:4b: +// 96:fb:9b:d8:63:e4:58:66:d3:ad:8a:6e:ff:7b:5e: +// f9:a5:56:1e:2d:82:31:5b:f0:e2:24:e6:41:4a:1f: +// ae:13 +// Exponent: 65537 (0x10001) +// X509v3 extensions: +// X509v3 Basic Constraints: +// CA:TRUE, pathlen:2 +// X509v3 Subject Key Identifier: +// 7C:9A:1E:7D:DF:79:54:BC:D7:CC:5E:CA:99:86:45:79:65:74:28:19 +// X509v3 Authority Key Identifier: +// keyid:7C:9A:1E:7D:DF:79:54:BC:D7:CC:5E:CA:99:86:45:79:65:74 +// :28:19 +// +// X509v3 Key Usage: +// Certificate Sign, CRL Sign +// Signature Algorithm: sha1WithRSAEncryption +// 80:f4:5a:fb:3d:28:19:51:20:d7:d4:fb:12:97:4a:65:f2:58: +// 35:92:77:30:6a:f1:d7:b6:51:1a:7f:9a:cd:c7:7b:03:42:ad: +// 55:6a:00:af:f0:e1:06:c2:bd:6b:78:75:db:fe:41:11:53:4a: +// 39:bb:9a:3a:c6:59:34:2f:2c:33:e3:b2:d6:5c:7f:dd:78:eb: +// 71:5b:39:da:83:90:c5:31:e2:3f:23:ef:da:eb:2b:2d:77:5e: +// de:c3:43:d2:c9:6b:59:82:ca:d5:ed:fa:a1:64:5b:cb:f1:0d: +// 1a:62:e1:9c:e8:a7:18:70:f0:5f:17:96:f8:ed:86:db:ae:1d: +// e0:cf:3e:5d:2e:ee:16:6d:95:2b:3c:fd:97:f3:05:5a:24:68: +// 4d:39:b6:f8:e4:58:ba:f5:e0:26:78:51:c5:5b:5d:4e:09:e5: +// 6c:47:8b:7a:5a:2e:89:53:e6:cc:36:5b:26:3c:f8:72:43:02: +// 82:d2:2b:cd:f0:d3:a3:ec:13:3e:52:d5:83:3d:07:dc:1d:43: +// 65:7a:33:02:01:a3:ce:b7:d6:60:51:3b:09:c2:23:8a:32:fe: +// 98:19:60:62:93:85:cd:34:46:db:d5:23:0f:79:da:77:00:2a: +// 02:6d:83:58:ce:03:77:35:e1:a3:20:93:c2:4a:a2:a4:46:1c: +// 75:2c:1f:4d +const unsigned char kCastRootCaDer[] = { + 0x30, 0x82, 0x03, 0xc5, 0x30, 0x82, 0x02, 0xad, 0xa0, 0x03, 0x02, 0x01, + 0x02, 0x02, 0x01, 0x02, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, + 0xf7, 0x0d, 0x01, 0x01, 0x05, 0x05, 0x00, 0x30, 0x75, 0x31, 0x0b, 0x30, + 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, 0x02, 0x55, 0x53, 0x31, 0x13, + 0x30, 0x11, 0x06, 0x03, 0x55, 0x04, 0x08, 0x0c, 0x0a, 0x43, 0x61, 0x6c, + 0x69, 0x66, 0x6f, 0x72, 0x6e, 0x69, 0x61, 0x31, 0x16, 0x30, 0x14, 0x06, + 0x03, 0x55, 0x04, 0x07, 0x0c, 0x0d, 0x4d, 0x6f, 0x75, 0x6e, 0x74, 0x61, + 0x69, 0x6e, 0x20, 0x56, 0x69, 0x65, 0x77, 0x31, 0x13, 0x30, 0x11, 0x06, + 0x03, 0x55, 0x04, 0x0a, 0x0c, 0x0a, 0x47, 0x6f, 0x6f, 0x67, 0x6c, 0x65, + 0x20, 0x49, 0x6e, 0x63, 0x31, 0x0d, 0x30, 0x0b, 0x06, 0x03, 0x55, 0x04, + 0x0b, 0x0c, 0x04, 0x43, 0x61, 0x73, 0x74, 0x31, 0x15, 0x30, 0x13, 0x06, + 0x03, 0x55, 0x04, 0x03, 0x0c, 0x0c, 0x43, 0x61, 0x73, 0x74, 0x20, 0x52, + 0x6f, 0x6f, 0x74, 0x20, 0x43, 0x41, 0x30, 0x1e, 0x17, 0x0d, 0x31, 0x34, + 0x30, 0x34, 0x30, 0x32, 0x31, 0x37, 0x33, 0x34, 0x32, 0x36, 0x5a, 0x17, + 0x0d, 0x33, 0x34, 0x30, 0x33, 0x32, 0x38, 0x31, 0x37, 0x33, 0x34, 0x32, + 0x36, 0x5a, 0x30, 0x75, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, + 0x06, 0x13, 0x02, 0x55, 0x53, 0x31, 0x13, 0x30, 0x11, 0x06, 0x03, 0x55, + 0x04, 0x08, 0x0c, 0x0a, 0x43, 0x61, 0x6c, 0x69, 0x66, 0x6f, 0x72, 0x6e, + 0x69, 0x61, 0x31, 0x16, 0x30, 0x14, 0x06, 0x03, 0x55, 0x04, 0x07, 0x0c, + 0x0d, 0x4d, 0x6f, 0x75, 0x6e, 0x74, 0x61, 0x69, 0x6e, 0x20, 0x56, 0x69, + 0x65, 0x77, 0x31, 0x13, 0x30, 0x11, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x0c, + 0x0a, 0x47, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x20, 0x49, 0x6e, 0x63, 0x31, + 0x0d, 0x30, 0x0b, 0x06, 0x03, 0x55, 0x04, 0x0b, 0x0c, 0x04, 0x43, 0x61, + 0x73, 0x74, 0x31, 0x15, 0x30, 0x13, 0x06, 0x03, 0x55, 0x04, 0x03, 0x0c, + 0x0c, 0x43, 0x61, 0x73, 0x74, 0x20, 0x52, 0x6f, 0x6f, 0x74, 0x20, 0x43, + 0x41, 0x30, 0x82, 0x01, 0x22, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, + 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01, 0x05, 0x00, 0x03, 0x82, 0x01, 0x0f, + 0x00, 0x30, 0x82, 0x01, 0x0a, 0x02, 0x82, 0x01, 0x01, 0x00, 0xba, 0xd9, + 0x65, 0x9d, 0xda, 0x39, 0xd3, 0xc1, 0x77, 0xf6, 0xd4, 0xd0, 0xae, 0x8f, + 0x58, 0x08, 0x68, 0x39, 0x4a, 0x95, 0xed, 0x70, 0xcf, 0xfd, 0x79, 0x08, + 0xa9, 0xaa, 0xe5, 0xe9, 0xb8, 0xa7, 0x2d, 0xa0, 0x67, 0x47, 0x8a, 0x9e, + 0xc9, 0xcf, 0x70, 0xb3, 0x05, 0x87, 0x69, 0x11, 0xec, 0x70, 0x98, 0x97, + 0xc3, 0xe6, 0xc3, 0xc3, 0xeb, 0xbd, 0xc6, 0xb0, 0x3d, 0xfc, 0x4f, 0xc1, + 0x5e, 0x38, 0x9f, 0xda, 0xcf, 0x73, 0x30, 0x06, 0x5b, 0x79, 0x37, 0xc1, + 0x5e, 0x8c, 0x87, 0x47, 0x94, 0x9a, 0x41, 0x92, 0x2a, 0xd6, 0x95, 0xc4, + 0x71, 0x5c, 0x27, 0x5d, 0x08, 0xb1, 0x80, 0xc6, 0x92, 0xbd, 0x1b, 0xe3, + 0x41, 0x97, 0xa1, 0xec, 0x75, 0x9f, 0x55, 0x9e, 0x3e, 0x9f, 0x8f, 0x1c, + 0xc7, 0x65, 0x64, 0x07, 0xd3, 0xb3, 0x96, 0xa1, 0x04, 0x9f, 0x91, 0xc4, + 0xde, 0x0a, 0x7b, 0x6c, 0xd9, 0xc8, 0xc0, 0x78, 0x31, 0xa0, 0x19, 0x42, + 0xa9, 0xe8, 0x83, 0xe3, 0xce, 0xfc, 0xf1, 0xce, 0xc2, 0x2e, 0x24, 0x46, + 0x95, 0x09, 0x19, 0xca, 0xc0, 0x46, 0xb2, 0xe5, 0x01, 0xba, 0xd7, 0x4f, + 0xf3, 0xbf, 0xf6, 0x69, 0xad, 0x99, 0x04, 0xfa, 0xa0, 0x07, 0x39, 0x0e, + 0xe6, 0xdf, 0x51, 0x47, 0x07, 0xc0, 0xe4, 0xa9, 0x5c, 0x4b, 0x94, 0xc5, + 0x2f, 0xb3, 0xa0, 0x30, 0x7f, 0xe7, 0x95, 0x6b, 0xb2, 0xaf, 0x32, 0x0d, + 0xf1, 0x8c, 0xd5, 0x6d, 0xcb, 0x7b, 0x47, 0xa7, 0x08, 0xab, 0xcb, 0x27, + 0xa3, 0x4d, 0xcf, 0x4a, 0x5a, 0xf1, 0x05, 0xd1, 0xf8, 0x62, 0xc5, 0x10, + 0x2a, 0x74, 0x69, 0xaa, 0xe6, 0x4b, 0x96, 0xfb, 0x9b, 0xd8, 0x63, 0xe4, + 0x58, 0x66, 0xd3, 0xad, 0x8a, 0x6e, 0xff, 0x7b, 0x5e, 0xf9, 0xa5, 0x56, + 0x1e, 0x2d, 0x82, 0x31, 0x5b, 0xf0, 0xe2, 0x24, 0xe6, 0x41, 0x4a, 0x1f, + 0xae, 0x13, 0x02, 0x03, 0x01, 0x00, 0x01, 0xa3, 0x60, 0x30, 0x5e, 0x30, + 0x0f, 0x06, 0x03, 0x55, 0x1d, 0x13, 0x04, 0x08, 0x30, 0x06, 0x01, 0x01, + 0xff, 0x02, 0x01, 0x02, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x0e, 0x04, + 0x16, 0x04, 0x14, 0x7c, 0x9a, 0x1e, 0x7d, 0xdf, 0x79, 0x54, 0xbc, 0xd7, + 0xcc, 0x5e, 0xca, 0x99, 0x86, 0x45, 0x79, 0x65, 0x74, 0x28, 0x19, 0x30, + 0x1f, 0x06, 0x03, 0x55, 0x1d, 0x23, 0x04, 0x18, 0x30, 0x16, 0x80, 0x14, + 0x7c, 0x9a, 0x1e, 0x7d, 0xdf, 0x79, 0x54, 0xbc, 0xd7, 0xcc, 0x5e, 0xca, + 0x99, 0x86, 0x45, 0x79, 0x65, 0x74, 0x28, 0x19, 0x30, 0x0b, 0x06, 0x03, + 0x55, 0x1d, 0x0f, 0x04, 0x04, 0x03, 0x02, 0x01, 0x06, 0x30, 0x0d, 0x06, + 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x05, 0x05, 0x00, + 0x03, 0x82, 0x01, 0x01, 0x00, 0x80, 0xf4, 0x5a, 0xfb, 0x3d, 0x28, 0x19, + 0x51, 0x20, 0xd7, 0xd4, 0xfb, 0x12, 0x97, 0x4a, 0x65, 0xf2, 0x58, 0x35, + 0x92, 0x77, 0x30, 0x6a, 0xf1, 0xd7, 0xb6, 0x51, 0x1a, 0x7f, 0x9a, 0xcd, + 0xc7, 0x7b, 0x03, 0x42, 0xad, 0x55, 0x6a, 0x00, 0xaf, 0xf0, 0xe1, 0x06, + 0xc2, 0xbd, 0x6b, 0x78, 0x75, 0xdb, 0xfe, 0x41, 0x11, 0x53, 0x4a, 0x39, + 0xbb, 0x9a, 0x3a, 0xc6, 0x59, 0x34, 0x2f, 0x2c, 0x33, 0xe3, 0xb2, 0xd6, + 0x5c, 0x7f, 0xdd, 0x78, 0xeb, 0x71, 0x5b, 0x39, 0xda, 0x83, 0x90, 0xc5, + 0x31, 0xe2, 0x3f, 0x23, 0xef, 0xda, 0xeb, 0x2b, 0x2d, 0x77, 0x5e, 0xde, + 0xc3, 0x43, 0xd2, 0xc9, 0x6b, 0x59, 0x82, 0xca, 0xd5, 0xed, 0xfa, 0xa1, + 0x64, 0x5b, 0xcb, 0xf1, 0x0d, 0x1a, 0x62, 0xe1, 0x9c, 0xe8, 0xa7, 0x18, + 0x70, 0xf0, 0x5f, 0x17, 0x96, 0xf8, 0xed, 0x86, 0xdb, 0xae, 0x1d, 0xe0, + 0xcf, 0x3e, 0x5d, 0x2e, 0xee, 0x16, 0x6d, 0x95, 0x2b, 0x3c, 0xfd, 0x97, + 0xf3, 0x05, 0x5a, 0x24, 0x68, 0x4d, 0x39, 0xb6, 0xf8, 0xe4, 0x58, 0xba, + 0xf5, 0xe0, 0x26, 0x78, 0x51, 0xc5, 0x5b, 0x5d, 0x4e, 0x09, 0xe5, 0x6c, + 0x47, 0x8b, 0x7a, 0x5a, 0x2e, 0x89, 0x53, 0xe6, 0xcc, 0x36, 0x5b, 0x26, + 0x3c, 0xf8, 0x72, 0x43, 0x02, 0x82, 0xd2, 0x2b, 0xcd, 0xf0, 0xd3, 0xa3, + 0xec, 0x13, 0x3e, 0x52, 0xd5, 0x83, 0x3d, 0x07, 0xdc, 0x1d, 0x43, 0x65, + 0x7a, 0x33, 0x02, 0x01, 0xa3, 0xce, 0xb7, 0xd6, 0x60, 0x51, 0x3b, 0x09, + 0xc2, 0x23, 0x8a, 0x32, 0xfe, 0x98, 0x19, 0x60, 0x62, 0x93, 0x85, 0xcd, + 0x34, 0x46, 0xdb, 0xd5, 0x23, 0x0f, 0x79, 0xda, 0x77, 0x00, 0x2a, 0x02, + 0x6d, 0x83, 0x58, 0xce, 0x03, 0x77, 0x35, 0xe1, 0xa3, 0x20, 0x93, 0xc2, + 0x4a, 0xa2, 0xa4, 0x46, 0x1c, 0x75, 0x2c, 0x1f, 0x4d}; + +#endif // CAST_COMMON_CERTIFICATE_CAST_ROOT_CA_CERT_DER_INC_H_ diff --git a/chromium/third_party/openscreen/src/cast/common/certificate/eureka_root_ca_der-inc.h b/chromium/third_party/openscreen/src/cast/common/certificate/eureka_root_ca_der-inc.h new file mode 100644 index 00000000000..2e88a093ba6 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/certificate/eureka_root_ca_der-inc.h @@ -0,0 +1,151 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_COMMON_CERTIFICATE_EUREKA_ROOT_CA_DER_INC_H_ +#define CAST_COMMON_CERTIFICATE_EUREKA_ROOT_CA_DER_INC_H_ + +// Certificate: +// Data: +// Version: 3 (0x2) +// Serial Number: 1 (0x1) +// Signature Algorithm: sha1WithRSAEncryption +// Issuer: C=US, ST=California, L=Mountain View, O=Google Inc, +// OU=Google TV, CN=Eureka Root CA +// Validity +// Not Before: Dec 17 22:39:33 2012 GMT +// Not After : Dec 12 22:39:33 2032 GMT +// Subject: C=US, ST=California, L=Mountain View, O=Google Inc, +// OU=Google TV, CN=Eureka Root CA +// Subject Public Key Info: +// Public Key Algorithm: rsaEncryption +// Public-Key: (2048 bit) +// Modulus: +// 00:b9:11:d0:ea:12:dc:32:e1:df:5c:33:6b:19:73: +// 1d:9d:9e:d0:39:76:bf:a5:84:09:a6:fd:6e:6d:e9: +// dc:8f:36:4e:e9:88:02:bd:9f:f4:e8:44:fd:4c:f5: +// 9a:02:56:6a:47:2a:63:6c:58:45:cc:7c:66:24:dc: +// 79:79:c3:2a:a4:b2:8b:a0:f7:a2:b5:cd:06:7e:db: +// be:ec:0c:86:f2:0d:24:60:74:84:ca:29:23:84:02: +// d8:a7:ed:3b:f1:ec:26:47:54:e3:b1:2d:e6:64:0f: +// f6:72:c5:e9:98:52:17:c0:fc:f2:2c:20:c8:40:f8: +// 47:c9:32:9e:3b:97:b1:8b:f5:98:24:70:63:66:19: +// c1:52:e8:04:05:3d:5f:8d:bc:d8:4b:af:77:98:6f: +// 1f:78:d1:b6:50:27:4d:e4:ec:14:69:67:1f:58:af: +// a9:a0:11:26:3c:94:32:07:7f:d7:e9:69:1f:ae:3f: +// 4f:63:8a:8f:89:d6:f2:19:78:5c:21:8e:b1:b6:57: +// d8:c0:e1:ee:7d:6e:dd:f1:3a:0a:6a:f1:ba:ff:f9: +// 83:2f:dc:b5:a4:20:17:63:36:ef:c8:62:19:cc:56: +// ce:b2:ea:31:89:4b:78:58:c1:bf:03:13:99:e0:12: +// f2:88:aa:9b:94:da:dd:76:79:17:1e:34:d1:0a:c4: +// 07:45 +// Exponent: 65537 (0x10001) +// X509v3 extensions: +// X509v3 Subject Key Identifier: +// 44:4E:2A:47:58:D8:B9:48:91:F6:4F:CE:74:A9:1D:32:9A:8D:8D:E9 +// X509v3 Authority Key Identifier: +// keyid:44:4E:2A:47:58:D8:B9:48:91:F6:4F:CE:74:A9:1D:32:9A:8D +// :8D:E9 +// +// X509v3 Basic Constraints: +// CA:TRUE +// Signature Algorithm: sha1WithRSAEncryption +// 3f:c8:26:a0:6e:5c:05:40:79:a1:98:a9:33:de:68:74:85:ee: +// ae:b7:1c:33:59:b0:11:de:9f:f4:4f:d3:eb:51:09:7d:47:7e: +// 6e:51:85:f4:54:cd:83:98:25:b1:ba:b0:57:ec:93:db:12:e2: +// ec:51:49:7a:96:73:9b:c8:96:6d:85:8c:d3:e1:3c:fa:32:e2: +// 58:0c:77:6d:87:0c:34:01:aa:30:a9:76:e0:c0:e7:db:5e:1b: +// e9:10:30:a4:e0:09:49:26:b9:58:cd:5a:07:e5:50:75:de:9a: +// 3b:f6:53:7e:b1:53:5e:45:27:4f:17:e3:08:33:b2:50:0a:bb: +// f4:fc:25:97:29:de:41:75:30:fa:77:38:aa:65:8a:73:4f:ea: +// 11:7b:eb:7c:17:60:27:0e:bc:3e:76:52:d8:8b:ed:1a:f8:eb: +// 37:bb:11:fd:ae:70:17:0a:fe:e0:ad:06:b3:1f:69:8a:72:04: +// c2:c0:33:0b:d6:2f:63:4c:33:11:14:b8:62:36:88:c5:03:65: +// 01:19:a3:ef:00:bb:6f:0e:92:ff:34:1c:a1:d6:31:d0:5c:5e: +// 9f:99:7d:c7:ca:bd:7c:72:0b:f4:5c:a5:7e:6e:04:a8:d2:99: +// 2c:51:01:14:fe:a2:48:f0:7e:be:84:0d:b4:d3:e2:f3:0e:7d: +// de:8b:f5:33 +const unsigned char kEurekaRootCaDer[] = { + 0x30, 0x82, 0x03, 0xc3, 0x30, 0x82, 0x02, 0xab, 0xa0, 0x03, 0x02, 0x01, + 0x02, 0x02, 0x01, 0x01, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, + 0xf7, 0x0d, 0x01, 0x01, 0x05, 0x05, 0x00, 0x30, 0x7c, 0x31, 0x0b, 0x30, + 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, 0x02, 0x55, 0x53, 0x31, 0x13, + 0x30, 0x11, 0x06, 0x03, 0x55, 0x04, 0x08, 0x0c, 0x0a, 0x43, 0x61, 0x6c, + 0x69, 0x66, 0x6f, 0x72, 0x6e, 0x69, 0x61, 0x31, 0x16, 0x30, 0x14, 0x06, + 0x03, 0x55, 0x04, 0x07, 0x0c, 0x0d, 0x4d, 0x6f, 0x75, 0x6e, 0x74, 0x61, + 0x69, 0x6e, 0x20, 0x56, 0x69, 0x65, 0x77, 0x31, 0x13, 0x30, 0x11, 0x06, + 0x03, 0x55, 0x04, 0x0a, 0x0c, 0x0a, 0x47, 0x6f, 0x6f, 0x67, 0x6c, 0x65, + 0x20, 0x49, 0x6e, 0x63, 0x31, 0x12, 0x30, 0x10, 0x06, 0x03, 0x55, 0x04, + 0x0b, 0x0c, 0x09, 0x47, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x20, 0x54, 0x56, + 0x31, 0x17, 0x30, 0x15, 0x06, 0x03, 0x55, 0x04, 0x03, 0x0c, 0x0e, 0x45, + 0x75, 0x72, 0x65, 0x6b, 0x61, 0x20, 0x52, 0x6f, 0x6f, 0x74, 0x20, 0x43, + 0x41, 0x30, 0x1e, 0x17, 0x0d, 0x31, 0x32, 0x31, 0x32, 0x31, 0x37, 0x32, + 0x32, 0x33, 0x39, 0x33, 0x33, 0x5a, 0x17, 0x0d, 0x33, 0x32, 0x31, 0x32, + 0x31, 0x32, 0x32, 0x32, 0x33, 0x39, 0x33, 0x33, 0x5a, 0x30, 0x7c, 0x31, + 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, 0x02, 0x55, 0x53, + 0x31, 0x13, 0x30, 0x11, 0x06, 0x03, 0x55, 0x04, 0x08, 0x0c, 0x0a, 0x43, + 0x61, 0x6c, 0x69, 0x66, 0x6f, 0x72, 0x6e, 0x69, 0x61, 0x31, 0x16, 0x30, + 0x14, 0x06, 0x03, 0x55, 0x04, 0x07, 0x0c, 0x0d, 0x4d, 0x6f, 0x75, 0x6e, + 0x74, 0x61, 0x69, 0x6e, 0x20, 0x56, 0x69, 0x65, 0x77, 0x31, 0x13, 0x30, + 0x11, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x0c, 0x0a, 0x47, 0x6f, 0x6f, 0x67, + 0x6c, 0x65, 0x20, 0x49, 0x6e, 0x63, 0x31, 0x12, 0x30, 0x10, 0x06, 0x03, + 0x55, 0x04, 0x0b, 0x0c, 0x09, 0x47, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x20, + 0x54, 0x56, 0x31, 0x17, 0x30, 0x15, 0x06, 0x03, 0x55, 0x04, 0x03, 0x0c, + 0x0e, 0x45, 0x75, 0x72, 0x65, 0x6b, 0x61, 0x20, 0x52, 0x6f, 0x6f, 0x74, + 0x20, 0x43, 0x41, 0x30, 0x82, 0x01, 0x22, 0x30, 0x0d, 0x06, 0x09, 0x2a, + 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01, 0x05, 0x00, 0x03, 0x82, + 0x01, 0x0f, 0x00, 0x30, 0x82, 0x01, 0x0a, 0x02, 0x82, 0x01, 0x01, 0x00, + 0xb9, 0x11, 0xd0, 0xea, 0x12, 0xdc, 0x32, 0xe1, 0xdf, 0x5c, 0x33, 0x6b, + 0x19, 0x73, 0x1d, 0x9d, 0x9e, 0xd0, 0x39, 0x76, 0xbf, 0xa5, 0x84, 0x09, + 0xa6, 0xfd, 0x6e, 0x6d, 0xe9, 0xdc, 0x8f, 0x36, 0x4e, 0xe9, 0x88, 0x02, + 0xbd, 0x9f, 0xf4, 0xe8, 0x44, 0xfd, 0x4c, 0xf5, 0x9a, 0x02, 0x56, 0x6a, + 0x47, 0x2a, 0x63, 0x6c, 0x58, 0x45, 0xcc, 0x7c, 0x66, 0x24, 0xdc, 0x79, + 0x79, 0xc3, 0x2a, 0xa4, 0xb2, 0x8b, 0xa0, 0xf7, 0xa2, 0xb5, 0xcd, 0x06, + 0x7e, 0xdb, 0xbe, 0xec, 0x0c, 0x86, 0xf2, 0x0d, 0x24, 0x60, 0x74, 0x84, + 0xca, 0x29, 0x23, 0x84, 0x02, 0xd8, 0xa7, 0xed, 0x3b, 0xf1, 0xec, 0x26, + 0x47, 0x54, 0xe3, 0xb1, 0x2d, 0xe6, 0x64, 0x0f, 0xf6, 0x72, 0xc5, 0xe9, + 0x98, 0x52, 0x17, 0xc0, 0xfc, 0xf2, 0x2c, 0x20, 0xc8, 0x40, 0xf8, 0x47, + 0xc9, 0x32, 0x9e, 0x3b, 0x97, 0xb1, 0x8b, 0xf5, 0x98, 0x24, 0x70, 0x63, + 0x66, 0x19, 0xc1, 0x52, 0xe8, 0x04, 0x05, 0x3d, 0x5f, 0x8d, 0xbc, 0xd8, + 0x4b, 0xaf, 0x77, 0x98, 0x6f, 0x1f, 0x78, 0xd1, 0xb6, 0x50, 0x27, 0x4d, + 0xe4, 0xec, 0x14, 0x69, 0x67, 0x1f, 0x58, 0xaf, 0xa9, 0xa0, 0x11, 0x26, + 0x3c, 0x94, 0x32, 0x07, 0x7f, 0xd7, 0xe9, 0x69, 0x1f, 0xae, 0x3f, 0x4f, + 0x63, 0x8a, 0x8f, 0x89, 0xd6, 0xf2, 0x19, 0x78, 0x5c, 0x21, 0x8e, 0xb1, + 0xb6, 0x57, 0xd8, 0xc0, 0xe1, 0xee, 0x7d, 0x6e, 0xdd, 0xf1, 0x3a, 0x0a, + 0x6a, 0xf1, 0xba, 0xff, 0xf9, 0x83, 0x2f, 0xdc, 0xb5, 0xa4, 0x20, 0x17, + 0x63, 0x36, 0xef, 0xc8, 0x62, 0x19, 0xcc, 0x56, 0xce, 0xb2, 0xea, 0x31, + 0x89, 0x4b, 0x78, 0x58, 0xc1, 0xbf, 0x03, 0x13, 0x99, 0xe0, 0x12, 0xf2, + 0x88, 0xaa, 0x9b, 0x94, 0xda, 0xdd, 0x76, 0x79, 0x17, 0x1e, 0x34, 0xd1, + 0x0a, 0xc4, 0x07, 0x45, 0x02, 0x03, 0x01, 0x00, 0x01, 0xa3, 0x50, 0x30, + 0x4e, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x0e, 0x04, 0x16, 0x04, 0x14, + 0x44, 0x4e, 0x2a, 0x47, 0x58, 0xd8, 0xb9, 0x48, 0x91, 0xf6, 0x4f, 0xce, + 0x74, 0xa9, 0x1d, 0x32, 0x9a, 0x8d, 0x8d, 0xe9, 0x30, 0x1f, 0x06, 0x03, + 0x55, 0x1d, 0x23, 0x04, 0x18, 0x30, 0x16, 0x80, 0x14, 0x44, 0x4e, 0x2a, + 0x47, 0x58, 0xd8, 0xb9, 0x48, 0x91, 0xf6, 0x4f, 0xce, 0x74, 0xa9, 0x1d, + 0x32, 0x9a, 0x8d, 0x8d, 0xe9, 0x30, 0x0c, 0x06, 0x03, 0x55, 0x1d, 0x13, + 0x04, 0x05, 0x30, 0x03, 0x01, 0x01, 0xff, 0x30, 0x0d, 0x06, 0x09, 0x2a, + 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x05, 0x05, 0x00, 0x03, 0x82, + 0x01, 0x01, 0x00, 0x3f, 0xc8, 0x26, 0xa0, 0x6e, 0x5c, 0x05, 0x40, 0x79, + 0xa1, 0x98, 0xa9, 0x33, 0xde, 0x68, 0x74, 0x85, 0xee, 0xae, 0xb7, 0x1c, + 0x33, 0x59, 0xb0, 0x11, 0xde, 0x9f, 0xf4, 0x4f, 0xd3, 0xeb, 0x51, 0x09, + 0x7d, 0x47, 0x7e, 0x6e, 0x51, 0x85, 0xf4, 0x54, 0xcd, 0x83, 0x98, 0x25, + 0xb1, 0xba, 0xb0, 0x57, 0xec, 0x93, 0xdb, 0x12, 0xe2, 0xec, 0x51, 0x49, + 0x7a, 0x96, 0x73, 0x9b, 0xc8, 0x96, 0x6d, 0x85, 0x8c, 0xd3, 0xe1, 0x3c, + 0xfa, 0x32, 0xe2, 0x58, 0x0c, 0x77, 0x6d, 0x87, 0x0c, 0x34, 0x01, 0xaa, + 0x30, 0xa9, 0x76, 0xe0, 0xc0, 0xe7, 0xdb, 0x5e, 0x1b, 0xe9, 0x10, 0x30, + 0xa4, 0xe0, 0x09, 0x49, 0x26, 0xb9, 0x58, 0xcd, 0x5a, 0x07, 0xe5, 0x50, + 0x75, 0xde, 0x9a, 0x3b, 0xf6, 0x53, 0x7e, 0xb1, 0x53, 0x5e, 0x45, 0x27, + 0x4f, 0x17, 0xe3, 0x08, 0x33, 0xb2, 0x50, 0x0a, 0xbb, 0xf4, 0xfc, 0x25, + 0x97, 0x29, 0xde, 0x41, 0x75, 0x30, 0xfa, 0x77, 0x38, 0xaa, 0x65, 0x8a, + 0x73, 0x4f, 0xea, 0x11, 0x7b, 0xeb, 0x7c, 0x17, 0x60, 0x27, 0x0e, 0xbc, + 0x3e, 0x76, 0x52, 0xd8, 0x8b, 0xed, 0x1a, 0xf8, 0xeb, 0x37, 0xbb, 0x11, + 0xfd, 0xae, 0x70, 0x17, 0x0a, 0xfe, 0xe0, 0xad, 0x06, 0xb3, 0x1f, 0x69, + 0x8a, 0x72, 0x04, 0xc2, 0xc0, 0x33, 0x0b, 0xd6, 0x2f, 0x63, 0x4c, 0x33, + 0x11, 0x14, 0xb8, 0x62, 0x36, 0x88, 0xc5, 0x03, 0x65, 0x01, 0x19, 0xa3, + 0xef, 0x00, 0xbb, 0x6f, 0x0e, 0x92, 0xff, 0x34, 0x1c, 0xa1, 0xd6, 0x31, + 0xd0, 0x5c, 0x5e, 0x9f, 0x99, 0x7d, 0xc7, 0xca, 0xbd, 0x7c, 0x72, 0x0b, + 0xf4, 0x5c, 0xa5, 0x7e, 0x6e, 0x04, 0xa8, 0xd2, 0x99, 0x2c, 0x51, 0x01, + 0x14, 0xfe, 0xa2, 0x48, 0xf0, 0x7e, 0xbe, 0x84, 0x0d, 0xb4, 0xd3, 0xe2, + 0xf3, 0x0e, 0x7d, 0xde, 0x8b, 0xf5, 0x33}; + +#endif // CAST_COMMON_CERTIFICATE_EUREKA_ROOT_CA_DER_INC_H_ diff --git a/chromium/third_party/openscreen/src/cast/common/mdns/mdns_constants.h b/chromium/third_party/openscreen/src/cast/common/mdns/mdns_constants.h index d0ff6c63ea2..192dbf1fa33 100644 --- a/chromium/third_party/openscreen/src/cast/common/mdns/mdns_constants.h +++ b/chromium/third_party/openscreen/src/cast/common/mdns/mdns_constants.h @@ -131,41 +131,33 @@ struct Header { uint16_t additional_record_count; }; -// TODO(mayaki): Here and below consider converting constants to members of -// enum classes. +static_assert(sizeof(Header) == 12, "Size of mDNS header must be 12 bytes."); + +enum class MessageType { + Query = 0, + Response = 1, +}; -// DNS Header flags. All flags are formatted to mask directly onto FLAG header -// field in network-byte order. constexpr uint16_t kFlagResponse = 0x8000; constexpr uint16_t kFlagAA = 0x0400; constexpr uint16_t kFlagTC = 0x0200; -constexpr uint16_t kFlagRD = 0x0100; -constexpr uint16_t kFlagRA = 0x0080; -constexpr uint16_t kFlagZ = 0x0040; // Unused field -constexpr uint16_t kFlagAD = 0x0020; -constexpr uint16_t kFlagCD = 0x0010; - -// DNS Header OPCODE mask and values. The mask is formatted to mask directly -// onto FLAG header field in network-byte order. The values are formatted after -// shifting into correct position. constexpr uint16_t kOpcodeMask = 0x7800; -constexpr uint8_t kOpcodeQUERY = 0; -constexpr uint8_t kOpcodeIQUERY = 1; -constexpr uint8_t kOpcodeSTATUS = 2; -constexpr uint8_t kOpcodeUNASSIGNED = 3; // Unused for now -constexpr uint8_t kOpcodeNOTIFY = 4; -constexpr uint8_t kOpcodeUPDATE = 5; - -// DNS Header RCODE mask and values. The mask is formatted to mask directly onto -// FLAG header field in network-byte order. The values are formatted after -// shifting into correct position. constexpr uint16_t kRcodeMask = 0x000F; -constexpr uint8_t kRcodeNOERROR = 0; -constexpr uint8_t kRcodeFORMERR = 1; -constexpr uint8_t kRcodeSERVFAIL = 2; -constexpr uint8_t kRcodeNXDOMAIN = 3; -constexpr uint8_t kRcodeNOTIMP = 4; -constexpr uint8_t kRcodeREFUSED = 5; + +constexpr MessageType GetMessageType(uint16_t flags) { + // RFC 6762 Section 18.2 + return (flags & kFlagResponse) ? MessageType::Response : MessageType::Query; +} + +constexpr uint16_t MakeFlags(MessageType type) { + // RFC 6762 Section 18.2 and Section 18.4 + return (type == MessageType::Response) ? (kFlagResponse | kFlagAA) : 0; +} + +constexpr bool IsValidFlagsSection(uint16_t flags) { + // RFC 6762 Section 18.3 and Section 18.11 + return (flags & (kOpcodeMask | kRcodeMask)) == 0; +} // ============================================================================ // Domain Name @@ -293,44 +285,58 @@ enum class DnsType : uint16_t { kANY = 255, // Only allowed for QTYPE }; -// DNS CLASS masks and values. -constexpr uint16_t kClassMask = 0x7FFF; +enum class DnsClass : uint16_t { + kIN = 1, + kANY = 255, // Only allowed for QCLASS +}; +// Unique and shared records are described in +// https://tools.ietf.org/html/rfc6762#section-2 and +// https://tools.ietf.org/html/rfc6762#section-10.2 +enum class RecordType { + kShared = 0, + kUnique = 1, +}; + +// Unicast and multicast preferred response types are described in +// https://tools.ietf.org/html/rfc6762#section-5.4 +enum class ResponseType { + kMulticast = 0, + kUnicast = 1, +}; + +// DNS CLASS masks and values. // In mDNS the most significant bit of the RRCLASS for response records is // designated as the "cache-flush bit", as described in // https://tools.ietf.org/html/rfc6762#section-10.2 -constexpr uint16_t kCacheFlushBit = 0x8000; // In mDNS the most significant bit of the RRCLASS for query records is // designated as the "unicast-response bit", as described in // https://tools.ietf.org/html/rfc6762#section-5.4 -constexpr uint16_t kUnicastResponseBit = 0x8000; - -enum class DnsClass : uint16_t { - kIN = 1, - kANY = 255, // Only allowed for QCLASS -}; +constexpr uint16_t kClassMask = 0x7FFF; +constexpr uint16_t kClassMsbMask = 0x8000; +constexpr uint16_t kClassMsbShift = 0xF; constexpr DnsClass GetDnsClass(uint16_t rrclass) { return static_cast<DnsClass>(rrclass & kClassMask); } -constexpr bool GetCacheFlush(uint16_t rrclass) { - return rrclass & kCacheFlushBit; +constexpr RecordType GetRecordType(uint16_t rrclass) { + return static_cast<RecordType>((rrclass & kClassMsbMask) >> kClassMsbShift); } -constexpr bool GetUnicastResponse(uint16_t rrclass) { - return rrclass & kUnicastResponseBit; +constexpr ResponseType GetResponseType(uint16_t rrclass) { + return static_cast<ResponseType>((rrclass & kClassMsbMask) >> kClassMsbShift); } -constexpr uint16_t MakeRecordClass(DnsClass dns_class, bool cache_flush) { +constexpr uint16_t MakeRecordClass(DnsClass dns_class, RecordType record_type) { return static_cast<uint16_t>(dns_class) | - (static_cast<uint16_t>(cache_flush) << 15); + (static_cast<uint16_t>(record_type) << kClassMsbShift); } constexpr uint16_t MakeQuestionClass(DnsClass dns_class, - bool unicast_response) { + ResponseType response_type) { return static_cast<uint16_t>(dns_class) | - (static_cast<uint16_t>(unicast_response) << 15); + (static_cast<uint16_t>(response_type) << kClassMsbShift); } // See RFC 6762, section 11: https://tools.ietf.org/html/rfc6762#section-11 diff --git a/chromium/third_party/openscreen/src/cast/common/mdns/mdns_reader.cc b/chromium/third_party/openscreen/src/cast/common/mdns/mdns_reader.cc index ae4443d894d..c91983b861e 100644 --- a/chromium/third_party/openscreen/src/cast/common/mdns/mdns_reader.cc +++ b/chromium/third_party/openscreen/src/cast/common/mdns/mdns_reader.cc @@ -202,7 +202,7 @@ bool MdnsReader::Read(MdnsRecord* out) { if (Read(&name) && Read(&type) && Read(&rrclass) && Read(&ttl) && Read(static_cast<DnsType>(type), &rdata)) { *out = MdnsRecord(std::move(name), static_cast<DnsType>(type), - GetDnsClass(rrclass), GetCacheFlush(rrclass), ttl, + GetDnsClass(rrclass), GetRecordType(rrclass), ttl, std::move(rdata)); cursor.Commit(); return true; @@ -218,7 +218,7 @@ bool MdnsReader::Read(MdnsQuestion* out) { uint16_t rrclass; if (Read(&name) && Read(&type) && Read(&rrclass)) { *out = MdnsQuestion(std::move(name), static_cast<DnsType>(type), - GetDnsClass(rrclass), GetUnicastResponse(rrclass)); + GetDnsClass(rrclass), GetResponseType(rrclass)); cursor.Commit(); return true; } @@ -237,8 +237,12 @@ bool MdnsReader::Read(MdnsMessage* out) { Read(header.answer_count, &answers) && Read(header.authority_record_count, &authority_records) && Read(header.additional_record_count, &additional_records)) { - *out = MdnsMessage(header.id, header.flags, questions, answers, - authority_records, additional_records); + // TODO(yakimakha): Skip messages with non-zero opcode and rcode. + // One way to do this is to change the method signature to return + // ErrorOr<MdnsMessage> and return different error codes for failure to read + // and for messages that were read successfully but are non-conforming. + *out = MdnsMessage(header.id, GetMessageType(header.flags), questions, + answers, authority_records, additional_records); cursor.Commit(); return true; } diff --git a/chromium/third_party/openscreen/src/cast/common/mdns/mdns_reader_unittest.cc b/chromium/third_party/openscreen/src/cast/common/mdns/mdns_reader_unittest.cc index 5153806107b..426d75e9f71 100644 --- a/chromium/third_party/openscreen/src/cast/common/mdns/mdns_reader_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/common/mdns/mdns_reader_unittest.cc @@ -375,10 +375,10 @@ TEST(MdnsReaderTest, ReadMdnsRecord_ARecordRdata) { 0x08, 0x08, 0x08, 0x08, // RDATA = 8.8.8.8 }; // clang-format on - TestReadEntrySucceeds( - kTestRecord, sizeof(kTestRecord), - MdnsRecord(DomainName{"testing", "local"}, DnsType::kA, DnsClass::kIN, - true, 120, ARecordRdata(IPAddress{8, 8, 8, 8}))); + TestReadEntrySucceeds(kTestRecord, sizeof(kTestRecord), + MdnsRecord(DomainName{"testing", "local"}, DnsType::kA, + DnsClass::kIN, RecordType::kUnique, 120, + ARecordRdata(IPAddress{8, 8, 8, 8}))); } TEST(MdnsReaderTest, ReadMdnsRecord_UnknownRecordType) { @@ -400,8 +400,9 @@ TEST(MdnsReaderTest, ReadMdnsRecord_UnknownRecordType) { TestReadEntrySucceeds( kTestRecord, sizeof(kTestRecord), MdnsRecord(DomainName{"testing", "local"}, - static_cast<DnsType>(5) /*CNAME class*/, DnsClass::kIN, true, - 120, RawRecordRdata(kCnameRdata, sizeof(kCnameRdata)))); + static_cast<DnsType>(5) /*CNAME class*/, DnsClass::kIN, + RecordType::kUnique, 120, + RawRecordRdata(kCnameRdata, sizeof(kCnameRdata)))); } TEST(MdnsReaderTest, ReadMdnsRecord_CompressedNames) { @@ -434,12 +435,12 @@ TEST(MdnsReaderTest, ReadMdnsRecord_CompressedNames) { EXPECT_TRUE(reader.Read(&record)); EXPECT_EQ(record, MdnsRecord(DomainName{"testing", "local"}, DnsType::kPTR, - DnsClass::kIN, false, 120, + DnsClass::kIN, RecordType::kShared, 120, PtrRecordRdata(DomainName{"ptr", "testing", "local"}))); EXPECT_TRUE(reader.Read(&record)); EXPECT_EQ(record, MdnsRecord(DomainName{"one", "two", "testing", "local"}, - DnsType::kA, DnsClass::kIN, true, 120, - ARecordRdata(IPAddress{8, 8, 8, 8}))); + DnsType::kA, DnsClass::kIN, RecordType::kUnique, + 120, ARecordRdata(IPAddress{8, 8, 8, 8}))); } TEST(MdnsReaderTest, ReadMdnsRecord_MissingRdata) { @@ -483,9 +484,10 @@ TEST(MdnsReaderTest, ReadMdnsQuestion) { 0x80, 0x01, // CLASS = IN (1) | UNICAST_BIT }; // clang-format on - TestReadEntrySucceeds(kTestQuestion, sizeof(kTestQuestion), - MdnsQuestion(DomainName{"testing", "local"}, - DnsType::kA, DnsClass::kIN, true)); + TestReadEntrySucceeds( + kTestQuestion, sizeof(kTestQuestion), + MdnsQuestion(DomainName{"testing", "local"}, DnsType::kA, DnsClass::kIN, + ResponseType::kUnicast)); } TEST(MdnsReaderTest, ReadMdnsQuestion_CompressedNames) { @@ -508,10 +510,10 @@ TEST(MdnsReaderTest, ReadMdnsQuestion_CompressedNames) { MdnsQuestion question; EXPECT_TRUE(reader.Read(&question)); EXPECT_EQ(question, MdnsQuestion(DomainName{"first", "local"}, DnsType::kA, - DnsClass::kIN, true)); + DnsClass::kIN, ResponseType::kUnicast)); EXPECT_TRUE(reader.Read(&question)); EXPECT_EQ(question, MdnsQuestion(DomainName{"second", "local"}, DnsType::kPTR, - DnsClass::kIN, false)); + DnsClass::kIN, ResponseType::kMulticast)); EXPECT_EQ(reader.remaining(), UINT64_C(0)); } @@ -559,13 +561,16 @@ TEST(MdnsReaderTest, ReadMdnsMessage) { }; // clang-format on - MdnsRecord record1(DomainName{"record1"}, DnsType::kPTR, DnsClass::kIN, false, - 120, PtrRecordRdata(DomainName{"testing", "local"})); - MdnsRecord record2(DomainName{"record2"}, DnsType::kA, DnsClass::kIN, false, - 120, ARecordRdata(IPAddress{172, 0, 0, 1})); - MdnsMessage message( - 1, 0x8400, std::vector<MdnsQuestion>{}, std::vector<MdnsRecord>{record1}, - std::vector<MdnsRecord>{}, std::vector<MdnsRecord>{record2}); + MdnsRecord record1(DomainName{"record1"}, DnsType::kPTR, DnsClass::kIN, + RecordType::kShared, 120, + PtrRecordRdata(DomainName{"testing", "local"})); + MdnsRecord record2(DomainName{"record2"}, DnsType::kA, DnsClass::kIN, + RecordType::kShared, 120, + ARecordRdata(IPAddress{172, 0, 0, 1})); + MdnsMessage message(1, MessageType::Response, std::vector<MdnsQuestion>{}, + std::vector<MdnsRecord>{record1}, + std::vector<MdnsRecord>{}, + std::vector<MdnsRecord>{record2}); TestReadEntrySucceeds(kTestMessage, sizeof(kTestMessage), message); } diff --git a/chromium/third_party/openscreen/src/cast/common/mdns/mdns_receiver.cc b/chromium/third_party/openscreen/src/cast/common/mdns/mdns_receiver.cc new file mode 100644 index 00000000000..a097838b1f9 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/mdns/mdns_receiver.cc @@ -0,0 +1,65 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/common/mdns/mdns_receiver.h" + +#include "cast/common/mdns/mdns_reader.h" +#include "platform/api/trace_logging.h" + +namespace cast { +namespace mdns { + +MdnsReceiver::MdnsReceiver(UdpSocket* socket, + NetworkRunner* network_runner, + Delegate* delegate) + : socket_(socket), network_runner_(network_runner), delegate_(delegate) { + OSP_DCHECK(socket_); + OSP_DCHECK(network_runner_); + OSP_DCHECK(delegate_); +} + +MdnsReceiver::~MdnsReceiver() { + if (state_ == State::kRunning) { + Stop(); + } +} + +Error MdnsReceiver::Start() { + if (state_ == State::kRunning) { + return Error::Code::kNone; + } + Error result = network_runner_->ReadRepeatedly(socket_, this); + if (result.ok()) { + state_ = State::kRunning; + } + return result; +} + +Error MdnsReceiver::Stop() { + if (state_ == State::kStopped) { + return Error::Code::kNone; + } + Error result = network_runner_->CancelRead(socket_); + if (result.ok()) { + state_ = State::kStopped; + } + return result; +} + +void MdnsReceiver::OnRead(UdpPacket packet, NetworkRunner* network_runner) { + TRACE_SCOPED(TraceCategory::mDNS, "MdnsReceiver::OnRead"); + MdnsReader reader(packet.data(), packet.size()); + MdnsMessage message; + if (!reader.Read(&message)) { + return; + } + if (message.type() == MessageType::Response) { + delegate_->OnResponseReceived(message, packet.source()); + } else { + delegate_->OnQueryReceived(message, packet.source()); + } +} + +} // namespace mdns +} // namespace cast diff --git a/chromium/third_party/openscreen/src/cast/common/mdns/mdns_receiver.h b/chromium/third_party/openscreen/src/cast/common/mdns/mdns_receiver.h new file mode 100644 index 00000000000..62aff212c59 --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/mdns/mdns_receiver.h @@ -0,0 +1,77 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_COMMON_MDNS_MDNS_RECEIVER_H_ +#define CAST_COMMON_MDNS_MDNS_RECEIVER_H_ + +#include "cast/common/mdns/mdns_records.h" +#include "platform/api/network_runner.h" +#include "platform/api/udp_packet.h" +#include "platform/api/udp_read_callback.h" +#include "platform/api/udp_socket.h" +#include "platform/base/error.h" +#include "platform/base/ip_address.h" + +namespace cast { +namespace mdns { + +using Error = openscreen::Error; +using NetworkRunner = openscreen::platform::NetworkRunner; +using UdpSocket = openscreen::platform::UdpSocket; +using UdpReadCallback = openscreen::platform::UdpReadCallback; +using UdpPacket = openscreen::platform::UdpPacket; +using IPEndpoint = openscreen::IPEndpoint; + +class MdnsReceiver : UdpReadCallback { + public: + class Delegate { + public: + virtual ~Delegate() = default; + virtual void OnQueryReceived(const MdnsMessage& message, + const IPEndpoint& sender) = 0; + virtual void OnResponseReceived(const MdnsMessage& message, + const IPEndpoint& sender) = 0; + }; + + // MdnsReceiver does not own |socket|, |network_runner| and |delegate| + // and expects that the lifetime of these objects exceeds the lifetime of + // MdnsReceiver. + MdnsReceiver(UdpSocket* socket, + NetworkRunner* network_runner, + Delegate* delegate); + MdnsReceiver(const MdnsReceiver& other) = delete; + MdnsReceiver(MdnsReceiver&& other) noexcept = delete; + ~MdnsReceiver() override; + + MdnsReceiver& operator=(const MdnsReceiver& other) = delete; + MdnsReceiver& operator=(MdnsReceiver&& other) noexcept = delete; + + // The receiver can be started and stopped multiple times. + // Start and Stop return Error::Code::kNone on success and return an error on + // failure. Start returns Error::Code::kNone when called on a receiver that + // has already been started. Stop returns Error::Code::kNone when called on a + // receiver that has already been stopped or not yet started. Start and Stop + // are both synchronous calls. After MdnsReceiver has been started it will + // receive OnRead callbacks from the network runner. + Error Start(); + Error Stop(); + + void OnRead(UdpPacket packet, NetworkRunner* network_runner) override; + + private: + enum class State { + kStopped, + kRunning, + }; + + UdpSocket* const socket_; + NetworkRunner* const network_runner_; + Delegate* const delegate_; + State state_ = State::kStopped; +}; + +} // namespace mdns +} // namespace cast + +#endif // CAST_COMMON_MDNS_MDNS_RECEIVER_H_ diff --git a/chromium/third_party/openscreen/src/cast/common/mdns/mdns_receiver_unittest.cc b/chromium/third_party/openscreen/src/cast/common/mdns/mdns_receiver_unittest.cc new file mode 100644 index 00000000000..ce796ddaedf --- /dev/null +++ b/chromium/third_party/openscreen/src/cast/common/mdns/mdns_receiver_unittest.cc @@ -0,0 +1,148 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/common/mdns/mdns_receiver.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "platform/api/time.h" +#include "platform/test/mock_udp_socket.h" + +namespace cast { +namespace mdns { + +using ::testing::_; +using ::testing::Return; +using MockUdpSocket = openscreen::platform::MockUdpSocket; + +// TODO(yakimakha): Update tests to use a fake NetworkRunner when implemented +class MockNetworkRunner : public NetworkRunner { + public: + MOCK_METHOD2(ReadRepeatedly, Error(UdpSocket*, UdpReadCallback*)); + MOCK_METHOD1(CancelRead, Error(UdpSocket*)); + + void PostPackagedTask(Task task) override {} + void PostPackagedTaskWithDelay( + Task task, + openscreen::platform::Clock::duration delay) override {} +}; + +class MockMdnsReceiverDelegate : public MdnsReceiver::Delegate { + public: + MOCK_METHOD2(OnQueryReceived, void(const MdnsMessage&, const IPEndpoint&)); + MOCK_METHOD2(OnResponseReceived, void(const MdnsMessage&, const IPEndpoint&)); +}; + +TEST(MdnsReceiverTest, ReceiveQuery) { + // clang-format off + const std::vector<uint8_t> kQueryBytes = { + 0x00, 0x01, // ID = 1 + 0x00, 0x00, // FLAGS = None + 0x00, 0x01, // Question count + 0x00, 0x00, // Answer count + 0x00, 0x00, // Authority count + 0x00, 0x00, // Additional count + // Question + 0x07, 't', 'e', 's', 't', 'i', 'n', 'g', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + 0x00, 0x01, // TYPE = A (1) + 0x00, 0x01, // CLASS = IN (1) + }; + // clang-format on + + std::unique_ptr<openscreen::platform::MockUdpSocket> socket_info = + MockUdpSocket::CreateDefault(openscreen::IPAddress::Version::kV4); + MockNetworkRunner runner; + MockMdnsReceiverDelegate delegate; + MdnsReceiver receiver(socket_info.get(), &runner, &delegate); + + EXPECT_CALL(runner, ReadRepeatedly(socket_info.get(), _)) + .WillOnce(Return(Error::Code::kNone)); + receiver.Start(); + + MdnsQuestion question(DomainName{"testing", "local"}, DnsType::kA, + DnsClass::kIN, ResponseType::kMulticast); + MdnsMessage message(1, MessageType::Query); + message.AddQuestion(question); + + UdpPacket packet(kQueryBytes.size()); + packet.assign(kQueryBytes.data(), kQueryBytes.data() + kQueryBytes.size()); + packet.set_source( + IPEndpoint{.address = IPAddress(192, 168, 1, 1), .port = 31337}); + packet.set_destination( + IPEndpoint{.address = IPAddress(kDefaultMulticastGroupIPv4), + .port = kDefaultMulticastPort}); + + // Imitate a call to OnRead from NetworkRunner by calling it manually here + EXPECT_CALL(delegate, OnQueryReceived(message, packet.source())).Times(1); + receiver.OnRead(std::move(packet), &runner); + + EXPECT_CALL(runner, CancelRead(socket_info.get())) + .WillOnce(Return(Error::Code::kNone)); + receiver.Stop(); +} + +TEST(MdnsReceiverTest, ReceiveResponse) { + // clang-format off + const std::vector<uint8_t> kResponseBytes = { + 0x00, 0x01, // ID = 1 + 0x84, 0x00, // FLAGS = AA | RESPONSE + 0x00, 0x00, // Question count + 0x00, 0x01, // Answer count + 0x00, 0x00, // Authority count + 0x00, 0x00, // Additional count + // Answer + 0x07, 't', 'e', 's', 't', 'i', 'n', 'g', + 0x05, 'l', 'o', 'c', 'a', 'l', + 0x00, + 0x00, 0x01, // TYPE = A (1) + 0x00, 0x01, // CLASS = IN (1) + 0x00, 0x00, 0x00, 0x78, // TTL = 120 seconds + 0x00, 0x04, // RDLENGTH = 4 bytes + 0xac, 0x00, 0x00, 0x01, // 172.0.0.1 + }; + + constexpr uint8_t kIPv6AddressBytes[] = { + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x02, 0x02, 0xb3, 0xff, 0xfe, 0x1e, 0x83, 0x29, + }; + // clang-format on + + std::unique_ptr<openscreen::platform::MockUdpSocket> socket_info = + MockUdpSocket::CreateDefault(openscreen::IPAddress::Version::kV6); + MockNetworkRunner runner; + MockMdnsReceiverDelegate delegate; + MdnsReceiver receiver(socket_info.get(), &runner, &delegate); + + EXPECT_CALL(runner, ReadRepeatedly(socket_info.get(), _)) + .WillOnce(Return(Error::Code::kNone)); + receiver.Start(); + + MdnsRecord record(DomainName{"testing", "local"}, DnsType::kA, DnsClass::kIN, + RecordType::kShared, 120, + ARecordRdata(IPAddress{172, 0, 0, 1})); + MdnsMessage message(1, MessageType::Response); + message.AddAnswer(record); + + UdpPacket packet(kResponseBytes.size()); + packet.assign(kResponseBytes.data(), + kResponseBytes.data() + kResponseBytes.size()); + packet.set_source( + IPEndpoint{.address = IPAddress(kIPv6AddressBytes), .port = 31337}); + packet.set_destination( + IPEndpoint{.address = IPAddress(kDefaultMulticastGroupIPv6), + .port = kDefaultMulticastPort}); + + // Imitate a call to OnRead from NetworkRunner by calling it manually here + EXPECT_CALL(delegate, OnResponseReceived(message, packet.source())).Times(1); + receiver.OnRead(std::move(packet), &runner); + + EXPECT_CALL(runner, CancelRead(socket_info.get())) + .WillOnce(Return(Error::Code::kNone)); + receiver.Stop(); +} + +} // namespace mdns +} // namespace cast diff --git a/chromium/third_party/openscreen/src/cast/common/mdns/mdns_records.cc b/chromium/third_party/openscreen/src/cast/common/mdns/mdns_records.cc index 3c26df172c4..cb0dfb5850d 100644 --- a/chromium/third_party/openscreen/src/cast/common/mdns/mdns_records.cc +++ b/chromium/third_party/openscreen/src/cast/common/mdns/mdns_records.cc @@ -21,21 +21,13 @@ bool IsValidDomainLabel(absl::string_view label) { return label_size > 0 && label_size <= kMaxLabelLength; } -const std::string& DomainName::Label(size_t label_index) const { - OSP_DCHECK(label_index < labels_.size()); - return labels_[label_index]; -} - std::string DomainName::ToString() const { return absl::StrJoin(labels_, "."); } bool DomainName::operator==(const DomainName& rhs) const { - auto predicate = [](const std::string& left, const std::string& right) { - return absl::EqualsIgnoreCase(left, right); - }; return std::equal(labels_.begin(), labels_.end(), rhs.labels_.begin(), - rhs.labels_.end(), predicate); + rhs.labels_.end(), absl::EqualsIgnoreCase); } bool DomainName::operator!=(const DomainName& rhs) const { @@ -167,34 +159,34 @@ size_t TxtRecordRdata::MaxWireSize() const { } MdnsRecord::MdnsRecord(DomainName name, - DnsType type, + DnsType dns_type, DnsClass record_class, - bool cache_flush, + RecordType record_type, uint32_t ttl, Rdata rdata) : name_(std::move(name)), - type_(type), + dns_type_(dns_type), record_class_(record_class), - cache_flush_(cache_flush), + record_type_(record_type), ttl_(ttl), rdata_(std::move(rdata)) { OSP_DCHECK(!name_.empty()); - OSP_DCHECK( - (type == DnsType::kSRV && - absl::holds_alternative<SrvRecordRdata>(rdata_)) || - (type == DnsType::kA && absl::holds_alternative<ARecordRdata>(rdata_)) || - (type == DnsType::kAAAA && - absl::holds_alternative<AAAARecordRdata>(rdata_)) || - (type == DnsType::kPTR && - absl::holds_alternative<PtrRecordRdata>(rdata_)) || - (type == DnsType::kTXT && - absl::holds_alternative<TxtRecordRdata>(rdata_)) || - absl::holds_alternative<RawRecordRdata>(rdata_)); + OSP_DCHECK((dns_type == DnsType::kSRV && + absl::holds_alternative<SrvRecordRdata>(rdata_)) || + (dns_type == DnsType::kA && + absl::holds_alternative<ARecordRdata>(rdata_)) || + (dns_type == DnsType::kAAAA && + absl::holds_alternative<AAAARecordRdata>(rdata_)) || + (dns_type == DnsType::kPTR && + absl::holds_alternative<PtrRecordRdata>(rdata_)) || + (dns_type == DnsType::kTXT && + absl::holds_alternative<TxtRecordRdata>(rdata_)) || + absl::holds_alternative<RawRecordRdata>(rdata_)); } bool MdnsRecord::operator==(const MdnsRecord& rhs) const { - return type_ == rhs.type_ && record_class_ == rhs.record_class_ && - cache_flush_ == rhs.cache_flush_ && ttl_ == rhs.ttl_ && + return dns_type_ == rhs.dns_type_ && record_class_ == rhs.record_class_ && + record_type_ == rhs.record_type_ && ttl_ == rhs.ttl_ && name_ == rhs.name_ && rdata_ == rhs.rdata_; } @@ -204,24 +196,24 @@ bool MdnsRecord::operator!=(const MdnsRecord& rhs) const { size_t MdnsRecord::MaxWireSize() const { auto wire_size_visitor = [](auto&& arg) { return arg.MaxWireSize(); }; - return name_.MaxWireSize() + sizeof(type_) + sizeof(record_class_) + + return name_.MaxWireSize() + sizeof(dns_type_) + sizeof(record_class_) + sizeof(ttl_) + absl::visit(wire_size_visitor, rdata_); } MdnsQuestion::MdnsQuestion(DomainName name, - DnsType type, + DnsType dns_type, DnsClass record_class, - bool unicast_response) + ResponseType response_type) : name_(std::move(name)), - type_(type), + dns_type_(dns_type), record_class_(record_class), - unicast_response_(unicast_response) { + response_type_(response_type) { OSP_CHECK(!name_.empty()); } bool MdnsQuestion::operator==(const MdnsQuestion& rhs) const { - return type_ == rhs.type_ && record_class_ == rhs.record_class_ && - unicast_response_ == rhs.unicast_response_ && name_ == rhs.name_; + return dns_type_ == rhs.dns_type_ && record_class_ == rhs.record_class_ && + response_type_ == rhs.response_type_ && name_ == rhs.name_; } bool MdnsQuestion::operator!=(const MdnsQuestion& rhs) const { @@ -229,20 +221,20 @@ bool MdnsQuestion::operator!=(const MdnsQuestion& rhs) const { } size_t MdnsQuestion::MaxWireSize() const { - return name_.MaxWireSize() + sizeof(type_) + sizeof(record_class_); + return name_.MaxWireSize() + sizeof(dns_type_) + sizeof(record_class_); } -MdnsMessage::MdnsMessage(uint16_t id, uint16_t flags) - : id_(id), flags_(flags) {} +MdnsMessage::MdnsMessage(uint16_t id, MessageType type) + : id_(id), type_(type) {} MdnsMessage::MdnsMessage(uint16_t id, - uint16_t flags, + MessageType type, std::vector<MdnsQuestion> questions, std::vector<MdnsRecord> answers, std::vector<MdnsRecord> authority_records, std::vector<MdnsRecord> additional_records) : id_(id), - flags_(flags), + type_(type), questions_(std::move(questions)), answers_(std::move(answers)), authority_records_(std::move(authority_records)), @@ -267,8 +259,8 @@ MdnsMessage::MdnsMessage(uint16_t id, } bool MdnsMessage::operator==(const MdnsMessage& rhs) const { - return id_ == rhs.id_ && flags_ == rhs.flags_ && - questions_ == rhs.questions_ && answers_ == rhs.answers_ && + return id_ == rhs.id_ && type_ == rhs.type_ && questions_ == rhs.questions_ && + answers_ == rhs.answers_ && authority_records_ == rhs.authority_records_ && additional_records_ == rhs.additional_records_; } diff --git a/chromium/third_party/openscreen/src/cast/common/mdns/mdns_records.h b/chromium/third_party/openscreen/src/cast/common/mdns/mdns_records.h index adbece2b4eb..f5d4568e5d0 100644 --- a/chromium/third_party/openscreen/src/cast/common/mdns/mdns_records.h +++ b/chromium/third_party/openscreen/src/cast/common/mdns/mdns_records.h @@ -52,9 +52,6 @@ class DomainName { bool operator==(const DomainName& rhs) const; bool operator!=(const DomainName& rhs) const; - // Returns a reference to the label at specified label_index. No bounds - // checking is performed. - const std::string& Label(size_t label_index) const; std::string ToString() const; // Returns the maximum space that the domain name could take up in its @@ -63,7 +60,12 @@ class DomainName { // compression the actual space taken in on-the-wire format is smaller. size_t MaxWireSize() const; bool empty() const { return labels_.empty(); } - size_t label_count() const { return labels_.size(); } + const std::vector<std::string>& labels() const { return labels_; } + + template <typename H> + friend H AbslHashValue(H h, const DomainName& domain_name) { + return H::combine(std::move(h), domain_name.labels_); + } private: // max_wire_size_ starts at 1 for the terminating character length. @@ -95,6 +97,11 @@ class RawRecordRdata { uint16_t size() const { return rdata_.size(); } const uint8_t* data() const { return rdata_.data(); } + template <typename H> + friend H AbslHashValue(H h, const RawRecordRdata& rdata) { + return H::combine(std::move(h), rdata.rdata_); + } + private: std::vector<uint8_t> rdata_; }; @@ -127,6 +134,12 @@ class SrvRecordRdata { uint16_t port() const { return port_; } const DomainName& target() const { return target_; } + template <typename H> + friend H AbslHashValue(H h, const SrvRecordRdata& rdata) { + return H::combine(std::move(h), rdata.priority_, rdata.weight_, rdata.port_, + rdata.target_); + } + private: uint16_t priority_ = 0; uint16_t weight_ = 0; @@ -153,6 +166,11 @@ class ARecordRdata { size_t MaxWireSize() const; const IPAddress& ipv4_address() const { return ipv4_address_; } + template <typename H> + friend H AbslHashValue(H h, const ARecordRdata& rdata) { + return H::combine(std::move(h), rdata.ipv4_address_.bytes()); + } + private: IPAddress ipv4_address_{0, 0, 0, 0}; }; @@ -176,6 +194,11 @@ class AAAARecordRdata { size_t MaxWireSize() const; const IPAddress& ipv6_address() const { return ipv6_address_; } + template <typename H> + friend H AbslHashValue(H h, const AAAARecordRdata& rdata) { + return H::combine(std::move(h), rdata.ipv6_address_.bytes()); + } + private: IPAddress ipv6_address_{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; }; @@ -199,6 +222,11 @@ class PtrRecordRdata { size_t MaxWireSize() const; const DomainName& ptr_domain() const { return ptr_domain_; } + template <typename H> + friend H AbslHashValue(H h, const PtrRecordRdata& rdata) { + return H::combine(std::move(h), rdata.ptr_domain_); + } + private: DomainName ptr_domain_; }; @@ -241,6 +269,11 @@ class TxtRecordRdata { size_t MaxWireSize() const; const std::vector<std::string>& texts() const { return texts_; } + template <typename H> + friend H AbslHashValue(H h, const TxtRecordRdata& rdata) { + return H::combine(std::move(h), rdata.texts_); + } + private: // max_wire_size_ is at least 3, uint16_t record length and at the // minimum a NULL byte character string is present. @@ -266,9 +299,9 @@ class MdnsRecord { public: MdnsRecord() = default; MdnsRecord(DomainName name, - DnsType type, + DnsType dns_type, DnsClass record_class, - bool cache_flush, + RecordType record_type, uint32_t ttl, Rdata rdata); MdnsRecord(const MdnsRecord& other) = default; @@ -283,17 +316,24 @@ class MdnsRecord { size_t MaxWireSize() const; const DomainName& name() const { return name_; } - DnsType type() const { return type_; } + DnsType dns_type() const { return dns_type_; } DnsClass record_class() const { return record_class_; } - bool cache_flush() const { return cache_flush_; } + RecordType record_type() const { return record_type_; } uint32_t ttl() const { return ttl_; } const Rdata& rdata() const { return rdata_; } + template <typename H> + friend H AbslHashValue(H h, const MdnsRecord& record) { + return H::combine(std::move(h), record.name_, record.dns_type_, + record.record_class_, record.record_type_, record.ttl_, + record.rdata_); + } + private: DomainName name_; - DnsType type_ = static_cast<DnsType>(0); + DnsType dns_type_ = static_cast<DnsType>(0); DnsClass record_class_ = static_cast<DnsClass>(0); - bool cache_flush_ = false; + RecordType record_type_ = RecordType::kShared; uint32_t ttl_ = kDefaultRecordTTL; // Default-constructed Rdata contains default-constructed RawRecordRdata // as it is the first alternative type and it is default-constructible. @@ -308,9 +348,9 @@ class MdnsQuestion { public: MdnsQuestion() = default; MdnsQuestion(DomainName name, - DnsType type, + DnsType dns_type, DnsClass record_class, - bool unicast_response); + ResponseType response_type); MdnsQuestion(const MdnsQuestion& other) = default; MdnsQuestion(MdnsQuestion&& other) noexcept = default; ~MdnsQuestion() = default; @@ -323,17 +363,23 @@ class MdnsQuestion { size_t MaxWireSize() const; const DomainName& name() const { return name_; } - DnsType type() const { return type_; } + DnsType dns_type() const { return dns_type_; } DnsClass record_class() const { return record_class_; } - bool unicast_response() const { return unicast_response_; } + ResponseType response_type() const { return response_type_; } + + template <typename H> + friend H AbslHashValue(H h, const MdnsQuestion& record) { + return H::combine(std::move(h), record.name_, record.dns_type_, + record.record_class_, record.response_type_); + } private: void CopyFrom(const MdnsQuestion& other); DomainName name_; - DnsType type_ = static_cast<DnsType>(0); + DnsType dns_type_ = static_cast<DnsType>(0); DnsClass record_class_ = static_cast<DnsClass>(0); - bool unicast_response_ = false; + ResponseType response_type_ = ResponseType::kMulticast; }; // Message top level format (http://www.ietf.org/rfc/rfc1035.txt): @@ -351,9 +397,9 @@ class MdnsMessage { MdnsMessage() = default; // Constructs a message with ID, flags and empty question, answer, authority // and additional record collections. - MdnsMessage(uint16_t id, uint16_t flags); + MdnsMessage(uint16_t id, MessageType type); MdnsMessage(uint16_t id, - uint16_t flags, + MessageType type, std::vector<MdnsQuestion> questions, std::vector<MdnsRecord> answers, std::vector<MdnsRecord> authority_records, @@ -375,7 +421,7 @@ class MdnsMessage { size_t MaxWireSize() const; uint16_t id() const { return id_; } - uint16_t flags() const { return flags_; } + MessageType type() const { return type_; } const std::vector<MdnsQuestion>& questions() const { return questions_; } const std::vector<MdnsRecord>& answers() const { return answers_; } const std::vector<MdnsRecord>& authority_records() const { @@ -385,11 +431,18 @@ class MdnsMessage { return additional_records_; } + template <typename H> + friend H AbslHashValue(H h, const MdnsMessage& message) { + return H::combine(std::move(h), message.id_, message.type_, + message.questions_, message.answers_, + message.authority_records_, message.additional_records_); + } + private: // The mDNS header is 12 bytes long size_t max_wire_size_ = sizeof(Header); uint16_t id_ = 0; - uint16_t flags_ = 0; + MessageType type_ = MessageType::Query; std::vector<MdnsQuestion> questions_; std::vector<MdnsRecord> answers_; std::vector<MdnsRecord> authority_records_; diff --git a/chromium/third_party/openscreen/src/cast/common/mdns/mdns_records_unittest.cc b/chromium/third_party/openscreen/src/cast/common/mdns/mdns_records_unittest.cc index c82a81e8649..cd4f25d6259 100644 --- a/chromium/third_party/openscreen/src/cast/common/mdns/mdns_records_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/common/mdns/mdns_records_unittest.cc @@ -33,25 +33,25 @@ TEST(MdnsDomainNameTest, Construct) { DomainName name1; EXPECT_TRUE(name1.empty()); EXPECT_EQ(name1.MaxWireSize(), UINT64_C(1)); - EXPECT_EQ(name1.label_count(), UINT64_C(0)); + EXPECT_EQ(name1.labels().size(), UINT64_C(0)); DomainName name2{"MyDevice", "_mYSERvice", "local"}; EXPECT_FALSE(name2.empty()); EXPECT_EQ(name2.MaxWireSize(), UINT64_C(27)); - ASSERT_EQ(name2.label_count(), UINT64_C(3)); - EXPECT_EQ(name2.Label(0), "MyDevice"); - EXPECT_EQ(name2.Label(1), "_mYSERvice"); - EXPECT_EQ(name2.Label(2), "local"); + ASSERT_EQ(name2.labels().size(), UINT64_C(3)); + EXPECT_EQ(name2.labels()[0], "MyDevice"); + EXPECT_EQ(name2.labels()[1], "_mYSERvice"); + EXPECT_EQ(name2.labels()[2], "local"); EXPECT_EQ(name2.ToString(), "MyDevice._mYSERvice.local"); std::vector<absl::string_view> labels{"OtherDevice", "_MYservice", "LOcal"}; DomainName name3(labels); EXPECT_FALSE(name3.empty()); EXPECT_EQ(name3.MaxWireSize(), UINT64_C(30)); - ASSERT_EQ(name3.label_count(), UINT64_C(3)); - EXPECT_EQ(name3.Label(0), "OtherDevice"); - EXPECT_EQ(name3.Label(1), "_MYservice"); - EXPECT_EQ(name3.Label(2), "LOcal"); + ASSERT_EQ(name3.labels().size(), UINT64_C(3)); + EXPECT_EQ(name3.labels()[0], "OtherDevice"); + EXPECT_EQ(name3.labels()[1], "_MYservice"); + EXPECT_EQ(name3.labels()[2], "LOcal"); EXPECT_EQ(name3.ToString(), "OtherDevice._MYservice.LOcal"); } @@ -283,48 +283,48 @@ TEST(MdnsTxtRecordRdataTest, CopyAndMove) { TEST(MdnsRecordTest, Construct) { MdnsRecord record1; - EXPECT_EQ(record1.MaxWireSize(), 11); + EXPECT_EQ(record1.MaxWireSize(), UINT64_C(11)); EXPECT_EQ(record1.name(), DomainName()); - EXPECT_EQ(record1.type(), static_cast<DnsType>(0)); + EXPECT_EQ(record1.dns_type(), static_cast<DnsType>(0)); EXPECT_EQ(record1.record_class(), static_cast<DnsClass>(0)); - EXPECT_EQ(record1.cache_flush(), false); + EXPECT_EQ(record1.record_type(), RecordType::kShared); EXPECT_EQ(record1.ttl(), UINT32_C(255)); // 255 is kDefaultRecordTTL EXPECT_EQ(record1.rdata(), Rdata(RawRecordRdata())); MdnsRecord record2(DomainName{"hostname", "local"}, DnsType::kPTR, - DnsClass::kIN, true, 120, + DnsClass::kIN, RecordType::kUnique, 120, PtrRecordRdata(DomainName{"testing", "local"})); EXPECT_EQ(record2.MaxWireSize(), UINT64_C(41)); EXPECT_EQ(record2.name(), (DomainName{"hostname", "local"})); - EXPECT_EQ(record2.type(), DnsType::kPTR); + EXPECT_EQ(record2.dns_type(), DnsType::kPTR); EXPECT_EQ(record2.record_class(), DnsClass::kIN); - EXPECT_EQ(record2.cache_flush(), true); - EXPECT_EQ(record2.ttl(), 120); + EXPECT_EQ(record2.record_type(), RecordType::kUnique); + EXPECT_EQ(record2.ttl(), UINT32_C(120)); EXPECT_EQ(record2.rdata(), Rdata(PtrRecordRdata(DomainName{"testing", "local"}))); } TEST(MdnsRecordTest, Compare) { MdnsRecord record1(DomainName{"hostname", "local"}, DnsType::kPTR, - DnsClass::kIN, false, 120, + DnsClass::kIN, RecordType::kShared, 120, PtrRecordRdata(DomainName{"testing", "local"})); MdnsRecord record2(DomainName{"hostname", "local"}, DnsType::kPTR, - DnsClass::kIN, false, 120, + DnsClass::kIN, RecordType::kShared, 120, PtrRecordRdata(DomainName{"testing", "local"})); MdnsRecord record3(DomainName{"othername", "local"}, DnsType::kPTR, - DnsClass::kIN, false, 120, + DnsClass::kIN, RecordType::kShared, 120, PtrRecordRdata(DomainName{"testing", "local"})); MdnsRecord record4(DomainName{"hostname", "local"}, DnsType::kA, - DnsClass::kIN, false, 120, + DnsClass::kIN, RecordType::kShared, 120, ARecordRdata(IPAddress{8, 8, 8, 8})); MdnsRecord record5(DomainName{"hostname", "local"}, DnsType::kPTR, - DnsClass::kIN, true, 120, + DnsClass::kIN, RecordType::kUnique, 120, PtrRecordRdata(DomainName{"testing", "local"})); MdnsRecord record6(DomainName{"hostname", "local"}, DnsType::kPTR, - DnsClass::kIN, false, 200, + DnsClass::kIN, RecordType::kShared, 200, PtrRecordRdata(DomainName{"testing", "local"})); MdnsRecord record7(DomainName{"hostname", "local"}, DnsType::kPTR, - DnsClass::kIN, false, 120, + DnsClass::kIN, RecordType::kShared, 120, PtrRecordRdata(DomainName{"device", "local"})); EXPECT_EQ(record1, record2); @@ -337,39 +337,39 @@ TEST(MdnsRecordTest, Compare) { TEST(MdnsRecordTest, CopyAndMove) { MdnsRecord record(DomainName{"hostname", "local"}, DnsType::kPTR, - DnsClass::kIN, true, 120, + DnsClass::kIN, RecordType::kUnique, 120, PtrRecordRdata(DomainName{"testing", "local"})); TestCopyAndMove(record); } TEST(MdnsQuestionTest, Construct) { MdnsQuestion question1; - EXPECT_EQ(question1.MaxWireSize(), 5); + EXPECT_EQ(question1.MaxWireSize(), UINT64_C(5)); EXPECT_EQ(question1.name(), DomainName()); - EXPECT_EQ(question1.type(), static_cast<DnsType>(0)); + EXPECT_EQ(question1.dns_type(), static_cast<DnsType>(0)); EXPECT_EQ(question1.record_class(), static_cast<DnsClass>(0)); - EXPECT_EQ(question1.unicast_response(), false); + EXPECT_EQ(question1.response_type(), ResponseType::kMulticast); MdnsQuestion question2(DomainName{"testing", "local"}, DnsType::kPTR, - DnsClass::kIN, true); + DnsClass::kIN, ResponseType::kUnicast); EXPECT_EQ(question2.MaxWireSize(), UINT64_C(19)); EXPECT_EQ(question2.name(), (DomainName{"testing", "local"})); - EXPECT_EQ(question2.type(), DnsType::kPTR); + EXPECT_EQ(question2.dns_type(), DnsType::kPTR); EXPECT_EQ(question2.record_class(), DnsClass::kIN); - EXPECT_EQ(question2.unicast_response(), true); + EXPECT_EQ(question2.response_type(), ResponseType::kUnicast); } TEST(MdnsQuestionTest, Compare) { MdnsQuestion question1(DomainName{"testing", "local"}, DnsType::kPTR, - DnsClass::kIN, false); + DnsClass::kIN, ResponseType::kMulticast); MdnsQuestion question2(DomainName{"testing", "local"}, DnsType::kPTR, - DnsClass::kIN, false); + DnsClass::kIN, ResponseType::kMulticast); MdnsQuestion question3(DomainName{"hostname", "local"}, DnsType::kPTR, - DnsClass::kIN, false); + DnsClass::kIN, ResponseType::kMulticast); MdnsQuestion question4(DomainName{"testing", "local"}, DnsType::kA, - DnsClass::kIN, false); + DnsClass::kIN, ResponseType::kMulticast); MdnsQuestion question5(DomainName{"hostname", "local"}, DnsType::kPTR, - DnsClass::kIN, true); + DnsClass::kIN, ResponseType::kUnicast); EXPECT_EQ(question1, question2); EXPECT_NE(question1, question3); @@ -379,7 +379,7 @@ TEST(MdnsQuestionTest, Compare) { TEST(MdnsQuestionTest, CopyAndMove) { MdnsQuestion question(DomainName{"testing", "local"}, DnsType::kPTR, - DnsClass::kIN, true); + DnsClass::kIN, ResponseType::kUnicast); TestCopyAndMove(question); } @@ -387,25 +387,28 @@ TEST(MdnsMessageTest, Construct) { MdnsMessage message1; EXPECT_EQ(message1.MaxWireSize(), UINT64_C(12)); EXPECT_EQ(message1.id(), UINT16_C(0)); - EXPECT_EQ(message1.flags(), UINT16_C(0)); + EXPECT_EQ(message1.type(), MessageType::Query); EXPECT_EQ(message1.questions().size(), UINT64_C(0)); EXPECT_EQ(message1.answers().size(), UINT64_C(0)); EXPECT_EQ(message1.authority_records().size(), UINT64_C(0)); EXPECT_EQ(message1.additional_records().size(), UINT64_C(0)); MdnsQuestion question(DomainName{"testing", "local"}, DnsType::kPTR, - DnsClass::kIN, true); - MdnsRecord record1(DomainName{"record1"}, DnsType::kA, DnsClass::kIN, false, - 120, ARecordRdata(IPAddress{172, 0, 0, 1})); - MdnsRecord record2(DomainName{"record2"}, DnsType::kTXT, DnsClass::kIN, false, - 120, TxtRecordRdata{"foo=1", "bar=2"}); - MdnsRecord record3(DomainName{"record3"}, DnsType::kPTR, DnsClass::kIN, false, - 120, PtrRecordRdata(DomainName{"device", "local"})); - - MdnsMessage message2(123, 0x8400); + DnsClass::kIN, ResponseType::kUnicast); + MdnsRecord record1(DomainName{"record1"}, DnsType::kA, DnsClass::kIN, + RecordType::kShared, 120, + ARecordRdata(IPAddress{172, 0, 0, 1})); + MdnsRecord record2(DomainName{"record2"}, DnsType::kTXT, DnsClass::kIN, + RecordType::kShared, 120, + TxtRecordRdata{"foo=1", "bar=2"}); + MdnsRecord record3(DomainName{"record3"}, DnsType::kPTR, DnsClass::kIN, + RecordType::kShared, 120, + PtrRecordRdata(DomainName{"device", "local"})); + + MdnsMessage message2(123, MessageType::Response); EXPECT_EQ(message2.MaxWireSize(), UINT64_C(12)); EXPECT_EQ(message2.id(), UINT16_C(123)); - EXPECT_EQ(message2.flags(), UINT16_C(0x8400)); + EXPECT_EQ(message2.type(), MessageType::Response); EXPECT_EQ(message2.questions().size(), UINT64_C(0)); EXPECT_EQ(message2.answers().size(), UINT64_C(0)); EXPECT_EQ(message2.authority_records().size(), UINT64_C(0)); @@ -427,10 +430,10 @@ TEST(MdnsMessageTest, Construct) { EXPECT_EQ(message2.authority_records()[0], record2); EXPECT_EQ(message2.additional_records()[0], record3); - MdnsMessage message3(123, 0x8400, std::vector<MdnsQuestion>{question}, - std::vector<MdnsRecord>{record1}, - std::vector<MdnsRecord>{record2}, - std::vector<MdnsRecord>{record3}); + MdnsMessage message3( + 123, MessageType::Response, std::vector<MdnsQuestion>{question}, + std::vector<MdnsRecord>{record1}, std::vector<MdnsRecord>{record2}, + std::vector<MdnsRecord>{record3}); EXPECT_EQ(message3.MaxWireSize(), UINT64_C(118)); ASSERT_EQ(message3.questions().size(), UINT64_C(1)); @@ -446,46 +449,49 @@ TEST(MdnsMessageTest, Construct) { TEST(MdnsMessageTest, Compare) { MdnsQuestion question(DomainName{"testing", "local"}, DnsType::kPTR, - DnsClass::kIN, true); - MdnsRecord record1(DomainName{"record1"}, DnsType::kA, DnsClass::kIN, false, - 120, ARecordRdata(IPAddress{172, 0, 0, 1})); - MdnsRecord record2(DomainName{"record2"}, DnsType::kTXT, DnsClass::kIN, false, - 120, TxtRecordRdata{"foo=1", "bar=2"}); - MdnsRecord record3(DomainName{"record3"}, DnsType::kPTR, DnsClass::kIN, false, - 120, PtrRecordRdata(DomainName{"device", "local"})); - - MdnsMessage message1(123, 0x8400, std::vector<MdnsQuestion>{question}, - std::vector<MdnsRecord>{record1}, - std::vector<MdnsRecord>{record2}, - std::vector<MdnsRecord>{record3}); - MdnsMessage message2(123, 0x8400, std::vector<MdnsQuestion>{question}, - std::vector<MdnsRecord>{record1}, - std::vector<MdnsRecord>{record2}, - std::vector<MdnsRecord>{record3}); - MdnsMessage message3(456, 0x8400, std::vector<MdnsQuestion>{question}, - std::vector<MdnsRecord>{record1}, - std::vector<MdnsRecord>{record2}, - std::vector<MdnsRecord>{record3}); - MdnsMessage message4(123, 0x400, std::vector<MdnsQuestion>{question}, - std::vector<MdnsRecord>{record1}, - std::vector<MdnsRecord>{record2}, - std::vector<MdnsRecord>{record3}); - MdnsMessage message5(123, 0x8400, std::vector<MdnsQuestion>{}, + DnsClass::kIN, ResponseType::kUnicast); + MdnsRecord record1(DomainName{"record1"}, DnsType::kA, DnsClass::kIN, + RecordType::kShared, 120, + ARecordRdata(IPAddress{172, 0, 0, 1})); + MdnsRecord record2(DomainName{"record2"}, DnsType::kTXT, DnsClass::kIN, + RecordType::kShared, 120, + TxtRecordRdata{"foo=1", "bar=2"}); + MdnsRecord record3(DomainName{"record3"}, DnsType::kPTR, DnsClass::kIN, + RecordType::kShared, 120, + PtrRecordRdata(DomainName{"device", "local"})); + + MdnsMessage message1( + 123, MessageType::Response, std::vector<MdnsQuestion>{question}, + std::vector<MdnsRecord>{record1}, std::vector<MdnsRecord>{record2}, + std::vector<MdnsRecord>{record3}); + MdnsMessage message2( + 123, MessageType::Response, std::vector<MdnsQuestion>{question}, + std::vector<MdnsRecord>{record1}, std::vector<MdnsRecord>{record2}, + std::vector<MdnsRecord>{record3}); + MdnsMessage message3( + 456, MessageType::Response, std::vector<MdnsQuestion>{question}, + std::vector<MdnsRecord>{record1}, std::vector<MdnsRecord>{record2}, + std::vector<MdnsRecord>{record3}); + MdnsMessage message4( + 123, MessageType::Query, std::vector<MdnsQuestion>{question}, + std::vector<MdnsRecord>{record1}, std::vector<MdnsRecord>{record2}, + std::vector<MdnsRecord>{record3}); + MdnsMessage message5(123, MessageType::Response, std::vector<MdnsQuestion>{}, std::vector<MdnsRecord>{record1}, std::vector<MdnsRecord>{record2}, std::vector<MdnsRecord>{record3}); - MdnsMessage message6(123, 0x8400, std::vector<MdnsQuestion>{question}, - std::vector<MdnsRecord>{}, - std::vector<MdnsRecord>{record2}, - std::vector<MdnsRecord>{record3}); - MdnsMessage message7(123, 0x8400, std::vector<MdnsQuestion>{question}, - std::vector<MdnsRecord>{record1}, - std::vector<MdnsRecord>{}, - std::vector<MdnsRecord>{record3}); - MdnsMessage message8(123, 0x8400, std::vector<MdnsQuestion>{question}, - std::vector<MdnsRecord>{record1}, - std::vector<MdnsRecord>{record2}, - std::vector<MdnsRecord>{}); + MdnsMessage message6( + 123, MessageType::Response, std::vector<MdnsQuestion>{question}, + std::vector<MdnsRecord>{}, std::vector<MdnsRecord>{record2}, + std::vector<MdnsRecord>{record3}); + MdnsMessage message7( + 123, MessageType::Response, std::vector<MdnsQuestion>{question}, + std::vector<MdnsRecord>{record1}, std::vector<MdnsRecord>{}, + std::vector<MdnsRecord>{record3}); + MdnsMessage message8( + 123, MessageType::Response, std::vector<MdnsQuestion>{question}, + std::vector<MdnsRecord>{record1}, std::vector<MdnsRecord>{record2}, + std::vector<MdnsRecord>{}); EXPECT_EQ(message1, message2); EXPECT_NE(message1, message3); @@ -498,17 +504,20 @@ TEST(MdnsMessageTest, Compare) { TEST(MdnsMessageTest, CopyAndMove) { MdnsQuestion question(DomainName{"testing", "local"}, DnsType::kPTR, - DnsClass::kIN, true); - MdnsRecord record1(DomainName{"record1"}, DnsType::kA, DnsClass::kIN, false, - 120, ARecordRdata(IPAddress{172, 0, 0, 1})); - MdnsRecord record2(DomainName{"record2"}, DnsType::kTXT, DnsClass::kIN, false, - 120, TxtRecordRdata{"foo=1", "bar=2"}); - MdnsRecord record3(DomainName{"record3"}, DnsType::kPTR, DnsClass::kIN, false, - 120, PtrRecordRdata(DomainName{"device", "local"})); - MdnsMessage message(123, 0x8400, std::vector<MdnsQuestion>{question}, - std::vector<MdnsRecord>{record1}, - std::vector<MdnsRecord>{record2}, - std::vector<MdnsRecord>{record3}); + DnsClass::kIN, ResponseType::kUnicast); + MdnsRecord record1(DomainName{"record1"}, DnsType::kA, DnsClass::kIN, + RecordType::kShared, 120, + ARecordRdata(IPAddress{172, 0, 0, 1})); + MdnsRecord record2(DomainName{"record2"}, DnsType::kTXT, DnsClass::kIN, + RecordType::kShared, 120, + TxtRecordRdata{"foo=1", "bar=2"}); + MdnsRecord record3(DomainName{"record3"}, DnsType::kPTR, DnsClass::kIN, + RecordType::kShared, 120, + PtrRecordRdata(DomainName{"device", "local"})); + MdnsMessage message( + 123, MessageType::Response, std::vector<MdnsQuestion>{question}, + std::vector<MdnsRecord>{record1}, std::vector<MdnsRecord>{record2}, + std::vector<MdnsRecord>{record3}); TestCopyAndMove(message); } diff --git a/chromium/third_party/openscreen/src/cast/common/mdns/mdns_sender.h b/chromium/third_party/openscreen/src/cast/common/mdns/mdns_sender.h index 44f3e2f859a..143e596ff13 100644 --- a/chromium/third_party/openscreen/src/cast/common/mdns/mdns_sender.h +++ b/chromium/third_party/openscreen/src/cast/common/mdns/mdns_sender.h @@ -19,19 +19,21 @@ using IPEndpoint = openscreen::IPEndpoint; class MdnsSender { public: + // MdnsSender does not own |socket| and expects that its lifetime exceeds the + // lifetime of MdnsSender. explicit MdnsSender(UdpSocket* socket); MdnsSender(const MdnsSender& other) = delete; - MdnsSender(MdnsSender&& other) noexcept = default; + MdnsSender(MdnsSender&& other) noexcept = delete; ~MdnsSender() = default; MdnsSender& operator=(const MdnsSender& other) = delete; - MdnsSender& operator=(MdnsSender&& other) noexcept = default; + MdnsSender& operator=(MdnsSender&& other) noexcept = delete; Error SendMulticast(const MdnsMessage& message); Error SendUnicast(const MdnsMessage& message, const IPEndpoint& endpoint); private: - UdpSocket* socket_; + UdpSocket* const socket_; }; } // namespace mdns diff --git a/chromium/third_party/openscreen/src/cast/common/mdns/mdns_sender_unittest.cc b/chromium/third_party/openscreen/src/cast/common/mdns/mdns_sender_unittest.cc index bdb85627336..821ac1f1f49 100644 --- a/chromium/third_party/openscreen/src/cast/common/mdns/mdns_sender_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/common/mdns/mdns_sender_unittest.cc @@ -39,30 +39,30 @@ class MdnsSenderTest : public ::testing::Test { : a_question_(DomainName{"testing", "local"}, DnsType::kA, DnsClass::kIN, - false), + ResponseType::kMulticast), a_record_(DomainName{"testing", "local"}, DnsType::kA, DnsClass::kIN, - false, + RecordType::kShared, 120, ARecordRdata(IPAddress{172, 0, 0, 1})), - question_message_(1, 0x0400), - answer_message_(1, 0x0400), + query_message_(1, MessageType::Query), + response_message_(1, MessageType::Response), ipv4_multicast_endpoint_{ .address = IPAddress(kDefaultMulticastGroupIPv4), .port = kDefaultMulticastPort}, ipv6_multicast_endpoint_{ .address = IPAddress(kDefaultMulticastGroupIPv6), .port = kDefaultMulticastPort} { - question_message_.AddQuestion(a_question_); - answer_message_.AddAnswer(a_record_); + query_message_.AddQuestion(a_question_); + response_message_.AddAnswer(a_record_); } protected: // clang-format off - const std::vector<uint8_t> kQuestionBytes = { + const std::vector<uint8_t> kQueryBytes = { 0x00, 0x01, // ID = 1 - 0x04, 0x00, // FLAGS = AA + 0x00, 0x00, // FLAGS = None 0x00, 0x01, // Question count 0x00, 0x00, // Answer count 0x00, 0x00, // Authority count @@ -75,9 +75,9 @@ class MdnsSenderTest : public ::testing::Test { 0x00, 0x01, // CLASS = IN (1) }; - const std::vector<uint8_t> kAnswerBytes = { + const std::vector<uint8_t> kResponseBytes = { 0x00, 0x01, // ID = 1 - 0x04, 0x00, // FLAGS = AA + 0x84, 0x00, // FLAGS = AA | RESPONSE 0x00, 0x00, // Question count 0x00, 0x01, // Answer count 0x00, 0x00, // Authority count @@ -96,41 +96,46 @@ class MdnsSenderTest : public ::testing::Test { MdnsQuestion a_question_; MdnsRecord a_record_; - MdnsMessage question_message_; - MdnsMessage answer_message_; + MdnsMessage query_message_; + MdnsMessage response_message_; IPEndpoint ipv4_multicast_endpoint_; IPEndpoint ipv6_multicast_endpoint_; }; TEST_F(MdnsSenderTest, SendMulticastIPv4) { - MockUdpSocket socket(openscreen::IPAddress::Version::kV4); - MdnsSender sender(&socket); - EXPECT_CALL(socket, - SendMessage(VoidPointerMatchesBytes(kQuestionBytes), - kQuestionBytes.size(), ipv4_multicast_endpoint_)) + std::unique_ptr<openscreen::platform::MockUdpSocket> socket_info = + MockUdpSocket::CreateDefault(openscreen::IPAddress::Version::kV4); + MdnsSender sender(socket_info.get()); + EXPECT_CALL(*socket_info.get(), + SendMessage(VoidPointerMatchesBytes(kQueryBytes), + kQueryBytes.size(), ipv4_multicast_endpoint_)) .Times(1); - EXPECT_EQ(sender.SendMulticast(question_message_), Error::Code::kNone); + EXPECT_EQ(sender.SendMulticast(query_message_), Error::Code::kNone); } TEST_F(MdnsSenderTest, SendMulticastIPv6) { - MockUdpSocket socket(openscreen::IPAddress::Version::kV6); - MdnsSender sender(&socket); - EXPECT_CALL(socket, - SendMessage(VoidPointerMatchesBytes(kQuestionBytes), - kQuestionBytes.size(), ipv6_multicast_endpoint_)) + std::unique_ptr<openscreen::platform::MockUdpSocket> socket_info = + MockUdpSocket::CreateDefault(openscreen::IPAddress::Version::kV6); + MdnsSender sender(socket_info.get()); + EXPECT_CALL(*socket_info.get(), + SendMessage(VoidPointerMatchesBytes(kQueryBytes), + kQueryBytes.size(), ipv6_multicast_endpoint_)) .Times(1); - EXPECT_EQ(sender.SendMulticast(question_message_), Error::Code::kNone); + EXPECT_EQ(sender.SendMulticast(query_message_), Error::Code::kNone); } TEST_F(MdnsSenderTest, SendUnicastIPv4) { IPEndpoint endpoint{.address = IPAddress{192, 168, 1, 1}, .port = 31337}; - MockUdpSocket socket(openscreen::IPAddress::Version::kV4); - MdnsSender sender(&socket); - EXPECT_CALL(socket, SendMessage(VoidPointerMatchesBytes(kAnswerBytes), - kAnswerBytes.size(), endpoint)) + std::unique_ptr<openscreen::platform::MockUdpSocket> socket_info = + MockUdpSocket::CreateDefault(openscreen::IPAddress::Version::kV4); + MdnsSender sender(socket_info.get()); + EXPECT_CALL(*socket_info.get(), + SendMessage(VoidPointerMatchesBytes(kResponseBytes), + kResponseBytes.size(), endpoint)) .Times(1); - EXPECT_EQ(sender.SendUnicast(answer_message_, endpoint), Error::Code::kNone); + EXPECT_EQ(sender.SendUnicast(response_message_, endpoint), + Error::Code::kNone); } TEST_F(MdnsSenderTest, SendUnicastIPv6) { @@ -140,33 +145,38 @@ TEST_F(MdnsSenderTest, SendUnicastIPv6) { }; IPEndpoint endpoint{.address = IPAddress(kIPv6AddressBytes), .port = 31337}; - MockUdpSocket socket(openscreen::IPAddress::Version::kV6); - MdnsSender sender(&socket); - EXPECT_CALL(socket, SendMessage(VoidPointerMatchesBytes(kAnswerBytes), - kAnswerBytes.size(), endpoint)) + std::unique_ptr<openscreen::platform::MockUdpSocket> socket_info = + MockUdpSocket::CreateDefault(openscreen::IPAddress::Version::kV6); + MdnsSender sender(socket_info.get()); + EXPECT_CALL(*socket_info.get(), + SendMessage(VoidPointerMatchesBytes(kResponseBytes), + kResponseBytes.size(), endpoint)) .Times(1); - EXPECT_EQ(sender.SendUnicast(answer_message_, endpoint), Error::Code::kNone); + EXPECT_EQ(sender.SendUnicast(response_message_, endpoint), + Error::Code::kNone); } TEST_F(MdnsSenderTest, MessageTooBig) { - MdnsMessage big_message_(1, 0x0400); + MdnsMessage big_message_(1, MessageType::Query); for (size_t i = 0; i < 100; ++i) { big_message_.AddQuestion(a_question_); big_message_.AddAnswer(a_record_); } - MockUdpSocket socket(openscreen::IPAddress::Version::kV4); - MdnsSender sender(&socket); - EXPECT_CALL(socket, SendMessage(_, _, _)).Times(0); + std::unique_ptr<openscreen::platform::MockUdpSocket> socket_info = + MockUdpSocket::CreateDefault(openscreen::IPAddress::Version::kV4); + MdnsSender sender(socket_info.get()); + EXPECT_CALL(*socket_info.get(), SendMessage(_, _, _)).Times(0); EXPECT_EQ(sender.SendMulticast(big_message_), Error::Code::kInsufficientBuffer); } TEST_F(MdnsSenderTest, ReturnsErrorOnSocketFailure) { - MockUdpSocket socket(openscreen::IPAddress::Version::kV4); - MdnsSender sender(&socket); - EXPECT_CALL(socket, SendMessage(_, _, _)) + std::unique_ptr<openscreen::platform::MockUdpSocket> socket_info = + MockUdpSocket::CreateDefault(openscreen::IPAddress::Version::kV4); + MdnsSender sender(socket_info.get()); + EXPECT_CALL(*socket_info.get(), SendMessage(_, _, _)) .WillOnce(Return(Error::Code::kConnectionFailed)); - EXPECT_EQ(sender.SendMulticast(question_message_), + EXPECT_EQ(sender.SendMulticast(query_message_), Error::Code::kConnectionFailed); } diff --git a/chromium/third_party/openscreen/src/cast/common/mdns/mdns_writer.cc b/chromium/third_party/openscreen/src/cast/common/mdns/mdns_writer.cc index 79d101736f5..94a190bd1f4 100644 --- a/chromium/third_party/openscreen/src/cast/common/mdns/mdns_writer.cc +++ b/chromium/third_party/openscreen/src/cast/common/mdns/mdns_writer.cc @@ -27,13 +27,13 @@ std::vector<uint64_t> ComputeDomainNameSubhashes(const DomainName& name) { return b; }; + const std::vector<std::string>& labels = name.labels(); // Use a large prime between 2^63 and 2^64 as a starting value. // This is taken from absl::Hash implementation. uint64_t hash_value = UINT64_C(0xc3a5c85c97cb3127); - std::vector<uint64_t> subhashes(name.label_count()); - for (size_t i = name.label_count(); i-- > 0;) { - hash_value = - hash_combiner(hash_value, absl::AsciiStrToLower(name.Label(i))); + std::vector<uint64_t> subhashes(labels.size()); + for (size_t i = labels.size(); i-- > 0;) { + hash_value = hash_combiner(hash_value, absl::AsciiStrToLower(labels[i])); subhashes[i] = hash_value; } return subhashes; @@ -84,8 +84,9 @@ bool MdnsWriter::Write(const DomainName& name) { // Tentative dictionary contains label pointer entries to be added to the // compression dictionary after successfully writing the domain name. std::unordered_map<uint64_t, uint16_t> tentative_dictionary; - for (size_t i = 0; i < name.label_count(); ++i) { - OSP_DCHECK(IsValidDomainLabel(name.Label(i))); + const std::vector<std::string>& labels = name.labels(); + for (size_t i = 0; i < labels.size(); ++i) { + OSP_DCHECK(IsValidDomainLabel(labels[i])); // We only need to do a look up in the compression dictionary and not in the // tentative dictionary. The tentative dictionary cannot possibly contain a // valid label pointer as all the entries previously added to it are for @@ -106,9 +107,8 @@ bool MdnsWriter::Write(const DomainName& name) { tentative_dictionary.insert( std::make_pair(subhashes[i], MakePointerLabel(current() - begin()))); } - const std::string& label = name.Label(i); - if (!Write(MakeDirectLabel(label.size())) || - !Write(label.data(), label.size())) { + if (!Write(MakeDirectLabel(labels[i].size())) || + !Write(labels[i].data(), labels[i].size())) { return false; } } @@ -204,8 +204,8 @@ bool MdnsWriter::Write(const TxtRecordRdata& rdata) { bool MdnsWriter::Write(const MdnsRecord& record) { Cursor cursor(this); - if (Write(record.name()) && Write(static_cast<uint16_t>(record.type())) && - Write(MakeRecordClass(record.record_class(), record.cache_flush())) && + if (Write(record.name()) && Write(static_cast<uint16_t>(record.dns_type())) && + Write(MakeRecordClass(record.record_class(), record.record_type())) && Write(record.ttl()) && Write(record.rdata())) { cursor.Commit(); return true; @@ -215,9 +215,10 @@ bool MdnsWriter::Write(const MdnsRecord& record) { bool MdnsWriter::Write(const MdnsQuestion& question) { Cursor cursor(this); - if (Write(question.name()) && Write(static_cast<uint16_t>(question.type())) && + if (Write(question.name()) && + Write(static_cast<uint16_t>(question.dns_type())) && Write(MakeQuestionClass(question.record_class(), - question.unicast_response()))) { + question.response_type()))) { cursor.Commit(); return true; } @@ -228,7 +229,7 @@ bool MdnsWriter::Write(const MdnsMessage& message) { Cursor cursor(this); Header header; header.id = message.id(); - header.flags = message.flags(); + header.flags = MakeFlags(message.type()); header.question_count = message.questions().size(); header.answer_count = message.answers().size(); header.authority_record_count = message.authority_records().size(); diff --git a/chromium/third_party/openscreen/src/cast/common/mdns/mdns_writer_unittest.cc b/chromium/third_party/openscreen/src/cast/common/mdns/mdns_writer_unittest.cc index bf8b2553f8c..c1f12b4c199 100644 --- a/chromium/third_party/openscreen/src/cast/common/mdns/mdns_writer_unittest.cc +++ b/chromium/third_party/openscreen/src/cast/common/mdns/mdns_writer_unittest.cc @@ -306,10 +306,10 @@ TEST(MdnsWriterTest, WriteMdnsRecord_ARecordRdata) { 0xac, 0x00, 0x00, 0x01, // 172.0.0.1 }; // clang-format on - TestWriteEntrySucceeds( - MdnsRecord(DomainName{"testing", "local"}, DnsType::kA, DnsClass::kIN, - true, 120, ARecordRdata(IPAddress{172, 0, 0, 1})), - kExpectedResult, sizeof(kExpectedResult)); + TestWriteEntrySucceeds(MdnsRecord(DomainName{"testing", "local"}, DnsType::kA, + DnsClass::kIN, RecordType::kUnique, 120, + ARecordRdata(IPAddress{172, 0, 0, 1})), + kExpectedResult, sizeof(kExpectedResult)); } TEST(MdnsWriterTest, WriteMdnsRecord_PtrRecordRdata) { @@ -328,15 +328,15 @@ TEST(MdnsWriterTest, WriteMdnsRecord_PtrRecordRdata) { // clang-format on TestWriteEntrySucceeds( MdnsRecord(DomainName{"_service", "testing", "local"}, DnsType::kPTR, - DnsClass::kIN, false, 120, + DnsClass::kIN, RecordType::kShared, 120, PtrRecordRdata(DomainName{"testing", "local"})), kExpectedResult, sizeof(kExpectedResult)); } TEST(MdnsWriterTest, WriteMdnsRecord_InsufficientBuffer) { - TestWriteEntryInsufficientBuffer( - MdnsRecord(DomainName{"testing", "local"}, DnsType::kA, DnsClass::kIN, - true, 120, ARecordRdata(IPAddress{172, 0, 0, 1}))); + TestWriteEntryInsufficientBuffer(MdnsRecord( + DomainName{"testing", "local"}, DnsType::kA, DnsClass::kIN, + RecordType::kUnique, 120, ARecordRdata(IPAddress{172, 0, 0, 1}))); } TEST(MdnsWriterTest, WriteMdnsQuestion) { @@ -350,22 +350,23 @@ TEST(MdnsWriterTest, WriteMdnsQuestion) { 0x80, 0x01, // CLASS = IN (1) | UNICAST_BIT }; // clang-format on - TestWriteEntrySucceeds(MdnsQuestion(DomainName{"wire", "format", "local"}, - DnsType::kPTR, DnsClass::kIN, true), - kExpectedResult, sizeof(kExpectedResult)); + TestWriteEntrySucceeds( + MdnsQuestion(DomainName{"wire", "format", "local"}, DnsType::kPTR, + DnsClass::kIN, ResponseType::kUnicast), + kExpectedResult, sizeof(kExpectedResult)); } TEST(MdnsWriterTest, WriteMdnsQuestion_InsufficientBuffer) { TestWriteEntryInsufficientBuffer( MdnsQuestion(DomainName{"wire", "format", "local"}, DnsType::kPTR, - DnsClass::kIN, true)); + DnsClass::kIN, ResponseType::kUnicast)); } TEST(MdnsWriterTest, WriteMdnsMessage) { // clang-format off constexpr uint8_t kExpectedMessage[] = { 0x00, 0x01, // ID = 1 - 0x04, 0x00, // FLAGS = AA + 0x00, 0x00, // FLAGS = None 0x00, 0x01, // Question count 0x00, 0x00, // Answer count 0x00, 0x01, // Authority count @@ -387,12 +388,13 @@ TEST(MdnsWriterTest, WriteMdnsMessage) { }; // clang-format on MdnsQuestion question(DomainName{"question"}, DnsType::kPTR, DnsClass::kIN, - false); + ResponseType::kMulticast); MdnsRecord auth_record(DomainName{"auth"}, DnsType::kTXT, DnsClass::kIN, - false, 120, TxtRecordRdata{"foo=1", "bar=2"}); + RecordType::kShared, 120, + TxtRecordRdata{"foo=1", "bar=2"}); - MdnsMessage message(1, 0x0400); + MdnsMessage message(1, MessageType::Query); message.AddQuestion(question); message.AddAuthorityRecord(auth_record); @@ -405,12 +407,13 @@ TEST(MdnsWriterTest, WriteMdnsMessage) { TEST(MdnsWriterTest, WriteMdnsMessage_InsufficientBuffer) { MdnsQuestion question(DomainName{"question"}, DnsType::kPTR, DnsClass::kIN, - false); + ResponseType::kMulticast); - MdnsRecord auth_record(DomainName{"auth"}, DnsType::kTXT, DnsClass::kIN, 120, - false, TxtRecordRdata{"foo=1", "bar=2"}); + MdnsRecord auth_record(DomainName{"auth"}, DnsType::kTXT, DnsClass::kIN, + RecordType::kShared, 120, + TxtRecordRdata{"foo=1", "bar=2"}); - MdnsMessage message(1, 0x0400); + MdnsMessage message(1, MessageType::Query); message.AddQuestion(question); message.AddAuthorityRecord(auth_record); TestWriteEntryInsufficientBuffer(message); diff --git a/chromium/third_party/openscreen/src/docs/trace_logging.md b/chromium/third_party/openscreen/src/docs/trace_logging.md index 5866778f14c..523b5cb7a8f 100644 --- a/chromium/third_party/openscreen/src/docs/trace_logging.md +++ b/chromium/third_party/openscreen/src/docs/trace_logging.md @@ -6,6 +6,12 @@ for, whether the function was successful, and connect all of this information to any existing informational logs. The below provides information about how this can be achieved with OSP's TraceLogging Infrastructure. +## Compilation + +By default, TraceLogging is enabled as part of the build. +To disable TraceLogging, include flag `--args="enable_trace_logging=false"` +when calling `gn gen` as part of building this library. + ## Imports To use TraceLogging, import the following header file: diff --git a/chromium/third_party/openscreen/src/infra/config/global/commit-queue.cfg b/chromium/third_party/openscreen/src/infra/config/global/commit-queue.cfg index 7e58538a96c..367629a76f9 100644 --- a/chromium/third_party/openscreen/src/infra/config/global/commit-queue.cfg +++ b/chromium/third_party/openscreen/src/infra/config/global/commit-queue.cfg @@ -29,6 +29,9 @@ config_groups { name: "openscreen/try/linux64_debug_gcc" } builders { + name: "openscreen/try/linux64_tsan" + } + builders { name: "openscreen/try/mac_debug" } builders { diff --git a/chromium/third_party/openscreen/src/infra/config/global/cr-buildbucket.cfg b/chromium/third_party/openscreen/src/infra/config/global/cr-buildbucket.cfg index 7e8fe9cbeb7..7967288f5cc 100644 --- a/chromium/third_party/openscreen/src/infra/config/global/cr-buildbucket.cfg +++ b/chromium/third_party/openscreen/src/infra/config/global/cr-buildbucket.cfg @@ -49,6 +49,13 @@ builder_mixins { } builder_mixins { + name: "tsan" + recipe { + properties_j: "is_tsan:true" + } +} + +builder_mixins { name: "linux" dimensions: "os:Ubuntu-16.04" } @@ -118,6 +125,13 @@ buckets { } builders { + name: "linux64_tsan" + mixins: "linux" + mixins: "x64" + mixins: "tsan" + } + + builders { name: "mac_debug" mixins: "mac" mixins: "debug" @@ -175,6 +189,13 @@ buckets: { } builders { + name: "linux64_tsan" + mixins: "linux" + mixins: "x64" + mixins: "tsan" + } + + builders { name: "mac_debug" mixins: "mac" mixins: "debug" diff --git a/chromium/third_party/openscreen/src/infra/config/global/luci-milo.cfg b/chromium/third_party/openscreen/src/infra/config/global/luci-milo.cfg index fd1e1a7560f..d3007a807aa 100644 --- a/chromium/third_party/openscreen/src/infra/config/global/luci-milo.cfg +++ b/chromium/third_party/openscreen/src/infra/config/global/luci-milo.cfg @@ -14,6 +14,12 @@ consoles { } builders { + name: "buildbucket/luci.openscreen.ci/linux64_tsan" + category: "linux|x64" + short_name: "tsan" + } + + builders { name: "buildbucket/luci.openscreen.ci/linux64_debug_gcc" category: "linux|x64|gcc" short_name: "dbg" @@ -52,6 +58,12 @@ consoles { } builders { + name: "buildbucket/luci.openscreen.try/linux64_tsan" + category: "linux|x64" + short_name: "tsan" + } + + builders { name: "buildbucket/luci.openscreen.try/linux64_debug_gcc" category: "linux|x64|gcc" short_name: "dbg" diff --git a/chromium/third_party/openscreen/src/infra/config/global/luci-scheduler.cfg b/chromium/third_party/openscreen/src/infra/config/global/luci-scheduler.cfg index da09652d628..796ca0f1efe 100644 --- a/chromium/third_party/openscreen/src/infra/config/global/luci-scheduler.cfg +++ b/chromium/third_party/openscreen/src/infra/config/global/luci-scheduler.cfg @@ -23,6 +23,7 @@ trigger { refs: "refs/heads/master" } triggers: "linux64_debug" + triggers: "linux64_tsan" triggers: "linux64_debug_gcc" triggers: "mac_debug" } @@ -50,6 +51,16 @@ job { } job { + id: "linux64_tsan" + acl_sets: "default" + buildbucket: { + server: "cr-buildbucket.appspot.com" + bucket: "luci.openscreen.ci" + builder: "linux64_tsan" + } +} + +job { id: "linux64_debug_gcc" acl_sets: "default" buildbucket: { diff --git a/chromium/third_party/openscreen/src/osp/demo/README.md b/chromium/third_party/openscreen/src/osp/demo/README.md index ca5e67bf238..96acd759265 100644 --- a/chromium/third_party/openscreen/src/osp/demo/README.md +++ b/chromium/third_party/openscreen/src/osp/demo/README.md @@ -34,7 +34,7 @@ output while the demo is running is to make these named pipes like so: Then `cat` them in separate terminals while the demo is running. -## Controller commands +## Listener commands - `avail <url>`: Begin listening for receivers that support the presentation of `url`. @@ -49,7 +49,7 @@ Then `cat` them in separate terminals while the demo is running. This allows using the `msg` command again. - `term`: Terminate the previously started presentation. -## Receiver commands +## Publisher commands - `avail`: Toggle whether the receiver is publishing itself as an available screen. The receiver starts in the publishing state. diff --git a/chromium/third_party/openscreen/src/osp/demo/demo.cc b/chromium/third_party/openscreen/src/osp/demo/demo.cc index 83c410d936e..6babe25c00d 100644 --- a/chromium/third_party/openscreen/src/osp/demo/demo.cc +++ b/chromium/third_party/openscreen/src/osp/demo/demo.cc @@ -28,8 +28,12 @@ #include "osp/public/service_publisher.h" #include "platform/api/logging.h" #include "platform/api/network_interface.h" +#include "platform/api/network_runner.h" +#include "platform/api/network_runner_lifetime_manager.h" #include "platform/api/time.h" #include "platform/api/trace_logging.h" +#include "platform/impl/network_runner.h" +#include "platform/impl/task_runner.h" #include "platform/impl/text_trace_logging_platform.h" #include "third_party/tinycbor/src/src/cbor.h" @@ -345,12 +349,10 @@ struct CommandWaitResult { }; CommandWaitResult WaitForCommand(pollfd* pollfd) { - NetworkServiceManager* network_service = NetworkServiceManager::Get(); while (poll(pollfd, 1, 10) >= 0) { if (g_done) { return {true}; } - network_service->RunEventLoopOnce(); if (pollfd->revents == 0) { continue; @@ -372,7 +374,6 @@ CommandWaitResult WaitForCommand(pollfd* pollfd) { } void RunControllerPollLoop(presentation::Controller* controller) { - TRACE_SCOPED(TraceCategory::CastFlinging, "RunControllerPollLoop"); ReceiverObserver receiver_observer; RequestDelegate request_delegate; ConnectionDelegate connection_delegate; @@ -380,9 +381,7 @@ void RunControllerPollLoop(presentation::Controller* controller) { presentation::Controller::ConnectRequest connect_request; pollfd stdin_pollfd{STDIN_FILENO, POLLIN}; - uint64_t it = 1; while (true) { - TRACE_SCOPED(TraceCategory::CastFlinging, "ControllerPollIteration", it); write(STDOUT_FILENO, "$ ", 2); CommandWaitResult command_result = WaitForCommand(&stdin_pollfd); @@ -417,8 +416,6 @@ void RunControllerPollLoop(presentation::Controller* controller) { request_delegate.connection->Terminate( presentation::TerminationReason::kControllerTerminateCalled); } - - it += 2; // +2 to keep the ids for reciever and controller distinct. }; watch = presentation::Controller::ReceiverWatch(); @@ -427,16 +424,21 @@ void RunControllerPollLoop(presentation::Controller* controller) { void ListenerDemo() { SignalThings(); + std::unique_ptr<platform::NetworkRunnerLifetimeManager> + network_runner_manager = platform::NetworkRunnerLifetimeManager::Create(); + network_runner_manager->CreateNetworkRunner(); + platform::NetworkRunner* network_runner = network_runner_manager->Get(); + ListenerObserver listener_observer; MdnsServiceListenerConfig listener_config; - auto mdns_listener = - MdnsServiceListenerFactory::Create(listener_config, &listener_observer); + auto mdns_listener = MdnsServiceListenerFactory::Create( + listener_config, &listener_observer, network_runner); MessageDemuxer demuxer(platform::Clock::now, MessageDemuxer::kDefaultBufferLimit); ConnectionClientObserver client_observer; - auto connection_client = - ProtocolConnectionClientFactory::Create(&demuxer, &client_observer); + auto connection_client = ProtocolConnectionClientFactory::Create( + &demuxer, &client_observer, network_runner); auto* network_service = NetworkServiceManager::Create( std::move(mdns_listener), nullptr, std::move(connection_client), nullptr); @@ -484,11 +486,8 @@ void HandleReceiverCommand(absl::string_view command, void RunReceiverPollLoop(pollfd& file_descriptor, NetworkServiceManager* manager, ReceiverDelegate& delegate) { - TRACE_SCOPED(TraceCategory::CastFlinging, "RunReceiverPollLoop"); pollfd stdin_pollfd{STDIN_FILENO, POLLIN}; - uint64_t it = 2; while (true) { - TRACE_SCOPED(TraceCategory::CastFlinging, "ReceiverPollIteration", it); write(STDOUT_FILENO, "$ ", 2); CommandWaitResult command_result = WaitForCommand(&stdin_pollfd); @@ -499,8 +498,6 @@ void RunReceiverPollLoop(pollfd& file_descriptor, HandleReceiverCommand(command_result.command_line.command, command_result.command_line.argument_tail, delegate, manager); - - it += 2; // +2 to keep the ids for reciever and controller distinct. } } @@ -518,6 +515,11 @@ void PublisherDemo(absl::string_view friendly_name) { constexpr uint16_t server_port = 6667; + std::unique_ptr<platform::NetworkRunnerLifetimeManager> + network_runner_manager = platform::NetworkRunnerLifetimeManager::Create(); + network_runner_manager->CreateNetworkRunner(); + platform::NetworkRunner* network_runner = network_runner_manager->Get(); + PublisherObserver publisher_observer; // TODO(btolsch): aggregate initialization probably better? ServicePublisher::Config publisher_config; @@ -527,7 +529,7 @@ void PublisherDemo(absl::string_view friendly_name) { publisher_config.connection_server_port = server_port; auto mdns_publisher = MdnsServicePublisherFactory::Create( - publisher_config, &publisher_observer); + publisher_config, &publisher_observer, network_runner); ServerConfig server_config; std::vector<platform::InterfaceAddresses> interfaces = @@ -541,7 +543,8 @@ void PublisherDemo(absl::string_view friendly_name) { MessageDemuxer::kDefaultBufferLimit); ConnectionServerObserver server_observer; auto connection_server = ProtocolConnectionServerFactory::Create( - server_config, &demuxer, &server_observer); + server_config, &demuxer, &server_observer, network_runner); + auto* network_service = NetworkServiceManager::Create(nullptr, std::move(mdns_publisher), nullptr, std::move(connection_server)); diff --git a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/domain_name.cc b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/domain_name.cc index 7e584d60038..ece5485175f 100644 --- a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/domain_name.cc +++ b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/domain_name.cc @@ -57,7 +57,7 @@ bool DomainName::operator==(const DomainName& other) const { if (domain_name_.size() != other.domain_name_.size()) { return false; } - for (int i = 0; i < domain_name_.size(); ++i) { + for (size_t i = 0; i < domain_name_.size(); ++i) { if (tolower(domain_name_[i]) != tolower(other.domain_name_[i])) { return false; } diff --git a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/domain_name_unittest.cc b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/domain_name_unittest.cc index 48ea272b687..8cf61fd798d 100644 --- a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/domain_name_unittest.cc +++ b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/domain_name_unittest.cc @@ -162,7 +162,7 @@ TEST(DomainNameTest, GetLabels) { DomainName domain_name = UnpackErrorOr(FromLabels(labels)); const auto actual_labels = domain_name.GetLabels(); - for (auto i = 0; i < labels.size(); ++i) { + for (size_t i = 0; i < labels.size(); ++i) { EXPECT_EQ(labels[i], actual_labels[i]); } } diff --git a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/embedder_demo.cc b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/embedder_demo.cc index d52098df659..2cbbf9e2111 100644 --- a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/embedder_demo.cc +++ b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/embedder_demo.cc @@ -12,7 +12,10 @@ #include "osp/impl/discovery/mdns/mdns_responder_adapter_impl.h" #include "platform/api/logging.h" +#include "platform/api/time.h" #include "platform/base/error.h" +#include "platform/impl/network_runner.h" +#include "platform/impl/task_runner.h" // This file contains a demo of our mDNSResponder wrapper code. It can both // listen for mDNS services and advertise an mDNS service. The command-line @@ -48,6 +51,21 @@ struct Service { std::vector<std::string> txt; }; +class DemoSocketClient : public platform::UdpSocket::Client { + void OnError(platform::UdpSocket* socket, Error error) override { + OSP_UNIMPLEMENTED(); + } + + void OnSendError(platform::UdpSocket* socket, Error error) override { + OSP_UNIMPLEMENTED(); + } + + void OnRead(platform::UdpSocket* socket, + ErrorOr<platform::UdpPacket> packet) override { + OSP_UNIMPLEMENTED(); + } +}; + using ServiceMap = std::map<mdns::DomainName, Service, mdns::DomainNameComparator>; ServiceMap* g_services = nullptr; @@ -96,11 +114,13 @@ void SignalThings() { } std::vector<platform::UdpSocketUniquePtr> SetUpMulticastSockets( - const std::vector<platform::NetworkInterfaceIndex>& index_list) { + platform::TaskRunner* task_runner, + const std::vector<platform::NetworkInterfaceIndex>& index_list, + platform::UdpSocket::Client* client) { std::vector<platform::UdpSocketUniquePtr> sockets; for (const auto ifindex : index_list) { auto create_result = - platform::UdpSocket::Create(platform::UdpSocket::Version::kV4); + platform::UdpSocket::Create(task_runner, client, IPEndpoint{{}, 5353}); if (!create_result) { OSP_LOG_ERROR << "failed to create IPv4 socket for interface " << ifindex << ": " << create_result.error().message(); @@ -123,7 +143,7 @@ std::vector<platform::UdpSocketUniquePtr> SetUpMulticastSockets( continue; } - result = socket->Bind({{}, 5353}); + result = socket->Bind(); if (!result.ok()) { OSP_LOG_ERROR << "bind failed for interface " << ifindex << ": " << result.message(); @@ -234,7 +254,8 @@ void HandleEvents(mdns::MdnsResponderAdapterImpl* mdns_adapter) { } } -void BrowseDemo(const std::string& service_name, +void BrowseDemo(platform::NetworkRunner* network_runner, + const std::string& service_name, const std::string& service_protocol, const std::string& service_instance) { SignalThings(); @@ -250,7 +271,6 @@ void BrowseDemo(const std::string& service_name, } auto mdns_adapter = std::make_unique<mdns::MdnsResponderAdapterImpl>(); - platform::EventWaiterPtr waiter = platform::CreateEventWaiter(); mdns_adapter->Init(); mdns_adapter->SetHostLabel("gigliorononomicon"); auto interface_addresses = platform::GetInterfaceAddresses(); @@ -264,7 +284,8 @@ void BrowseDemo(const std::string& service_name, index_list.push_back(interface.info.index); } - auto sockets = SetUpMulticastSockets(index_list); + DemoSocketClient client; + auto sockets = SetUpMulticastSockets(network_runner, index_list, &client); // The code below assumes the elements in |sockets| is in exact 1:1 // correspondence with the elements in |index_list|. Crash the demo if any // sockets are missing (i.e., failed to be set up). @@ -291,7 +312,7 @@ void BrowseDemo(const std::string& service_name, } for (const platform::UdpSocketUniquePtr& socket : sockets) { - platform::WatchUdpSocketReadable(waiter, socket.get()); + network_runner->ReadRepeatedly(socket.get(), mdns_adapter.get()); mdns_adapter->StartPtrQuery(socket.get(), service_type.value()); } @@ -310,23 +331,15 @@ void BrowseDemo(const std::string& service_name, g_dump_services = false; } mdns_adapter->RunTasks(); - auto data = platform::OnePlatformLoopIteration(waiter); - for (auto& packet : data) { - mdns_adapter->OnDataReceived(packet.source(), packet.destination(), - packet.data(), packet.size(), - packet.socket()); - } } OSP_LOG << "num services: " << g_services->size(); for (const auto& s : *g_services) { LogService(s.second); } - platform::StopWatchingNetworkChange(waiter); for (const platform::UdpSocketUniquePtr& socket : sockets) { - platform::StopWatchingUdpSocketReadable(waiter, socket.get()); + network_runner->CancelRead(socket.get()); mdns_adapter->DeregisterInterface(socket.get()); } - platform::DestroyEventWaiter(waiter); mdns_adapter->Close(); } @@ -354,7 +367,23 @@ int main(int argc, char** argv) { openscreen::ServiceMap services; openscreen::g_services = &services; - openscreen::BrowseDemo(labels[0], labels[1], service_instance); + auto task_runner = std::make_unique<openscreen::platform::TaskRunnerImpl>( + openscreen::platform::Clock::now); + std::thread task_runner_thread( + [&task_runner]() { task_runner->RunUntilStopped(); }); + auto network_runner = + std::make_unique<openscreen::platform::NetworkRunnerImpl>( + std::move(task_runner)); + std::thread network_runner_thread( + [&network_runner]() { network_runner->RunUntilStopped(); }); + + openscreen::BrowseDemo(network_runner.get(), labels[0], labels[1], + service_instance); + + network_runner->RequestStopSoon(); + task_runner->RequestStopSoon(); + network_runner_thread.join(); + task_runner_thread.join(); openscreen::g_services = nullptr; return 0; } diff --git a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter.h b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter.h index 0bf732be623..ee1a2f55e9e 100644 --- a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter.h +++ b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter.h @@ -10,13 +10,14 @@ #include <string> #include <vector> +#include "absl/types/optional.h" #include "osp/impl/discovery/mdns/domain_name.h" #include "osp/impl/discovery/mdns/mdns_responder_platform.h" #include "platform/api/network_interface.h" +#include "platform/api/time.h" #include "platform/api/udp_socket.h" #include "platform/base/error.h" #include "platform/base/ip_address.h" -#include "platform/impl/event_loop.h" namespace openscreen { namespace mdns { @@ -164,7 +165,7 @@ enum class MdnsResponderErrorCode { // called after any sequence of calls to mDNSResponder. It also returns a // timeout value, after which it must be called again (e.g. for maintaining its // cache). -class MdnsResponderAdapter { +class MdnsResponderAdapter : public platform::UdpReadCallback { public: MdnsResponderAdapter(); virtual ~MdnsResponderAdapter() = 0; @@ -193,14 +194,9 @@ class MdnsResponderAdapter { platform::UdpSocket* socket) = 0; virtual Error DeregisterInterface(platform::UdpSocket* socket) = 0; - virtual void OnDataReceived(const IPEndpoint& source, - const IPEndpoint& original_destination, - const uint8_t* data, - size_t length, - platform::UdpSocket* receiving_socket) = 0; - - // Returns the number of seconds after which this method must be called again. - virtual int RunTasks() = 0; + // Returns the time period after which this method must be called again, if + // any. + virtual absl::optional<platform::Clock::duration> RunTasks() = 0; virtual std::vector<PtrEvent> TakePtrResponses() = 0; virtual std::vector<SrvEvent> TakeSrvResponses() = 0; diff --git a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl.cc b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl.cc index 16eea535050..63033f4c61f 100644 --- a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl.cc +++ b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl.cc @@ -11,6 +11,7 @@ #include <memory> #include "platform/api/logging.h" +#include "platform/api/trace_logging.h" namespace openscreen { namespace mdns { @@ -203,6 +204,7 @@ Error MdnsResponderAdapterImpl::Init() { } void MdnsResponderAdapterImpl::Close() { + TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::Close"); mDNS_StartExit(&mdns_); // Let all services send goodbyes. while (!service_records_.empty()) { @@ -224,6 +226,7 @@ void MdnsResponderAdapterImpl::Close() { } Error MdnsResponderAdapterImpl::SetHostLabel(const std::string& host_label) { + TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::SetHostLabel"); if (host_label.size() > DomainName::kDomainNameMaxLabelLength) return Error::Code::kDomainNameTooLong; @@ -240,6 +243,8 @@ Error MdnsResponderAdapterImpl::RegisterInterface( const platform::InterfaceInfo& interface_info, const platform::IPSubnet& interface_address, platform::UdpSocket* socket) { + TRACE_SCOPED(TraceCategory::mDNS, + "MdnsResponderAdapterImpl::RegisterInterface"); OSP_DCHECK(socket); const auto info_it = responder_interface_info_.find(socket); @@ -277,9 +282,11 @@ Error MdnsResponderAdapterImpl::RegisterInterface( Error MdnsResponderAdapterImpl::DeregisterInterface( platform::UdpSocket* socket) { + TRACE_SCOPED(TraceCategory::mDNS, + "MdnsResponderAdapterImpl::DeregisterInterface"); const auto info_it = responder_interface_info_.find(socket); if (info_it == responder_interface_info_.end()) - return Error::Code::kNoItemFound; + return Error::Code::kItemNotFound; const auto it = std::find(platform_storage_.sockets.begin(), platform_storage_.sockets.end(), socket); @@ -293,45 +300,70 @@ Error MdnsResponderAdapterImpl::DeregisterInterface( responder_interface_info_.erase(info_it); return Error::None(); } - -void MdnsResponderAdapterImpl::OnDataReceived( - const IPEndpoint& source, - const IPEndpoint& original_destination, - const uint8_t* data, - size_t length, - platform::UdpSocket* receiving_socket) { +void MdnsResponderAdapterImpl::OnRead(platform::UdpPacket packet, + platform::NetworkRunner* network_runner) { + TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::OnRead"); mDNSAddr src; - if (source.address.IsV4()) { + if (packet.source().address.IsV4()) { src.type = mDNSAddrType_IPv4; - source.address.CopyToV4(src.ip.v4.b); + packet.source().address.CopyToV4(src.ip.v4.b); } else { src.type = mDNSAddrType_IPv6; - source.address.CopyToV6(src.ip.v6.b); + packet.source().address.CopyToV6(src.ip.v6.b); } mDNSIPPort srcport; - AssignMdnsPort(&srcport, source.port); + AssignMdnsPort(&srcport, packet.source().port); mDNSAddr dst; - if (source.address.IsV4()) { + if (packet.source().address.IsV4()) { dst.type = mDNSAddrType_IPv4; - original_destination.address.CopyToV4(dst.ip.v4.b); + packet.destination().address.CopyToV4(dst.ip.v4.b); } else { dst.type = mDNSAddrType_IPv6; - original_destination.address.CopyToV6(dst.ip.v6.b); + packet.destination().address.CopyToV6(dst.ip.v6.b); } mDNSIPPort dstport; - AssignMdnsPort(&dstport, original_destination.port); - - mDNSCoreReceive(&mdns_, const_cast<uint8_t*>(data), data + length, &src, - srcport, &dst, dstport, - reinterpret_cast<mDNSInterfaceID>(receiving_socket)); -} - -int MdnsResponderAdapterImpl::RunTasks() { - const auto t = mDNS_Execute(&mdns_); - const auto now = mDNSPlatformRawTime(); - const auto next = t - now; - return next; + AssignMdnsPort(&dstport, packet.destination().port); + + auto* packet_data = packet.data(); + mDNSCoreReceive(&mdns_, const_cast<uint8_t*>(packet_data), + packet_data + packet.size(), &src, srcport, &dst, dstport, + reinterpret_cast<mDNSInterfaceID>(packet.socket())); +} + +absl::optional<platform::Clock::duration> MdnsResponderAdapterImpl::RunTasks() { + TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::RunTasks"); + + mDNS_Execute(&mdns_); + + // Using mDNS_Execute's response to determine the correct timespan before + // re-running this method doesn't work as expected. In the demo, under some + // cases (about 25% of demo runs), the response is set to an unreasonably + // large number (in the order of multiple days). + // + // From the mDNS documentation: "it is the responsibility [...] to set the + // timer according to the m->NextScheduledEvent value, and then when the timer + // fires, the timer callback function should call mDNS_Execute()" - for more + // details see third_party/mDNSResponder/src/mDNSCore/mDNS.c : 3390 + // + // Together, I understand these to mean that the mdns library code doesn't + // expect we need mDNS_Execute called again by the task runner, only in the + // other special cases it calls out in documentation (which we currently do + // correctly). In our code, when we call mDNS_Execute again outside of the + // task runner, the result is currently discarded. What we would need to do is + // reach into the Task Runner's task and update how long before the task runs + // again. That would require some large refactoring and changes. + // + // Additionally, beyond this, the mDNS code documents that there are cases + // where the return value for mDNS_Execute should be ignored because it may be + // stale. + // + // TODO(rwkeane): More accurately determine when the next run of this method + // should be. + constexpr auto seconds_before_next_run = 1; + + // Return as a duration. + return std::chrono::seconds(seconds_before_next_run); } std::vector<PtrEvent> MdnsResponderAdapterImpl::TakePtrResponses() { @@ -357,6 +389,7 @@ std::vector<AaaaEvent> MdnsResponderAdapterImpl::TakeAaaaResponses() { MdnsResponderErrorCode MdnsResponderAdapterImpl::StartPtrQuery( platform::UdpSocket* socket, const DomainName& service_type) { + TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::StartPtrQuery"); auto& ptr_questions = socket_to_questions_[socket].ptr; if (ptr_questions.find(service_type) != ptr_questions.end()) return MdnsResponderErrorCode::kNoError; @@ -403,6 +436,7 @@ MdnsResponderErrorCode MdnsResponderAdapterImpl::StartPtrQuery( MdnsResponderErrorCode MdnsResponderAdapterImpl::StartSrvQuery( platform::UdpSocket* socket, const DomainName& service_instance) { + TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::StartSrvQuery"); if (!service_instance.EndsWithLocalDomain()) return MdnsResponderErrorCode::kInvalidParameters; @@ -440,6 +474,7 @@ MdnsResponderErrorCode MdnsResponderAdapterImpl::StartSrvQuery( MdnsResponderErrorCode MdnsResponderAdapterImpl::StartTxtQuery( platform::UdpSocket* socket, const DomainName& service_instance) { + TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::StartTxtQuery"); if (!service_instance.EndsWithLocalDomain()) return MdnsResponderErrorCode::kInvalidParameters; @@ -477,6 +512,7 @@ MdnsResponderErrorCode MdnsResponderAdapterImpl::StartTxtQuery( MdnsResponderErrorCode MdnsResponderAdapterImpl::StartAQuery( platform::UdpSocket* socket, const DomainName& domain_name) { + TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::StartAQuery"); if (!domain_name.EndsWithLocalDomain()) return MdnsResponderErrorCode::kInvalidParameters; @@ -514,6 +550,7 @@ MdnsResponderErrorCode MdnsResponderAdapterImpl::StartAQuery( MdnsResponderErrorCode MdnsResponderAdapterImpl::StartAaaaQuery( platform::UdpSocket* socket, const DomainName& domain_name) { + TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::StartAaaaQuery"); if (!domain_name.EndsWithLocalDomain()) return MdnsResponderErrorCode::kInvalidParameters; @@ -551,6 +588,7 @@ MdnsResponderErrorCode MdnsResponderAdapterImpl::StartAaaaQuery( MdnsResponderErrorCode MdnsResponderAdapterImpl::StopPtrQuery( platform::UdpSocket* socket, const DomainName& service_type) { + TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::StopPtrQuery"); auto interface_entry = socket_to_questions_.find(socket); if (interface_entry == socket_to_questions_.end()) return MdnsResponderErrorCode::kNoError; @@ -568,6 +606,7 @@ MdnsResponderErrorCode MdnsResponderAdapterImpl::StopPtrQuery( MdnsResponderErrorCode MdnsResponderAdapterImpl::StopSrvQuery( platform::UdpSocket* socket, const DomainName& service_instance) { + TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::StopSrvQuery"); auto interface_entry = socket_to_questions_.find(socket); if (interface_entry == socket_to_questions_.end()) return MdnsResponderErrorCode::kNoError; @@ -585,6 +624,7 @@ MdnsResponderErrorCode MdnsResponderAdapterImpl::StopSrvQuery( MdnsResponderErrorCode MdnsResponderAdapterImpl::StopTxtQuery( platform::UdpSocket* socket, const DomainName& service_instance) { + TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::StopTxtQuery"); auto interface_entry = socket_to_questions_.find(socket); if (interface_entry == socket_to_questions_.end()) return MdnsResponderErrorCode::kNoError; @@ -602,6 +642,7 @@ MdnsResponderErrorCode MdnsResponderAdapterImpl::StopTxtQuery( MdnsResponderErrorCode MdnsResponderAdapterImpl::StopAQuery( platform::UdpSocket* socket, const DomainName& domain_name) { + TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::StopAQuery"); auto interface_entry = socket_to_questions_.find(socket); if (interface_entry == socket_to_questions_.end()) return MdnsResponderErrorCode::kNoError; @@ -619,6 +660,7 @@ MdnsResponderErrorCode MdnsResponderAdapterImpl::StopAQuery( MdnsResponderErrorCode MdnsResponderAdapterImpl::StopAaaaQuery( platform::UdpSocket* socket, const DomainName& domain_name) { + TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::StopAaaaQuery"); auto interface_entry = socket_to_questions_.find(socket); if (interface_entry == socket_to_questions_.end()) return MdnsResponderErrorCode::kNoError; @@ -640,6 +682,8 @@ MdnsResponderErrorCode MdnsResponderAdapterImpl::RegisterService( const DomainName& target_host, uint16_t target_port, const std::map<std::string, std::string>& txt_data) { + TRACE_SCOPED(TraceCategory::mDNS, + "MdnsResponderAdapterImpl::RegisterService"); OSP_DCHECK(IsValidServiceName(service_name)); OSP_DCHECK(IsValidServiceProtocol(service_protocol)); service_records_.push_back(std::make_unique<ServiceRecordSet>()); @@ -683,6 +727,8 @@ MdnsResponderErrorCode MdnsResponderAdapterImpl::DeregisterService( const std::string& service_instance, const std::string& service_name, const std::string& service_protocol) { + TRACE_SCOPED(TraceCategory::mDNS, + "MdnsResponderAdapterImpl::DeregisterService"); domainlabel instance; domainlabel name; domainlabel protocol; @@ -711,6 +757,7 @@ MdnsResponderErrorCode MdnsResponderAdapterImpl::UpdateTxtData( const std::string& service_name, const std::string& service_protocol, const std::map<std::string, std::string>& txt_data) { + TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::UpdateTxtData"); domainlabel instance; domainlabel name; domainlabel protocol; @@ -744,6 +791,7 @@ void MdnsResponderAdapterImpl::AQueryCallback(mDNS* m, DNSQuestion* question, const ResourceRecord* answer, QC_result added) { + TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderAdapterImpl::AQueryCallback"); OSP_DCHECK(question); OSP_DCHECK(answer); OSP_DCHECK_EQ(answer->rrtype, kDNSType_A); @@ -774,6 +822,8 @@ void MdnsResponderAdapterImpl::AaaaQueryCallback(mDNS* m, DNSQuestion* question, const ResourceRecord* answer, QC_result added) { + TRACE_SCOPED(TraceCategory::mDNS, + "MdnsResponderAdapterImpl::AaaaQueryCallback"); OSP_DCHECK(question); OSP_DCHECK(answer); OSP_DCHECK_EQ(answer->rrtype, kDNSType_A); @@ -804,6 +854,8 @@ void MdnsResponderAdapterImpl::PtrQueryCallback(mDNS* m, DNSQuestion* question, const ResourceRecord* answer, QC_result added) { + TRACE_SCOPED(TraceCategory::mDNS, + "MdnsResponderAdapterImpl::PtrQueryCallback"); OSP_DCHECK(question); OSP_DCHECK(answer); OSP_DCHECK_EQ(answer->rrtype, kDNSType_PTR); @@ -833,6 +885,8 @@ void MdnsResponderAdapterImpl::SrvQueryCallback(mDNS* m, DNSQuestion* question, const ResourceRecord* answer, QC_result added) { + TRACE_SCOPED(TraceCategory::mDNS, + "MdnsResponderAdapterImpl::SrvQueryCallback"); OSP_DCHECK(question); OSP_DCHECK(answer); OSP_DCHECK_EQ(answer->rrtype, kDNSType_SRV); @@ -916,6 +970,8 @@ void MdnsResponderAdapterImpl::ServiceCallback(mDNS* m, } void MdnsResponderAdapterImpl::AdvertiseInterfaces() { + TRACE_SCOPED(TraceCategory::mDNS, + "MdnsResponderAdapterImpl::AdvertiseInterfaces"); for (auto& info : responder_interface_info_) { platform::UdpSocket* socket = info.first; NetworkInterfaceInfo& interface_info = info.second; diff --git a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl.h b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl.h index ea874ee220d..cddb137fc5b 100644 --- a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl.h +++ b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl.h @@ -34,12 +34,9 @@ class MdnsResponderAdapterImpl final : public MdnsResponderAdapter { platform::UdpSocket* socket) override; Error DeregisterInterface(platform::UdpSocket* socket) override; - void OnDataReceived(const IPEndpoint& source, - const IPEndpoint& original_destination, - const uint8_t* data, - size_t length, - platform::UdpSocket* receiving_socket) override; - int RunTasks() override; + void OnRead(platform::UdpPacket packet, + platform::NetworkRunner* network_runner) override; + absl::optional<platform::Clock::duration> RunTasks() override; std::vector<PtrEvent> TakePtrResponses() override; std::vector<SrvEvent> TakeSrvResponses() override; diff --git a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl_unittest.cc b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl_unittest.cc index e200c204315..0f3508b7cee 100644 --- a/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl_unittest.cc +++ b/chromium/third_party/openscreen/src/osp/impl/discovery/mdns/mdns_responder_adapter_impl_unittest.cc @@ -48,12 +48,18 @@ TEST(MdnsResponderAdapterImplTest, ExampleData) { 4, '_', 'u', 'd', 'p', 5, 'l', 'o', 'c', 'a', 'l', 0}}; const IPEndpoint mdns_endpoint{{224, 0, 0, 251}, 5353}; + platform::UdpPacket packet; + packet.set_source({{192, 168, 0, 2}, 6556}); + packet.set_destination(mdns_endpoint); + packet.reserve(sizeof(data)); + std::copy(std::begin(data), std::end(data), back_inserter(packet)); + packet.set_socket(nullptr); + auto mdns_adapter = std::unique_ptr<mdns::MdnsResponderAdapter>( new mdns::MdnsResponderAdapterImpl); mdns_adapter->Init(); mdns_adapter->StartPtrQuery(0, openscreen_service); - mdns_adapter->OnDataReceived({{192, 168, 0, 2}, 6556}, mdns_endpoint, data, - sizeof(data), 0); + mdns_adapter->OnRead(std::move(packet), nullptr); mdns_adapter->RunTasks(); auto ptr = mdns_adapter->TakePtrResponses(); diff --git a/chromium/third_party/openscreen/src/osp/impl/internal_services.cc b/chromium/third_party/openscreen/src/osp/impl/internal_services.cc index c948c403904..61bbffa9960 100644 --- a/chromium/third_party/openscreen/src/osp/impl/internal_services.cc +++ b/chromium/third_party/openscreen/src/osp/impl/internal_services.cc @@ -55,7 +55,7 @@ Error SetUpMulticastSocket(platform::UdpSocket* socket, return result; } - result = socket->Bind({{}, kMulticastListeningPort}); + result = socket->Bind(); if (!result.ok()) { OSP_LOG_ERROR << "bind failed for interface " << ifindex << ": " << result.message(); @@ -73,17 +73,11 @@ int g_instance_ref_count = 0; } // namespace // static -void InternalServices::RunEventLoopOnce() { - OSP_CHECK(g_instance) << "No listener or publisher is alive."; - g_instance->mdns_service_.HandleNewEvents( - platform::OnePlatformLoopIteration(g_instance->mdns_waiter_)); -} - -// static std::unique_ptr<ServiceListener> InternalServices::CreateListener( const MdnsServiceListenerConfig& config, - ServiceListener::Observer* observer) { - auto* services = ReferenceSingleton(); + ServiceListener::Observer* observer, + platform::NetworkRunner* network_runner) { + auto* services = ReferenceSingleton(network_runner); auto listener = std::make_unique<ServiceListenerImpl>(&services->mdns_service_); listener->AddObserver(observer); @@ -95,8 +89,9 @@ std::unique_ptr<ServiceListener> InternalServices::CreateListener( // static std::unique_ptr<ServicePublisher> InternalServices::CreatePublisher( const ServicePublisher::Config& config, - ServicePublisher::Observer* observer) { - auto* services = ReferenceSingleton(); + ServicePublisher::Observer* observer, + platform::NetworkRunner* network_runner) { + auto* services = ReferenceSingleton(network_runner); services->mdns_service_.SetServiceConfig( config.hostname, config.service_instance_name, config.connection_server_port, config.network_interface_indices, @@ -153,7 +148,8 @@ InternalServices::InternalPlatformLinkage::RegisterInterfaces( const platform::IPSubnet& primary_subnet = addr.addresses.front(); auto create_result = - platform::UdpSocket::Create(primary_subnet.address.version()); + platform::UdpSocket::Create(parent_->network_runner_, parent_, + IPEndpoint{{}, kMulticastListeningPort}); if (!create_result) { OSP_LOG_ERROR << "failed to create socket for interface " << index << ": " << create_result.error().message(); @@ -165,6 +161,7 @@ InternalServices::InternalPlatformLinkage::RegisterInterfaces( } result.emplace_back(addr.info, primary_subnet, socket.get()); parent_->RegisterMdnsSocket(socket.get()); + open_sockets_.emplace_back(std::move(socket)); } @@ -187,32 +184,31 @@ void InternalServices::InternalPlatformLinkage::DeregisterInterfaces( } } -InternalServices::InternalServices() - : mdns_service_(kServiceName, +InternalServices::InternalServices(platform::NetworkRunner* network_runner) + : mdns_service_(network_runner, + kServiceName, kServiceProtocol, std::make_unique<MdnsResponderAdapterImplFactory>(), std::make_unique<InternalPlatformLinkage>(this)), - mdns_waiter_(platform::CreateEventWaiter()) { - OSP_DCHECK(mdns_waiter_); -} + network_runner_(network_runner) {} -InternalServices::~InternalServices() { - DestroyEventWaiter(mdns_waiter_); -} +InternalServices::~InternalServices() = default; void InternalServices::RegisterMdnsSocket(platform::UdpSocket* socket) { - platform::WatchUdpSocketReadable(mdns_waiter_, socket); + OSP_CHECK(g_instance) << "No listener or publisher is alive."; + network_runner_->ReadRepeatedly(socket, &g_instance->mdns_service_); } void InternalServices::DeregisterMdnsSocket(platform::UdpSocket* socket) { - platform::StopWatchingUdpSocketReadable(mdns_waiter_, socket); + network_runner_->CancelRead(socket); } // static -InternalServices* InternalServices::ReferenceSingleton() { +InternalServices* InternalServices::ReferenceSingleton( + platform::NetworkRunner* network_runner) { if (!g_instance) { OSP_CHECK_EQ(g_instance_ref_count, 0); - g_instance = new InternalServices(); + g_instance = new InternalServices(network_runner); } ++g_instance_ref_count; return g_instance; @@ -229,4 +225,17 @@ void InternalServices::DereferenceSingleton(void* instance) { } } +void InternalServices::OnError(platform::UdpSocket* socket, Error error) { + OSP_UNIMPLEMENTED(); +} + +void InternalServices::OnSendError(platform::UdpSocket* socket, Error error) { + OSP_UNIMPLEMENTED(); +} + +void InternalServices::OnRead(platform::UdpSocket* socket, + ErrorOr<platform::UdpPacket> packet) { + OSP_UNIMPLEMENTED(); +} + } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/osp/impl/internal_services.h b/chromium/third_party/openscreen/src/osp/impl/internal_services.h index 82df8404e7f..0c50c2fd557 100644 --- a/chromium/third_party/openscreen/src/osp/impl/internal_services.h +++ b/chromium/third_party/openscreen/src/osp/impl/internal_services.h @@ -17,30 +17,37 @@ #include "osp/public/mdns_service_publisher_factory.h" #include "osp/public/protocol_connection_client.h" #include "osp/public/protocol_connection_server.h" -#include "platform/api/event_waiter.h" #include "platform/api/network_interface.h" #include "platform/api/udp_socket.h" #include "platform/base/ip_address.h" #include "platform/base/macros.h" -#include "platform/impl/event_loop.h" namespace openscreen { +namespace platform { +class NetworkRunner; +} // namespace platform // Factory for ServiceListener and ServicePublisher instances; owns internal // objects needed to instantiate them such as MdnsResponderService and runs an // event loop. // TODO(btolsch): This may be renamed and/or split up once QUIC code lands and // this use case is more concrete. -class InternalServices { +class InternalServices : platform::UdpSocket::Client { public: - static void RunEventLoopOnce(); - static std::unique_ptr<ServiceListener> CreateListener( const MdnsServiceListenerConfig& config, - ServiceListener::Observer* observer); + ServiceListener::Observer* observer, + platform::NetworkRunner* network_runner); static std::unique_ptr<ServicePublisher> CreatePublisher( const ServicePublisher::Config& config, - ServicePublisher::Observer* observer); + ServicePublisher::Observer* observer, + platform::NetworkRunner* network_runner); + + // UdpSocket::Client overrides. + void OnError(platform::UdpSocket* socket, Error error) override; + void OnSendError(platform::UdpSocket* socket, Error error) override; + void OnRead(platform::UdpSocket* socket, + ErrorOr<platform::UdpPacket> packet) override; private: class InternalPlatformLinkage final : public MdnsPlatformService { @@ -58,24 +65,22 @@ class InternalServices { std::vector<platform::UdpSocketUniquePtr> open_sockets_; }; - InternalServices(); - ~InternalServices(); + // Creates a new InsternalServices instance using the provided NetworkRunner. + // The NetworkRunner should live for the duration of this InternalService + // object's lifetime. + explicit InternalServices(platform::NetworkRunner* network_runner); + ~InternalServices() override; void RegisterMdnsSocket(platform::UdpSocket* socket); void DeregisterMdnsSocket(platform::UdpSocket* socket); - static InternalServices* ReferenceSingleton(); + static InternalServices* ReferenceSingleton( + platform::NetworkRunner* network_runner); static void DereferenceSingleton(void* instance); MdnsResponderService mdns_service_; - // TODO(btolsch): To support e.g. both QUIC and mDNS listening for separate - // sockets, we need to either: - // - give them their own individual waiter objects - // - remember who registered for what in a wrapper here - // - something else... - // Currently, RegisterMdnsSocket is our hook to do 1 or 2. - platform::EventWaiterPtr mdns_waiter_; + platform::NetworkRunner* const network_runner_; OSP_DISALLOW_COPY_AND_ASSIGN(InternalServices); }; diff --git a/chromium/third_party/openscreen/src/osp/impl/mdns_platform_service.h b/chromium/third_party/openscreen/src/osp/impl/mdns_platform_service.h index 6014c19b54a..be5efcee30b 100644 --- a/chromium/third_party/openscreen/src/osp/impl/mdns_platform_service.h +++ b/chromium/third_party/openscreen/src/osp/impl/mdns_platform_service.h @@ -7,10 +7,8 @@ #include <vector> -#include "platform/api/event_waiter.h" #include "platform/api/network_interface.h" #include "platform/api/udp_socket.h" -#include "platform/impl/event_loop.h" namespace openscreen { diff --git a/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service.cc b/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service.cc index 3599c508df3..260c1e08679 100644 --- a/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service.cc +++ b/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service.cc @@ -10,6 +10,7 @@ #include "osp/impl/internal_services.h" #include "platform/api/logging.h" +#include "platform/api/trace_logging.h" #include "platform/base/error.h" namespace openscreen { @@ -29,13 +30,15 @@ std::string ServiceIdFromServiceInstanceName( } // namespace MdnsResponderService::MdnsResponderService( + platform::NetworkRunner* network_runner, const std::string& service_name, const std::string& service_protocol, std::unique_ptr<MdnsResponderAdapterFactory> mdns_responder_factory, std::unique_ptr<MdnsPlatformService> platform) : service_type_{{service_name, service_protocol}}, mdns_responder_factory_(std::move(mdns_responder_factory)), - platform_(std::move(platform)) {} + platform_(std::move(platform)), + network_runner_(network_runner) {} MdnsResponderService::~MdnsResponderService() = default; @@ -55,34 +58,84 @@ void MdnsResponderService::SetServiceConfig( service_txt_data_ = txt_data; } -void MdnsResponderService::HandleNewEvents( - const std::vector<platform::UdpPacket>& packets) { - if (!mdns_responder_) +void MdnsResponderService::OnRead(platform::UdpPacket packet, + platform::NetworkRunner* network_runner) { + TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderService::OnRead"); + if (!mdns_responder_) { return; - for (auto& packet : packets) { - mdns_responder_->OnDataReceived(packet.source(), packet.destination(), - packet.data(), packet.size(), - packet.socket()); } - mdns_responder_->RunTasks(); + mdns_responder_->OnRead(std::move(packet), network_runner); HandleMdnsEvents(); } void MdnsResponderService::StartListener() { - if (!mdns_responder_) + network_runner_->PostTask([this]() { this->StartListenerInternal(); }); +} + +void MdnsResponderService::StartAndSuspendListener() { + network_runner_->PostTask( + [this]() { this->StartAndSuspendListenerInternal(); }); +} + +void MdnsResponderService::StopListener() { + network_runner_->PostTask([this]() { this->StopListenerInternal(); }); +} + +void MdnsResponderService::SuspendListener() { + network_runner_->PostTask([this]() { this->SuspendListenerInternal(); }); +} + +void MdnsResponderService::ResumeListener() { + network_runner_->PostTask([this]() { this->ResumeListenerInternal(); }); +} + +void MdnsResponderService::SearchNow(ServiceListener::State from) { + network_runner_->PostTask([this, from]() { this->SearchNowInternal(from); }); +} + +void MdnsResponderService::StartPublisher() { + network_runner_->PostTask([this]() { this->StartPublisherInternal(); }); +} + +void MdnsResponderService::StartAndSuspendPublisher() { + network_runner_->PostTask( + [this]() { this->StartAndSuspendPublisherInternal(); }); +} + +void MdnsResponderService::StopPublisher() { + network_runner_->PostTask([this]() { this->StopPublisherInternal(); }); +} + +void MdnsResponderService::SuspendPublisher() { + network_runner_->PostTask([this]() { this->SuspendPublisherInternal(); }); +} + +void MdnsResponderService::ResumePublisher() { + network_runner_->PostTask([this]() { this->ResumePublisherInternal(); }); +} + +void MdnsResponderService::StartListenerInternal() { + if (!mdns_responder_) { mdns_responder_ = mdns_responder_factory_->Create(); + } StartListening(); ServiceListenerImpl::Delegate::SetState(ServiceListener::State::kRunning); + // TODO(rwkeane): Use new Alarm class instead once owning CL is merged in. + // Then it can be more effectively cancelled when the state changes away from + // 'running'. + platform::RepeatingFunction::Post( + network_runner_, + std::bind(&mdns::MdnsResponderAdapter::RunTasks, mdns_responder_.get())); } -void MdnsResponderService::StartAndSuspendListener() { +void MdnsResponderService::StartAndSuspendListenerInternal() { mdns_responder_ = mdns_responder_factory_->Create(); ServiceListenerImpl::Delegate::SetState(ServiceListener::State::kSuspended); } -void MdnsResponderService::StopListener() { +void MdnsResponderService::StopListenerInternal() { StopListening(); if (!publisher_ || publisher_->state() == ServicePublisher::State::kStopped || publisher_->state() == ServicePublisher::State::kSuspended) { @@ -93,38 +146,37 @@ void MdnsResponderService::StopListener() { ServiceListenerImpl::Delegate::SetState(ServiceListener::State::kStopped); } -void MdnsResponderService::SuspendListener() { +void MdnsResponderService::SuspendListenerInternal() { StopMdnsResponder(); ServiceListenerImpl::Delegate::SetState(ServiceListener::State::kSuspended); } -void MdnsResponderService::ResumeListener() { +void MdnsResponderService::ResumeListenerInternal() { StartListening(); ServiceListenerImpl::Delegate::SetState(ServiceListener::State::kRunning); } -void MdnsResponderService::SearchNow(ServiceListener::State from) { +void MdnsResponderService::SearchNowInternal(ServiceListener::State from) { ServiceListenerImpl::Delegate::SetState(from); } -void MdnsResponderService::RunTasksListener() { - InternalServices::RunEventLoopOnce(); -} - -void MdnsResponderService::StartPublisher() { +void MdnsResponderService::StartPublisherInternal() { if (!mdns_responder_) mdns_responder_ = mdns_responder_factory_->Create(); StartService(); ServicePublisherImpl::Delegate::SetState(ServicePublisher::State::kRunning); + platform::RepeatingFunction::Post( + network_runner_, + std::bind(&mdns::MdnsResponderAdapter::RunTasks, mdns_responder_.get())); } -void MdnsResponderService::StartAndSuspendPublisher() { +void MdnsResponderService::StartAndSuspendPublisherInternal() { mdns_responder_ = mdns_responder_factory_->Create(); ServicePublisherImpl::Delegate::SetState(ServicePublisher::State::kSuspended); } -void MdnsResponderService::StopPublisher() { +void MdnsResponderService::StopPublisherInternal() { StopService(); if (!listener_ || listener_->state() == ServiceListener::State::kStopped || listener_->state() == ServiceListener::State::kSuspended) { @@ -135,20 +187,16 @@ void MdnsResponderService::StopPublisher() { ServicePublisherImpl::Delegate::SetState(ServicePublisher::State::kStopped); } -void MdnsResponderService::SuspendPublisher() { +void MdnsResponderService::SuspendPublisherInternal() { StopService(); ServicePublisherImpl::Delegate::SetState(ServicePublisher::State::kSuspended); } -void MdnsResponderService::ResumePublisher() { +void MdnsResponderService::ResumePublisherInternal() { StartService(); ServicePublisherImpl::Delegate::SetState(ServicePublisher::State::kRunning); } -void MdnsResponderService::RunTasksPublisher() { - InternalServices::RunEventLoopOnce(); -} - bool MdnsResponderService::NetworkScopedDomainNameComparator::operator()( const NetworkScopedDomainName& a, const NetworkScopedDomainName& b) const { @@ -159,6 +207,7 @@ bool MdnsResponderService::NetworkScopedDomainNameComparator::operator()( } void MdnsResponderService::HandleMdnsEvents() { + TRACE_SCOPED(TraceCategory::mDNS, "MdnsResponderService::HandleMdnsEvents"); // NOTE: In the common case, we will get a single combined packet for // PTR/SRV/TXT/A and then no other packets. If we don't loop here, we would // start SRV/TXT queries based on the PTR response, but never check for events @@ -193,8 +242,11 @@ void MdnsResponderService::HandleMdnsEvents() { events_possible = HandleAaaaEvent(aaaa_event, &modified_instance_names) || events_possible; } - if (events_possible) + if (events_possible) { + // NOTE: This still needs to be called here, even though it runs in the + // background regularly, because we just finished processing MDNS events. mdns_responder_->RunTasks(); + } } while (events_possible); for (const auto& instance_name : modified_instance_names) { @@ -256,8 +308,9 @@ void MdnsResponderService::StartListening() { ErrorOr<mdns::DomainName> service_type = mdns::DomainName::FromLabels(service_type_.begin(), service_type_.end()); OSP_CHECK(service_type); - for (const auto& interface : bound_interfaces_) + for (const auto& interface : bound_interfaces_) { mdns_responder_->StartPtrQuery(interface.socket, service_type.value()); + } } void MdnsResponderService::StopListening() { @@ -279,8 +332,9 @@ void MdnsResponderService::StopListening() { mdns_responder_->StopTxtQuery(socket, service.first); } service_by_name_.clear(); - for (const auto& interface : bound_interfaces_) + for (const auto& interface : bound_interfaces_) { mdns_responder_->StopPtrQuery(interface.socket, service_type.value()); + } RemoveAllReceivers(); } @@ -315,6 +369,7 @@ void MdnsResponderService::StartService() { interface.subnet, interface.socket); } } + ErrorOr<mdns::DomainName> domain_name = mdns::DomainName::FromLabels(&service_hostname_, &service_hostname_ + 1); OSP_CHECK(domain_name) << "bad hostname configured: " << service_hostname_; diff --git a/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service.h b/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service.h index d9d8a62e4a2..bcb24e7aac4 100644 --- a/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service.h +++ b/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service.h @@ -17,8 +17,8 @@ #include "osp/impl/service_listener_impl.h" #include "osp/impl/service_publisher_impl.h" #include "platform/api/network_interface.h" +#include "platform/api/network_runner.h" #include "platform/base/ip_address.h" -#include "platform/impl/event_loop.h" namespace openscreen { @@ -29,15 +29,17 @@ class MdnsResponderAdapterFactory { virtual std::unique_ptr<mdns::MdnsResponderAdapter> Create() = 0; }; -class MdnsResponderService final : public ServiceListenerImpl::Delegate, - public ServicePublisherImpl::Delegate { +class MdnsResponderService : public ServiceListenerImpl::Delegate, + public ServicePublisherImpl::Delegate, + public platform::UdpReadCallback { public: - explicit MdnsResponderService( + MdnsResponderService( + platform::NetworkRunner* network_runner, const std::string& service_name, const std::string& service_protocol, std::unique_ptr<MdnsResponderAdapterFactory> mdns_responder_factory, std::unique_ptr<MdnsPlatformService> platform); - ~MdnsResponderService() override; + virtual ~MdnsResponderService() override; void SetServiceConfig( const std::string& hostname, @@ -46,7 +48,9 @@ class MdnsResponderService final : public ServiceListenerImpl::Delegate, const std::vector<platform::NetworkInterfaceIndex> whitelist, const std::map<std::string, std::string>& txt_data); - void HandleNewEvents(const std::vector<platform::UdpPacket>& packets); + // UdpReadCallback overrides. + void OnRead(platform::UdpPacket packet, + platform::NetworkRunner* network_runner) override; // ServiceListenerImpl::Delegate overrides. void StartListener() override; @@ -55,7 +59,6 @@ class MdnsResponderService final : public ServiceListenerImpl::Delegate, void SuspendListener() override; void ResumeListener() override; void SearchNow(ServiceListener::State from) override; - void RunTasksListener() override; // ServicePublisherImpl::Delegate overrides. void StartPublisher() override; @@ -63,9 +66,30 @@ class MdnsResponderService final : public ServiceListenerImpl::Delegate, void StopPublisher() override; void SuspendPublisher() override; void ResumePublisher() override; - void RunTasksPublisher() override; + + protected: + void HandleMdnsEvents(); + + std::unique_ptr<mdns::MdnsResponderAdapter> mdns_responder_; private: + // Create internal versions of all public methods. These are used to push all + // calls to these methods to the task runner. + // TODO(rwkeane): Clean up these methods. Some result in multiple pushes to + // the task runner when just one would suffice. + // ServiceListenerImpl::Delegate overrides. + void StartListenerInternal(); + void StartAndSuspendListenerInternal(); + void StopListenerInternal(); + void SuspendListenerInternal(); + void ResumeListenerInternal(); + void SearchNowInternal(ServiceListener::State from); + void StartPublisherInternal(); + void StartAndSuspendPublisherInternal(); + void StopPublisherInternal(); + void SuspendPublisherInternal(); + void ResumePublisherInternal(); + // NOTE: service_instance implicit in map key. struct ServiceInstance { platform::UdpSocket* ptr_socket = nullptr; @@ -98,7 +122,6 @@ class MdnsResponderService final : public ServiceListenerImpl::Delegate, using InstanceNameSet = std::set<mdns::DomainName, mdns::DomainNameComparator>; - void HandleMdnsEvents(); void StartListening(); void StopListening(); void StartService(); @@ -148,7 +171,6 @@ class MdnsResponderService final : public ServiceListenerImpl::Delegate, std::map<std::string, std::string> service_txt_data_; std::unique_ptr<MdnsResponderAdapterFactory> mdns_responder_factory_; - std::unique_ptr<mdns::MdnsResponderAdapter> mdns_responder_; std::unique_ptr<MdnsPlatformService> platform_; std::vector<MdnsPlatformService::BoundInterface> bound_interfaces_; @@ -170,6 +192,10 @@ class MdnsResponderService final : public ServiceListenerImpl::Delegate, network_scoped_domain_to_host_; std::map<std::string, ServiceInfo> receiver_info_; + + platform::NetworkRunner* network_runner_; + + friend class TestingMdnsResponderService; }; } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service_unittest.cc b/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service_unittest.cc index cebf6890676..867967be2b2 100644 --- a/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service_unittest.cc +++ b/chromium/third_party/openscreen/src/osp/impl/mdns_responder_service_unittest.cc @@ -5,6 +5,7 @@ #include "osp/impl/mdns_responder_service.h" #include <cstdint> +#include <iostream> #include <memory> #include "gmock/gmock.h" @@ -12,17 +13,62 @@ #include "osp/impl/service_listener_impl.h" #include "osp/impl/testing/fake_mdns_platform_service.h" #include "osp/impl/testing/fake_mdns_responder_adapter.h" +#include "platform/test/fake_network_runner.h" namespace openscreen { -namespace { -using ::testing::_; +// Child of the MdnsResponderService for testing purposes. Only difference +// betweeen this and the base class is that methods on this class are executed +// synchronously, rather than pushed to the task runner for later execution. +class TestingMdnsResponderService final : public MdnsResponderService { + public: + TestingMdnsResponderService( + platform::FakeNetworkRunner* network_runner, + const std::string& service_name, + const std::string& service_protocol, + std::unique_ptr<MdnsResponderAdapterFactory> mdns_responder_factory, + std::unique_ptr<MdnsPlatformService> platform_service) + : MdnsResponderService(network_runner, + service_name, + service_protocol, + std::move(mdns_responder_factory), + std::move(platform_service)) {} + ~TestingMdnsResponderService() = default; + + // Override the default ServiceListenerImpl and ServicePublisherImpl + // implementations. These call the internal implementations of each of the + // methods provided, meaning that the end result of the call is the same, but + // without pushing to the task runner and waiting for it to be pulled off + // again. + // ServiceListenerImpl::Delegate overrides. + void StartListener() override { StartListenerInternal(); } + void StartAndSuspendListener() override { StartAndSuspendListenerInternal(); } + void StopListener() override { StopListenerInternal(); } + void SuspendListener() override { SuspendListenerInternal(); } + void ResumeListener() override { ResumeListenerInternal(); } + void SearchNow(ServiceListener::State from) override { + SearchNowInternal(from); + } -constexpr char kTestServiceInstance[] = "turtle"; -constexpr char kTestServiceName[] = "_foo"; -constexpr char kTestServiceProtocol[] = "_udp"; -constexpr char kTestHostname[] = "hostname"; -constexpr uint16_t kTestPort = 12345; + // ServicePublisherImpl::Delegate overrides. + void StartPublisher() override { StartPublisherInternal(); } + void StartAndSuspendPublisher() override { + StartAndSuspendPublisherInternal(); + } + void StopPublisher() override { StopPublisherInternal(); } + void SuspendPublisher() override { SuspendPublisherInternal(); } + void ResumePublisher() override { ResumePublisherInternal(); } + + // Handles new events as OnRead does, but without the need of a NetworkRunner. + void HandleNewEvents() { + if (!mdns_responder_) { + return; + } + + mdns_responder_->RunTasks(); + HandleMdnsEvents(); + } +}; class FakeMdnsResponderAdapterFactory final : public MdnsResponderAdapterFactory, @@ -62,6 +108,39 @@ class FakeMdnsResponderAdapterFactory final size_t last_registered_services_size_ = 0; }; +namespace { + +using ::testing::_; + +constexpr char kTestServiceInstance[] = "turtle"; +constexpr char kTestServiceName[] = "_foo"; +constexpr char kTestServiceProtocol[] = "_udp"; +constexpr char kTestHostname[] = "hostname"; +constexpr uint16_t kTestPort = 12345; + +// Wrapper around the above class. In MdnsResponderServiceTest, we need to both +// pass a unique_ptr to the created MdnsResponderService and to maintain a +// local pointer as well. Doing this with the same object causes a race +// condition, where ~FakeMdnsResponderAdapter() calls observer_->OnDestroyed() +// after the object is already deleted, resulting in a seg fault. This is to +// prevent that race condition. +class WrapperMdnsResponderAdapterFactory final + : public MdnsResponderAdapterFactory, + public FakeMdnsResponderAdapter::LifetimeObserver { + public: + WrapperMdnsResponderAdapterFactory(FakeMdnsResponderAdapterFactory* ptr) + : other_(ptr) {} + + std::unique_ptr<mdns::MdnsResponderAdapter> Create() override { + return other_->Create(); + } + + void OnDestroyed() override { other_->OnDestroyed(); } + + private: + FakeMdnsResponderAdapterFactory* other_; +}; + class MockServiceListenerObserver final : public ServiceListener::Observer { public: ~MockServiceListenerObserver() override = default; @@ -99,15 +178,17 @@ platform::UdpSocket* const kSecondSocket = class MdnsResponderServiceTest : public ::testing::Test { protected: void SetUp() override { - auto mdns_responder_factory = + mdns_responder_factory_ = std::make_unique<FakeMdnsResponderAdapterFactory>(); - mdns_responder_factory_ = mdns_responder_factory.get(); + auto wrapper_factory = std::make_unique<WrapperMdnsResponderAdapterFactory>( + mdns_responder_factory_.get()); + network_runner_ = std::make_unique<platform::FakeNetworkRunner>(); auto platform_service = std::make_unique<FakeMdnsPlatformService>(); fake_platform_service_ = platform_service.get(); fake_platform_service_->set_interfaces(bound_interfaces_); - mdns_service_ = std::make_unique<MdnsResponderService>( - kTestServiceName, kTestServiceProtocol, - std::move(mdns_responder_factory), std::move(platform_service)); + mdns_service_ = std::make_unique<TestingMdnsResponderService>( + network_runner_.get(), kTestServiceName, kTestServiceProtocol, + std::move(wrapper_factory), std::move(platform_service)); service_listener_ = std::make_unique<ServiceListenerImpl>(mdns_service_.get()); service_listener_->AddObserver(&observer_); @@ -118,10 +199,11 @@ class MdnsResponderServiceTest : public ::testing::Test { &publisher_observer_, mdns_service_.get()); } + std::unique_ptr<platform::FakeNetworkRunner> network_runner_; MockServiceListenerObserver observer_; FakeMdnsPlatformService* fake_platform_service_; - FakeMdnsResponderAdapterFactory* mdns_responder_factory_; - std::unique_ptr<MdnsResponderService> mdns_service_; + std::unique_ptr<FakeMdnsResponderAdapterFactory> mdns_responder_factory_; + std::unique_ptr<TestingMdnsResponderService> mdns_service_; std::unique_ptr<ServiceListenerImpl> service_listener_; MockServicePublisherObserver publisher_observer_; std::unique_ptr<ServicePublisherImpl> service_publisher_; @@ -162,7 +244,7 @@ TEST_F(MdnsResponderServiceTest, BasicServiceStates) { EXPECT_EQ((IPEndpoint{{192, 168, 3, 7}, kTestPort}), info.v4_endpoint); EXPECT_FALSE(info.v6_endpoint.address); })); - mdns_service_->HandleNewEvents({}); + mdns_service_->HandleNewEvents(); mdns_responder->AddAEvent(MakeAEvent( "gigliorononomicon", IPAddress{192, 168, 3, 8}, kDefaultSocket)); @@ -174,7 +256,7 @@ TEST_F(MdnsResponderServiceTest, BasicServiceStates) { EXPECT_EQ((IPEndpoint{{192, 168, 3, 8}, kTestPort}), info.v4_endpoint); EXPECT_FALSE(info.v6_endpoint.address); })); - mdns_service_->HandleNewEvents({}); + mdns_service_->HandleNewEvents(); auto ptr_remove = MakePtrEvent(kTestServiceInstance, kTestServiceName, kTestServiceProtocol, kDefaultSocket); @@ -185,7 +267,7 @@ TEST_F(MdnsResponderServiceTest, BasicServiceStates) { .WillOnce(::testing::Invoke([&service_id](const ServiceInfo& info) { EXPECT_EQ(service_id, info.service_id); })); - mdns_service_->HandleNewEvents({}); + mdns_service_->HandleNewEvents(); } TEST_F(MdnsResponderServiceTest, NetworkNetworkInterfaceIndex) { @@ -211,7 +293,7 @@ TEST_F(MdnsResponderServiceTest, NetworkNetworkInterfaceIndex) { .WillOnce(::testing::Invoke([](const ServiceInfo& info) { EXPECT_EQ(2, info.network_interface_index); })); - mdns_service_->HandleNewEvents({}); + mdns_service_->HandleNewEvents(); } TEST_F(MdnsResponderServiceTest, SimultaneousFieldChanges) { @@ -228,7 +310,7 @@ TEST_F(MdnsResponderServiceTest, SimultaneousFieldChanges) { kDefaultSocket); EXPECT_CALL(observer_, OnReceiverAdded(_)); - mdns_service_->HandleNewEvents({}); + mdns_service_->HandleNewEvents(); mdns_responder->AddSrvEvent( MakeSrvEvent(kTestServiceInstance, kTestServiceName, kTestServiceProtocol, @@ -246,7 +328,7 @@ TEST_F(MdnsResponderServiceTest, SimultaneousFieldChanges) { EXPECT_EQ(54321, info.v4_endpoint.port); EXPECT_FALSE(info.v6_endpoint.address); })); - mdns_service_->HandleNewEvents({}); + mdns_service_->HandleNewEvents(); } TEST_F(MdnsResponderServiceTest, SimultaneousHostAndAddressChange) { @@ -263,7 +345,7 @@ TEST_F(MdnsResponderServiceTest, SimultaneousHostAndAddressChange) { kDefaultSocket); EXPECT_CALL(observer_, OnReceiverAdded(_)); - mdns_service_->HandleNewEvents({}); + mdns_service_->HandleNewEvents(); auto srv_remove = MakeSrvEvent(kTestServiceInstance, kTestServiceName, kTestServiceProtocol, @@ -283,7 +365,7 @@ TEST_F(MdnsResponderServiceTest, SimultaneousHostAndAddressChange) { EXPECT_EQ((IPAddress{192, 168, 3, 10}), info.v4_endpoint.address); EXPECT_FALSE(info.v6_endpoint.address); })); - mdns_service_->HandleNewEvents({}); + mdns_service_->HandleNewEvents(); } TEST_F(MdnsResponderServiceTest, ListenerStateTransitions) { @@ -530,7 +612,7 @@ TEST_F(MdnsResponderServiceTest, AddressQueryStopped) { kDefaultSocket); EXPECT_CALL(observer_, OnReceiverAdded(_)); - mdns_service_->HandleNewEvents({}); + mdns_service_->HandleNewEvents(); auto srv_remove = MakeSrvEvent(kTestServiceInstance, kTestServiceName, kTestServiceProtocol, @@ -539,7 +621,7 @@ TEST_F(MdnsResponderServiceTest, AddressQueryStopped) { mdns_responder->AddSrvEvent(std::move(srv_remove)); EXPECT_CALL(observer_, OnReceiverRemoved(_)); - mdns_service_->HandleNewEvents({}); + mdns_service_->HandleNewEvents(); EXPECT_FALSE(mdns_responder->ptr_queries_empty()); EXPECT_FALSE(mdns_responder->srv_queries_empty()); @@ -564,7 +646,7 @@ TEST_F(MdnsResponderServiceTest, AddressQueryRefCount) { IPAddress{192, 168, 3, 7}, kDefaultSocket); EXPECT_CALL(observer_, OnReceiverAdded(_)).Times(2); - mdns_service_->HandleNewEvents({}); + mdns_service_->HandleNewEvents(); auto srv_remove = MakeSrvEvent(kTestServiceInstance, kTestServiceName, kTestServiceProtocol, @@ -573,7 +655,7 @@ TEST_F(MdnsResponderServiceTest, AddressQueryRefCount) { mdns_responder->AddSrvEvent(std::move(srv_remove)); EXPECT_CALL(observer_, OnReceiverRemoved(_)); - mdns_service_->HandleNewEvents({}); + mdns_service_->HandleNewEvents(); EXPECT_FALSE(mdns_responder->ptr_queries_empty()); EXPECT_FALSE(mdns_responder->srv_queries_empty()); @@ -588,7 +670,7 @@ TEST_F(MdnsResponderServiceTest, AddressQueryRefCount) { mdns_responder->AddSrvEvent(std::move(srv_remove)); EXPECT_CALL(observer_, OnReceiverRemoved(_)); - mdns_service_->HandleNewEvents({}); + mdns_service_->HandleNewEvents(); EXPECT_FALSE(mdns_responder->ptr_queries_empty()); EXPECT_FALSE(mdns_responder->srv_queries_empty()); @@ -609,7 +691,7 @@ TEST_F(MdnsResponderServiceTest, ServiceQueriesStoppedSrvFirst) { kDefaultSocket); EXPECT_CALL(observer_, OnReceiverAdded(_)); - mdns_service_->HandleNewEvents({}); + mdns_service_->HandleNewEvents(); auto srv_remove = MakeSrvEvent(kTestServiceInstance, kTestServiceName, kTestServiceProtocol, @@ -618,7 +700,7 @@ TEST_F(MdnsResponderServiceTest, ServiceQueriesStoppedSrvFirst) { mdns_responder->AddSrvEvent(std::move(srv_remove)); EXPECT_CALL(observer_, OnReceiverRemoved(_)); - mdns_service_->HandleNewEvents({}); + mdns_service_->HandleNewEvents(); EXPECT_FALSE(mdns_responder->ptr_queries_empty()); EXPECT_FALSE(mdns_responder->srv_queries_empty()); @@ -630,7 +712,7 @@ TEST_F(MdnsResponderServiceTest, ServiceQueriesStoppedSrvFirst) { kTestServiceProtocol, kDefaultSocket); ptr_remove.header.response_type = mdns::QueryEventHeader::Type::kRemoved; mdns_responder->AddPtrEvent(std::move(ptr_remove)); - mdns_service_->HandleNewEvents({}); + mdns_service_->HandleNewEvents(); EXPECT_FALSE(mdns_responder->ptr_queries_empty()); EXPECT_TRUE(mdns_responder->srv_queries_empty()); @@ -651,7 +733,7 @@ TEST_F(MdnsResponderServiceTest, ServiceQueriesStoppedPtrFirst) { kDefaultSocket); EXPECT_CALL(observer_, OnReceiverAdded(_)); - mdns_service_->HandleNewEvents({}); + mdns_service_->HandleNewEvents(); auto ptr_remove = MakePtrEvent(kTestServiceInstance, kTestServiceName, kTestServiceProtocol, kDefaultSocket); @@ -659,7 +741,7 @@ TEST_F(MdnsResponderServiceTest, ServiceQueriesStoppedPtrFirst) { mdns_responder->AddPtrEvent(std::move(ptr_remove)); EXPECT_CALL(observer_, OnReceiverRemoved(_)); - mdns_service_->HandleNewEvents({}); + mdns_service_->HandleNewEvents(); EXPECT_FALSE(mdns_responder->ptr_queries_empty()); EXPECT_FALSE(mdns_responder->srv_queries_empty()); @@ -672,7 +754,7 @@ TEST_F(MdnsResponderServiceTest, ServiceQueriesStoppedPtrFirst) { "gigliorononomicon", kTestPort, kDefaultSocket); srv_remove.header.response_type = mdns::QueryEventHeader::Type::kRemoved; mdns_responder->AddSrvEvent(std::move(srv_remove)); - mdns_service_->HandleNewEvents({}); + mdns_service_->HandleNewEvents(); EXPECT_FALSE(mdns_responder->ptr_queries_empty()); EXPECT_TRUE(mdns_responder->srv_queries_empty()); @@ -697,7 +779,7 @@ TEST_F(MdnsResponderServiceTest, MultipleInterfaceRemove) { kSecondSocket); EXPECT_CALL(observer_, OnReceiverAdded(_)); - mdns_service_->HandleNewEvents({}); + mdns_service_->HandleNewEvents(); auto srv_remove1 = MakeSrvEvent(kTestServiceInstance, kTestServiceName, kTestServiceProtocol, @@ -706,7 +788,7 @@ TEST_F(MdnsResponderServiceTest, MultipleInterfaceRemove) { mdns_responder->AddSrvEvent(std::move(srv_remove1)); EXPECT_CALL(observer_, OnReceiverChanged(_)).Times(0); EXPECT_CALL(observer_, OnReceiverRemoved(_)).Times(0); - mdns_service_->HandleNewEvents({}); + mdns_service_->HandleNewEvents(); auto srv_remove2 = MakeSrvEvent(kTestServiceInstance, kTestServiceName, kTestServiceProtocol, @@ -714,14 +796,14 @@ TEST_F(MdnsResponderServiceTest, MultipleInterfaceRemove) { srv_remove2.header.response_type = mdns::QueryEventHeader::Type::kRemoved; mdns_responder->AddSrvEvent(std::move(srv_remove2)); EXPECT_CALL(observer_, OnReceiverRemoved(_)); - mdns_service_->HandleNewEvents({}); + mdns_service_->HandleNewEvents(); EXPECT_TRUE(mdns_responder->a_queries_empty()); auto ptr_remove = MakePtrEvent(kTestServiceInstance, kTestServiceName, kTestServiceProtocol, kDefaultSocket); ptr_remove.header.response_type = mdns::QueryEventHeader::Type::kRemoved; mdns_responder->AddPtrEvent(std::move(ptr_remove)); - mdns_service_->HandleNewEvents({}); + mdns_service_->HandleNewEvents(); EXPECT_FALSE(mdns_responder->ptr_queries_empty()); EXPECT_TRUE(mdns_responder->srv_queries_empty()); @@ -766,7 +848,7 @@ TEST_F(MdnsResponderServiceTest, RestorePtrNotifiesObserver) { kDefaultSocket); EXPECT_CALL(observer_, OnReceiverAdded(_)); - mdns_service_->HandleNewEvents({}); + mdns_service_->HandleNewEvents(); auto ptr_remove = MakePtrEvent(kTestServiceInstance, kTestServiceName, kTestServiceProtocol, kDefaultSocket); @@ -774,14 +856,14 @@ TEST_F(MdnsResponderServiceTest, RestorePtrNotifiesObserver) { mdns_responder->AddPtrEvent(std::move(ptr_remove)); EXPECT_CALL(observer_, OnReceiverRemoved(_)); - mdns_service_->HandleNewEvents({}); + mdns_service_->HandleNewEvents(); auto ptr_add = MakePtrEvent(kTestServiceInstance, kTestServiceName, kTestServiceProtocol, kDefaultSocket); mdns_responder->AddPtrEvent(std::move(ptr_add)); EXPECT_CALL(observer_, OnReceiverAdded(_)); - mdns_service_->HandleNewEvents({}); + mdns_service_->HandleNewEvents(); } } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/osp/impl/mdns_service_listener_factory.cc b/chromium/third_party/openscreen/src/osp/impl/mdns_service_listener_factory.cc index d6b7d0f2e84..6be2184ae5a 100644 --- a/chromium/third_party/openscreen/src/osp/impl/mdns_service_listener_factory.cc +++ b/chromium/third_party/openscreen/src/osp/impl/mdns_service_listener_factory.cc @@ -7,12 +7,16 @@ #include "osp/impl/internal_services.h" namespace openscreen { +namespace platform { +class NetworkRunner; +} // namespace platform // static std::unique_ptr<ServiceListener> MdnsServiceListenerFactory::Create( const MdnsServiceListenerConfig& config, - ServiceListener::Observer* observer) { - return InternalServices::CreateListener(config, observer); + ServiceListener::Observer* observer, + platform::NetworkRunner* network_runner) { + return InternalServices::CreateListener(config, observer, network_runner); } } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/osp/impl/mdns_service_publisher_factory.cc b/chromium/third_party/openscreen/src/osp/impl/mdns_service_publisher_factory.cc index 05cc6952a84..3808fe8f67a 100644 --- a/chromium/third_party/openscreen/src/osp/impl/mdns_service_publisher_factory.cc +++ b/chromium/third_party/openscreen/src/osp/impl/mdns_service_publisher_factory.cc @@ -7,12 +7,16 @@ #include "osp/impl/internal_services.h" namespace openscreen { +namespace platform { +class NetworkRunner; +} // namespace platform // static std::unique_ptr<ServicePublisher> MdnsServicePublisherFactory::Create( const ServicePublisher::Config& config, - ServicePublisher::Observer* observer) { - return InternalServices::CreatePublisher(config, observer); + ServicePublisher::Observer* observer, + platform::NetworkRunner* network_runner) { + return InternalServices::CreatePublisher(config, observer, network_runner); } } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/osp/impl/network_service_manager.cc b/chromium/third_party/openscreen/src/osp/impl/network_service_manager.cc index 92a43384342..19414d2caae 100644 --- a/chromium/third_party/openscreen/src/osp/impl/network_service_manager.cc +++ b/chromium/third_party/openscreen/src/osp/impl/network_service_manager.cc @@ -44,17 +44,6 @@ void NetworkServiceManager::Dispose() { g_network_service_manager_instance = nullptr; } -void NetworkServiceManager::RunEventLoopOnce() { - if (mdns_listener_) - mdns_listener_->RunTasks(); - if (mdns_publisher_) - mdns_publisher_->RunTasks(); - if (connection_client_) - connection_client_->RunTasks(); - if (connection_server_) - connection_server_->RunTasks(); -} - ServiceListener* NetworkServiceManager::GetMdnsServiceListener() { return mdns_listener_.get(); } diff --git a/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_connection.cc b/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_connection.cc index 6bb789a28f9..cee8a60efe8 100644 --- a/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_connection.cc +++ b/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_connection.cc @@ -199,7 +199,7 @@ ErrorOr<size_t> ConnectionManager::OnStreamMessage( Connection* connection = GetConnection(message.connection_id); if (!connection) { - return Error::Code::kNoItemFound; + return Error::Code::kItemNotFound; } switch (message.message.which) { diff --git a/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_connection_unittest.cc b/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_connection_unittest.cc index 0c5096cf972..666ef53a2a7 100644 --- a/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_connection_unittest.cc +++ b/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_connection_unittest.cc @@ -16,6 +16,7 @@ #include "osp/public/network_service_manager.h" #include "osp/public/presentation/presentation_controller.h" #include "platform/test/fake_clock.h" +#include "platform/test/fake_network_runner.h" namespace openscreen { namespace presentation { @@ -55,11 +56,22 @@ class MockConnectRequest final } // namespace class ConnectionTest : public ::testing::Test { + public: + ConnectionTest() { + network_runner_ = std::make_unique<platform::FakeNetworkRunner>(); + quic_bridge_ = std::make_unique<FakeQuicBridge>(network_runner_.get(), + platform::FakeClock::now); + controller_connection_manager_ = std::make_unique<ConnectionManager>( + quic_bridge_->controller_demuxer.get()); + receiver_connection_manager_ = std::make_unique<ConnectionManager>( + quic_bridge_->receiver_demuxer.get()); + } + protected: void SetUp() override { NetworkServiceManager::Create(nullptr, nullptr, - std::move(quic_bridge_.quic_client), - std::move(quic_bridge_.quic_server)); + std::move(quic_bridge_->quic_client), + std::move(quic_bridge_->quic_server)); } void TearDown() override { NetworkServiceManager::Dispose(); } @@ -74,13 +86,12 @@ class ConnectionTest : public ::testing::Test { return response; } + std::unique_ptr<platform::FakeNetworkRunner> network_runner_; platform::FakeClock fake_clock_{ platform::Clock::time_point(std::chrono::milliseconds(1298424))}; - FakeQuicBridge quic_bridge_{platform::FakeClock::now}; - ConnectionManager controller_connection_manager_{ - quic_bridge_.controller_demuxer.get()}; - ConnectionManager receiver_connection_manager_{ - quic_bridge_.receiver_demuxer.get()}; + std::unique_ptr<FakeQuicBridge> quic_bridge_; + std::unique_ptr<ConnectionManager> controller_connection_manager_; + std::unique_ptr<ConnectionManager> receiver_connection_manager_; NiceMock<MockParentDelegate> mock_controller_; NiceMock<MockParentDelegate> mock_receiver_; }; @@ -128,24 +139,28 @@ TEST_F(ConnectionTest, ConnectAndSend) { EXPECT_EQ(Connection::State::kConnecting, controller.state()); EXPECT_EQ(Connection::State::kConnecting, receiver.state()); + std::cout << "1\n"; + std::cout.flush(); MockConnectRequest mock_connect_request; std::unique_ptr<ProtocolConnection> controller_stream; std::unique_ptr<ProtocolConnection> receiver_stream; NetworkServiceManager::Get()->GetProtocolConnectionClient()->Connect( - quic_bridge_.kReceiverEndpoint, &mock_connect_request); + quic_bridge_->kReceiverEndpoint, &mock_connect_request); EXPECT_CALL(mock_connect_request, OnConnectionOpenedMock(_, _)) .WillOnce(Invoke([&controller_stream](uint64_t request_id, ProtocolConnection* stream) { controller_stream.reset(stream); })); - EXPECT_CALL(quic_bridge_.mock_server_observer, OnIncomingConnectionMock(_)) + std::cout << "2\n"; + std::cout.flush(); + EXPECT_CALL(quic_bridge_->mock_server_observer, OnIncomingConnectionMock(_)) .WillOnce(testing::WithArgs<0>(testing::Invoke( [&receiver_stream](std::unique_ptr<ProtocolConnection>& connection) { receiver_stream = std::move(connection); }))); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); ASSERT_TRUE(controller_stream); ASSERT_TRUE(receiver_stream); @@ -157,8 +172,8 @@ TEST_F(ConnectionTest, ConnectAndSend) { std::move(controller_stream)); receiver.OnConnected(connection_id, controller_endpoint_id, std::move(receiver_stream)); - controller_connection_manager_.AddConnection(&controller); - receiver_connection_manager_.AddConnection(&receiver); + controller_connection_manager_->AddConnection(&controller); + receiver_connection_manager_->AddConnection(&receiver); EXPECT_EQ(Connection::State::kConnected, controller.state()); EXPECT_EQ(Connection::State::kConnected, receiver.state()); @@ -174,7 +189,7 @@ TEST_F(ConnectionTest, ConnectAndSend) { OnStringMessage(static_cast<absl::string_view>(expected_message))) .WillOnce(Invoke( [&received](absl::string_view s) { received = std::string(s); })); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); std::string string_response = MakeEchoResponse(received); receiver.SendString(string_response); @@ -182,7 +197,7 @@ TEST_F(ConnectionTest, ConnectAndSend) { EXPECT_CALL( mock_controller_delegate, OnStringMessage(static_cast<absl::string_view>(expected_response))); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); std::vector<uint8_t> data{0, 3, 2, 4, 4, 6, 1}; const std::vector<uint8_t> expected_data = data; @@ -196,20 +211,20 @@ TEST_F(ConnectionTest, ConnectAndSend) { .WillOnce(Invoke([&received_data](std::vector<uint8_t> d) { received_data = std::move(d); })); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); receiver.SendBinary(MakeEchoResponse(received_data)); EXPECT_CALL(mock_controller_delegate, OnBinaryMessage(expected_response_data)); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); EXPECT_CALL(mock_controller_delegate, OnClosedByRemote()); receiver.Close(Connection::CloseReason::kClosed); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); EXPECT_EQ(Connection::State::kClosed, controller.state()); EXPECT_EQ(Connection::State::kClosed, receiver.state()); - controller_connection_manager_.RemoveConnection(&controller); - receiver_connection_manager_.RemoveConnection(&receiver); + controller_connection_manager_->RemoveConnection(&controller); + receiver_connection_manager_->RemoveConnection(&receiver); } } // namespace presentation diff --git a/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_controller_unittest.cc b/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_controller_unittest.cc index 5b09d10cf03..991bc90dece 100644 --- a/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_controller_unittest.cc +++ b/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_controller_unittest.cc @@ -16,6 +16,7 @@ #include "osp/public/network_service_manager.h" #include "osp/public/testing/message_demuxer_test_support.h" #include "platform/test/fake_clock.h" +#include "platform/test/fake_network_runner.h" namespace openscreen { namespace presentation { @@ -74,22 +75,31 @@ class MockRequestDelegate final : public RequestDelegate { } // namespace class ControllerTest : public ::testing::Test { + public: + ControllerTest() { + network_runner_ = std::make_unique<platform::FakeNetworkRunner>(); + quic_bridge_ = std::make_unique<FakeQuicBridge>(network_runner_.get(), + platform::FakeClock::now); + receiver_info1 = { + "service-id1", "lucas-auer", 1, quic_bridge_->kReceiverEndpoint, {}}; + } + protected: void SetUp() override { auto service_listener = std::make_unique<ServiceListenerImpl>(&mock_listener_delegate_); NetworkServiceManager::Create(std::move(service_listener), nullptr, - std::move(quic_bridge_.quic_client), - std::move(quic_bridge_.quic_server)); + std::move(quic_bridge_->quic_client), + std::move(quic_bridge_->quic_server)); controller_ = std::make_unique<Controller>(platform::FakeClock::now); - ON_CALL(quic_bridge_.mock_server_observer, OnIncomingConnectionMock(_)) + ON_CALL(quic_bridge_->mock_server_observer, OnIncomingConnectionMock(_)) .WillByDefault( Invoke([this](std::unique_ptr<ProtocolConnection>& connection) { controller_endpoint_id_ = connection->endpoint_id(); })); availability_watch_ = - quic_bridge_.receiver_demuxer->SetDefaultMessageTypeWatch( + quic_bridge_->receiver_demuxer->SetDefaultMessageTypeWatch( msgs::Type::kPresentationUrlAvailabilityRequest, &mock_callback_); } @@ -114,7 +124,7 @@ class ControllerTest : public ::testing::Test { buffer, buffer_size, request); return decode_result; })); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); ASSERT_EQ(msg_type, msgs::Type::kPresentationUrlAvailabilityRequest); ASSERT_GT(decode_result, 0); } @@ -205,7 +215,7 @@ class ControllerTest : public ::testing::Test { })); connection->Close(Connection::CloseReason::kClosed); EXPECT_EQ(connection->state(), Connection::State::kClosed); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); ASSERT_EQ(msg_type, msgs::Type::kPresentationConnectionCloseRequest); ASSERT_GT(decode_result, 0); } @@ -242,10 +252,10 @@ class ControllerTest : public ::testing::Test { MockConnectionDelegate* mock_connection_delegate, std::unique_ptr<Connection>* connection) { MessageDemuxer::MessageWatch start_presentation_watch = - quic_bridge_.receiver_demuxer->SetDefaultMessageTypeWatch( + quic_bridge_->receiver_demuxer->SetDefaultMessageTypeWatch( msgs::Type::kPresentationStartRequest, mock_callback); mock_listener_delegate_.listener()->OnReceiverAdded(receiver_info1); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); MockRequestDelegate mock_request_delegate; msgs::PresentationStartRequest request; @@ -265,7 +275,7 @@ class ControllerTest : public ::testing::Test { "https://example.com/receiver.html", receiver_info1.service_id, &mock_request_delegate, mock_connection_delegate); ASSERT_TRUE(connect_request); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); ASSERT_EQ(msgs::Type::kPresentationStartRequest, msg_type); msgs::PresentationStartResponse response; @@ -279,22 +289,19 @@ class ControllerTest : public ::testing::Test { *connection = std::move(c); })); EXPECT_CALL(*mock_connection_delegate, OnConnected()); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); ASSERT_TRUE(*connection); } + std::unique_ptr<platform::FakeNetworkRunner> network_runner_; MessageDemuxer::MessageWatch availability_watch_; MockMessageCallback mock_callback_; platform::FakeClock fake_clock_{platform::Clock::time_point(seconds(11111))}; - FakeQuicBridge quic_bridge_{platform::FakeClock::now}; + std::unique_ptr<FakeQuicBridge> quic_bridge_; MockServiceListenerDelegate mock_listener_delegate_; std::unique_ptr<Controller> controller_; - ServiceInfo receiver_info1{"service-id1", - "lucas-auer", - 1, - quic_bridge_.kReceiverEndpoint, - {}}; + ServiceInfo receiver_info1; MockReceiverObserver mock_receiver_observer_; uint64_t controller_endpoint_id_{0}; }; @@ -345,7 +352,7 @@ TEST_F(ControllerTest, ReceiverAvailable) { response.url_availabilities.push_back(msgs::UrlAvailability::kAvailable); SendAvailabilityResponse(response); EXPECT_CALL(mock_receiver_observer_, OnReceiverAvailable(_, _)); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); MockReceiverObserver mock_receiver_observer2; EXPECT_CALL(mock_receiver_observer2, OnReceiverAvailable(_, _)); @@ -366,7 +373,7 @@ TEST_F(ControllerTest, ReceiverWatchCancel) { response.url_availabilities.push_back(msgs::UrlAvailability::kAvailable); SendAvailabilityResponse(response); EXPECT_CALL(mock_receiver_observer_, OnReceiverAvailable(_, _)); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); MockReceiverObserver mock_receiver_observer2; EXPECT_CALL(mock_receiver_observer2, OnReceiverAvailable(_, _)); @@ -381,7 +388,7 @@ TEST_F(ControllerTest, ReceiverWatchCancel) { EXPECT_CALL(mock_receiver_observer2, OnReceiverUnavailable(_, _)); EXPECT_CALL(mock_receiver_observer_, OnReceiverUnavailable(_, _)).Times(0); SendAvailabilityEvent(event); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); } TEST_F(ControllerTest, StartPresentation) { @@ -398,7 +405,7 @@ TEST_F(ControllerTest, TerminatePresentationFromController) { StartPresentation(&mock_callback, &mock_connection_delegate, &connection); MessageDemuxer::MessageWatch terminate_presentation_watch = - quic_bridge_.receiver_demuxer->SetDefaultMessageTypeWatch( + quic_bridge_->receiver_demuxer->SetDefaultMessageTypeWatch( msgs::Type::kPresentationTerminationRequest, &mock_callback); msgs::PresentationTerminationRequest termination_request; msgs::Type msg_type; @@ -414,7 +421,7 @@ TEST_F(ControllerTest, TerminatePresentationFromController) { return result; })); connection->Terminate(TerminationReason::kControllerTerminateCalled); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); ASSERT_EQ(msgs::Type::kPresentationTerminationRequest, msg_type); msgs::PresentationTerminationResponse termination_response; @@ -425,7 +432,7 @@ TEST_F(ControllerTest, TerminatePresentationFromController) { // TODO(btolsch): Check OnTerminated of other connections when reconnect // lands. - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); } TEST_F(ControllerTest, TerminatePresentationFromReceiver) { @@ -441,7 +448,7 @@ TEST_F(ControllerTest, TerminatePresentationFromReceiver) { SendTerminationEvent(termination_event); EXPECT_CALL(mock_connection_delegate, OnTerminated()); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); } TEST_F(ControllerTest, CloseConnection) { @@ -451,7 +458,7 @@ TEST_F(ControllerTest, CloseConnection) { StartPresentation(&mock_callback, &mock_connection_delegate, &connection); MessageDemuxer::MessageWatch close_request_watch = - quic_bridge_.receiver_demuxer->SetDefaultMessageTypeWatch( + quic_bridge_->receiver_demuxer->SetDefaultMessageTypeWatch( msgs::Type::kPresentationConnectionCloseRequest, &mock_callback); msgs::PresentationConnectionCloseRequest close_request; ExpectCloseRequest(&mock_callback, &close_request, connection.get()); @@ -461,7 +468,7 @@ TEST_F(ControllerTest, CloseConnection) { close_response.result = msgs::PresentationConnectionCloseResponse_result::kSuccess; SendCloseResponse(close_response); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); } TEST_F(ControllerTest, Reconnect) { @@ -471,7 +478,7 @@ TEST_F(ControllerTest, Reconnect) { StartPresentation(&mock_callback, &mock_connection_delegate, &connection); MessageDemuxer::MessageWatch close_request_watch = - quic_bridge_.receiver_demuxer->SetDefaultMessageTypeWatch( + quic_bridge_->receiver_demuxer->SetDefaultMessageTypeWatch( msgs::Type::kPresentationConnectionCloseRequest, &mock_callback); msgs::PresentationConnectionCloseRequest close_request; ExpectCloseRequest(&mock_callback, &close_request, connection.get()); @@ -481,10 +488,10 @@ TEST_F(ControllerTest, Reconnect) { close_response.result = msgs::PresentationConnectionCloseResponse_result::kSuccess; SendCloseResponse(close_response); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); MessageDemuxer::MessageWatch connection_open_watch = - quic_bridge_.receiver_demuxer->SetDefaultMessageTypeWatch( + quic_bridge_->receiver_demuxer->SetDefaultMessageTypeWatch( msgs::Type::kPresentationConnectionOpenRequest, &mock_callback); msgs::PresentationConnectionOpenRequest open_request; MockRequestDelegate reconnect_delegate; @@ -505,7 +512,7 @@ TEST_F(ControllerTest, Reconnect) { buffer, buffer_size, &open_request); return decode_result; })); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); ASSERT_FALSE(connection); ASSERT_EQ(msg_type, msgs::Type::kPresentationConnectionOpenRequest); @@ -522,7 +529,7 @@ TEST_F(ControllerTest, Reconnect) { connection = std::move(c); })); EXPECT_CALL(mock_connection_delegate, OnConnected()); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); ASSERT_TRUE(connection); EXPECT_EQ(connection->state(), Connection::State::kConnected); } diff --git a/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_receiver.cc b/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_receiver.cc index 95888ff5b29..5c3cd0d1175 100644 --- a/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_receiver.cc +++ b/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_receiver.cc @@ -14,6 +14,7 @@ #include "osp/public/protocol_connection_server.h" #include "platform/api/logging.h" #include "platform/api/time.h" +#include "platform/api/trace_logging.h" namespace openscreen { namespace presentation { @@ -105,8 +106,11 @@ ErrorOr<size_t> Receiver::OnStreamMessage(uint64_t endpoint_id, const uint8_t* buffer, size_t buffer_size, platform::Clock::time_point now) { + TRACE_SCOPED(TraceCategory::Presentation, "Receiver::OnStreamMessage"); switch (message_type) { case msgs::Type::kPresentationUrlAvailabilityRequest: { + TRACE_SCOPED(TraceCategory::Presentation, + "kPresentationUrlAvailabilityRequest"); OSP_VLOG << "got presentation-url-availability-request"; msgs::PresentationUrlAvailabilityRequest request; ssize_t decode_result = msgs::DecodePresentationUrlAvailabilityRequest( @@ -114,6 +118,7 @@ ErrorOr<size_t> Receiver::OnStreamMessage(uint64_t endpoint_id, if (decode_result < 0) { OSP_LOG_WARN << "Presentation-url-availability-request parse error: " << decode_result; + TRACE_SET_RESULT(Error::Code::kParseError); return Error::Code::kParseError; } @@ -130,6 +135,7 @@ ErrorOr<size_t> Receiver::OnStreamMessage(uint64_t endpoint_id, } case msgs::Type::kPresentationStartRequest: { + TRACE_SCOPED(TraceCategory::Presentation, "kPresentationStartRequest"); OSP_VLOG << "got presentation-start-request"; msgs::PresentationStartRequest request; const ssize_t result = @@ -137,6 +143,7 @@ ErrorOr<size_t> Receiver::OnStreamMessage(uint64_t endpoint_id, if (result < 0) { OSP_LOG_WARN << "Presentation-initiation-request parse error: " << result; + TRACE_SET_RESULT(Error::Code::kParseError); return Error::Code::kParseError; } @@ -151,8 +158,11 @@ ErrorOr<size_t> Receiver::OnStreamMessage(uint64_t endpoint_id, Error write_error = WritePresentationInitiationResponse( response, GetProtocolConnection(endpoint_id).get()); - if (!write_error.ok()) + if (!write_error.ok()) { + TRACE_SET_RESULT(write_error); return write_error; + } + return result; } @@ -177,12 +187,17 @@ ErrorOr<size_t> Receiver::OnStreamMessage(uint64_t endpoint_id, response.result = msgs::PresentationStartResponse_result::kUnknownError; Error write_error = WritePresentationInitiationResponse( response, GetProtocolConnection(endpoint_id).get()); - if (!write_error.ok()) + if (!write_error.ok()) { + TRACE_SET_RESULT(write_error); return write_error; + } + return result; } case msgs::Type::kPresentationConnectionOpenRequest: { + TRACE_SCOPED(TraceCategory::Presentation, + "kPresentationConnectionOpenRequest"); OSP_VLOG << "Got a presentation-connection-open-request"; msgs::PresentationConnectionOpenRequest request; const ssize_t result = msgs::DecodePresentationConnectionOpenRequest( @@ -190,6 +205,7 @@ ErrorOr<size_t> Receiver::OnStreamMessage(uint64_t endpoint_id, if (result < 0) { OSP_LOG_WARN << "Presentation-connection-open-request parse error: " << result; + TRACE_SET_RESULT(Error::Code::kParseError); return Error::Code::kParseError; } @@ -207,8 +223,11 @@ ErrorOr<size_t> Receiver::OnStreamMessage(uint64_t endpoint_id, kInvalidPresentationId; Error write_error = WritePresentationConnectionOpenResponse( response, GetProtocolConnection(endpoint_id).get()); - if (!write_error.ok()) + if (!write_error.ok()) { + TRACE_SET_RESULT(write_error); return write_error; + } + return result; } @@ -236,12 +255,17 @@ ErrorOr<size_t> Receiver::OnStreamMessage(uint64_t endpoint_id, msgs::PresentationConnectionOpenResponse_result::kUnknownError; Error write_error = WritePresentationConnectionOpenResponse( response, GetProtocolConnection(endpoint_id).get()); - if (!write_error.ok()) + if (!write_error.ok()) { + TRACE_SET_RESULT(write_error); return write_error; + } + return result; } case msgs::Type::kPresentationTerminationRequest: { + TRACE_SCOPED(TraceCategory::Presentation, + "kPresentationTerminationRequest"); OSP_VLOG << "got presentation-termination-request"; msgs::PresentationTerminationRequest request; const ssize_t result = msgs::DecodePresentationTerminationRequest( @@ -249,6 +273,7 @@ ErrorOr<size_t> Receiver::OnStreamMessage(uint64_t endpoint_id, if (result < 0) { OSP_LOG_WARN << "Presentation-termination-request parse error: " << result; + TRACE_SET_RESULT(Error::Code::kParseError); return Error::Code::kParseError; } @@ -272,8 +297,10 @@ ErrorOr<size_t> Receiver::OnStreamMessage(uint64_t endpoint_id, kInvalidPresentationId; Error write_error = WritePresentationTerminationResponse( response, GetProtocolConnection(endpoint_id).get()); - if (!write_error.ok()) + if (!write_error.ok()) { + TRACE_SET_RESULT(write_error); return write_error; + } return result; } @@ -289,6 +316,7 @@ ErrorOr<size_t> Receiver::OnStreamMessage(uint64_t endpoint_id, } default: + TRACE_SET_RESULT(Error::Code::kUnknownMessageType); return Error::Code::kUnknownMessageType; } } diff --git a/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_receiver_unittest.cc b/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_receiver_unittest.cc index 51e01aa9c44..ac92a7ea907 100644 --- a/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_receiver_unittest.cc +++ b/chromium/third_party/openscreen/src/osp/impl/presentation/presentation_receiver_unittest.cc @@ -17,6 +17,7 @@ #include "osp/public/protocol_connection_server.h" #include "osp/public/testing/message_demuxer_test_support.h" #include "platform/test/fake_clock.h" +#include "platform/test/fake_network_runner.h" namespace openscreen { namespace presentation { @@ -66,22 +67,29 @@ class MockReceiverDelegate final : public ReceiverDelegate { }; class PresentationReceiverTest : public ::testing::Test { + public: + PresentationReceiverTest() { + network_runner_ = std::make_unique<platform::FakeNetworkRunner>(); + quic_bridge_ = std::make_unique<FakeQuicBridge>(network_runner_.get(), + platform::FakeClock::now); + } + protected: std::unique_ptr<ProtocolConnection> MakeClientStream() { MockConnectRequest mock_connect_request; NetworkServiceManager::Get()->GetProtocolConnectionClient()->Connect( - quic_bridge_.kReceiverEndpoint, &mock_connect_request); + quic_bridge_->kReceiverEndpoint, &mock_connect_request); ProtocolConnection* stream; EXPECT_CALL(mock_connect_request, OnConnectionOpenedMock(_, _)) .WillOnce(::testing::SaveArg<1>(&stream)); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); return std::unique_ptr<ProtocolConnection>(stream); } void SetUp() override { NetworkServiceManager::Create(nullptr, nullptr, - std::move(quic_bridge_.quic_client), - std::move(quic_bridge_.quic_server)); + std::move(quic_bridge_->quic_client), + std::move(quic_bridge_->quic_server)); Receiver::Get()->Init(); Receiver::Get()->SetReceiverDelegate(&mock_receiver_delegate_); } @@ -92,10 +100,11 @@ class PresentationReceiverTest : public ::testing::Test { NetworkServiceManager::Dispose(); } + std::unique_ptr<platform::FakeNetworkRunner> network_runner_; const std::string url1_{"https://www.example.com/receiver.html"}; platform::FakeClock fake_clock_{ platform::Clock::time_point(std::chrono::milliseconds(1298424))}; - FakeQuicBridge quic_bridge_{platform::FakeClock::now}; + std::unique_ptr<FakeQuicBridge> quic_bridge_; MockReceiverDelegate mock_receiver_delegate_; }; @@ -106,7 +115,7 @@ class PresentationReceiverTest : public ::testing::Test { TEST_F(PresentationReceiverTest, QueryAvailability) { MockMessageCallback mock_callback; MessageDemuxer::MessageWatch availability_watch = - quic_bridge_.controller_demuxer->SetDefaultMessageTypeWatch( + quic_bridge_->controller_demuxer->SetDefaultMessageTypeWatch( msgs::Type::kPresentationUrlAvailabilityResponse, &mock_callback); std::unique_ptr<ProtocolConnection> stream = MakeClientStream(); @@ -139,7 +148,7 @@ TEST_F(PresentationReceiverTest, QueryAvailability) { buffer, buffer_size, &response); return result; })); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); EXPECT_EQ(request.request_id, response.request_id); EXPECT_EQ( (std::vector<msgs::UrlAvailability>{msgs::UrlAvailability::kAvailable}), @@ -149,7 +158,7 @@ TEST_F(PresentationReceiverTest, QueryAvailability) { TEST_F(PresentationReceiverTest, StartPresentation) { MockMessageCallback mock_callback; MessageDemuxer::MessageWatch initiation_watch = - quic_bridge_.controller_demuxer->SetDefaultMessageTypeWatch( + quic_bridge_->controller_demuxer->SetDefaultMessageTypeWatch( msgs::Type::kPresentationStartResponse, &mock_callback); std::unique_ptr<ProtocolConnection> stream = MakeClientStream(); @@ -168,7 +177,7 @@ TEST_F(PresentationReceiverTest, StartPresentation) { EXPECT_CALL(mock_receiver_delegate_, StartPresentation(_, _, request.headers)) .WillOnce(::testing::DoAll(::testing::SaveArg<0>(&info), ::testing::Return(true))); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); EXPECT_EQ(presentation_id, info.id); EXPECT_EQ(url1_, info.url); @@ -187,7 +196,7 @@ TEST_F(PresentationReceiverTest, StartPresentation) { buffer, buffer_size, &response); return result; })); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); EXPECT_EQ(msgs::Result::kSuccess, response.result); EXPECT_EQ(connection.connection_id(), response.connection_id); } diff --git a/chromium/third_party/openscreen/src/osp/impl/presentation/url_availability_requester_unittest.cc b/chromium/third_party/openscreen/src/osp/impl/presentation/url_availability_requester_unittest.cc index 9732fc38d6b..8678265d1b5 100644 --- a/chromium/third_party/openscreen/src/osp/impl/presentation/url_availability_requester_unittest.cc +++ b/chromium/third_party/openscreen/src/osp/impl/presentation/url_availability_requester_unittest.cc @@ -15,6 +15,7 @@ #include "osp/public/testing/message_demuxer_test_support.h" #include "platform/api/logging.h" #include "platform/test/fake_clock.h" +#include "platform/test/fake_network_runner.h" using std::chrono::milliseconds; using std::chrono::seconds; @@ -44,12 +45,19 @@ class MockReceiverObserver : public ReceiverObserver { class UrlAvailabilityRequesterTest : public Test { public: + UrlAvailabilityRequesterTest() { + network_runner_ = std::make_unique<platform::FakeNetworkRunner>(); + quic_bridge_ = std::make_unique<FakeQuicBridge>(network_runner_.get(), + platform::FakeClock::now); + info1_ = {service_id_, friendly_name_, 1, quic_bridge_->kReceiverEndpoint}; + } + void SetUp() override { NetworkServiceManager::Create(nullptr, nullptr, - std::move(quic_bridge_.quic_client), - std::move(quic_bridge_.quic_server)); + std::move(quic_bridge_->quic_client), + std::move(quic_bridge_->quic_server)); availability_watch_ = - quic_bridge_.receiver_demuxer->SetDefaultMessageTypeWatch( + quic_bridge_->receiver_demuxer->SetDefaultMessageTypeWatch( msgs::Type::kPresentationUrlAvailabilityRequest, &mock_callback_); } @@ -62,12 +70,12 @@ class UrlAvailabilityRequesterTest : public Test { std::unique_ptr<ProtocolConnection> ExpectIncomingConnection() { std::unique_ptr<ProtocolConnection> stream; - EXPECT_CALL(quic_bridge_.mock_server_observer, OnIncomingConnectionMock(_)) + EXPECT_CALL(quic_bridge_->mock_server_observer, OnIncomingConnectionMock(_)) .WillOnce( Invoke([&stream](std::unique_ptr<ProtocolConnection>& connection) { stream = std::move(connection); })); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); return stream; } @@ -115,19 +123,19 @@ class UrlAvailabilityRequesterTest : public Test { stream->Write(buffer.data(), buffer.size()); } + std::unique_ptr<platform::FakeNetworkRunner> network_runner_; MockMessageCallback mock_callback_; MessageDemuxer::MessageWatch availability_watch_; platform::FakeClock fake_clock_{ platform::Clock::time_point(milliseconds(1298424))}; - FakeQuicBridge quic_bridge_{platform::FakeClock::now}; + std::unique_ptr<FakeQuicBridge> quic_bridge_; UrlAvailabilityRequester listener_{platform::FakeClock::now}; std::string url1_{"https://example.com/foo.html"}; std::string url2_{"https://example.com/bar.html"}; std::string service_id_{"asdf"}; std::string friendly_name_{"turtle"}; - ServiceInfo info1_{service_id_, friendly_name_, 1, - quic_bridge_.kReceiverEndpoint}; + ServiceInfo info1_; }; TEST_F(UrlAvailabilityRequesterTest, AvailableObserverFirst) { @@ -151,7 +159,7 @@ TEST_F(UrlAvailabilityRequesterTest, AvailableObserverFirst) { EXPECT_CALL(mock_observer, OnReceiverAvailable(url1_, service_id_)); EXPECT_CALL(mock_observer, OnReceiverUnavailable(url1_, service_id_)) .Times(0); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); } TEST_F(UrlAvailabilityRequesterTest, AvailableReceiverFirst) { @@ -175,7 +183,7 @@ TEST_F(UrlAvailabilityRequesterTest, AvailableReceiverFirst) { EXPECT_CALL(mock_observer, OnReceiverAvailable(url1_, service_id_)); EXPECT_CALL(mock_observer, OnReceiverUnavailable(url1_, service_id_)) .Times(0); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); } TEST_F(UrlAvailabilityRequesterTest, Unavailable) { @@ -198,7 +206,7 @@ TEST_F(UrlAvailabilityRequesterTest, Unavailable) { EXPECT_CALL(mock_observer, OnReceiverAvailable(url1_, service_id_)).Times(0); EXPECT_CALL(mock_observer, OnReceiverUnavailable(url1_, service_id_)); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); } TEST_F(UrlAvailabilityRequesterTest, AvailabilityIsCached) { @@ -221,7 +229,7 @@ TEST_F(UrlAvailabilityRequesterTest, AvailabilityIsCached) { EXPECT_CALL(mock_observer1, OnReceiverAvailable(url1_, service_id_)).Times(0); EXPECT_CALL(mock_observer1, OnReceiverUnavailable(url1_, service_id_)); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); MockReceiverObserver mock_observer2; EXPECT_CALL(mock_observer2, OnReceiverAvailable(url1_, service_id_)).Times(0); @@ -249,7 +257,7 @@ TEST_F(UrlAvailabilityRequesterTest, AvailabilityCacheIsTransient) { EXPECT_CALL(mock_observer1, OnReceiverAvailable(url1_, service_id_)).Times(0); EXPECT_CALL(mock_observer1, OnReceiverUnavailable(url1_, service_id_)); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); listener_.RemoveObserverUrls({url1_}, &mock_observer1); MockReceiverObserver mock_observer2; @@ -279,7 +287,7 @@ TEST_F(UrlAvailabilityRequesterTest, PartiallyCachedAnswer) { EXPECT_CALL(mock_observer1, OnReceiverAvailable(url1_, service_id_)).Times(0); EXPECT_CALL(mock_observer1, OnReceiverUnavailable(url1_, service_id_)); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); MockReceiverObserver mock_observer2; EXPECT_CALL(mock_observer2, OnReceiverAvailable(url1_, service_id_)).Times(0); @@ -287,7 +295,7 @@ TEST_F(UrlAvailabilityRequesterTest, PartiallyCachedAnswer) { listener_.AddObserver({url1_, url2_}, &mock_observer2); ExpectStreamMessage(&mock_callback_, &request); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); EXPECT_EQ(std::vector<std::string>{url2_}, request.urls); SendAvailabilityResponse( @@ -297,7 +305,7 @@ TEST_F(UrlAvailabilityRequesterTest, PartiallyCachedAnswer) { EXPECT_CALL(mock_observer2, OnReceiverAvailable(url2_, service_id_)).Times(0); EXPECT_CALL(mock_observer2, OnReceiverUnavailable(url2_, service_id_)); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); } TEST_F(UrlAvailabilityRequesterTest, MultipleOverlappingObservers) { @@ -321,13 +329,13 @@ TEST_F(UrlAvailabilityRequesterTest, MultipleOverlappingObservers) { EXPECT_CALL(mock_observer1, OnReceiverAvailable(url1_, service_id_)); EXPECT_CALL(mock_observer1, OnReceiverUnavailable(url1_, service_id_)) .Times(0); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); MockReceiverObserver mock_observer2; EXPECT_CALL(mock_observer2, OnReceiverAvailable(url1_, service_id_)); listener_.AddObserver({url1_, url2_}, &mock_observer2); ExpectStreamMessage(&mock_callback_, &request); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); EXPECT_EQ(std::vector<std::string>{url2_}, request.urls); SendAvailabilityResponse( @@ -338,7 +346,7 @@ TEST_F(UrlAvailabilityRequesterTest, MultipleOverlappingObservers) { EXPECT_CALL(mock_observer1, OnReceiverUnavailable(_, service_id_)).Times(0); EXPECT_CALL(mock_observer2, OnReceiverAvailable(_, service_id_)).Times(0); EXPECT_CALL(mock_observer2, OnReceiverUnavailable(url2_, service_id_)); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); } TEST_F(UrlAvailabilityRequesterTest, RemoveObserverUrls) { @@ -362,7 +370,7 @@ TEST_F(UrlAvailabilityRequesterTest, RemoveObserverUrls) { EXPECT_CALL(mock_observer1, OnReceiverAvailable(url1_, service_id_)); EXPECT_CALL(mock_observer1, OnReceiverUnavailable(url1_, service_id_)) .Times(0); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); Mock::VerifyAndClearExpectations(&mock_observer1); MockReceiverObserver mock_observer2; @@ -370,7 +378,7 @@ TEST_F(UrlAvailabilityRequesterTest, RemoveObserverUrls) { listener_.AddObserver({url1_, url2_}, &mock_observer2); ExpectStreamMessage(&mock_callback_, &request); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); Mock::VerifyAndClearExpectations(&mock_observer2); EXPECT_EQ(std::vector<std::string>{url2_}, request.urls); @@ -383,7 +391,7 @@ TEST_F(UrlAvailabilityRequesterTest, RemoveObserverUrls) { EXPECT_CALL(mock_observer2, OnReceiverAvailable(_, service_id_)).Times(0); EXPECT_CALL(mock_observer2, OnReceiverUnavailable(url2_, service_id_)); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); Mock::VerifyAndClearExpectations(&mock_observer1); Mock::VerifyAndClearExpectations(&mock_observer2); @@ -395,7 +403,7 @@ TEST_F(UrlAvailabilityRequesterTest, RemoveObserverUrls) { EXPECT_CALL(mock_observer1, OnReceiverUnavailable(url1_, service_id_)) .Times(0); EXPECT_CALL(mock_observer2, OnReceiverUnavailable(url1_, service_id_)); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); Mock::VerifyAndClearExpectations(&mock_observer1); Mock::VerifyAndClearExpectations(&mock_observer2); } @@ -423,7 +431,7 @@ TEST_F(UrlAvailabilityRequesterTest, RemoveObserver) { EXPECT_CALL(mock_observer1, OnReceiverAvailable(url1_, service_id_)); EXPECT_CALL(mock_observer1, OnReceiverUnavailable(url1_, service_id_)) .Times(0); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); Mock::VerifyAndClearExpectations(&mock_observer1); MockReceiverObserver mock_observer2; @@ -431,7 +439,7 @@ TEST_F(UrlAvailabilityRequesterTest, RemoveObserver) { listener_.AddObserver({url1_, url2_}, &mock_observer2); ExpectStreamMessage(&mock_callback_, &request); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); Mock::VerifyAndClearExpectations(&mock_observer2); uint64_t url2_watch_id = request.watch_id; @@ -445,7 +453,7 @@ TEST_F(UrlAvailabilityRequesterTest, RemoveObserver) { EXPECT_CALL(mock_observer2, OnReceiverAvailable(_, service_id_)).Times(0); EXPECT_CALL(mock_observer2, OnReceiverUnavailable(url2_, service_id_)); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); Mock::VerifyAndClearExpectations(&mock_observer1); Mock::VerifyAndClearExpectations(&mock_observer2); @@ -457,7 +465,7 @@ TEST_F(UrlAvailabilityRequesterTest, RemoveObserver) { EXPECT_CALL(mock_observer1, OnReceiverUnavailable(url1_, service_id_)) .Times(0); EXPECT_CALL(mock_observer2, OnReceiverUnavailable(url1_, service_id_)); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); Mock::VerifyAndClearExpectations(&mock_observer1); Mock::VerifyAndClearExpectations(&mock_observer2); @@ -475,7 +483,7 @@ TEST_F(UrlAvailabilityRequesterTest, RemoveObserver) { EXPECT_CALL(mock_observer1, OnReceiverUnavailable(_, service_id_)).Times(0); EXPECT_CALL(mock_observer2, OnReceiverUnavailable(_, service_id_)).Times(0); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); Mock::VerifyAndClearExpectations(&mock_observer1); Mock::VerifyAndClearExpectations(&mock_observer2); } @@ -502,7 +510,7 @@ TEST_F(UrlAvailabilityRequesterTest, EventUpdate) { EXPECT_CALL(mock_observer1, OnReceiverAvailable(url1_, service_id_)); EXPECT_CALL(mock_observer1, OnReceiverAvailable(url2_, service_id_)); EXPECT_CALL(mock_observer1, OnReceiverUnavailable(_, service_id_)).Times(0); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); EXPECT_CALL(mock_callback_, OnStreamMessage(_, _, _, _, _, _)).Times(0); SendAvailabilityEvent( @@ -512,7 +520,7 @@ TEST_F(UrlAvailabilityRequesterTest, EventUpdate) { stream.get()); EXPECT_CALL(mock_observer1, OnReceiverUnavailable(url2_, service_id_)); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); } TEST_F(UrlAvailabilityRequesterTest, RefreshWatches) { @@ -535,13 +543,13 @@ TEST_F(UrlAvailabilityRequesterTest, RefreshWatches) { EXPECT_CALL(mock_observer1, OnReceiverAvailable(url1_, service_id_)); EXPECT_CALL(mock_observer1, OnReceiverUnavailable(_, service_id_)).Times(0); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); fake_clock_.Advance(seconds(60)); ExpectStreamMessage(&mock_callback_, &request); listener_.RefreshWatches(); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); EXPECT_EQ(std::vector<std::string>{url1_}, request.urls); SendAvailabilityResponse( @@ -550,7 +558,7 @@ TEST_F(UrlAvailabilityRequesterTest, RefreshWatches) { stream.get()); EXPECT_CALL(mock_observer1, OnReceiverUnavailable(url1_, service_id_)); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); } TEST_F(UrlAvailabilityRequesterTest, ResponseAfterRemoveObserver) { @@ -575,7 +583,7 @@ TEST_F(UrlAvailabilityRequesterTest, ResponseAfterRemoveObserver) { EXPECT_CALL(mock_observer1, OnReceiverAvailable(url1_, service_id_)).Times(0); EXPECT_CALL(mock_observer1, OnReceiverUnavailable(url1_, service_id_)) .Times(0); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); MockReceiverObserver mock_observer2; EXPECT_CALL(mock_observer2, OnReceiverAvailable(url1_, service_id_)).Times(0); @@ -606,7 +614,7 @@ TEST_F(UrlAvailabilityRequesterTest, EXPECT_CALL(mock_observer1, OnReceiverAvailable(url1_, service_id_)); EXPECT_CALL(mock_observer1, OnReceiverUnavailable(url1_, service_id_)) .Times(0); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); listener_.RemoveObserverUrls({url1_}, &mock_observer1); listener_.RemoveReceiver(info1_); @@ -655,14 +663,14 @@ TEST_F(UrlAvailabilityRequesterTest, RemoveObserverInSteps) { // immediately, this still went out on the wire. ExpectStreamMessage(&mock_callback_, &request); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); EXPECT_EQ((std::vector<std::string>{url2_}), request.urls); fake_clock_.Advance(seconds(60)); listener_.RefreshWatches(); EXPECT_CALL(mock_callback_, OnStreamMessage(_, _, _, _, _, _)).Times(0); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); } } // namespace presentation diff --git a/chromium/third_party/openscreen/src/osp/impl/protocol_connection_client_factory.cc b/chromium/third_party/openscreen/src/osp/impl/protocol_connection_client_factory.cc index 11eafcced8d..63abc04bfc5 100644 --- a/chromium/third_party/openscreen/src/osp/impl/protocol_connection_client_factory.cc +++ b/chromium/third_party/openscreen/src/osp/impl/protocol_connection_client_factory.cc @@ -8,6 +8,8 @@ #include "osp/impl/quic/quic_client.h" #include "osp/impl/quic/quic_connection_factory_impl.h" +#include "osp/public/network_service_manager.h" +#include "platform/api/network_runner.h" namespace openscreen { @@ -15,9 +17,11 @@ namespace openscreen { std::unique_ptr<ProtocolConnectionClient> ProtocolConnectionClientFactory::Create( MessageDemuxer* demuxer, - ProtocolConnectionServiceObserver* observer) { + ProtocolConnectionServiceObserver* observer, + platform::NetworkRunner* network_runner) { return std::make_unique<QuicClient>( - demuxer, std::make_unique<QuicConnectionFactoryImpl>(), observer); + demuxer, std::make_unique<QuicConnectionFactoryImpl>(network_runner), + observer, network_runner); } } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/osp/impl/protocol_connection_server_factory.cc b/chromium/third_party/openscreen/src/osp/impl/protocol_connection_server_factory.cc index f3bf17b2f80..7fb61ed7682 100644 --- a/chromium/third_party/openscreen/src/osp/impl/protocol_connection_server_factory.cc +++ b/chromium/third_party/openscreen/src/osp/impl/protocol_connection_server_factory.cc @@ -8,6 +8,8 @@ #include "osp/impl/quic/quic_connection_factory_impl.h" #include "osp/impl/quic/quic_server.h" +#include "osp/public/network_service_manager.h" +#include "platform/api/network_runner.h" namespace openscreen { @@ -16,9 +18,12 @@ std::unique_ptr<ProtocolConnectionServer> ProtocolConnectionServerFactory::Create( const ServerConfig& config, MessageDemuxer* demuxer, - ProtocolConnectionServer::Observer* observer) { + ProtocolConnectionServer::Observer* observer, + platform::NetworkRunner* network_runner) { return std::make_unique<QuicServer>( - config, demuxer, std::make_unique<QuicConnectionFactoryImpl>(), observer); + config, demuxer, + std::make_unique<QuicConnectionFactoryImpl>(network_runner), observer, + network_runner); } } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/quic_client.cc b/chromium/third_party/openscreen/src/osp/impl/quic/quic_client.cc index 34b5281f62b..ce193071a83 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/quic_client.cc +++ b/chromium/third_party/openscreen/src/osp/impl/quic/quic_client.cc @@ -5,18 +5,28 @@ #include "osp/impl/quic/quic_client.h" #include <algorithm> +#include <functional> #include <memory> +#include "absl/types/optional.h" #include "platform/api/logging.h" +#include "platform/api/task_runner.h" +#include "platform/api/time.h" namespace openscreen { QuicClient::QuicClient( MessageDemuxer* demuxer, std::unique_ptr<QuicConnectionFactory> connection_factory, - ProtocolConnectionServiceObserver* observer) + ProtocolConnectionServiceObserver* observer, + platform::TaskRunner* task_runner) : ProtocolConnectionClient(demuxer, observer), - connection_factory_(std::move(connection_factory)) {} + connection_factory_(std::move(connection_factory)) { + if (task_runner != nullptr) { + platform::RepeatingFunction::Post(task_runner, + std::bind(&QuicClient::Cleanup, this)); + } +} QuicClient::~QuicClient() { CloseAllConnections(); @@ -39,8 +49,7 @@ bool QuicClient::Stop() { return true; } -void QuicClient::RunTasks() { - connection_factory_->RunTasks(); +absl::optional<platform::Clock::duration> QuicClient::Cleanup() { for (auto& entry : connections_) { entry.second.delegate->DestroyClosedStreams(); if (!entry.second.delegate->has_streams()) @@ -51,6 +60,12 @@ void QuicClient::RunTasks() { connections_.erase(entry); delete_connections_.clear(); + + constexpr platform::Clock::duration kQuicCleanupFrequency = + std::chrono::milliseconds(500); + return state_ == State::kStopped + ? absl::optional<platform::Clock::duration>(absl::nullopt) + : absl::optional<platform::Clock::duration>(kQuicCleanupFrequency); } QuicClient::ConnectRequest QuicClient::Connect( diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/quic_client.h b/chromium/third_party/openscreen/src/osp/impl/quic/quic_client.h index 605f199d237..f819ae463d9 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/quic_client.h +++ b/chromium/third_party/openscreen/src/osp/impl/quic/quic_client.h @@ -13,6 +13,8 @@ #include "osp/impl/quic/quic_connection_factory.h" #include "osp/impl/quic/quic_service_common.h" #include "osp/public/protocol_connection_client.h" +#include "platform/api/task_runner.h" +#include "platform/api/time.h" #include "platform/base/ip_address.h" namespace openscreen { @@ -38,13 +40,13 @@ class QuicClient final : public ProtocolConnectionClient, public: QuicClient(MessageDemuxer* demuxer, std::unique_ptr<QuicConnectionFactory> connection_factory, - ProtocolConnectionServiceObserver* observer); + ProtocolConnectionServiceObserver* observer, + platform::TaskRunner* task_runner); ~QuicClient() override; // ProtocolConnectionClient overrides. bool Start() override; bool Stop() override; - void RunTasks() override; ConnectRequest Connect(const IPEndpoint& endpoint, ConnectionRequestCallback* request) override; std::unique_ptr<ProtocolConnection> CreateProtocolConnection( @@ -90,6 +92,10 @@ class QuicClient final : public ProtocolConnectionClient, void CancelConnectRequest(uint64_t request_id) override; + // Deletes dead QUIC connections then returns the time interval before this + // method should be run again. + absl::optional<platform::Clock::duration> Cleanup(); + std::unique_ptr<QuicConnectionFactory> connection_factory_; // Maps an IPEndpoint to a generated endpoint ID. This is used to insulate diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/quic_client_unittest.cc b/chromium/third_party/openscreen/src/osp/impl/quic/quic_client_unittest.cc index 9ee1938ff76..9a87c6572ae 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/quic_client_unittest.cc +++ b/chromium/third_party/openscreen/src/osp/impl/quic/quic_client_unittest.cc @@ -17,6 +17,7 @@ #include "platform/api/logging.h" #include "platform/base/error.h" #include "platform/test/fake_clock.h" +#include "platform/test/fake_network_runner.h" namespace openscreen { namespace { @@ -56,12 +57,19 @@ class ConnectionCallback final }; class QuicClientTest : public ::testing::Test { + public: + QuicClientTest() { + network_runner_ = std::make_unique<platform::FakeNetworkRunner>(); + quic_bridge_ = std::make_unique<FakeQuicBridge>(network_runner_.get(), + platform::FakeClock::now); + } + protected: void SetUp() override { - client_ = quic_bridge_.quic_client.get(); + client_ = quic_bridge_->quic_client.get(); NetworkServiceManager::Create(nullptr, nullptr, - std::move(quic_bridge_.quic_client), - std::move(quic_bridge_.quic_server)); + std::move(quic_bridge_->quic_client), + std::move(quic_bridge_->quic_server)); } void TearDown() override { NetworkServiceManager::Dispose(); } @@ -69,7 +77,7 @@ class QuicClientTest : public ::testing::Test { void SendTestMessage(ProtocolConnection* connection) { MockMessageCallback mock_message_callback; MessageDemuxer::MessageWatch message_watch = - quic_bridge_.receiver_demuxer->WatchMessageType( + quic_bridge_->receiver_demuxer->WatchMessageType( 0, msgs::Type::kPresentationConnectionMessage, &mock_message_callback); @@ -99,7 +107,7 @@ class QuicClientTest : public ::testing::Test { return ErrorOr<size_t>(Error::Code::kCborParsing); return ErrorOr<size_t>(decode_result); })); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); ASSERT_GT(decode_result, 0); EXPECT_EQ(decode_result, static_cast<ssize_t>(buffer.size() - 1)); @@ -109,9 +117,10 @@ class QuicClientTest : public ::testing::Test { EXPECT_EQ(received_message.message.str, message.message.str); } + std::unique_ptr<platform::FakeNetworkRunner> network_runner_; platform::FakeClock fake_clock_{ platform::Clock::time_point(std::chrono::milliseconds(1298424))}; - FakeQuicBridge quic_bridge_{platform::FakeClock::now}; + std::unique_ptr<FakeQuicBridge> quic_bridge_; QuicClient* client_; }; @@ -123,10 +132,10 @@ TEST_F(QuicClientTest, Connect) { std::unique_ptr<ProtocolConnection> connection; ConnectionCallback connection_callback(&connection); ProtocolConnectionClient::ConnectRequest request = - client_->Connect(quic_bridge_.kReceiverEndpoint, &connection_callback); + client_->Connect(quic_bridge_->kReceiverEndpoint, &connection_callback); ASSERT_TRUE(request); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); ASSERT_TRUE(connection); SendTestMessage(connection.get()); @@ -140,17 +149,17 @@ TEST_F(QuicClientTest, DoubleConnect) { std::unique_ptr<ProtocolConnection> connection1; ConnectionCallback connection_callback1(&connection1); ProtocolConnectionClient::ConnectRequest request1 = - client_->Connect(quic_bridge_.kReceiverEndpoint, &connection_callback1); + client_->Connect(quic_bridge_->kReceiverEndpoint, &connection_callback1); ASSERT_TRUE(request1); ASSERT_FALSE(connection1); std::unique_ptr<ProtocolConnection> connection2; ConnectionCallback connection_callback2(&connection2); ProtocolConnectionClient::ConnectRequest request2 = - client_->Connect(quic_bridge_.kReceiverEndpoint, &connection_callback2); + client_->Connect(quic_bridge_->kReceiverEndpoint, &connection_callback2); ASSERT_TRUE(request2); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); ASSERT_TRUE(connection1); ASSERT_TRUE(connection2); @@ -170,13 +179,13 @@ TEST_F(QuicClientTest, OpenImmediate) { ConnectionCallback connection_callback(&connection1); ProtocolConnectionClient::ConnectRequest request = - client_->Connect(quic_bridge_.kReceiverEndpoint, &connection_callback); + client_->Connect(quic_bridge_->kReceiverEndpoint, &connection_callback); ASSERT_TRUE(request); connection2 = client_->CreateProtocolConnection(1); EXPECT_FALSE(connection2); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); ASSERT_TRUE(connection1); connection2 = client_->CreateProtocolConnection(connection1->endpoint_id()); @@ -192,20 +201,20 @@ TEST_F(QuicClientTest, States) { std::unique_ptr<ProtocolConnection> connection1; ConnectionCallback connection_callback(&connection1); ProtocolConnectionClient::ConnectRequest request = - client_->Connect(quic_bridge_.kReceiverEndpoint, &connection_callback); + client_->Connect(quic_bridge_->kReceiverEndpoint, &connection_callback); EXPECT_FALSE(request); std::unique_ptr<ProtocolConnection> connection2 = client_->CreateProtocolConnection(1); EXPECT_FALSE(connection2); - EXPECT_CALL(quic_bridge_.mock_client_observer, OnRunning()); + EXPECT_CALL(quic_bridge_->mock_client_observer, OnRunning()); EXPECT_TRUE(client_->Start()); EXPECT_FALSE(client_->Start()); request = - client_->Connect(quic_bridge_.kReceiverEndpoint, &connection_callback); + client_->Connect(quic_bridge_->kReceiverEndpoint, &connection_callback); ASSERT_TRUE(request); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); ASSERT_TRUE(connection1); MockConnectionObserver mock_connection_observer1; connection1->SetObserver(&mock_connection_observer1); @@ -217,12 +226,12 @@ TEST_F(QuicClientTest, States) { EXPECT_CALL(mock_connection_observer1, OnConnectionClosed(_)); EXPECT_CALL(mock_connection_observer2, OnConnectionClosed(_)); - EXPECT_CALL(quic_bridge_.mock_client_observer, OnStopped()); + EXPECT_CALL(quic_bridge_->mock_client_observer, OnStopped()); EXPECT_TRUE(client_->Stop()); EXPECT_FALSE(client_->Stop()); request = - client_->Connect(quic_bridge_.kReceiverEndpoint, &connection_callback); + client_->Connect(quic_bridge_->kReceiverEndpoint, &connection_callback); EXPECT_FALSE(request); connection2 = client_->CreateProtocolConnection(1); EXPECT_FALSE(connection2); @@ -231,17 +240,17 @@ TEST_F(QuicClientTest, States) { TEST_F(QuicClientTest, RequestIds) { client_->Start(); - EXPECT_CALL(quic_bridge_.mock_server_observer, OnIncomingConnectionMock(_)) + EXPECT_CALL(quic_bridge_->mock_server_observer, OnIncomingConnectionMock(_)) .WillOnce(Invoke([](std::unique_ptr<ProtocolConnection>& connection) { connection->CloseWriteEnd(); })); std::unique_ptr<ProtocolConnection> connection; ConnectionCallback connection_callback(&connection); ProtocolConnectionClient::ConnectRequest request = - client_->Connect(quic_bridge_.kReceiverEndpoint, &connection_callback); + client_->Connect(quic_bridge_->kReceiverEndpoint, &connection_callback); ASSERT_TRUE(request); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); ASSERT_TRUE(connection); const uint64_t endpoint_id = connection->endpoint_id(); @@ -250,8 +259,8 @@ TEST_F(QuicClientTest, RequestIds) { connection->CloseWriteEnd(); connection.reset(); - quic_bridge_.RunTasksUntilIdle(); - EXPECT_EQ(0u, client_->endpoint_request_ids()->GetNextRequestId(endpoint_id)); + quic_bridge_->RunTasksUntilIdle(); + EXPECT_EQ(4u, client_->endpoint_request_ids()->GetNextRequestId(endpoint_id)); client_->Stop(); EXPECT_EQ(0u, client_->endpoint_request_ids()->GetNextRequestId(endpoint_id)); diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection.h b/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection.h index 069b06af7d3..a75fc815177 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection.h +++ b/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection.h @@ -9,7 +9,6 @@ #include <vector> #include "platform/api/udp_socket.h" -#include "platform/impl/event_loop.h" namespace openscreen { @@ -37,7 +36,7 @@ class QuicStream { uint64_t id_; }; -class QuicConnection { +class QuicConnection : public platform::UdpReadCallback { public: class Delegate { public: @@ -68,11 +67,6 @@ class QuicConnection { explicit QuicConnection(Delegate* delegate) : delegate_(delegate) {} virtual ~QuicConnection() = default; - // Passes a received UDP packet to the QUIC implementation. If this contains - // any stream data, it will be passed automatically to the relevant - // QuicStream::Delegate objects. - virtual void OnDataReceived(const platform::UdpPacket& packet) = 0; - virtual std::unique_ptr<QuicStream> MakeOutgoingStream( QuicStream::Delegate* delegate) = 0; virtual void Close() = 0; diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory.h b/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory.h index 5668f9dc5a2..5d001892071 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory.h +++ b/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory.h @@ -8,14 +8,17 @@ #include <memory> #include <vector> +#include "absl/types/optional.h" #include "osp/impl/quic/quic_connection.h" +#include "platform/api/time.h" #include "platform/base/ip_address.h" namespace openscreen { // This interface provides a way to make new QUIC connections to endpoints. It // also provides a way to receive incoming QUIC connections (as a server). -class QuicConnectionFactory { +class QuicConnectionFactory : public platform::UdpReadCallback, + public platform::UdpSocket::Client { public: class ServerDelegate { public: @@ -34,10 +37,6 @@ class QuicConnectionFactory { virtual void SetServerDelegate(ServerDelegate* delegate, const std::vector<IPEndpoint>& endpoints) = 0; - // Listen for incoming network packets on both client and server sockets and - // dispatch any results. - virtual void RunTasks() = 0; - virtual std::unique_ptr<QuicConnection> Connect( const IPEndpoint& endpoint, QuicConnection::Delegate* connection_delegate) = 0; diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory_impl.cc b/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory_impl.cc index b176d24656b..8e455e707d5 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory_impl.cc +++ b/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory_impl.cc @@ -9,24 +9,20 @@ #include "osp/impl/quic/quic_connection_impl.h" #include "platform/api/logging.h" +#include "platform/api/network_runner.h" +#include "platform/api/task_runner.h" +#include "platform/api/time.h" +#include "platform/api/trace_logging.h" #include "platform/base/error.h" -#include "platform/impl/event_loop.h" #include "third_party/chromium_quic/src/base/location.h" #include "third_party/chromium_quic/src/base/task_runner.h" #include "third_party/chromium_quic/src/net/third_party/quic/core/quic_constants.h" #include "third_party/chromium_quic/src/net/third_party/quic/platform/impl/quic_chromium_clock.h" namespace openscreen { - -struct Task { - ::base::Location whence; - ::base::OnceClosure task; - ::base::TimeDelta delay; -}; - class QuicTaskRunner final : public ::base::TaskRunner { public: - QuicTaskRunner(); + explicit QuicTaskRunner(platform::TaskRunner* task_runner); ~QuicTaskRunner() override; void RunTasks(); @@ -39,35 +35,23 @@ class QuicTaskRunner final : public ::base::TaskRunner { bool RunsTasksInCurrentSequence() const override; private: - uint64_t last_run_unix_; - std::list<Task> tasks_; + platform::TaskRunner* const task_runner_; }; -QuicTaskRunner::QuicTaskRunner() = default; +QuicTaskRunner::QuicTaskRunner(platform::TaskRunner* task_runner) + : task_runner_(task_runner) {} QuicTaskRunner::~QuicTaskRunner() = default; -void QuicTaskRunner::RunTasks() { - auto* clock = ::quic::QuicChromiumClock::GetInstance(); - ::quic::QuicWallTime now = clock->WallNow(); - uint64_t now_unix = now.ToUNIXMicroseconds(); - for (auto it = tasks_.begin(); it != tasks_.end();) { - Task& next_task = *it; - next_task.delay -= - ::base::TimeDelta::FromMicroseconds(now_unix - last_run_unix_); - if (next_task.delay.InMicroseconds() < 0) { - std::move(next_task.task).Run(); - it = tasks_.erase(it); - } else { - ++it; - } - } - last_run_unix_ = now_unix; -} +void QuicTaskRunner::RunTasks() {} bool QuicTaskRunner::PostDelayedTask(const ::base::Location& whence, ::base::OnceClosure task, ::base::TimeDelta delay) { - tasks_.push_back({whence, std::move(task), delay}); + platform::Clock::duration wait = + platform::Clock::duration(delay.InMilliseconds()); + task_runner_->PostTaskWithDelay( + [closure = std::move(task)]() mutable { std::move(closure).Run(); }, + wait); return true; } @@ -75,20 +59,20 @@ bool QuicTaskRunner::RunsTasksInCurrentSequence() const { return true; } -QuicConnectionFactoryImpl::QuicConnectionFactoryImpl() { - task_runner_ = ::base::MakeRefCounted<QuicTaskRunner>(); +QuicConnectionFactoryImpl::QuicConnectionFactoryImpl( + platform::NetworkRunner* network_runner) + : network_runner_(network_runner) { + task_runner_ = ::base::MakeRefCounted<QuicTaskRunner>(network_runner); alarm_factory_ = std::make_unique<::net::QuicChromiumAlarmFactory>( task_runner_.get(), ::quic::QuicChromiumClock::GetInstance()); ::quic::QuartcFactoryConfig factory_config; factory_config.alarm_factory = alarm_factory_.get(); factory_config.clock = ::quic::QuicChromiumClock::GetInstance(); quartc_factory_ = std::make_unique<::quic::QuartcFactory>(factory_config); - waiter_ = platform::CreateEventWaiter(); } QuicConnectionFactoryImpl::~QuicConnectionFactoryImpl() { OSP_DCHECK(connections_.empty()); - platform::DestroyEventWaiter(waiter_); } void QuicConnectionFactoryImpl::SetServerDelegate( @@ -105,68 +89,70 @@ void QuicConnectionFactoryImpl::SetServerDelegate( // partial progress (i.e. "unwatch" all the sockets and call // sockets_.clear() to close the sockets)? auto create_result = - platform::UdpSocket::Create(endpoint.address.version()); + platform::UdpSocket::Create(network_runner_, this, endpoint); if (!create_result) { OSP_LOG_ERROR << "failed to create socket (for " << endpoint << "): " << create_result.error().message(); continue; } platform::UdpSocketUniquePtr server_socket = create_result.MoveValue(); - Error bind_result = server_socket->Bind(endpoint); + Error bind_result = server_socket->Bind(); if (!bind_result.ok()) { OSP_LOG_ERROR << "failed to bind socket (for " << endpoint << "): " << bind_result.message(); continue; } - platform::WatchUdpSocketReadable(waiter_, server_socket.get()); + network_runner_->ReadRepeatedly(server_socket.get(), this); sockets_.emplace_back(std::move(server_socket)); } } -void QuicConnectionFactoryImpl::RunTasks() { - for (const auto& packet : platform::OnePlatformLoopIteration(waiter_)) { - // Ensure that |packet.socket| is one of the instances owned by - // QuicConnectionFactoryImpl. - OSP_DCHECK(std::find_if(sockets_.begin(), sockets_.end(), - [&packet](const platform::UdpSocketUniquePtr& s) { - return s.get() == packet.socket(); - }) != sockets_.end()); - - // TODO(btolsch): We will need to rethink this both for ICE and connection - // migration support. - auto conn_it = connections_.find(packet.source()); - if (conn_it == connections_.end()) { - if (server_delegate_) { - OSP_VLOG << __func__ << ": spawning connection from " - << packet.source(); - auto transport = - std::make_unique<UdpTransport>(packet.socket(), packet.source()); - ::quic::QuartcSessionConfig session_config; - session_config.perspective = ::quic::Perspective::IS_SERVER; - session_config.packet_transport = transport.get(); - - auto result = std::make_unique<QuicConnectionImpl>( - this, server_delegate_->NextConnectionDelegate(packet.source()), - std::move(transport), - quartc_factory_->CreateQuartcSession(session_config)); - auto* result_ptr = result.get(); - connections_.emplace(packet.source(), - OpenConnection{result_ptr, packet.socket()}); - server_delegate_->OnIncomingConnection(std::move(result)); - result_ptr->OnDataReceived(packet); - } - } else { - OSP_VLOG << __func__ << ": data for existing connection from " - << packet.source(); - conn_it->second.connection->OnDataReceived(packet); +void QuicConnectionFactoryImpl::OnRead( + platform::UdpPacket packet, + platform::NetworkRunner* network_runner) { + TRACE_SCOPED(TraceCategory::Quic, "QuicConnectionFactoryImpl::OnRead"); + // Ensure that |packet.socket| is one of the instances owned by + // QuicConnectionFactoryImpl. + auto packet_ptr = &packet; + OSP_DCHECK(std::find_if(sockets_.begin(), sockets_.end(), + [packet_ptr](const platform::UdpSocketUniquePtr& s) { + return s.get() == packet_ptr->socket(); + }) != sockets_.end()); + + // TODO(btolsch): We will need to rethink this both for ICE and connection + // migration support. + auto conn_it = connections_.find(packet.source()); + if (conn_it == connections_.end()) { + if (server_delegate_) { + OSP_VLOG << __func__ << ": spawning connection from " << packet.source(); + auto transport = + std::make_unique<UdpTransport>(packet.socket(), packet.source()); + ::quic::QuartcSessionConfig session_config; + session_config.perspective = ::quic::Perspective::IS_SERVER; + session_config.packet_transport = transport.get(); + + auto result = std::make_unique<QuicConnectionImpl>( + this, server_delegate_->NextConnectionDelegate(packet.source()), + std::move(transport), + quartc_factory_->CreateQuartcSession(session_config)); + auto* result_ptr = result.get(); + connections_.emplace(packet.source(), + OpenConnection{result_ptr, packet.socket()}); + server_delegate_->OnIncomingConnection(std::move(result)); + result_ptr->OnRead(std::move(packet), network_runner); } + } else { + OSP_VLOG << __func__ << ": data for existing connection from " + << packet.source(); + conn_it->second.connection->OnRead(std::move(packet), network_runner); } } std::unique_ptr<QuicConnection> QuicConnectionFactoryImpl::Connect( const IPEndpoint& endpoint, QuicConnection::Delegate* connection_delegate) { - auto create_result = platform::UdpSocket::Create(endpoint.address.version()); + auto create_result = + platform::UdpSocket::Create(network_runner_, this, endpoint); if (!create_result) { OSP_LOG_ERROR << "failed to create socket: " << create_result.error().message(); @@ -187,7 +173,7 @@ std::unique_ptr<QuicConnection> QuicConnectionFactoryImpl::Connect( this, connection_delegate, std::move(transport), quartc_factory_->CreateQuartcSession(session_config)); - platform::WatchUdpSocketReadable(waiter_, socket.get()); + network_runner_->ReadRepeatedly(socket.get(), this); // TODO(btolsch): This presents a problem for multihomed receivers, which may // register as a different endpoint in their response. I think QUIC is @@ -215,7 +201,7 @@ void QuicConnectionFactoryImpl::OnConnectionClosed(QuicConnection* connection) { [socket](const decltype(connections_)::value_type& entry) { return entry.second.socket == socket; }) == connections_.end()) { - platform::StopWatchingUdpSocketReadable(waiter_, socket); + network_runner_->CancelRead(socket); auto socket_it = std::find_if(sockets_.begin(), sockets_.end(), [socket](const platform::UdpSocketUniquePtr& s) { @@ -226,4 +212,19 @@ void QuicConnectionFactoryImpl::OnConnectionClosed(QuicConnection* connection) { } } +void QuicConnectionFactoryImpl::OnError(platform::UdpSocket* socket, + Error error) { + OSP_UNIMPLEMENTED(); +} + +void QuicConnectionFactoryImpl::OnSendError(platform::UdpSocket* socket, + Error error) { + OSP_UNIMPLEMENTED(); +} + +void QuicConnectionFactoryImpl::OnRead(platform::UdpSocket* socket, + ErrorOr<platform::UdpPacket> packet) { + OSP_UNIMPLEMENTED(); +} + } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory_impl.h b/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory_impl.h index 8771eacd42a..4bc62452ffa 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory_impl.h +++ b/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_factory_impl.h @@ -9,7 +9,6 @@ #include <memory> #include "osp/impl/quic/quic_connection_factory.h" -#include "platform/api/event_waiter.h" #include "platform/api/udp_socket.h" #include "platform/base/ip_address.h" #include "third_party/chromium_quic/src/base/at_exit.h" @@ -22,13 +21,22 @@ class QuicTaskRunner; class QuicConnectionFactoryImpl final : public QuicConnectionFactory { public: - QuicConnectionFactoryImpl(); + QuicConnectionFactoryImpl(platform::NetworkRunner* network_runner); ~QuicConnectionFactoryImpl() override; + // UdpReadCallback overrides. + void OnRead(platform::UdpPacket data, + platform::NetworkRunner* network_runner) override; + + // UdpSocket::Client overrides. + void OnError(platform::UdpSocket* socket, Error error) override; + void OnSendError(platform::UdpSocket* socket, Error error) override; + void OnRead(platform::UdpSocket* socket, + ErrorOr<platform::UdpPacket> packet) override; + // QuicConnectionFactory overrides. void SetServerDelegate(ServerDelegate* delegate, const std::vector<IPEndpoint>& endpoints) override; - void RunTasks() override; std::unique_ptr<QuicConnection> Connect( const IPEndpoint& endpoint, QuicConnection::Delegate* connection_delegate) override; @@ -45,13 +53,17 @@ class QuicConnectionFactoryImpl final : public QuicConnectionFactory { std::vector<platform::UdpSocketUniquePtr> sockets_; - platform::EventWaiterPtr waiter_; - struct OpenConnection { QuicConnection* connection; platform::UdpSocket* socket; // References one of the owned |sockets_|. }; std::map<IPEndpoint, OpenConnection, IPEndpointComparator> connections_; + + // Network runner to use for network operations. + // NOTE: Must be provided in constructor and stored as an instance variable + // rather than using the static accessor method to allow for UTs to mock this + // layer. + platform::NetworkRunner* const network_runner_; }; } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_impl.cc b/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_impl.cc index f07c5bf9728..36f42e27d80 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_impl.cc +++ b/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_impl.cc @@ -10,6 +10,7 @@ #include "absl/types/optional.h" #include "osp/impl/quic/quic_connection_factory_impl.h" #include "platform/api/logging.h" +#include "platform/api/trace_logging.h" #include "platform/base/error.h" #include "third_party/chromium_quic/src/net/third_party/quic/platform/impl/quic_chromium_clock.h" @@ -29,6 +30,7 @@ UdpTransport& UdpTransport::operator=(UdpTransport&&) noexcept = default; int UdpTransport::Write(const char* buffer, size_t buffer_length, const PacketInfo& info) { + TRACE_SCOPED(TraceCategory::Quic, "UdpTransport::Write"); switch (socket_->SendMessage(buffer, buffer_length, destination_).code()) { case Error::Code::kNone: OSP_DCHECK_LE(buffer_length, @@ -50,6 +52,7 @@ QuicStreamImpl::QuicStreamImpl(QuicStream::Delegate* delegate, QuicStreamImpl::~QuicStreamImpl() = default; void QuicStreamImpl::Write(const uint8_t* data, size_t data_size) { + TRACE_SCOPED(TraceCategory::Quic, "QuicStreamImpl::Write"); OSP_DCHECK(!stream_->write_side_closed()); stream_->WriteOrBufferData( ::quic::QuicStringPiece(reinterpret_cast<const char*>(data), data_size), @@ -57,6 +60,7 @@ void QuicStreamImpl::Write(const uint8_t* data, size_t data_size) { } void QuicStreamImpl::CloseWriteEnd() { + TRACE_SCOPED(TraceCategory::Quic, "QuicStreamImpl::CloseWriteEnd"); if (!stream_->write_side_closed()) stream_->FinishWriting(); } @@ -64,15 +68,27 @@ void QuicStreamImpl::CloseWriteEnd() { void QuicStreamImpl::OnReceived(::quic::QuartcStream* stream, const char* data, size_t data_size) { + TRACE_SCOPED(TraceCategory::Quic, "QuicStreamImpl::OnReceived"); delegate_->OnReceived(this, data, data_size); } void QuicStreamImpl::OnClose(::quic::QuartcStream* stream) { + TRACE_SCOPED(TraceCategory::Quic, "QuicStreamImpl::OnClose"); delegate_->OnClose(stream->id()); } void QuicStreamImpl::OnBufferChanged(::quic::QuartcStream* stream) {} +// Passes a received UDP packet to the QUIC implementation. If this contains +// any stream data, it will be passed automatically to the relevant +// QuicStream::Delegate objects. +void QuicConnectionImpl::OnRead(platform::UdpPacket data, + platform::NetworkRunner* network_runner) { + TRACE_SCOPED(TraceCategory::Quic, "QuicConnectionImpl::OnRead"); + session_->OnTransportReceived(reinterpret_cast<const char*>(data.data()), + data.size()); +} + QuicConnectionImpl::QuicConnectionImpl( QuicConnectionFactoryImpl* parent_factory, QuicConnection::Delegate* delegate, @@ -82,6 +98,7 @@ QuicConnectionImpl::QuicConnectionImpl( parent_factory_(parent_factory), session_(std::move(session)), udp_transport_(std::move(udp_transport)) { + TRACE_SCOPED(TraceCategory::Quic, "QuicConnectionImpl::QuicConnectionImpl"); session_->SetDelegate(this); session_->OnTransportCanWrite(); session_->StartCryptoHandshake(); @@ -89,26 +106,26 @@ QuicConnectionImpl::QuicConnectionImpl( QuicConnectionImpl::~QuicConnectionImpl() = default; -void QuicConnectionImpl::OnDataReceived(const platform::UdpPacket& packet) { - session_->OnTransportReceived(reinterpret_cast<const char*>(packet.data()), - packet.size()); -} - std::unique_ptr<QuicStream> QuicConnectionImpl::MakeOutgoingStream( QuicStream::Delegate* delegate) { + TRACE_SCOPED(TraceCategory::Quic, "QuicConnectionImpl::MakeOutgoingStream"); ::quic::QuartcStream* stream = session_->CreateOutgoingDynamicStream(); return std::make_unique<QuicStreamImpl>(delegate, stream); } void QuicConnectionImpl::Close() { + TRACE_SCOPED(TraceCategory::Quic, "QuicConnectionImpl::Close"); session_->CloseConnection("closed"); } void QuicConnectionImpl::OnCryptoHandshakeComplete() { + TRACE_SCOPED(TraceCategory::Quic, + "QuicConnectionImpl::OnCryptoHandshakeComplete"); delegate_->OnCryptoHandshakeComplete(session_->connection_id()); } void QuicConnectionImpl::OnIncomingStream(::quic::QuartcStream* stream) { + TRACE_SCOPED(TraceCategory::Quic, "QuicConnectionImpl::OnIncomingStream"); auto public_stream = std::make_unique<QuicStreamImpl>( delegate_->NextStreamDelegate(session_->connection_id(), stream->id()), stream); @@ -121,6 +138,7 @@ void QuicConnectionImpl::OnConnectionClosed( ::quic::QuicErrorCode error_code, const ::quic::QuicString& error_details, ::quic::ConnectionCloseSource source) { + TRACE_SCOPED(TraceCategory::Quic, "QuicConnectionImpl::OnConnectionClosed"); parent_factory_->OnConnectionClosed(this); delegate_->OnConnectionClosed(session_->connection_id()); } diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_impl.h b/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_impl.h index 2288fafdbdd..479b3f0e0c7 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_impl.h +++ b/chromium/third_party/openscreen/src/osp/impl/quic/quic_connection_impl.h @@ -74,8 +74,11 @@ class QuicConnectionImpl final : public QuicConnection, ~QuicConnectionImpl() override; + // UdpReadCallback overrides. + void OnRead(platform::UdpPacket data, + platform::NetworkRunner* network_runner) override; + // QuicConnection overrides. - void OnDataReceived(const platform::UdpPacket& packet) override; std::unique_ptr<QuicStream> MakeOutgoingStream( QuicStream::Delegate* delegate) override; void Close() override; diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/quic_server.cc b/chromium/third_party/openscreen/src/osp/impl/quic/quic_server.cc index 3f94a136fce..7aa00a9df7c 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/quic_server.cc +++ b/chromium/third_party/openscreen/src/osp/impl/quic/quic_server.cc @@ -4,9 +4,13 @@ #include "osp/impl/quic/quic_server.h" +#include <functional> #include <memory> +#include "absl/types/optional.h" #include "platform/api/logging.h" +#include "platform/api/task_runner.h" +#include "platform/api/time.h" namespace openscreen { @@ -14,10 +18,16 @@ QuicServer::QuicServer( const ServerConfig& config, MessageDemuxer* demuxer, std::unique_ptr<QuicConnectionFactory> connection_factory, - ProtocolConnectionServer::Observer* observer) + ProtocolConnectionServer::Observer* observer, + platform::TaskRunner* task_runner) : ProtocolConnectionServer(demuxer, observer), connection_endpoints_(config.connection_endpoints), - connection_factory_(std::move(connection_factory)) {} + connection_factory_(std::move(connection_factory)) { + if (task_runner != nullptr) { + platform::RepeatingFunction::Post(task_runner, + std::bind(&QuicServer::Cleanup, this)); + } +} QuicServer::~QuicServer() { CloseAllConnections(); @@ -59,9 +69,7 @@ bool QuicServer::Resume() { return true; } -void QuicServer::RunTasks() { - if (state_ == State::kRunning) - connection_factory_->RunTasks(); +absl::optional<platform::Clock::duration> QuicServer::Cleanup() { for (auto& entry : connections_) entry.second.delegate->DestroyClosedStreams(); @@ -69,15 +77,23 @@ void QuicServer::RunTasks() { connections_.erase(entry); delete_connections_.clear(); + + constexpr platform::Clock::duration kQuicCleanupFrequency = + std::chrono::milliseconds(500); + return state_ == State::kStopped + ? absl::optional<platform::Clock::duration>(absl::nullopt) + : absl::optional<platform::Clock::duration>(kQuicCleanupFrequency); } std::unique_ptr<ProtocolConnection> QuicServer::CreateProtocolConnection( uint64_t endpoint_id) { - if (state_ != State::kRunning) + if (state_ != State::kRunning) { return nullptr; + } auto connection_entry = connections_.find(endpoint_id); - if (connection_entry == connections_.end()) + if (connection_entry == connections_.end()) { return nullptr; + } return QuicProtocolConnection::FromExisting( this, connection_entry->second.connection.get(), connection_entry->second.delegate.get(), endpoint_id); diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/quic_server.h b/chromium/third_party/openscreen/src/osp/impl/quic/quic_server.h index 289ccbf1cb3..964a0ff2961 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/quic_server.h +++ b/chromium/third_party/openscreen/src/osp/impl/quic/quic_server.h @@ -12,6 +12,8 @@ #include "osp/impl/quic/quic_connection_factory.h" #include "osp/impl/quic/quic_service_common.h" #include "osp/public/protocol_connection_server.h" +#include "platform/api/task_runner.h" +#include "platform/api/time.h" #include "platform/base/ip_address.h" namespace openscreen { @@ -32,7 +34,8 @@ class QuicServer final : public ProtocolConnectionServer, QuicServer(const ServerConfig& config, MessageDemuxer* demuxer, std::unique_ptr<QuicConnectionFactory> connection_factory, - ProtocolConnectionServer::Observer* observer); + ProtocolConnectionServer::Observer* observer, + platform::TaskRunner* task_runner); ~QuicServer() override; // ProtocolConnectionServer overrides. @@ -40,7 +43,6 @@ class QuicServer final : public ProtocolConnectionServer, bool Stop() override; bool Suspend() override; bool Resume() override; - void RunTasks() override; std::unique_ptr<ProtocolConnection> CreateProtocolConnection( uint64_t endpoint_id) override; @@ -68,6 +70,10 @@ class QuicServer final : public ProtocolConnectionServer, void OnIncomingConnection( std::unique_ptr<QuicConnection> connection) override; + // Deletes dead QUIC connections then returns the time interval before this + // method should be run again. + absl::optional<platform::Clock::duration> Cleanup(); + const std::vector<IPEndpoint> connection_endpoints_; std::unique_ptr<QuicConnectionFactory> connection_factory_; diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/quic_server_unittest.cc b/chromium/third_party/openscreen/src/osp/impl/quic/quic_server_unittest.cc index 7442f44b0f9..d433e8d463c 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/quic_server_unittest.cc +++ b/chromium/third_party/openscreen/src/osp/impl/quic/quic_server_unittest.cc @@ -15,6 +15,7 @@ #include "osp/public/testing/message_demuxer_test_support.h" #include "platform/base/error.h" #include "platform/test/fake_clock.h" +#include "platform/test/fake_network_runner.h" namespace openscreen { namespace { @@ -45,27 +46,34 @@ class MockConnectionObserver final : public ProtocolConnection::Observer { }; class QuicServerTest : public Test { + public: + QuicServerTest() { + network_runner_ = std::make_unique<platform::FakeNetworkRunner>(); + quic_bridge_ = std::make_unique<FakeQuicBridge>(network_runner_.get(), + platform::FakeClock::now); + } + protected: std::unique_ptr<ProtocolConnection> ExpectIncomingConnection() { MockConnectRequest mock_connect_request; NetworkServiceManager::Get()->GetProtocolConnectionClient()->Connect( - quic_bridge_.kReceiverEndpoint, &mock_connect_request); + quic_bridge_->kReceiverEndpoint, &mock_connect_request); std::unique_ptr<ProtocolConnection> stream; EXPECT_CALL(mock_connect_request, OnConnectionOpenedMock()); - EXPECT_CALL(quic_bridge_.mock_server_observer, OnIncomingConnectionMock(_)) + EXPECT_CALL(quic_bridge_->mock_server_observer, OnIncomingConnectionMock(_)) .WillOnce( Invoke([&stream](std::unique_ptr<ProtocolConnection>& connection) { stream = std::move(connection); })); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); return stream; } void SetUp() override { - server_ = quic_bridge_.quic_server.get(); + server_ = quic_bridge_->quic_server.get(); NetworkServiceManager::Create(nullptr, nullptr, - std::move(quic_bridge_.quic_client), - std::move(quic_bridge_.quic_server)); + std::move(quic_bridge_->quic_client), + std::move(quic_bridge_->quic_server)); } void TearDown() override { NetworkServiceManager::Dispose(); } @@ -73,7 +81,7 @@ class QuicServerTest : public Test { void SendTestMessage(ProtocolConnection* connection) { MockMessageCallback mock_message_callback; MessageDemuxer::MessageWatch message_watch = - quic_bridge_.controller_demuxer->WatchMessageType( + quic_bridge_->controller_demuxer->WatchMessageType( 0, msgs::Type::kPresentationConnectionMessage, &mock_message_callback); @@ -102,7 +110,7 @@ class QuicServerTest : public Test { return ErrorOr<size_t>(Error::Code::kCborParsing); return ErrorOr<size_t>(decode_result); })); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); ASSERT_GT(decode_result, 0); EXPECT_EQ(decode_result, static_cast<ssize_t>(buffer.size() - 1)); @@ -112,9 +120,10 @@ class QuicServerTest : public Test { EXPECT_EQ(received_message.message.str, message.message.str); } + std::unique_ptr<platform::FakeNetworkRunner> network_runner_; platform::FakeClock fake_clock_{ platform::Clock::time_point(std::chrono::milliseconds(1298424))}; - FakeQuicBridge quic_bridge_{platform::FakeClock::now}; + std::unique_ptr<FakeQuicBridge> quic_bridge_; QuicServer* server_; }; @@ -145,7 +154,7 @@ TEST_F(QuicServerTest, OpenImmediate) { TEST_F(QuicServerTest, States) { server_->Stop(); - EXPECT_CALL(quic_bridge_.mock_server_observer, OnRunning()); + EXPECT_CALL(quic_bridge_->mock_server_observer, OnRunning()); EXPECT_TRUE(server_->Start()); EXPECT_FALSE(server_->Start()); @@ -155,27 +164,27 @@ TEST_F(QuicServerTest, States) { connection->SetObserver(&mock_connection_observer); EXPECT_CALL(mock_connection_observer, OnConnectionClosed(_)); - EXPECT_CALL(quic_bridge_.mock_server_observer, OnStopped()); + EXPECT_CALL(quic_bridge_->mock_server_observer, OnStopped()); EXPECT_TRUE(server_->Stop()); EXPECT_FALSE(server_->Stop()); - EXPECT_CALL(quic_bridge_.mock_server_observer, OnRunning()); + EXPECT_CALL(quic_bridge_->mock_server_observer, OnRunning()); EXPECT_TRUE(server_->Start()); - EXPECT_CALL(quic_bridge_.mock_server_observer, OnSuspended()); + EXPECT_CALL(quic_bridge_->mock_server_observer, OnSuspended()); EXPECT_TRUE(server_->Suspend()); EXPECT_FALSE(server_->Suspend()); EXPECT_FALSE(server_->Start()); - EXPECT_CALL(quic_bridge_.mock_server_observer, OnRunning()); + EXPECT_CALL(quic_bridge_->mock_server_observer, OnRunning()); EXPECT_TRUE(server_->Resume()); EXPECT_FALSE(server_->Resume()); EXPECT_FALSE(server_->Start()); - EXPECT_CALL(quic_bridge_.mock_server_observer, OnSuspended()); + EXPECT_CALL(quic_bridge_->mock_server_observer, OnSuspended()); EXPECT_TRUE(server_->Suspend()); - EXPECT_CALL(quic_bridge_.mock_server_observer, OnStopped()); + EXPECT_CALL(quic_bridge_->mock_server_observer, OnStopped()); EXPECT_TRUE(server_->Stop()); } @@ -189,7 +198,7 @@ TEST_F(QuicServerTest, RequestIds) { connection->CloseWriteEnd(); connection.reset(); - quic_bridge_.RunTasksUntilIdle(); + quic_bridge_->RunTasksUntilIdle(); EXPECT_EQ(5u, server_->endpoint_request_ids()->GetNextRequestId(endpoint_id)); server_->Stop(); diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection.cc b/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection.cc index 06b827954eb..ec206e85532 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection.cc +++ b/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection.cc @@ -60,7 +60,8 @@ std::unique_ptr<FakeQuicStream> FakeQuicConnection::MakeIncomingStream() { return result; } -void FakeQuicConnection::OnDataReceived(const platform::UdpPacket& packet) { +void FakeQuicConnection::OnRead(platform::UdpPacket data, + platform::NetworkRunner* network_runner) { OSP_DCHECK(false) << "data should go directly to fake streams"; } diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection.h b/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection.h index 90833c48044..0c154527c50 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection.h +++ b/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection.h @@ -57,7 +57,8 @@ class FakeQuicConnection final : public QuicConnection { std::unique_ptr<FakeQuicStream> MakeIncomingStream(); // QuicConnection overrides. - void OnDataReceived(const platform::UdpPacket& packet) override; + void OnRead(platform::UdpPacket data, + platform::NetworkRunner* network_runner) override; std::unique_ptr<QuicStream> MakeOutgoingStream( QuicStream::Delegate* delegate) override; void Close() override; diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection_factory.cc b/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection_factory.cc index abb3e13eef7..24b501fc127 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection_factory.cc +++ b/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection_factory.cc @@ -51,13 +51,16 @@ void FakeQuicConnectionFactoryBridge::SetServerDelegate( receiver_endpoint_ = endpoint; } -void FakeQuicConnectionFactoryBridge::RunTasks() { - idle_ = true; - if (!connections_.controller || !connections_.receiver) +void FakeQuicConnectionFactoryBridge::RunTasks(bool is_client) { + bool* idle_flag = is_client ? &client_idle_ : &server_idle_; + *idle_flag = true; + + if (!connections_.controller || !connections_.receiver) { return; + } if (connections_pending_) { - idle_ = false; + *idle_flag = false; connections_.receiver->delegate()->OnCryptoHandshakeComplete( connections_.receiver->id()); connections_.controller->delegate()->OnCryptoHandshakeComplete( @@ -80,9 +83,8 @@ void FakeQuicConnectionFactoryBridge::RunTasks() { std::vector<uint8_t> written_data = controller_stream->TakeWrittenData(); OSP_DCHECK(controller_stream->TakeReceivedData().empty()); - // TODO(jophba): Move to a task runner here. if (!written_data.empty()) { - idle_ = false; + *idle_flag = false; receiver_stream->delegate()->OnReceived( receiver_stream, reinterpret_cast<const char*>(written_data.data()), written_data.size()); @@ -92,7 +94,7 @@ void FakeQuicConnectionFactoryBridge::RunTasks() { OSP_DCHECK(receiver_stream->TakeReceivedData().empty()); if (written_data.size()) { - idle_ = false; + *idle_flag = false; controller_stream->delegate()->OnReceived( controller_stream, reinterpret_cast<const char*>(written_data.data()), written_data.size()); @@ -158,8 +160,11 @@ void FakeClientQuicConnectionFactory::SetServerDelegate( OSP_DCHECK(false) << "don't call SetServerDelegate from QuicClient side"; } -void FakeClientQuicConnectionFactory::RunTasks() { - bridge_->RunTasks(); +void FakeClientQuicConnectionFactory::OnRead( + platform::UdpPacket data, + platform::NetworkRunner* network_runner) { + bridge_->RunTasks(true); + idle_ = bridge_->client_idle(); } std::unique_ptr<QuicConnection> FakeClientQuicConnectionFactory::Connect( @@ -168,6 +173,22 @@ std::unique_ptr<QuicConnection> FakeClientQuicConnectionFactory::Connect( return bridge_->Connect(endpoint, connection_delegate); } +void FakeClientQuicConnectionFactory::OnError(platform::UdpSocket* socket, + Error error) { + OSP_UNIMPLEMENTED(); +} + +void FakeClientQuicConnectionFactory::OnSendError(platform::UdpSocket* socket, + Error error) { + OSP_UNIMPLEMENTED(); +} + +void FakeClientQuicConnectionFactory::OnRead( + platform::UdpSocket* socket, + ErrorOr<platform::UdpPacket> packet) { + OSP_UNIMPLEMENTED(); +} + FakeServerQuicConnectionFactory::FakeServerQuicConnectionFactory( FakeQuicConnectionFactoryBridge* bridge) : bridge_(bridge) {} @@ -184,8 +205,11 @@ void FakeServerQuicConnectionFactory::SetServerDelegate( endpoints.empty() ? IPEndpoint{} : endpoints[0]); } -void FakeServerQuicConnectionFactory::RunTasks() { - bridge_->RunTasks(); +void FakeServerQuicConnectionFactory::OnRead( + platform::UdpPacket data, + platform::NetworkRunner* network_runner) { + bridge_->RunTasks(false); + idle_ = bridge_->server_idle(); } std::unique_ptr<QuicConnection> FakeServerQuicConnectionFactory::Connect( @@ -195,4 +219,20 @@ std::unique_ptr<QuicConnection> FakeServerQuicConnectionFactory::Connect( return nullptr; } +void FakeServerQuicConnectionFactory::OnError(platform::UdpSocket* socket, + Error error) { + OSP_UNIMPLEMENTED(); +} + +void FakeServerQuicConnectionFactory::OnSendError(platform::UdpSocket* socket, + Error error) { + OSP_UNIMPLEMENTED(); +} + +void FakeServerQuicConnectionFactory::OnRead( + platform::UdpSocket* socket, + ErrorOr<platform::UdpPacket> packet) { + OSP_UNIMPLEMENTED(); +} + } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection_factory.h b/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection_factory.h index c2fa364c546..f8458922dd7 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection_factory.h +++ b/chromium/third_party/openscreen/src/osp/impl/quic/testing/fake_quic_connection_factory.h @@ -7,6 +7,7 @@ #include <vector> +#include "gmock/gmock.h" #include "osp/impl/quic/quic_connection_factory.h" #include "osp/impl/quic/testing/fake_quic_connection.h" #include "osp/public/message_demuxer.h" @@ -15,17 +16,17 @@ namespace openscreen { class FakeQuicConnectionFactoryBridge { public: - explicit FakeQuicConnectionFactoryBridge( - const IPEndpoint& controller_endpoint); + FakeQuicConnectionFactoryBridge(const IPEndpoint& controller_endpoint); - bool idle() const { return idle_; } + bool server_idle() const { return server_idle_; } + bool client_idle() const { return client_idle_; } void OnConnectionClosed(QuicConnection* connection); void OnOutgoingStream(QuicConnection* connection, QuicStream* stream); void SetServerDelegate(QuicConnectionFactory::ServerDelegate* delegate, const IPEndpoint& endpoint); - void RunTasks(); + void RunTasks(bool is_client); std::unique_ptr<QuicConnection> Connect( const IPEndpoint& endpoint, QuicConnection::Delegate* connection_delegate); @@ -38,7 +39,8 @@ class FakeQuicConnectionFactoryBridge { const IPEndpoint controller_endpoint_; IPEndpoint receiver_endpoint_; - bool idle_ = true; + bool client_idle_ = true; + bool server_idle_ = true; uint64_t next_connection_id_ = 0; bool connections_pending_ = true; ConnectionPair connections_ = {}; @@ -51,16 +53,30 @@ class FakeClientQuicConnectionFactory final : public QuicConnectionFactory { FakeQuicConnectionFactoryBridge* bridge); ~FakeClientQuicConnectionFactory() override; + // UdpReadCallback overrides. + void OnRead(platform::UdpPacket data, + platform::NetworkRunner* network_runner) override; + + // UdpSocket::Client overrides. + void OnError(platform::UdpSocket* socket, Error error) override; + void OnSendError(platform::UdpSocket* socket, Error error) override; + void OnRead(platform::UdpSocket* socket, + ErrorOr<platform::UdpPacket> packet) override; + // QuicConnectionFactory overrides. void SetServerDelegate(ServerDelegate* delegate, const std::vector<IPEndpoint>& endpoints) override; - void RunTasks() override; std::unique_ptr<QuicConnection> Connect( const IPEndpoint& endpoint, QuicConnection::Delegate* connection_delegate) override; + bool idle() const { return idle_; } + + std::unique_ptr<platform::UdpSocket> socket_; + private: FakeQuicConnectionFactoryBridge* bridge_; + bool idle_ = true; }; class FakeServerQuicConnectionFactory final : public QuicConnectionFactory { @@ -69,16 +85,28 @@ class FakeServerQuicConnectionFactory final : public QuicConnectionFactory { FakeQuicConnectionFactoryBridge* bridge); ~FakeServerQuicConnectionFactory() override; + // UdpReadCallback overrides. + void OnRead(platform::UdpPacket data, + platform::NetworkRunner* network_runner) override; + + // UdpSocket::Client overrides. + void OnError(platform::UdpSocket* socket, Error error) override; + void OnSendError(platform::UdpSocket* socket, Error error) override; + void OnRead(platform::UdpSocket* socket, + ErrorOr<platform::UdpPacket> packet) override; + // QuicConnectionFactory overrides. void SetServerDelegate(ServerDelegate* delegate, const std::vector<IPEndpoint>& endpoints) override; - void RunTasks() override; std::unique_ptr<QuicConnection> Connect( const IPEndpoint& endpoint, QuicConnection::Delegate* connection_delegate) override; + bool idle() const { return idle_; } + private: FakeQuicConnectionFactoryBridge* bridge_; + bool idle_ = true; }; } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/testing/quic_test_support.cc b/chromium/third_party/openscreen/src/osp/impl/quic/testing/quic_test_support.cc index 290e46fbdce..c9997119771 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/testing/quic_test_support.cc +++ b/chromium/third_party/openscreen/src/osp/impl/quic/testing/quic_test_support.cc @@ -12,7 +12,9 @@ namespace openscreen { -FakeQuicBridge::FakeQuicBridge(platform::ClockNowFunctionPtr now_function) { +FakeQuicBridge::FakeQuicBridge(platform::FakeNetworkRunner* network_runner, + platform::ClockNowFunctionPtr now_function) + : network_runner_(network_runner) { fake_bridge = std::make_unique<FakeQuicConnectionFactoryBridge>(kControllerEndpoint); @@ -23,17 +25,28 @@ FakeQuicBridge::FakeQuicBridge(platform::ClockNowFunctionPtr now_function) { auto fake_client_factory = std::make_unique<FakeClientQuicConnectionFactory>(fake_bridge.get()); + client_socket_ = std::make_unique<platform::MockUdpSocket>( + network_runner_, fake_client_factory.get()); + network_runner_->ReadRepeatedly(client_socket_.get(), + fake_client_factory.get()); + + // TODO(rwkeane): Pass actual network runner instead of nullptr once the fake + // network runner correctly respects the time delay for delayed tasks. quic_client = std::make_unique<QuicClient>(controller_demuxer.get(), std::move(fake_client_factory), - &mock_client_observer); + &mock_client_observer, nullptr); auto fake_server_factory = std::make_unique<FakeServerQuicConnectionFactory>(fake_bridge.get()); + server_socket_ = std::make_unique<platform::MockUdpSocket>( + network_runner_, fake_server_factory.get()); + network_runner_->ReadRepeatedly(server_socket_.get(), + fake_server_factory.get()); ServerConfig config; config.connection_endpoints.push_back(kReceiverEndpoint); quic_server = std::make_unique<QuicServer>(config, receiver_demuxer.get(), std::move(fake_server_factory), - &mock_server_observer); + &mock_server_observer, nullptr); quic_client->Start(); quic_server->Start(); @@ -41,15 +54,34 @@ FakeQuicBridge::FakeQuicBridge(platform::ClockNowFunctionPtr now_function) { FakeQuicBridge::~FakeQuicBridge() = default; +void FakeQuicBridge::PostClientPacket() { + platform::UdpPacket packet; + packet.set_socket(client_socket_.get()); + network_runner_->PostNewPacket(std::move(packet)); +} + +void FakeQuicBridge::PostServerPacket() { + platform::UdpPacket packet; + packet.set_socket(server_socket_.get()); + network_runner_->PostNewPacket(std::move(packet)); +} + +void FakeQuicBridge::PostPacketsUntilIdle() { + bool client_idle = fake_bridge->client_idle(); + bool server_idle = fake_bridge->server_idle(); + if (!client_idle || !server_idle) { + PostClientPacket(); + PostServerPacket(); + network_runner_->PostTask([this]() { this->PostPacketsUntilIdle(); }); + } +} + void FakeQuicBridge::RunTasksUntilIdle() { - bool client_idle = true; - bool server_idle = true; - do { - NetworkServiceManager::Get()->GetProtocolConnectionClient()->RunTasks(); - client_idle = fake_bridge->idle(); - NetworkServiceManager::Get()->GetProtocolConnectionServer()->RunTasks(); - server_idle = fake_bridge->idle(); - } while (!client_idle || !server_idle); + PostClientPacket(); + PostServerPacket(); + network_runner_->PostTask( + std::bind(&FakeQuicBridge::PostPacketsUntilIdle, this)); + network_runner_->RunTasksUntilIdle(); } } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/osp/impl/quic/testing/quic_test_support.h b/chromium/third_party/openscreen/src/osp/impl/quic/testing/quic_test_support.h index e3c261bf161..64cdca04fe8 100644 --- a/chromium/third_party/openscreen/src/osp/impl/quic/testing/quic_test_support.h +++ b/chromium/third_party/openscreen/src/osp/impl/quic/testing/quic_test_support.h @@ -16,8 +16,12 @@ #include "osp/public/protocol_connection_client.h" #include "osp/public/protocol_connection_server.h" #include "platform/api/time.h" +#include "platform/api/udp_socket.h" #include "platform/base/ip_address.h" +#include "platform/impl/network_runner.h" #include "platform/test/fake_clock.h" +#include "platform/test/fake_network_runner.h" +#include "platform/test/mock_udp_socket.h" namespace openscreen { @@ -52,11 +56,10 @@ class MockServerObserver : public ProtocolConnectionServer::Observer { class FakeQuicBridge { public: - explicit FakeQuicBridge(platform::ClockNowFunctionPtr now_function); + FakeQuicBridge(platform::FakeNetworkRunner* network_runner, + platform::ClockNowFunctionPtr now_function); ~FakeQuicBridge(); - void RunTasksUntilIdle(); - const IPEndpoint kControllerEndpoint{{192, 168, 1, 3}, 4321}; const IPEndpoint kReceiverEndpoint{{192, 168, 1, 17}, 1234}; @@ -67,6 +70,19 @@ class FakeQuicBridge { std::unique_ptr<FakeQuicConnectionFactoryBridge> fake_bridge; ::testing::NiceMock<MockServiceObserver> mock_client_observer; ::testing::NiceMock<MockServerObserver> mock_server_observer; + + void RunTasksUntilIdle(); + + private: + void PostClientPacket(); + void PostServerPacket(); + void PostPacketsUntilIdle(); + FakeClientQuicConnectionFactory* GetClientFactory(); + FakeServerQuicConnectionFactory* GetServerFactory(); + platform::FakeNetworkRunner* network_runner_; + + std::unique_ptr<platform::MockUdpSocket> client_socket_; + std::unique_ptr<platform::MockUdpSocket> server_socket_; }; } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/osp/impl/receiver_list.cc b/chromium/third_party/openscreen/src/osp/impl/receiver_list.cc index a4f11047989..522a3f0e430 100644 --- a/chromium/third_party/openscreen/src/osp/impl/receiver_list.cc +++ b/chromium/third_party/openscreen/src/osp/impl/receiver_list.cc @@ -21,7 +21,7 @@ Error ReceiverList::OnReceiverChanged(const ServiceInfo& info) { return x.service_id == info.service_id; }); if (existing_info == receivers_.end()) - return Error::Code::kNoItemFound; + return Error::Code::kItemNotFound; *existing_info = info; return Error::None(); @@ -30,7 +30,7 @@ Error ReceiverList::OnReceiverChanged(const ServiceInfo& info) { Error ReceiverList::OnReceiverRemoved(const ServiceInfo& info) { const auto it = std::remove(receivers_.begin(), receivers_.end(), info); if (it == receivers_.end()) - return Error::Code::kNoItemFound; + return Error::Code::kItemNotFound; receivers_.erase(it, receivers_.end()); return Error::None(); @@ -39,7 +39,7 @@ Error ReceiverList::OnReceiverRemoved(const ServiceInfo& info) { Error ReceiverList::OnAllReceiversRemoved() { const auto empty = receivers_.empty(); receivers_.clear(); - return empty ? Error::Code::kNoItemFound : Error::None(); + return empty ? Error::Code::kItemNotFound : Error::None(); } } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/osp/impl/service_listener_impl.cc b/chromium/third_party/openscreen/src/osp/impl/service_listener_impl.cc index 8ba3473ffb5..a32ae9f7ca0 100644 --- a/chromium/third_party/openscreen/src/osp/impl/service_listener_impl.cc +++ b/chromium/third_party/openscreen/src/osp/impl/service_listener_impl.cc @@ -150,10 +150,6 @@ bool ServiceListenerImpl::SearchNow() { return true; } -void ServiceListenerImpl::RunTasks() { - delegate_->RunTasksListener(); -} - void ServiceListenerImpl::AddObserver(Observer* observer) { OSP_DCHECK(observer); observers_.push_back(observer); diff --git a/chromium/third_party/openscreen/src/osp/impl/service_listener_impl.h b/chromium/third_party/openscreen/src/osp/impl/service_listener_impl.h index 04f9dd04ede..5763366dfd4 100644 --- a/chromium/third_party/openscreen/src/osp/impl/service_listener_impl.h +++ b/chromium/third_party/openscreen/src/osp/impl/service_listener_impl.h @@ -31,7 +31,6 @@ class ServiceListenerImpl final : public ServiceListener, virtual void SuspendListener() = 0; virtual void ResumeListener() = 0; virtual void SearchNow(State from) = 0; - virtual void RunTasksListener() = 0; protected: void SetState(State state) { listener_->SetState(state); } @@ -60,8 +59,6 @@ class ServiceListenerImpl final : public ServiceListener, bool Resume() override; bool SearchNow() override; - void RunTasks() override; - void AddObserver(Observer* observer) override; void RemoveObserver(Observer* observer) override; diff --git a/chromium/third_party/openscreen/src/osp/impl/service_publisher_impl.cc b/chromium/third_party/openscreen/src/osp/impl/service_publisher_impl.cc index 808c3967930..6c2dcaf5385 100644 --- a/chromium/third_party/openscreen/src/osp/impl/service_publisher_impl.cc +++ b/chromium/third_party/openscreen/src/osp/impl/service_publisher_impl.cc @@ -87,10 +87,6 @@ bool ServicePublisherImpl::Resume() { return true; } -void ServicePublisherImpl::RunTasks() { - delegate_->RunTasksPublisher(); -} - void ServicePublisherImpl::SetState(State state) { OSP_DCHECK(IsTransitionValid(state_, state)); state_ = state; diff --git a/chromium/third_party/openscreen/src/osp/impl/service_publisher_impl.h b/chromium/third_party/openscreen/src/osp/impl/service_publisher_impl.h index b2805fc96e8..d08b1a0543b 100644 --- a/chromium/third_party/openscreen/src/osp/impl/service_publisher_impl.h +++ b/chromium/third_party/openscreen/src/osp/impl/service_publisher_impl.h @@ -26,7 +26,6 @@ class ServicePublisherImpl final : public ServicePublisher, virtual void StopPublisher() = 0; virtual void SuspendPublisher() = 0; virtual void ResumePublisher() = 0; - virtual void RunTasksPublisher() = 0; protected: void SetState(State state) { publisher_->SetState(state); } @@ -47,8 +46,6 @@ class ServicePublisherImpl final : public ServicePublisher, bool Suspend() override; bool Resume() override; - void RunTasks() override; - private: // Called by |delegate_| to transition the state machine (except kStarting and // kStopping which are done automatically). diff --git a/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_responder_adapter.cc b/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_responder_adapter.cc index 9f2064fe0af..56565ef2f7f 100644 --- a/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_responder_adapter.cc +++ b/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_responder_adapter.cc @@ -210,13 +210,13 @@ Error FakeMdnsResponderAdapter::RegisterInterface( const platform::IPSubnet& interface_address, platform::UdpSocket* socket) { if (!running_) - return Error::Code::kNotRunning; + return Error::Code::kOperationInvalid; if (std::find_if(registered_interfaces_.begin(), registered_interfaces_.end(), [&socket](const RegisteredInterface& interface) { return interface.socket == socket; }) != registered_interfaces_.end()) { - return Error::Code::kNoItemFound; + return Error::Code::kItemNotFound; } registered_interfaces_.push_back({interface_info, interface_address, socket}); return Error::None(); @@ -230,23 +230,19 @@ Error FakeMdnsResponderAdapter::DeregisterInterface( return interface.socket == socket; }); if (it == registered_interfaces_.end()) - return Error::Code::kNoItemFound; + return Error::Code::kItemNotFound; registered_interfaces_.erase(it); return Error::None(); } -void FakeMdnsResponderAdapter::OnDataReceived( - const IPEndpoint& source, - const IPEndpoint& original_destination, - const uint8_t* data, - size_t length, - platform::UdpSocket* receiving_socket) { +void FakeMdnsResponderAdapter::OnRead(platform::UdpPacket packet, + platform::NetworkRunner* network_runner) { OSP_CHECK(false) << "Tests should not drive this class with packets"; } -int FakeMdnsResponderAdapter::RunTasks() { - return 1; +absl::optional<platform::Clock::duration> FakeMdnsResponderAdapter::RunTasks() { + return absl::nullopt; } std::vector<mdns::PtrEvent> FakeMdnsResponderAdapter::TakePtrResponses() { diff --git a/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_responder_adapter.h b/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_responder_adapter.h index c4855974a06..8a8a7e50a0f 100644 --- a/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_responder_adapter.h +++ b/chromium/third_party/openscreen/src/osp/impl/testing/fake_mdns_responder_adapter.h @@ -97,6 +97,10 @@ class FakeMdnsResponderAdapter final : public mdns::MdnsResponderAdapter { bool aaaa_queries_empty() const; bool running() const { return running_; } + // UdpReadCallback overrides. + void OnRead(platform::UdpPacket packet, + platform::NetworkRunner* network_runner) override; + // mdns::MdnsResponderAdapter overrides. Error Init() override; void Close() override; @@ -110,13 +114,7 @@ class FakeMdnsResponderAdapter final : public mdns::MdnsResponderAdapter { platform::UdpSocket* socket) override; Error DeregisterInterface(platform::UdpSocket* socket) override; - void OnDataReceived(const IPEndpoint& source, - const IPEndpoint& original_destination, - const uint8_t* data, - size_t length, - platform::UdpSocket* receiving_socket) override; - - int RunTasks() override; + absl::optional<platform::Clock::duration> RunTasks() override; std::vector<mdns::PtrEvent> TakePtrResponses() override; std::vector<mdns::SrvEvent> TakeSrvResponses() override; diff --git a/chromium/third_party/openscreen/src/osp/msgs/BUILD.gn b/chromium/third_party/openscreen/src/osp/msgs/BUILD.gn index 324e838d357..898ce915675 100644 --- a/chromium/third_party/openscreen/src/osp/msgs/BUILD.gn +++ b/chromium/third_party/openscreen/src/osp/msgs/BUILD.gn @@ -40,19 +40,27 @@ action("cddl_gen") { outputs += [ root_gen_dir + "/" + o ] } - args = [ - "--header", - outputs_src[0], - "--cc", - outputs_src[1], - "--gen-dir", - rebase_path(root_gen_dir, root_build_dir), - "--log", - rebase_path("cddl.log", "//"), - ] + rebase_path(sources, root_build_dir) + cddl_label = "../../tools/cddl:cddl($host_toolchain)" + cddl_path = get_label_info(cddl_label, "root_out_dir") + "/cddl" + args = + [ + "--cddl", + + # Path should be rebased because |root_build_dir| for current toolchain + # may be different from |root_out_dir| of cddl built on host toolchain. + "./" + rebase_path(cddl_path, root_build_dir), + "--header", + outputs_src[0], + "--cc", + outputs_src[1], + "--gen-dir", + rebase_path(root_gen_dir, root_build_dir), + "--log", + rebase_path("cddl.log", "//"), + ] + rebase_path(sources, root_build_dir) deps = [ - "../../tools/cddl", + cddl_label, ] } diff --git a/chromium/third_party/openscreen/src/osp/public/mdns_service_listener_factory.h b/chromium/third_party/openscreen/src/osp/public/mdns_service_listener_factory.h index de4242678e1..df471fdb908 100644 --- a/chromium/third_party/openscreen/src/osp/public/mdns_service_listener_factory.h +++ b/chromium/third_party/openscreen/src/osp/public/mdns_service_listener_factory.h @@ -10,6 +10,9 @@ #include "osp/public/service_listener.h" namespace openscreen { +namespace platform { +class NetworkRunner; +} // namespace platform struct MdnsServiceListenerConfig { // TODO(mfoltz): Populate with actual parameters as implementation progresses. @@ -20,7 +23,8 @@ class MdnsServiceListenerFactory { public: static std::unique_ptr<ServiceListener> Create( const MdnsServiceListenerConfig& config, - ServiceListener::Observer* observer); + ServiceListener::Observer* observer, + platform::NetworkRunner* network_runner); }; } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/osp/public/mdns_service_publisher_factory.h b/chromium/third_party/openscreen/src/osp/public/mdns_service_publisher_factory.h index a2623404456..cdbb037763a 100644 --- a/chromium/third_party/openscreen/src/osp/public/mdns_service_publisher_factory.h +++ b/chromium/third_party/openscreen/src/osp/public/mdns_service_publisher_factory.h @@ -8,14 +8,19 @@ #include <memory> #include "osp/public/service_publisher.h" +#include "platform/api/network_runner.h" namespace openscreen { +namespace platform { +class NetworkRunner; +} // namespace platform class MdnsServicePublisherFactory { public: static std::unique_ptr<ServicePublisher> Create( const ServicePublisher::Config& config, - ServicePublisher::Observer* observer); + ServicePublisher::Observer* observer, + platform::NetworkRunner* network_runner); }; } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/osp/public/message_demuxer.cc b/chromium/third_party/openscreen/src/osp/public/message_demuxer.cc index c9ec82f1751..c1afed19d6a 100644 --- a/chromium/third_party/openscreen/src/osp/public/message_demuxer.cc +++ b/chromium/third_party/openscreen/src/osp/public/message_demuxer.cc @@ -159,14 +159,16 @@ MessageDemuxer::MessageWatch MessageDemuxer::SetDefaultMessageTypeWatch( if (!emplace_result.second) return MessageWatch(); for (auto& endpoint_buffers : buffers_) { - for (auto& buffer : endpoint_buffers.second) { - if (buffer.second.empty()) + auto endpoint_id = endpoint_buffers.first; + for (auto& stream_map : endpoint_buffers.second) { + if (stream_map.second.empty()) continue; - auto buffered_type = static_cast<msgs::Type>(buffer.second[0]); + auto buffered_type = static_cast<msgs::Type>(stream_map.second[0]); if (message_type == buffered_type) { - auto callbacks_entry = message_callbacks_.find(endpoint_buffers.first); - HandleStreamBufferLoop(endpoint_buffers.first, buffer.first, - callbacks_entry, &buffer.second); + auto connection_id = stream_map.first; + auto callbacks_entry = message_callbacks_.find(endpoint_id); + HandleStreamBufferLoop(endpoint_id, connection_id, callbacks_entry, + &stream_map.second); } } } diff --git a/chromium/third_party/openscreen/src/osp/public/message_demuxer.h b/chromium/third_party/openscreen/src/osp/public/message_demuxer.h index 1cbfcc5988d..5ff38c34337 100644 --- a/chromium/third_party/openscreen/src/osp/public/message_demuxer.h +++ b/chromium/third_party/openscreen/src/osp/public/message_demuxer.h @@ -112,6 +112,8 @@ class MessageDemuxer { const size_t buffer_limit_; std::map<uint64_t, std::map<msgs::Type, MessageCallback*>> message_callbacks_; std::map<msgs::Type, MessageCallback*> default_callbacks_; + + // Map<endpoint_id, Map<connection_id, data_buffer>> std::map<uint64_t, std::map<uint64_t, std::vector<uint8_t>>> buffers_; }; diff --git a/chromium/third_party/openscreen/src/osp/public/network_service_manager.h b/chromium/third_party/openscreen/src/osp/public/network_service_manager.h index 83972840324..e5f64a38694 100644 --- a/chromium/third_party/openscreen/src/osp/public/network_service_manager.h +++ b/chromium/third_party/openscreen/src/osp/public/network_service_manager.h @@ -40,25 +40,20 @@ class NetworkServiceManager final { // by the service instance destructors. static void Dispose(); - // Runs the event loop once for all of its owned services. This mostly - // consists of check for available network events and passing that data to the - // listening services. - void RunEventLoopOnce(); - - // Returns an instance of the mDNS receiver listener, or nullptr if - // not provided. + // Returns an instance of the mDNS receiver listener, or nullptr if not + // provided. ServiceListener* GetMdnsServiceListener(); - // Returns an instance of the mDNS receiver publisher, or nullptr - // if not provided. + // Returns an instance of the mDNS receiver publisher, or nullptr if not + // provided. ServicePublisher* GetMdnsServicePublisher(); - // Returns an instance of the protocol connection client, or nullptr - // if not provided. + // Returns an instance of the protocol connection client, or nullptr if not + // provided. ProtocolConnectionClient* GetProtocolConnectionClient(); - // Returns an instance of the protocol connection server, or nullptr if - // not provided. + // Returns an instance of the protocol connection server, or nullptr if not + // provided. ProtocolConnectionServer* GetProtocolConnectionServer(); private: diff --git a/chromium/third_party/openscreen/src/osp/public/protocol_connection_client.h b/chromium/third_party/openscreen/src/osp/public/protocol_connection_client.h index b53f74a6516..d4e47a9a11f 100644 --- a/chromium/third_party/openscreen/src/osp/public/protocol_connection_client.h +++ b/chromium/third_party/openscreen/src/osp/public/protocol_connection_client.h @@ -76,8 +76,6 @@ class ProtocolConnectionClient { // Returns true if state() != (kStopped|kStopping). virtual bool Stop() = 0; - virtual void RunTasks() = 0; - // Open a new connection to |endpoint|. This may succeed synchronously if // there are already connections open to |endpoint|, otherwise it will be // asynchronous. diff --git a/chromium/third_party/openscreen/src/osp/public/protocol_connection_client_factory.h b/chromium/third_party/openscreen/src/osp/public/protocol_connection_client_factory.h index f1190f77ae3..b2d16af8fe2 100644 --- a/chromium/third_party/openscreen/src/osp/public/protocol_connection_client_factory.h +++ b/chromium/third_party/openscreen/src/osp/public/protocol_connection_client_factory.h @@ -10,12 +10,16 @@ #include "osp/public/protocol_connection_client.h" namespace openscreen { +namespace platform { +class NetworkRunner; +} // namespace platform class ProtocolConnectionClientFactory { public: static std::unique_ptr<ProtocolConnectionClient> Create( MessageDemuxer* demuxer, - ProtocolConnectionServiceObserver* observer); + ProtocolConnectionServiceObserver* observer, + platform::NetworkRunner* network_runner); }; } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/osp/public/protocol_connection_server.h b/chromium/third_party/openscreen/src/osp/public/protocol_connection_server.h index 8eafe965ec6..2390c5587a0 100644 --- a/chromium/third_party/openscreen/src/osp/public/protocol_connection_server.h +++ b/chromium/third_party/openscreen/src/osp/public/protocol_connection_server.h @@ -65,8 +65,6 @@ class ProtocolConnectionServer { // connections. virtual bool Resume() = 0; - virtual void RunTasks() = 0; - // Synchronously open a new connection to an endpoint identified by // |endpoint_id|. Returns nullptr if it can't be completed synchronously // (e.g. there are no existing open connections to that endpoint). diff --git a/chromium/third_party/openscreen/src/osp/public/protocol_connection_server_factory.h b/chromium/third_party/openscreen/src/osp/public/protocol_connection_server_factory.h index 1a825926769..03fb58e3ce2 100644 --- a/chromium/third_party/openscreen/src/osp/public/protocol_connection_server_factory.h +++ b/chromium/third_party/openscreen/src/osp/public/protocol_connection_server_factory.h @@ -11,13 +11,17 @@ #include "osp/public/server_config.h" namespace openscreen { +namespace platform { +class NetworkRunner; +} class ProtocolConnectionServerFactory { public: static std::unique_ptr<ProtocolConnectionServer> Create( const ServerConfig& config, MessageDemuxer* demuxer, - ProtocolConnectionServer::Observer* observer); + ProtocolConnectionServer::Observer* observer, + platform::NetworkRunner* network_runner); }; } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/osp/public/service_listener.cc b/chromium/third_party/openscreen/src/osp/public/service_listener.cc index 97698631e21..914f407b8c2 100644 --- a/chromium/third_party/openscreen/src/osp/public/service_listener.cc +++ b/chromium/third_party/openscreen/src/osp/public/service_listener.cc @@ -20,7 +20,7 @@ ServiceListenerError& ServiceListenerError::operator=( ServiceListener::Metrics::Metrics() = default; ServiceListener::Metrics::~Metrics() = default; -ServiceListener::ServiceListener() = default; +ServiceListener::ServiceListener() : state_(State::kStopped) {} ServiceListener::~ServiceListener() = default; } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/osp/public/service_listener.h b/chromium/third_party/openscreen/src/osp/public/service_listener.h index 41d88604ecc..360436b5af3 100644 --- a/chromium/third_party/openscreen/src/osp/public/service_listener.h +++ b/chromium/third_party/openscreen/src/osp/public/service_listener.h @@ -5,6 +5,7 @@ #ifndef OSP_PUBLIC_SERVICE_LISTENER_H_ #define OSP_PUBLIC_SERVICE_LISTENER_H_ +#include <atomic> #include <cstdint> #include <string> #include <vector> @@ -133,8 +134,6 @@ class ServiceListener { // (kRunning|kSuspended). virtual bool SearchNow() = 0; - virtual void RunTasks() = 0; - virtual void AddObserver(Observer* observer) = 0; virtual void RemoveObserver(Observer* observer) = 0; @@ -150,7 +149,7 @@ class ServiceListener { protected: ServiceListener(); - State state_ = State::kStopped; + std::atomic<State> state_; ServiceListenerError last_error_; std::vector<Observer*> observers_; diff --git a/chromium/third_party/openscreen/src/osp/public/service_publisher.cc b/chromium/third_party/openscreen/src/osp/public/service_publisher.cc index b34940b65da..85de2fcd293 100644 --- a/chromium/third_party/openscreen/src/osp/public/service_publisher.cc +++ b/chromium/third_party/openscreen/src/osp/public/service_publisher.cc @@ -23,7 +23,8 @@ ServicePublisher::Metrics::~Metrics() = default; ServicePublisher::Config::Config() = default; ServicePublisher::Config::~Config() = default; -ServicePublisher::ServicePublisher(Observer* observer) : observer_(observer) {} +ServicePublisher::ServicePublisher(Observer* observer) + : state_(State::kStopped), observer_(observer) {} ServicePublisher::~ServicePublisher() = default; } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/osp/public/service_publisher.h b/chromium/third_party/openscreen/src/osp/public/service_publisher.h index d17e08ed484..0234c552254 100644 --- a/chromium/third_party/openscreen/src/osp/public/service_publisher.h +++ b/chromium/third_party/openscreen/src/osp/public/service_publisher.h @@ -5,6 +5,7 @@ #ifndef OSP_PUBLIC_SERVICE_PUBLISHER_H_ #define OSP_PUBLIC_SERVICE_PUBLISHER_H_ +#include <atomic> #include <cstdint> #include <string> #include <vector> @@ -134,8 +135,6 @@ class ServicePublisher { // Resumes publishing. Returns true if state() == kSuspended. virtual bool Resume() = 0; - virtual void RunTasks() = 0; - // Returns the current state of the publisher. State state() const { return state_; } @@ -145,7 +144,7 @@ class ServicePublisher { protected: explicit ServicePublisher(Observer* observer); - State state_ = State::kStopped; + std::atomic<State> state_; ServicePublisherError last_error_; Observer* observer_; diff --git a/chromium/third_party/openscreen/src/platform/BUILD.gn b/chromium/third_party/openscreen/src/platform/BUILD.gn index e891b1b236d..bc709b6d5f0 100644 --- a/chromium/third_party/openscreen/src/platform/BUILD.gn +++ b/chromium/third_party/openscreen/src/platform/BUILD.gn @@ -5,9 +5,9 @@ import("//build_overrides/build.gni") source_set("platform") { + defines = [] + sources = [ - "api/event_waiter.cc", - "api/event_waiter.h", "api/internal/trace_logging_internal.cc", "api/internal/trace_logging_internal.h", "api/internal/trace_logging_macros_internal.h", @@ -16,6 +16,7 @@ source_set("platform") { "api/network_interface.cc", "api/network_interface.h", "api/network_runner.h", + "api/network_runner_lifetime_manager.h", "api/network_waiter.h", "api/scoped_wake_lock.cc", "api/scoped_wake_lock.h", @@ -52,12 +53,12 @@ source_set("platform") { if (!build_with_chromium) { sources += [ - "impl/event_loop.cc", - "impl/event_loop.h", "impl/network_reader.cc", "impl/network_reader.h", "impl/network_runner.cc", "impl/network_runner.h", + "impl/network_runner_lifetime_manager.cc", + "impl/network_runner_lifetime_manager.h", "impl/task_runner.cc", "impl/task_runner.h", "impl/text_trace_logging_platform.cc", @@ -68,7 +69,7 @@ source_set("platform") { if (is_linux) { sources += [ "impl/network_interface_linux.cc" ] } else if (is_mac) { - defines = [ + defines += [ # Required, to use the new IPv6 Sockets options introduced by RFC 3542. "__APPLE_USE_RFC_3542", ] @@ -78,7 +79,6 @@ source_set("platform") { if (is_posix) { sources += [ - "impl/event_waiter_posix.cc", "impl/network_waiter_posix.cc", "impl/network_waiter_posix.h", "impl/scoped_pipe.h", @@ -100,10 +100,13 @@ source_set("test") { sources = [ "test/fake_clock.cc", "test/fake_clock.h", + "test/fake_network_runner.cc", + "test/fake_network_runner.h", "test/fake_task_runner.cc", "test/fake_task_runner.h", "test/mock_udp_socket.cc", "test/mock_udp_socket.h", + "test/trace_logging_helpers.h", ] configs += [ "../build:allow_build_from_embedder" ] @@ -142,17 +145,21 @@ source_set("platform_unittests") { "base/error_unittest.cc", "base/ip_address_unittest.cc", "base/location_unittest.cc", - "impl/ssl_context_unittest.cc", ] # The unit tests in impl/ assume the standalone implementation is being used. # Exclude them if an embedder is providing the implementation. if (!build_with_chromium) { sources += [ + # TODO(jophba): move over to general sources when UDP socket create + # is implemented in Chromium, as part of the NetworkRunner work. + "api/socket_integration_unittest.cc", "impl/network_reader_unittest.cc", + "impl/ssl_context_unittest.cc", "impl/task_runner_unittest.cc", "impl/time_unittest.cc", ] + if (is_posix) { sources += [ "impl/network_waiter_posix_unittest.cc", diff --git a/chromium/third_party/openscreen/src/platform/api/event_waiter.cc b/chromium/third_party/openscreen/src/platform/api/event_waiter.cc deleted file mode 100644 index 4a68de52255..00000000000 --- a/chromium/third_party/openscreen/src/platform/api/event_waiter.cc +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "platform/api/event_waiter.h" - -namespace openscreen { -namespace platform { - -Events::Events() = default; -Events::~Events() = default; -Events::Events(Events&& o) = default; -Events& Events::operator=(Events&& o) = default; - -} // namespace platform -} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/api/event_waiter.h b/chromium/third_party/openscreen/src/platform/api/event_waiter.h deleted file mode 100644 index ff4fead2067..00000000000 --- a/chromium/third_party/openscreen/src/platform/api/event_waiter.h +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef PLATFORM_API_EVENT_WAITER_H_ -#define PLATFORM_API_EVENT_WAITER_H_ - -#include <vector> - -#include "platform/api/time.h" -#include "platform/api/udp_socket.h" -#include "platform/base/error.h" - -namespace openscreen { -namespace platform { - -struct UdpSocketReadableEvent { - UdpSocket* socket; -}; - -struct UdpSocketWritableEvent { - UdpSocket* socket; -}; - -struct EventWaiterPrivate; -using EventWaiterPtr = EventWaiterPrivate*; - -// This struct represents a set of events associated with a particular -// EventWaiterPtr and is created by WaitForEvents. Any combination and number -// of events may be present, depending on how the platform implements event -// waiting and what has occured since the last WaitForEvents call. -struct Events { - Events(); - ~Events(); - Events(Events&& o); - Events& operator=(Events&& o); - - std::vector<UdpSocketReadableEvent> udp_readable_events; - std::vector<UdpSocketWritableEvent> udp_writable_events; -}; - -// TODO(miu): This should be a std::unique_ptr<> instead of two separate -// methods, so that code structure auto-scopes the lifetime of the instance. -EventWaiterPtr CreateEventWaiter(); -void DestroyEventWaiter(EventWaiterPtr waiter); - -Error WatchUdpSocketReadable(EventWaiterPtr waiter, UdpSocket* socket); -Error StopWatchingUdpSocketReadable(EventWaiterPtr waiter, UdpSocket* socket); - -Error WatchUdpSocketWritable(EventWaiterPtr waiter, UdpSocket* socket); -Error StopWatchingUdpSocketWritable(EventWaiterPtr waiter, UdpSocket* socket); - -Error WatchNetworkChange(EventWaiterPtr waiter); -Error StopWatchingNetworkChange(EventWaiterPtr waiter); - -// Returns the number of events that were added to |events| if there were any, 0 -// if there were no events, and -1 if an error occurred. -ErrorOr<Events> WaitForEvents(EventWaiterPtr waiter); - -} // namespace platform -} // namespace openscreen - -#endif // PLATFORM_API_EVENT_WAITER_H_ diff --git a/chromium/third_party/openscreen/src/platform/api/internal/trace_logging_internal.cc b/chromium/third_party/openscreen/src/platform/api/internal/trace_logging_internal.cc index 36230fd984b..2618b41551a 100644 --- a/chromium/third_party/openscreen/src/platform/api/internal/trace_logging_internal.cc +++ b/chromium/third_party/openscreen/src/platform/api/internal/trace_logging_internal.cc @@ -107,34 +107,24 @@ TraceLoggerBase::TraceLoggerBase(TraceCategory::Value category, ids.root) {} SynchronousTraceLogger::~SynchronousTraceLogger() { - // If this object has an instance variable platform, use that. Otherwise, - // use the static variable for the shared class. In practice, the instance - // variable should only be set when testing, so branch prediction will - // always pick the correct path in production code and it should be of - // negligable cost. auto* current_platform = TraceLoggingPlatform::GetDefaultTracingPlatform(); if (current_platform == nullptr) { return; } auto end_time = Clock::now(); current_platform->LogTrace(this->name_, this->line_number_, this->file_name_, - this->start_time_, end_time, this->trace_id_, - this->parent_id_, this->root_id_, this->result_); + this->start_time_, end_time, this->to_hierarchy(), + this->result_); } AsynchronousTraceLogger::~AsynchronousTraceLogger() { - // If this object has an instance variable platform, use that. Otherwise, - // use the static variable for the shared class. In practice, the instance - // variable should only be set when testing, so branch prediction will - // always pick the correct path in production code and it should be of - // negligable cost. auto* current_platform = TraceLoggingPlatform::GetDefaultTracingPlatform(); if (current_platform == nullptr) { return; } - current_platform->LogAsyncStart( - this->name_, this->line_number_, this->file_name_, this->start_time_, - this->trace_id_, this->parent_id_, this->root_id_); + current_platform->LogAsyncStart(this->name_, this->line_number_, + this->file_name_, this->start_time_, + this->to_hierarchy()); } TraceIdSetter::~TraceIdSetter() = default; diff --git a/chromium/third_party/openscreen/src/platform/api/internal/trace_logging_internal.h b/chromium/third_party/openscreen/src/platform/api/internal/trace_logging_internal.h index a4d62a24996..d1682648ba7 100644 --- a/chromium/third_party/openscreen/src/platform/api/internal/trace_logging_internal.h +++ b/chromium/third_party/openscreen/src/platform/api/internal/trace_logging_internal.h @@ -10,6 +10,7 @@ #include <stack> #include <vector> +#include "platform/api/logging.h" #include "platform/api/time.h" #include "platform/api/trace_logging_platform.h" #include "platform/api/trace_logging_types.h" @@ -58,12 +59,10 @@ class ScopedTraceOperation : public TraceBase { static TraceIdHierarchy hierarchy() { if (traces_ == nullptr) { - return {kEmptyTraceId, kEmptyTraceId, kEmptyTraceId}; + return TraceIdHierarchy::Empty(); } - auto* top_of_stack = traces_->top(); - return {top_of_stack->trace_id_, top_of_stack->parent_id_, - top_of_stack->root_id_}; + return traces_->top()->to_hierarchy(); } // Static method to set the result of the most recent trace. @@ -92,6 +91,8 @@ class ScopedTraceOperation : public TraceBase { TraceId parent_id_; TraceId root_id_; + TraceIdHierarchy to_hierarchy() { return {trace_id_, parent_id_, root_id_}; } + private: // NOTE: A std::vector is used for backing the stack because it provides the // best perf. Further perf improvement could be achieved later by swapping @@ -131,7 +132,7 @@ class TraceLoggerBase : public ScopedTraceOperation { protected: // Set the result. - void SetTraceResult(Error::Code error) { result_ = error; } + void SetTraceResult(Error::Code error) override { result_ = error; } // Timestamp for when the object was created. Clock::time_point start_time_; diff --git a/chromium/third_party/openscreen/src/platform/api/internal/trace_logging_internal_unittest.cc b/chromium/third_party/openscreen/src/platform/api/internal/trace_logging_internal_unittest.cc index 94552b46f94..b8e65079fbc 100644 --- a/chromium/third_party/openscreen/src/platform/api/internal/trace_logging_internal_unittest.cc +++ b/chromium/third_party/openscreen/src/platform/api/internal/trace_logging_internal_unittest.cc @@ -10,6 +10,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" #include "platform/api/trace_logging.h" +#include "platform/test/trace_logging_helpers.h" // TODO(issue/52): Remove duplicate code from trace logging+internal unit tests namespace openscreen { @@ -20,67 +21,9 @@ using ::testing::_; using ::testing::DoAll; using ::testing::Invoke; -class MockLoggingPlatform : public TraceLoggingPlatform { - public: - MOCK_METHOD9(LogTrace, - void(const char*, - const uint32_t, - const char* file, - Clock::time_point, - Clock::time_point, - TraceId, - TraceId, - TraceId, - Error::Code)); - MOCK_METHOD7(LogAsyncStart, - void(const char*, - const uint32_t, - const char* file, - Clock::time_point, - TraceId, - TraceId, - TraceId)); - MOCK_METHOD5(LogAsyncEnd, - void(const uint32_t, - const char* file, - Clock::time_point, - TraceId, - Error::Code)); -}; - -// Methods to validate the results of platform-layer calls. -template <uint64_t milliseconds> -void ValidateTraceTimestampDiff(const char* name, - const uint32_t line, - const char* file, - Clock::time_point start_time, - Clock::time_point end_time, - TraceId trace_id, - TraceId parent_id, - TraceId root_id, - Error error) { - const auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>( - end_time - start_time); - EXPECT_GE(static_cast<uint64_t>(elapsed.count()), milliseconds); -} - -template <Error::Code result> -void ValidateTraceErrorCode(const char* name, - const uint32_t line, - const char* file, - Clock::time_point start_time, - Clock::time_point end_time, - TraceId trace_id, - TraceId parent_id, - TraceId root_id, - Error error) { - EXPECT_EQ(error.code(), result); -} - // These tests validate that parameters are passed correctly by using the Trace // Internals. -constexpr TraceCategory::Value category = - TraceCategory::Value::CastPlatformLayer; +constexpr TraceCategory::Value category = TraceCategory::mDNS; constexpr uint32_t line = 10; TEST(TraceLoggingInternalTest, CreatingNoTraceObjectValid) { @@ -91,7 +34,7 @@ TEST(TraceLoggingInternalTest, TestMacroStyleInitializationTrue) { constexpr uint32_t delay_in_ms = 50; MockLoggingPlatform platform; TRACE_SET_DEFAULT_PLATFORM(&platform); - EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _, _, _)) + EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _)) .Times(1) .WillOnce(DoAll(Invoke(ValidateTraceTimestampDiff<delay_in_ms>), Invoke(ValidateTraceErrorCode<Error::Code::kNone>))); @@ -114,7 +57,7 @@ TEST(TraceLoggingInternalTest, TestMacroStyleInitializationTrue) { TEST(TraceLoggingInternalTest, TestMacroStyleInitializationFalse) { MockLoggingPlatform platform; TRACE_SET_DEFAULT_PLATFORM(&platform); - EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _, _, _)).Times(0); + EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _)).Times(0); { uint8_t temp[sizeof(SynchronousTraceLogger)]; @@ -136,7 +79,7 @@ TEST(TraceLoggingInternalTest, ExpectParametersPassedToResult) { MockLoggingPlatform platform; TRACE_SET_DEFAULT_PLATFORM(&platform); EXPECT_CALL(platform, LogTrace(testing::StrEq("Name"), line, - testing::StrEq(__FILE__), _, _, _, _, _, _)) + testing::StrEq(__FILE__), _, _, _, _)) .WillOnce(Invoke(ValidateTraceErrorCode<Error::Code::kNone>)); { @@ -150,7 +93,7 @@ TEST(TraceLoggingInternalTest, CheckTraceAsyncStartLogsCorrectly) { MockLoggingPlatform platform; TRACE_SET_DEFAULT_PLATFORM(&platform); EXPECT_CALL(platform, LogAsyncStart(testing::StrEq("Name"), line, - testing::StrEq(__FILE__), _, _, _, _)) + testing::StrEq(__FILE__), _, _)) .Times(1); { AsynchronousTraceLogger{category, "Name", __FILE__, line}; } diff --git a/chromium/third_party/openscreen/src/platform/api/internal/trace_logging_macros_internal.h b/chromium/third_party/openscreen/src/platform/api/internal/trace_logging_macros_internal.h index 70ef2dfb6fb..8ab4ba90434 100644 --- a/chromium/third_party/openscreen/src/platform/api/internal/trace_logging_macros_internal.h +++ b/chromium/third_party/openscreen/src/platform/api/internal/trace_logging_macros_internal.h @@ -8,6 +8,13 @@ #include "platform/api/internal/trace_logging_internal.h" #include "platform/api/trace_logging_types.h" +// Using statements to simplify the below macros. +using openscreen::TraceCategory; +using openscreen::platform::internal::AsynchronousTraceLogger; +using openscreen::platform::internal::SynchronousTraceLogger; +using openscreen::platform::internal::TraceIdSetter; +using openscreen::platform::internal::TraceInstanceHelper; + namespace openscreen { // Helper macros. These are used to simplify the macros below. @@ -30,21 +37,18 @@ namespace openscreen { #define TRACE_INTERNAL_IGNORE_UNUSED_VAR [[maybe_unused]] #endif // defined(__clang__) -// Define a macro to check if tracing is enabled so that unit tests don't break -// when it is not. +// Define a macro to check if tracing is enabled or not for testing and +// compilation reasons. #ifndef TRACE_FORCE_ENABLE #define TRACE_IS_ENABLED(category) \ openscreen::platform::IsTraceLoggingEnabled(TraceCategory::Value::Any) -#else +#ifndef ENABLE_TRACE_LOGGING +#define TRACE_FORCE_DISABLE true +#endif // ENABLE_TRACE_LOGGING +#else // TRACE_FORCE_ENABLE defined #define TRACE_IS_ENABLED(category) true #endif -// Using statements to simplify the below macros. -using openscreen::platform::internal::AsynchronousTraceLogger; -using openscreen::platform::internal::SynchronousTraceLogger; -using openscreen::platform::internal::TraceIdSetter; -using openscreen::platform::internal::TraceInstanceHelper; - // Internal logging macros. #define TRACE_SET_HIERARCHY_INTERNAL(line, ids) \ alignas(32) uint8_t TRACE_INTERNAL_CONCAT_CONST( \ diff --git a/chromium/third_party/openscreen/src/platform/api/network_runner.h b/chromium/third_party/openscreen/src/platform/api/network_runner.h index 7a1742fcee6..67abb5116c9 100644 --- a/chromium/third_party/openscreen/src/platform/api/network_runner.h +++ b/chromium/third_party/openscreen/src/platform/api/network_runner.h @@ -47,12 +47,10 @@ class NetworkRunner : public TaskRunner { virtual Error ReadRepeatedly(UdpSocket* socket, UdpReadCallback* callback) = 0; - // Cancels any pending wait on reading |socket|. Returns false only if the - // socket was not yet being watched, and true if the operation is successful - // and the socket is no longer watched. - // TODO(rwkeane): Make this return void and either include a DCHECK inside of - // the implementation or allow failure with no return code. - virtual bool CancelRead(UdpSocket* socket) = 0; + // Cancels any pending wait on reading |socket|. Returns Error::Code::kNone if + // the operation is successful and the socket is no longer watched, returns an + // error otherwise. + virtual Error CancelRead(UdpSocket* socket) = 0; }; } // namespace platform diff --git a/chromium/third_party/openscreen/src/platform/api/network_runner_lifetime_manager.h b/chromium/third_party/openscreen/src/platform/api/network_runner_lifetime_manager.h new file mode 100644 index 00000000000..1fce71997ee --- /dev/null +++ b/chromium/third_party/openscreen/src/platform/api/network_runner_lifetime_manager.h @@ -0,0 +1,34 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file + +#ifndef PLATFORM_API_NETWORK_RUNNER_LIFETIME_MANAGER_H_ +#define PLATFORM_API_NETWORK_RUNNER_LIFETIME_MANAGER_H_ + +#include "platform/api/network_runner.h" +#include "platform/api/task_runner.h" + +namespace openscreen { +namespace platform { + +class NetworkRunnerLifetimeManager { + public: + virtual ~NetworkRunnerLifetimeManager() = default; + + // Creates a new NetworkRunnerLifetimeManager + // NOTE: The platform must implement this method if + // NetworkRunnerLifetimeManager is to be used. + static std::unique_ptr<NetworkRunnerLifetimeManager> Create(); + + // Creates the NetworkRunner for this factory instance. This method must be + // called no more than once. + virtual void CreateNetworkRunner() = 0; + + // Returns a pointer to the NetworkRunner instance owned by this factory. + virtual NetworkRunner* Get() = 0; +}; + +} // namespace platform +} // namespace openscreen + +#endif // PLATFORM_API_NETWORK_RUNNER_LIFETIME_MANAGER_H_ diff --git a/chromium/third_party/openscreen/src/platform/api/socket_integration_unittest.cc b/chromium/third_party/openscreen/src/platform/api/socket_integration_unittest.cc new file mode 100644 index 00000000000..4d469a926b5 --- /dev/null +++ b/chromium/third_party/openscreen/src/platform/api/socket_integration_unittest.cc @@ -0,0 +1,51 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "gtest/gtest.h" +#include "platform/api/time.h" +#include "platform/test/fake_clock.h" +#include "platform/test/fake_task_runner.h" +#include "platform/test/mock_udp_socket.h" + +namespace openscreen { +namespace platform { + +// Tests that a UdpSocket that does not specify any address or port will +// successfully Bind(), and that the operating system will return the +// auto-assigned socket name (i.e., the local endpoint's port will not be zero). +TEST(SocketIntegrationTest, ResolvesLocalEndpoint_IPv4) { + const uint8_t kIpV4AddrAny[4] = {}; + FakeClock clock(Clock::now()); + FakeTaskRunner task_runner(&clock); + MockUdpSocket::MockClient client; + ErrorOr<UdpSocketUniquePtr> create_result = UdpSocket::Create( + &task_runner, &client, IPEndpoint{IPAddress(kIpV4AddrAny), 0}); + ASSERT_TRUE(create_result) << create_result.error(); + const auto socket = create_result.MoveValue(); + const Error bind_result = socket->Bind(); + ASSERT_TRUE(bind_result.ok()) << bind_result; + const IPEndpoint local_endpoint = socket->GetLocalEndpoint(); + EXPECT_NE(local_endpoint.port, 0) << local_endpoint; +} + +// Tests that a UdpSocket that does not specify any address or port will +// successfully Bind(), and that the operating system will return the +// auto-assigned socket name (i.e., the local endpoint's port will not be zero). +TEST(SocketIntegrationTest, ResolvesLocalEndpoint_IPv6) { + const uint8_t kIpV6AddrAny[16] = {}; + FakeClock clock(Clock::now()); + FakeTaskRunner task_runner(&clock); + MockUdpSocket::MockClient client; + ErrorOr<UdpSocketUniquePtr> create_result = UdpSocket::Create( + &task_runner, &client, IPEndpoint{IPAddress(kIpV6AddrAny), 0}); + ASSERT_TRUE(create_result) << create_result.error(); + const auto socket = create_result.MoveValue(); + const Error bind_result = socket->Bind(); + ASSERT_TRUE(bind_result.ok()) << bind_result; + const IPEndpoint local_endpoint = socket->GetLocalEndpoint(); + EXPECT_NE(local_endpoint.port, 0) << local_endpoint; +} + +} // namespace platform +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/api/tls_socket.h b/chromium/third_party/openscreen/src/platform/api/tls_socket.h index 36fe249d240..f3524034bb5 100644 --- a/chromium/third_party/openscreen/src/platform/api/tls_socket.h +++ b/chromium/third_party/openscreen/src/platform/api/tls_socket.h @@ -37,7 +37,7 @@ struct TlsSocketMessage { class TlsSocket { public: - class Delegate { + class Client { public: // Provides a unique ID for use by the TlsSocketFactory. virtual const std::string& GetNewSocketId() = 0; @@ -48,12 +48,15 @@ class TlsSocket { // Called when |socket| is closed. virtual void OnClosed(TlsSocket* socket) = 0; + // Called when |socket| experiences an error, such as a read error. + virtual void OnError(TlsSocket* socket, Error error) = 0; + // Called when a |message| arrives on |socket|. virtual void OnMessage(TlsSocket* socket, const TlsSocketMessage& message) = 0; protected: - virtual ~Delegate() = default; + virtual ~Client() = default; }; enum CloseReason { @@ -68,17 +71,14 @@ class TlsSocket { static ErrorOr<TlsSocketUniquePtr> Create(IPAddress::Version version); // Returns true if |socket| belongs to the IPv4/IPv6 address family. - bool IsIPv4() const; - bool IsIPv6() const; + virtual bool IsIPv4() const = 0; + virtual bool IsIPv6() const = 0; // Closes this socket. Delegate::OnClosed is called when complete. virtual void Close(CloseReason reason) = 0; - // Start reading data. Delegate::OnMessage is called when new data arrives. - virtual Error Read() = 0; - - // Sends a message and returns the number of bytes sent, on success. - virtual Error SendMessage(const TlsSocketMessage& message) = 0; + // Sends a message. + virtual void Write(const TlsSocketMessage& message) = 0; // Returns the unique identifier of the factory that created this socket. virtual const std::string& GetFactoryId() const = 0; @@ -87,14 +87,14 @@ class TlsSocket { const std::string& id() const { return id_; } protected: - Delegate* delegate() const { return delegate_; } + Client* client() const { return client_; } - explicit TlsSocket(Delegate* delegate) {} + explicit TlsSocket(Client* client) : client_(client) {} virtual ~TlsSocket() = 0; private: const std::string id_; - Delegate* const delegate_; + Client* const client_; OSP_DISALLOW_COPY_AND_ASSIGN(TlsSocket); }; @@ -111,7 +111,7 @@ class TlsSocketFactory { // Gets the local address, if set, otherwise nullptr. virtual IPEndpoint* GetLocalAddress() = 0; - // Start accepting new sockets. Should call Delegate::OnAccepted(). + // Start accepting new sockets. Should call Client::OnAccepted(). virtual void Accept() = 0; // Stop accepting new sockets. @@ -121,11 +121,12 @@ class TlsSocketFactory { virtual void SetCredentials(const TlsSocketCreds& creds) = 0; protected: - virtual TlsSocket::Delegate* GetDelegate() const = 0; + virtual TlsSocket::Client* GetClient() const = 0; private: OSP_DISALLOW_COPY_AND_ASSIGN(TlsSocketFactory); -} +}; + } // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/api/trace_logging.h b/chromium/third_party/openscreen/src/platform/api/trace_logging.h index 887f1c95208..8ab4e139642 100644 --- a/chromium/third_party/openscreen/src/platform/api/trace_logging.h +++ b/chromium/third_party/openscreen/src/platform/api/trace_logging.h @@ -23,9 +23,9 @@ namespace openscreen { // // Further details about how these macros are used can be found in // docs/trace_logging.md. -// TODO(rwkeane): Add option to compile these macros out entirely. // TODO(rwkeane): Add support for user-provided properties. +#ifndef TRACE_FORCE_DISABLE #define TRACE_SET_RESULT(result) \ do { \ if (TRACE_IS_ENABLED(TraceCategory::Value::Any)) { \ @@ -61,6 +61,17 @@ namespace openscreen { __LINE__, __FILE__, id, result) \ : false +#else // TRACE_FORCE_DISABLE defined +#define TRACE_SET_RESULT(result) +#define TRACE_SET_HIERARCHY(ids) +#define TRACE_HIERARCHY TraceIdHierarchy::Empty() +#define TRACE_CURRENT_ID kEmptyTraceId +#define TRACE_ROOT_ID kEmptyTraceId +#define TRACE_SCOPED(category, name, ...) +#define TRACE_ASYNC_START(category, name, ...) +#define TRACE_ASYNC_END(category, id, result) false +#endif // TRACE_FORCE_DISABLE + } // namespace openscreen #endif // PLATFORM_API_TRACE_LOGGING_H_ diff --git a/chromium/third_party/openscreen/src/platform/api/trace_logging_platform.h b/chromium/third_party/openscreen/src/platform/api/trace_logging_platform.h index e421a63b583..a65b787f461 100644 --- a/chromium/third_party/openscreen/src/platform/api/trace_logging_platform.h +++ b/chromium/third_party/openscreen/src/platform/api/trace_logging_platform.h @@ -35,9 +35,7 @@ class TraceLoggingPlatform { const char* file, Clock::time_point start_time, Clock::time_point end_time, - TraceId trace_id, - TraceId parent_id, - TraceId root_id, + TraceIdHierarchy ids, Error::Code error) = 0; // Log an asynchronous trace start. @@ -45,9 +43,7 @@ class TraceLoggingPlatform { const uint32_t line, const char* file, Clock::time_point timestamp, - TraceId trace_id, - TraceId parent_id, - TraceId root_id) = 0; + TraceIdHierarchy ids) = 0; // Log an asynchronous trace end. virtual void LogAsyncEnd(const uint32_t line, diff --git a/chromium/third_party/openscreen/src/platform/api/trace_logging_types.h b/chromium/third_party/openscreen/src/platform/api/trace_logging_types.h index 0ec11d38574..59b61a6a02c 100644 --- a/chromium/third_party/openscreen/src/platform/api/trace_logging_types.h +++ b/chromium/third_party/openscreen/src/platform/api/trace_logging_types.h @@ -35,6 +35,15 @@ struct TraceIdHierarchy { bool HasParent() { return parent != kUnsetTraceId; } bool HasRoot() { return root != kUnsetTraceId; } }; +inline bool operator==(const TraceIdHierarchy& lhs, + const TraceIdHierarchy& rhs) { + return lhs.current == rhs.current && lhs.parent == rhs.parent && + lhs.root == rhs.root; +} +inline bool operator!=(const TraceIdHierarchy& lhs, + const TraceIdHierarchy& rhs) { + return !(lhs == rhs); +} // BitFlags to represent the supported tracing categories. // NOTE: These are currently placeholder values and later changes should feel @@ -42,9 +51,9 @@ struct TraceIdHierarchy { struct TraceCategory { enum Value : uint64_t { Any = std::numeric_limits<uint64_t>::max(), - CastPlatformLayer = 0x01, - CastStreaming = 0x01 << 1, - CastFlinging = 0x01 << 2 + mDNS = 0x01 << 0, + Quic = 0x01 << 1, + Presentation = 0x01 << 2, }; }; diff --git a/chromium/third_party/openscreen/src/platform/api/trace_logging_unittest.cc b/chromium/third_party/openscreen/src/platform/api/trace_logging_unittest.cc index 9f83b7b558a..05121f5e335 100644 --- a/chromium/third_party/openscreen/src/platform/api/trace_logging_unittest.cc +++ b/chromium/third_party/openscreen/src/platform/api/trace_logging_unittest.cc @@ -3,7 +3,6 @@ // found in the LICENSE file. #include <chrono> -#include <iostream> #include <thread> #include "absl/types/optional.h" @@ -13,91 +12,43 @@ #define TRACE_FORCE_ENABLE true #include "platform/api/trace_logging.h" +#include "platform/test/trace_logging_helpers.h" // TODO(issue/52): Remove duplicate code from trace logging+internal unit tests namespace openscreen { namespace platform { +namespace { +constexpr TraceHierarchyParts kAllParts = static_cast<TraceHierarchyParts>( + TraceHierarchyParts::kRoot | TraceHierarchyParts::kParent | + TraceHierarchyParts::kCurrent); +constexpr TraceHierarchyParts kParentAndRoot = static_cast<TraceHierarchyParts>( + TraceHierarchyParts::kRoot | TraceHierarchyParts::kParent); +constexpr TraceId kEmptyId = TraceId{0}; +} // namespace using ::testing::_; using ::testing::DoAll; using ::testing::Invoke; -class MockLoggingPlatform : public TraceLoggingPlatform { - public: - MOCK_METHOD9(LogTrace, - void(const char*, - const uint32_t, - const char* file, - Clock::time_point, - Clock::time_point, - TraceId, - TraceId, - TraceId, - Error::Code)); - MOCK_METHOD7(LogAsyncStart, - void(const char*, - const uint32_t, - const char* file, - Clock::time_point, - TraceId, - TraceId, - TraceId)); - MOCK_METHOD5(LogAsyncEnd, - void(const uint32_t, - const char* file, - Clock::time_point, - TraceId, - Error::Code)); -}; - -// Methods to validate the results of platform-layer calls. -template <uint64_t milliseconds> -void ValidateTraceTimestampDiff(const char* name, - const uint32_t line, - const char* file, - Clock::time_point start_time, - Clock::time_point end_time, - TraceId trace_id, - TraceId parent_id, - TraceId root_id, - Error error) { - const auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>( - end_time - start_time); - EXPECT_GE(static_cast<uint64_t>(elapsed.count()), milliseconds); -} - -template <Error::Code result> -void ValidateTraceErrorCode(const char* name, - const uint32_t line, - const char* file, - Clock::time_point start_time, - Clock::time_point end_time, - TraceId trace_id, - TraceId parent_id, - TraceId root_id, - Error error) { - EXPECT_EQ(error.code(), result); -} - TEST(TraceLoggingTest, MacroCallScopedDoesntSegFault) { MockLoggingPlatform platform; TRACE_SET_DEFAULT_PLATFORM(&platform); - EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _, _, _)).Times(1); + EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _)).Times(1); { TRACE_SCOPED(TraceCategory::Value::Any, "test"); } } TEST(TraceLoggingTest, MacroCallUnscopedDoesntSegFault) { MockLoggingPlatform platform; TRACE_SET_DEFAULT_PLATFORM(&platform); - EXPECT_CALL(platform, LogAsyncStart(_, _, _, _, _, _, _)).Times(1); + EXPECT_CALL(platform, LogAsyncStart(_, _, _, _, _)).Times(1); { TRACE_ASYNC_START(TraceCategory::Value::Any, "test"); } } TEST(TraceLoggingTest, MacroVariablesUniquelyNames) { MockLoggingPlatform platform; TRACE_SET_DEFAULT_PLATFORM(&platform); - EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _, _, _)).Times(2); - EXPECT_CALL(platform, LogAsyncStart(_, _, _, _, _, _, _)).Times(2); + EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _)).Times(2); + EXPECT_CALL(platform, LogAsyncStart(_, _, _, _, _)).Times(2); { TRACE_SCOPED(TraceCategory::Value::Any, "test1"); @@ -111,7 +62,7 @@ TEST(TraceLoggingTest, ExpectTimestampsReflectDelay) { constexpr uint32_t delay_in_ms = 50; MockLoggingPlatform platform; TRACE_SET_DEFAULT_PLATFORM(&platform); - EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _, _, _)) + EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _)) .WillOnce(DoAll(Invoke(ValidateTraceTimestampDiff<delay_in_ms>), Invoke(ValidateTraceErrorCode<Error::Code::kNone>))); @@ -125,7 +76,7 @@ TEST(TraceLoggingTest, ExpectErrorsPassedToResult) { constexpr Error::Code result_code = Error::Code::kParseError; MockLoggingPlatform platform; TRACE_SET_DEFAULT_PLATFORM(&platform); - EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _, _, _)) + EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _)) .WillOnce(Invoke(ValidateTraceErrorCode<result_code>)); { @@ -137,7 +88,7 @@ TEST(TraceLoggingTest, ExpectErrorsPassedToResult) { TEST(TraceLoggingTest, ExpectUnsetTraceIdNotSet) { MockLoggingPlatform platform; TRACE_SET_DEFAULT_PLATFORM(&platform); - EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _, _, _)).Times(1); + EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _)).Times(1); TraceIdHierarchy h = {kUnsetTraceId, kUnsetTraceId, kUnsetTraceId}; { @@ -156,8 +107,11 @@ TEST(TraceLoggingTest, ExpectCreationWithIdsToWork) { constexpr TraceId root = 0x84; MockLoggingPlatform platform; TRACE_SET_DEFAULT_PLATFORM(&platform); - EXPECT_CALL(platform, LogTrace(_, _, _, _, _, current, parent, root, _)) - .WillOnce(Invoke(ValidateTraceErrorCode<Error::Code::kNone>)); + EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _)) + .WillOnce( + DoAll(Invoke(ValidateTraceErrorCode<Error::Code::kNone>), + Invoke(ValidateTraceIdHierarchyOnSyncTrace<current, parent, + root, kAllParts>))); { TraceIdHierarchy h = {current, parent, root}; @@ -179,10 +133,15 @@ TEST(TraceLoggingTest, ExpectHirearchyToBeApplied) { constexpr TraceId root = 0x84; MockLoggingPlatform platform; TRACE_SET_DEFAULT_PLATFORM(&platform); - EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, current, root, _)) - .WillOnce(Invoke(ValidateTraceErrorCode<Error::Code::kNone>)); - EXPECT_CALL(platform, LogTrace(_, _, _, _, _, current, parent, root, _)) - .WillOnce(Invoke(ValidateTraceErrorCode<Error::Code::kNone>)); + EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _)) + .WillOnce(DoAll( + Invoke(ValidateTraceErrorCode<Error::Code::kNone>), + Invoke(ValidateTraceIdHierarchyOnSyncTrace<kEmptyId, current, root, + kParentAndRoot>))) + .WillOnce( + DoAll(Invoke(ValidateTraceErrorCode<Error::Code::kNone>), + Invoke(ValidateTraceIdHierarchyOnSyncTrace<current, parent, + root, kAllParts>))); { TraceIdHierarchy h = {current, parent, root}; @@ -206,8 +165,11 @@ TEST(TraceLoggingTest, ExpectHirearchyToEndAfterScopeWhenSetWithSetter) { constexpr TraceId root = 0x84; MockLoggingPlatform platform; TRACE_SET_DEFAULT_PLATFORM(&platform); - EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, current, root, _)) - .WillOnce(Invoke(ValidateTraceErrorCode<Error::Code::kNone>)); + EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _)) + .WillOnce(DoAll( + Invoke(ValidateTraceErrorCode<Error::Code::kNone>), + Invoke(ValidateTraceIdHierarchyOnSyncTrace<kEmptyId, current, root, + kParentAndRoot>))); { TraceIdHierarchy ids = {current, parent, root}; @@ -233,10 +195,15 @@ TEST(TraceLoggingTest, ExpectHirearchyToEndAfterScope) { constexpr TraceId root = 0x84; MockLoggingPlatform platform; TRACE_SET_DEFAULT_PLATFORM(&platform); - EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, current, root, _)) - .WillOnce(Invoke(ValidateTraceErrorCode<Error::Code::kNone>)); - EXPECT_CALL(platform, LogTrace(_, _, _, _, _, current, parent, root, _)) - .WillOnce(Invoke(ValidateTraceErrorCode<Error::Code::kNone>)); + EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _)) + .WillOnce(DoAll( + Invoke(ValidateTraceErrorCode<Error::Code::kNone>), + Invoke(ValidateTraceIdHierarchyOnSyncTrace<kEmptyId, current, root, + kParentAndRoot>))) + .WillOnce( + DoAll(Invoke(ValidateTraceErrorCode<Error::Code::kNone>), + Invoke(ValidateTraceIdHierarchyOnSyncTrace<current, parent, + root, kAllParts>))); { TraceIdHierarchy ids = {current, parent, root}; @@ -262,8 +229,11 @@ TEST(TraceLoggingTest, ExpectSetHierarchyToApply) { constexpr TraceId root = 0x84; MockLoggingPlatform platform; TRACE_SET_DEFAULT_PLATFORM(&platform); - EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, current, root, _)) - .WillOnce(Invoke(ValidateTraceErrorCode<Error::Code::kNone>)); + EXPECT_CALL(platform, LogTrace(_, _, _, _, _, _, _)) + .WillOnce(DoAll( + Invoke(ValidateTraceErrorCode<Error::Code::kNone>), + Invoke(ValidateTraceIdHierarchyOnSyncTrace<kEmptyId, current, root, + kParentAndRoot>))); { TraceIdHierarchy ids = {current, parent, root}; @@ -284,7 +254,7 @@ TEST(TraceLoggingTest, ExpectSetHierarchyToApply) { TEST(TraceLoggingTest, CheckTraceAsyncStartLogsCorrectly) { MockLoggingPlatform platform; TRACE_SET_DEFAULT_PLATFORM(&platform); - EXPECT_CALL(platform, LogAsyncStart(_, _, _, _, _, _, _)).Times(1); + EXPECT_CALL(platform, LogAsyncStart(_, _, _, _, _)).Times(1); { TRACE_ASYNC_START(TraceCategory::Value::Any, "Name"); } } @@ -295,7 +265,10 @@ TEST(TraceLoggingTest, CheckTraceAsyncStartSetsHierarchy) { constexpr TraceId root = 84; MockLoggingPlatform platform; TRACE_SET_DEFAULT_PLATFORM(&platform); - EXPECT_CALL(platform, LogAsyncStart(_, _, _, _, _, current, root)).Times(1); + EXPECT_CALL(platform, LogAsyncStart(_, _, _, _, _)) + .WillOnce( + Invoke(ValidateTraceIdHierarchyOnAsyncTrace<kEmptyId, current, root, + kParentAndRoot>)); { TraceIdHierarchy ids = {current, parent, root}; diff --git a/chromium/third_party/openscreen/src/platform/api/udp_packet.h b/chromium/third_party/openscreen/src/platform/api/udp_packet.h index ad3b18f5124..41877ac61f0 100644 --- a/chromium/third_party/openscreen/src/platform/api/udp_packet.h +++ b/chromium/third_party/openscreen/src/platform/api/udp_packet.h @@ -23,8 +23,9 @@ class UdpPacket : public std::vector<uint8_t> { explicit UdpPacket(size_t size) : std::vector<uint8_t>(size) { OSP_DCHECK(size <= kUdpMaxPacketSize); } - UdpPacket() : UdpPacket(0) {} + UdpPacket(UdpPacket&& other) = default; + UdpPacket& operator=(UdpPacket&& other) = default; const IPEndpoint& source() const { return source_; } void set_source(IPEndpoint endpoint) { source_ = std::move(endpoint); } @@ -41,9 +42,11 @@ class UdpPacket : public std::vector<uint8_t> { IPEndpoint source_ = {}; IPEndpoint destination_ = {}; UdpSocket* socket_ = nullptr; + + OSP_DISALLOW_COPY_AND_ASSIGN(UdpPacket); }; } // namespace platform } // namespace openscreen -#endif // PLATFORM_API_UDP_PACKET_H_
\ No newline at end of file +#endif // PLATFORM_API_UDP_PACKET_H_ diff --git a/chromium/third_party/openscreen/src/platform/api/udp_socket.cc b/chromium/third_party/openscreen/src/platform/api/udp_socket.cc index 1bd96fd1b23..ed026836ccf 100644 --- a/chromium/third_party/openscreen/src/platform/api/udp_socket.cc +++ b/chromium/third_party/openscreen/src/platform/api/udp_socket.cc @@ -4,10 +4,14 @@ #include "platform/api/udp_socket.h" +#include "platform/api/task_runner.h" + namespace openscreen { namespace platform { -UdpSocket::UdpSocket() { +UdpSocket::UdpSocket(TaskRunner* task_runner, Client* client) + : client_(client), task_runner_(task_runner) { + OSP_CHECK(task_runner_); deletion_callback_ = [](UdpSocket* socket) {}; } @@ -19,5 +23,33 @@ void UdpSocket::SetDeletionCallback(std::function<void(UdpSocket*)> callback) { deletion_callback_ = callback; } +void UdpSocket::OnError(Error error) { + if (!client_) { + return; + } + + task_runner_->PostTask([e = std::move(error), this]() mutable { + this->client_->OnError(this, std::move(e)); + }); +} +void UdpSocket::OnSendError(Error error) { + if (!client_) { + return; + } + + task_runner_->PostTask([e = std::move(error), this]() mutable { + this->client_->OnSendError(this, std::move(e)); + }); +} +void UdpSocket::OnRead(ErrorOr<UdpPacket> read_data) { + if (!client_) { + return; + } + + task_runner_->PostTask([data = std::move(read_data), this]() mutable { + this->client_->OnRead(this, std::move(data)); + }); +} + } // namespace platform } // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/api/udp_socket.h b/chromium/third_party/openscreen/src/platform/api/udp_socket.h index e060d419d9d..14f0d651347 100644 --- a/chromium/third_party/openscreen/src/platform/api/udp_socket.h +++ b/chromium/third_party/openscreen/src/platform/api/udp_socket.h @@ -18,6 +18,7 @@ namespace openscreen { namespace platform { +class TaskRunner; class UdpSocket; using UdpSocketUniquePtr = std::unique_ptr<UdpSocket>; @@ -39,6 +40,26 @@ class UdpSocket { public: virtual ~UdpSocket(); + // Client for the UdpSocket class. + class Client { + public: + virtual ~Client() = default; + + // Method called on socket configuration operations when an error occurs. + // These specific APIs are: + // UdpSocket::Bind() + // UdpSocket::SetMulticastOutboundInterface(...) + // UdpSocket::JoinMulticastGroup(...) + // UdpSocket::SetDscp(...) + virtual void OnError(UdpSocket* socket, Error error) = 0; + + // Method called when an error occurs during a SendMessage call. + virtual void OnSendError(UdpSocket* socket, Error error) = 0; + + // Method called when a packet is read. + virtual void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet) = 0; + }; + // Constants used to specify how we want packets sent from this socket. enum class DscpMode : uint8_t { // Default value set by the system on creation of a new socket. @@ -56,16 +77,31 @@ class UdpSocket { using Version = IPAddress::Version; - // Creates a new, scoped UdpSocket within the IPv4 or IPv6 family. This method - // must be defined in the platform-level implementation. - static ErrorOr<UdpSocketUniquePtr> Create(Version version); + // Creates a new, scoped UdpSocket within the IPv4 or IPv6 family. + // |local_endpoint| may be zero (see comments for Bind()). This method must be + // defined in the platform-level implementation. All client_ methods called + // will be queued on the provided task_runner. For this reason, the provided + // task_runner and client must exist for the duration of the created socket's + // lifetime. + static ErrorOr<UdpSocketUniquePtr> Create(TaskRunner* task_runner, + Client* client, + const IPEndpoint& local_endpoint); // Returns true if |socket| belongs to the IPv4/IPv6 address family. virtual bool IsIPv4() const = 0; virtual bool IsIPv6() const = 0; - // Sets the socket for address reuse, binds to the address/port. - virtual Error Bind(const IPEndpoint& local_endpoint) = 0; + // Returns the current local endpoint's address and port. Initially, this will + // be the same as the value that was passed into Create(). However, it can + // later change after certain operations, such as Bind(), are executed. + virtual IPEndpoint GetLocalEndpoint() const = 0; + + // Binds to the address specified in the constructor. If the local endpoint's + // address is zero, the operating system will bind to all interfaces. If the + // local endpoint's port is zero, the operating system will automatically find + // a free local port and bind to it. Future calls to local_endpoint() will + // reflect the resolved port. + virtual Error Bind() = 0; // Sets the device to use for outgoing multicast packets on the socket. virtual Error SetMulticastOutboundInterface( @@ -76,13 +112,6 @@ class UdpSocket { virtual Error JoinMulticastGroup(const IPAddress& address, NetworkInterfaceIndex ifindex) = 0; - // Performs a non-blocking read on the socket, returning the number of bytes - // received. Note that a non-Error return value of 0 is a valid result, - // indicating an empty message has been received. Also note that - // Error::Code::kAgain might be returned if there is no message currently - // ready for receive, which can be expected during normal operation. - virtual ErrorOr<UdpPacket> ReceiveMessage() = 0; - // Sends a message and returns the number of bytes sent, on success. // Error::Code::kAgain might be returned to indicate the operation would // block, which can be expected during normal operation. @@ -99,13 +128,29 @@ class UdpSocket { void SetDeletionCallback(std::function<void(UdpSocket*)> callback); protected: - UdpSocket(); + // Creates a new UdpSocket. The provided client and task_runner must exist for + // the duration of this socket's lifetime. + UdpSocket(TaskRunner* task_runner, Client* client); + + // Methods to take care of posting UdpSocket::Client callbacks for client_ to + // task_runner_. + void OnError(Error error); + void OnSendError(Error error); + void OnRead(ErrorOr<UdpPacket> read_data); private: // This callback allows other objects to observe the socket's destructor and // act when it is called. std::function<void(UdpSocket*)> deletion_callback_; + // Client to use for callbacks. + // NOTE: client_ can be nullptr if the user does not want any callbacks (for + // example, in the send-only case). + Client* const client_; + + // Task runner to use for queuing client_ callbacks. + TaskRunner* const task_runner_; + OSP_DISALLOW_COPY_AND_ASSIGN(UdpSocket); }; diff --git a/chromium/third_party/openscreen/src/platform/api/udp_socket_unittest.cc b/chromium/third_party/openscreen/src/platform/api/udp_socket_unittest.cc index 78140325f52..38e624af0ac 100644 --- a/chromium/third_party/openscreen/src/platform/api/udp_socket_unittest.cc +++ b/chromium/third_party/openscreen/src/platform/api/udp_socket_unittest.cc @@ -5,6 +5,9 @@ #include "platform/api/udp_socket.h" #include "gtest/gtest.h" +#include "platform/api/time.h" +#include "platform/test/fake_clock.h" +#include "platform/test/fake_task_runner.h" #include "platform/test/mock_udp_socket.h" namespace openscreen { @@ -15,12 +18,20 @@ namespace platform { // which will then crash the running code. This test ensures that deleting a // new, unmodified UDP Socket object doesn't hit this edge case. TEST(UdpSocketTest, TestDeletionWithoutCallbackSet) { - UdpSocket* socket = new MockUdpSocket(UdpSocket::Version::kV4); + FakeClock clock(Clock::now()); + FakeTaskRunner task_runner(&clock); + MockUdpSocket::MockClient client; + UdpSocket* socket = + new MockUdpSocket(&task_runner, &client, UdpSocket::Version::kV4); delete socket; } TEST(UdpSocketTest, TestCallbackCalledOnDeletion) { - UdpSocket* socket = new MockUdpSocket(UdpSocket::Version::kV4); + FakeClock clock(Clock::now()); + FakeTaskRunner task_runner(&clock); + MockUdpSocket::MockClient client; + UdpSocket* socket = + new MockUdpSocket(&task_runner, &client, UdpSocket::Version::kV4); int call_count = 0; std::function<void(UdpSocket*)> callback = [&call_count](UdpSocket* socket) { call_count++; diff --git a/chromium/third_party/openscreen/src/platform/base/error.cc b/chromium/third_party/openscreen/src/platform/base/error.cc index 0edb6bcaf6e..248189f1b3e 100644 --- a/chromium/third_party/openscreen/src/platform/base/error.cc +++ b/chromium/third_party/openscreen/src/platform/base/error.cc @@ -60,14 +60,10 @@ std::ostream& operator<<(std::ostream& os, const Error::Code& code) { return os << "Failure: UnknownRequestId"; case Error::Code::kAddressInUse: return os << "Failure: AddressInUse"; - case Error::Code::kAlreadyListening: - return os << "Failure: AlreadyListening"; case Error::Code::kDomainNameTooLong: return os << "Failure: DomainNameTooLong"; case Error::Code::kDomainNameLabelTooLong: return os << "Failure: DomainNameLabelTooLong"; - case Error::Code::kGenericPlatformError: - return os << "Failure: GenericPlatformError"; case Error::Code::kIOFailure: return os << "Failure: IOFailure"; case Error::Code::kInitializationFailure: @@ -90,12 +86,6 @@ std::ostream& operator<<(std::ostream& os, const Error::Code& code) { return os << "Failure: SocketSendFailure"; case Error::Code::kMdnsRegisterFailure: return os << "Failure: MdnsRegisterFailure"; - case Error::Code::kNoItemFound: - return os << "Failure: NoItemFound"; - case Error::Code::kNotImplemented: - return os << "Failure: NotImplemented"; - case Error::Code::kNotRunning: - return os << "Failure: NotRunning"; case Error::Code::kParseError: return os << "Failure: ParseError"; case Error::Code::kUnknownMessageType: @@ -116,8 +106,44 @@ std::ostream& operator<<(std::ostream& os, const Error::Code& code) { return os << "Failure: JsonWriteError"; case Error::Code::kFileLoadFailure: return os << "Failure: FileLoadFailure"; + case Error::Code::kErrCertsMissing: + return os << "Failure: ErrCertsMissing"; + case Error::Code::kErrCertsParse: + return os << "Failure: ErrCertsParse"; + case Error::Code::kErrCertsRestrictions: + return os << "Failure: ErrCertsRestrictions"; + case Error::Code::kErrCertsDateInvalid: + return os << "Failure: ErrCertsDateInvalid"; + case Error::Code::kErrCertsVerifyGeneric: + return os << "Failure: ErrCertsVerifyGeneric"; + case Error::Code::kErrCrlInvalid: + return os << "Failure: ErrCrlInvalid"; + case Error::Code::kErrCertsRevoked: + return os << "Failure: ErrCertsRevoked"; + case Error::Code::kErrCertsPathlen: + return os << "Failure: ErrCertsPathlen"; + case Error::Code::kUnknownError: + return os << "Failure: UnknownError"; + case Error::Code::kNotImplemented: + return os << "Failure: NotImplemented"; case Error::Code::kInsufficientBuffer: return os << "Failure: InsufficientBuffer"; + case Error::Code::kParameterInvalid: + return os << "Failure: ParameterInvalid"; + case Error::Code::kParameterOutOfRange: + return os << "Failure: ParameterOutOfRange"; + case Error::Code::kParameterNullPointer: + return os << "Failure: ParameterNullPointer"; + case Error::Code::kIndexOutOfBounds: + return os << "Failure: IndexOutOfBounds"; + case Error::Code::kItemAlreadyExists: + return os << "Failure: ItemAlreadyExists"; + case Error::Code::kItemNotFound: + return os << "Failure: ItemNotFound"; + case Error::Code::kOperationInvalid: + return os << "Failure: OperationInvalid"; + case Error::Code::kOperationCancelled: + return os << "Failure: OperationCancelled"; } // Unused 'return' to get around failure on GCC. diff --git a/chromium/third_party/openscreen/src/platform/base/error.h b/chromium/third_party/openscreen/src/platform/base/error.h index 4b2f52fef20..f5faaa5afb9 100644 --- a/chromium/third_party/openscreen/src/platform/base/error.h +++ b/chromium/third_party/openscreen/src/platform/base/error.h @@ -17,6 +17,7 @@ namespace openscreen { // code and an optional message. class Error { public: + // TODO(issue/65): Group/rename OSP-specific errors enum class Code : int8_t { // No error occurred. kNone = 0, @@ -42,12 +43,9 @@ class Error { kUnknownRequestId, kAddressInUse, - kAlreadyListening, kDomainNameTooLong, kDomainNameLabelTooLong, - kGenericPlatformError, - kIOFailure, kInitializationFailure, kInvalidIPV4Address, @@ -62,10 +60,6 @@ class Error { kMdnsRegisterFailure, - kNoItemFound, - kNotImplemented, - kNotRunning, - kParseError, kUnknownMessageType, @@ -78,10 +72,48 @@ class Error { kJsonParseError, kJsonWriteError, - // OpenSSL errors + // OpenSSL errors. kFileLoadFailure, + // Cast certificate errors. + + // Certificates were not provided for verification. + kErrCertsMissing, + + // The certificates provided could not be parsed. + kErrCertsParse, + + // Key usage is missing or is not set to Digital Signature. + // This error could also be thrown if the CN is missing. + kErrCertsRestrictions, + + // The current date is before the notBefore date or after the notAfter date. + kErrCertsDateInvalid, + + // The certificate failed to chain to a trusted root. + kErrCertsVerifyGeneric, + + // The CRL is missing or failed to verify. + kErrCrlInvalid, + + // One of the certificates in the chain is revoked. + kErrCertsRevoked, + + // The pathlen constraint of the root certificate was exceeded. + kErrCertsPathlen, + + // Generic errors. + kUnknownError, + kNotImplemented, kInsufficientBuffer, + kParameterInvalid, + kParameterOutOfRange, + kParameterNullPointer, + kIndexOutOfBounds, + kItemAlreadyExists, + kItemNotFound, + kOperationInvalid, + kOperationCancelled, }; Error(); diff --git a/chromium/third_party/openscreen/src/platform/base/ip_address.h b/chromium/third_party/openscreen/src/platform/base/ip_address.h index 32cd2ddace3..536ce9bec6b 100644 --- a/chromium/third_party/openscreen/src/platform/base/ip_address.h +++ b/chromium/third_party/openscreen/src/platform/base/ip_address.h @@ -72,6 +72,10 @@ class IPAddress { void CopyToV4(uint8_t* x) const; void CopyToV6(uint8_t* x) const; + // In some instances, we want direct access to the underlying byte storage, + // in order to avoid making multiple copies. + const uint8_t* bytes() const { return bytes_.data(); } + private: static ErrorOr<IPAddress> ParseV4(const std::string& s); static ErrorOr<IPAddress> ParseV6(const std::string& s); diff --git a/chromium/third_party/openscreen/src/platform/base/ip_address_unittest.cc b/chromium/third_party/openscreen/src/platform/base/ip_address_unittest.cc index 85fa7abdcf2..1bbbe8ec0c7 100644 --- a/chromium/third_party/openscreen/src/platform/base/ip_address_unittest.cc +++ b/chromium/third_party/openscreen/src/platform/base/ip_address_unittest.cc @@ -23,6 +23,10 @@ TEST(IPAddressTest, V4Constructors) { address2.CopyToV4(bytes); EXPECT_THAT(bytes, ElementsAreArray(x)); + const auto b = address2.bytes(); + const uint8_t raw_bytes[4]{b[0], b[1], b[2], b[3]}; + EXPECT_THAT(raw_bytes, ElementsAreArray(x)); + IPAddress address3(IPAddress::Version::kV4, &x[0]); address3.CopyToV4(bytes); EXPECT_THAT(bytes, ElementsAreArray(x)); diff --git a/chromium/third_party/openscreen/src/platform/base/macros.h b/chromium/third_party/openscreen/src/platform/base/macros.h index fa1dad701dd..d0e491e33c3 100644 --- a/chromium/third_party/openscreen/src/platform/base/macros.h +++ b/chromium/third_party/openscreen/src/platform/base/macros.h @@ -18,6 +18,16 @@ #endif #endif +// TODO(issues/40): Delete this macro once the g++ version is upgraded on the +// bots. +#ifndef MAYBE_NODISCARD +#if defined(__GNUC__) && __GNUC__ < 6 +#define MAYBE_NODISCARD +#else +#define MAYBE_NODISCARD [[nodiscard]] +#endif +#endif + #ifdef DISALLOW_COPY #define OSP_DISALLOW_COPY DISALLOW_COPY #else diff --git a/chromium/third_party/openscreen/src/platform/impl/DEPS b/chromium/third_party/openscreen/src/platform/impl/DEPS index 82d9363c2ba..441e4d80a19 100644 --- a/chromium/third_party/openscreen/src/platform/impl/DEPS +++ b/chromium/third_party/openscreen/src/platform/impl/DEPS @@ -7,6 +7,5 @@ include_rules = [ '+json', # BoringSSL includes - '-third_party/boringssl', '+openssl' ] diff --git a/chromium/third_party/openscreen/src/platform/impl/event_loop.cc b/chromium/third_party/openscreen/src/platform/impl/event_loop.cc deleted file mode 100644 index 05454c9adaf..00000000000 --- a/chromium/third_party/openscreen/src/platform/impl/event_loop.cc +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "platform/impl/event_loop.h" - -#include <utility> - -#include "platform/api/logging.h" -#include "platform/api/udp_socket.h" - -namespace openscreen { -namespace platform { - -std::vector<UdpPacket> HandleUdpSocketReadEvents(const Events& events) { - std::vector<UdpPacket> packets(events.udp_readable_events.size()); - for (const auto& read_event : events.udp_readable_events) { - ErrorOr<UdpPacket> result = read_event.socket->ReceiveMessage(); - if (result) { - packets.emplace_back(result.MoveValue()); - } else { - OSP_LOG_ERROR << "ReceiveMessage() on socket failed: " - << result.error().message(); - } - } - return packets; -} - -std::vector<UdpPacket> OnePlatformLoopIteration(EventWaiterPtr waiter) { - ErrorOr<Events> events = WaitForEvents(waiter); - if (!events) - return {}; - - return HandleUdpSocketReadEvents(events.value()); -} - -} // namespace platform -} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/impl/event_loop.h b/chromium/third_party/openscreen/src/platform/impl/event_loop.h deleted file mode 100644 index 05bf458b918..00000000000 --- a/chromium/third_party/openscreen/src/platform/impl/event_loop.h +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef PLATFORM_IMPL_EVENT_LOOP_H_ -#define PLATFORM_IMPL_EVENT_LOOP_H_ - -#include <sys/types.h> - -#include <vector> - -#include "platform/api/event_waiter.h" -#include "platform/api/network_runner.h" -#include "platform/base/error.h" - -namespace openscreen { -namespace platform { - -std::vector<UdpPacket> HandleUdpSocketReadEvents(const Events& events); -std::vector<UdpPacket> OnePlatformLoopIteration(EventWaiterPtr waiter); - -} // namespace platform -} // namespace openscreen - -#endif // PLATFORM_IMPL_EVENT_LOOP_H_ diff --git a/chromium/third_party/openscreen/src/platform/impl/event_waiter_posix.cc b/chromium/third_party/openscreen/src/platform/impl/event_waiter_posix.cc deleted file mode 100644 index 5988d49625c..00000000000 --- a/chromium/third_party/openscreen/src/platform/impl/event_waiter_posix.cc +++ /dev/null @@ -1,124 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include <sys/select.h> - -#include <algorithm> -#include <vector> - -#include "platform/api/event_waiter.h" -#include "platform/api/logging.h" -#include "platform/base/error.h" -#include "platform/impl/udp_socket_posix.h" - -namespace openscreen { -namespace platform { -namespace { - -Error AddToVectorIfMissing(UdpSocketPosix* socket, - std::vector<UdpSocketPosix*>* watched_sockets) { - for (const auto* s : *watched_sockets) { - if (s->GetFd() == socket->GetFd()) - return Error::Code::kAlreadyListening; - } - watched_sockets->push_back(socket); - return Error::None(); -} - -Error RemoveFromVectorIfPresent(UdpSocketPosix* socket, - std::vector<UdpSocketPosix*>* watched_sockets) { - const auto it = std::find_if( - watched_sockets->begin(), watched_sockets->end(), - [socket](UdpSocketPosix* s) { return s->GetFd() == socket->GetFd(); }); - if (it == watched_sockets->end()) - return Error::Code::kNoItemFound; - - watched_sockets->erase(it); - return Error::None(); -} - -} // namespace - -struct EventWaiterPrivate { - std::vector<UdpSocketPosix*> read_sockets; - std::vector<UdpSocketPosix*> write_sockets; -}; - -EventWaiterPtr CreateEventWaiter() { - return new EventWaiterPrivate; -} - -void DestroyEventWaiter(EventWaiterPtr waiter) { - delete waiter; -} - -Error WatchUdpSocketReadable(EventWaiterPtr waiter, UdpSocket* socket) { - return AddToVectorIfMissing(static_cast<UdpSocketPosix*>(socket), - &waiter->read_sockets); -} - -Error StopWatchingUdpSocketReadable(EventWaiterPtr waiter, UdpSocket* socket) { - return RemoveFromVectorIfPresent(static_cast<UdpSocketPosix*>(socket), - &waiter->read_sockets); -} - -Error WatchUdpSocketWritable(EventWaiterPtr waiter, UdpSocket* socket) { - return AddToVectorIfMissing(static_cast<UdpSocketPosix*>(socket), - &waiter->write_sockets); -} - -Error StopWatchingUdpSocketWritable(EventWaiterPtr waiter, UdpSocket* socket) { - return RemoveFromVectorIfPresent(static_cast<UdpSocketPosix*>(socket), - &waiter->write_sockets); -} - -Error WatchNetworkChange(EventWaiterPtr waiter) { - // TODO(btolsch): Implement network change watching. - OSP_UNIMPLEMENTED(); - return Error::Code::kNotImplemented; -} - -Error StopWatchingNetworkChange(EventWaiterPtr waiter) { - // TODO(btolsch): Implement stop network change watching. - OSP_UNIMPLEMENTED(); - return Error::Code::kNotImplemented; -} - -ErrorOr<Events> WaitForEvents(EventWaiterPtr waiter) { - int max_fd = -1; - fd_set readfds; - fd_set writefds; - FD_ZERO(&readfds); - FD_ZERO(&writefds); - for (const auto* read_socket : waiter->read_sockets) { - FD_SET(read_socket->GetFd(), &readfds); - max_fd = std::max(max_fd, read_socket->GetFd()); - } - for (const auto* write_socket : waiter->write_sockets) { - FD_SET(write_socket->GetFd(), &writefds); - max_fd = std::max(max_fd, write_socket->GetFd()); - } - if (max_fd == -1) - return Error::Code::kIOFailure; - - struct timeval tv = {}; - const int rv = select(max_fd + 1, &readfds, &writefds, nullptr, &tv); - if (rv == -1 || rv == 0) - return Error::Code::kIOFailure; - - Events events; - for (auto* read_socket : waiter->read_sockets) { - if (FD_ISSET(read_socket->GetFd(), &readfds)) - events.udp_readable_events.push_back({read_socket}); - } - for (auto* write_socket : waiter->write_sockets) { - if (FD_ISSET(write_socket->GetFd(), &writefds)) - events.udp_writable_events.push_back({write_socket}); - } - - return std::move(events); -} - -} // namespace platform -} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/impl/network_reader.cc b/chromium/third_party/openscreen/src/platform/impl/network_reader.cc index 4a9470555f6..c805e6523c4 100644 --- a/chromium/third_party/openscreen/src/platform/impl/network_reader.cc +++ b/chromium/third_party/openscreen/src/platform/impl/network_reader.cc @@ -8,6 +8,7 @@ #include <condition_variable> #include "platform/api/logging.h" +#include "platform/impl/udp_socket_posix.h" namespace openscreen { namespace platform { @@ -34,9 +35,10 @@ Error NetworkReader::ReadRepeatedly(UdpSocket* socket, Callback callback) { : Error::None(); } -bool NetworkReader::CancelRead(UdpSocket* socket) { +Error NetworkReader::CancelRead(UdpSocket* socket) { std::lock_guard<std::mutex> lock(mutex_); - return read_callbacks_.erase(socket) != 0; + return read_callbacks_.erase(socket) != 0 ? Error::Code::kNone + : Error::Code::kOperationInvalid; } Error NetworkReader::WaitAndRead(Clock::duration timeout) { @@ -67,7 +69,9 @@ Error NetworkReader::WaitAndRead(Clock::duration timeout) { continue; } - ErrorOr<UdpPacket> read_packet = mapped_socket->first->ReceiveMessage(); + // TODO(rwkeane): Remove this unsafe cast. + UdpSocketPosix* read_socket = static_cast<UdpSocketPosix*>(read); + ErrorOr<UdpPacket> read_packet = read_socket->ReceiveMessage(); if (read_packet.is_error()) { error = read_packet.error(); continue; diff --git a/chromium/third_party/openscreen/src/platform/impl/network_reader.h b/chromium/third_party/openscreen/src/platform/impl/network_reader.h index 4d714b33117..c2a2f0f8c48 100644 --- a/chromium/third_party/openscreen/src/platform/impl/network_reader.h +++ b/chromium/third_party/openscreen/src/platform/impl/network_reader.h @@ -41,11 +41,10 @@ class NetworkReader { Error ReadRepeatedly(UdpSocket* socket, Callback callback); // Cancels any pending wait on reading |socket|. Following this call, any - // pending reads will proceed but their associated callbacks will not fire. - // This function returns false only if the socket was not yet being watched, - // and true if the operation is successful and the socket is no longer - // watched. - bool CancelRead(UdpSocket* socket); + // pending reads will proceed but their associated callbacks will not fire. + // This function returns Error::Code::kNone if the operation is successful and + // the socket is no longer watched and returns an error on failure. + Error CancelRead(UdpSocket* socket); // Runs the Wait function in a loop until the below RequestStopSoon function // is called. diff --git a/chromium/third_party/openscreen/src/platform/impl/network_reader_unittest.cc b/chromium/third_party/openscreen/src/platform/impl/network_reader_unittest.cc index e1dcf8e41d8..13129926c45 100644 --- a/chromium/third_party/openscreen/src/platform/impl/network_reader_unittest.cc +++ b/chromium/third_party/openscreen/src/platform/impl/network_reader_unittest.cc @@ -6,7 +6,10 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "platform/api/event_waiter.h" +#include "platform/api/time.h" +#include "platform/impl/udp_socket_posix.h" +#include "platform/test/fake_clock.h" +#include "platform/test/fake_task_runner.h" #include "platform/test/mock_udp_socket.h" namespace openscreen { @@ -16,6 +19,31 @@ using namespace ::testing; using ::testing::_; using ::testing::Invoke; +class MockUdpSocketPosix : public UdpSocketPosix { + public: + explicit MockUdpSocketPosix(TaskRunner* task_runner, + Client* client, + Version version = Version::kV4) + : UdpSocketPosix(task_runner, client, 0, IPEndpoint()), + version_(version) {} + ~MockUdpSocketPosix() override = default; + + bool IsIPv4() const override { return version_ == UdpSocket::Version::kV4; } + + bool IsIPv6() const override { return version_ == UdpSocket::Version::kV6; } + + MOCK_METHOD0(Bind, Error()); + MOCK_METHOD1(SetMulticastOutboundInterface, Error(NetworkInterfaceIndex)); + MOCK_METHOD2(JoinMulticastGroup, + Error(const IPAddress&, NetworkInterfaceIndex)); + MOCK_METHOD0(ReceiveMessage, ErrorOr<UdpPacket>()); + MOCK_METHOD3(SendMessage, Error(const void*, size_t, const IPEndpoint&)); + MOCK_METHOD1(SetDscp, Error(DscpMode)); + + private: + Version version_; +}; + // Mock event waiter class MockNetworkWaiter final : public NetworkWaiter { public: @@ -83,8 +111,10 @@ TEST(NetworkReaderTest, WatchReadableSucceeds) { std::unique_ptr<NetworkWaiter>(new MockNetworkWaiter()); std::unique_ptr<TaskRunner> task_runner = std::unique_ptr<TaskRunner>(new MockTaskRunner()); - std::unique_ptr<MockUdpSocket> socket = - std::make_unique<MockUdpSocket>(UdpSocket::Version::kV4); + MockUdpSocket::MockClient client; + std::unique_ptr<MockUdpSocketPosix> socket = + std::make_unique<MockUdpSocketPosix>(task_runner.get(), &client, + UdpSocket::Version::kV4); TestingNetworkWaiter network_waiter(std::move(mock_waiter), task_runner.get()); MockCallbacks callbacks; @@ -113,23 +143,27 @@ TEST(NetworkReaderTest, UnwatchReadableSucceeds) { std::unique_ptr<NetworkWaiter>(new MockNetworkWaiter()); std::unique_ptr<TaskRunner> task_runner = std::unique_ptr<TaskRunner>(new MockTaskRunner()); - std::unique_ptr<MockUdpSocket> socket = - std::make_unique<MockUdpSocket>(UdpSocket::Version::kV4); + MockUdpSocket::MockClient client; + std::unique_ptr<MockUdpSocketPosix> socket = + std::make_unique<MockUdpSocketPosix>(task_runner.get(), &client, + UdpSocket::Version::kV4); TestingNetworkWaiter network_waiter(std::move(mock_waiter), task_runner.get()); MockCallbacks callbacks; auto callback = callbacks.GetReadCallback(); - EXPECT_FALSE(network_waiter.CancelRead(socket.get())); + EXPECT_EQ(network_waiter.CancelRead(socket.get()), + Error::Code::kOperationInvalid); EXPECT_FALSE(network_waiter.IsMappedRead(socket.get())); EXPECT_EQ(network_waiter.ReadRepeatedly(socket.get(), callback).code(), Error::Code::kNone); - EXPECT_TRUE(network_waiter.CancelRead(socket.get())); + EXPECT_EQ(network_waiter.CancelRead(socket.get()), Error::Code::kNone); EXPECT_FALSE(network_waiter.IsMappedRead(socket.get())); - EXPECT_FALSE(network_waiter.CancelRead(socket.get())); + EXPECT_EQ(network_waiter.CancelRead(socket.get()), + Error::Code::kOperationInvalid); // Set deletion callback because otherwise the destructor tries to call a // callback on the deleted object when it goes out of scope. @@ -152,7 +186,7 @@ TEST(NetworkReaderTest, WaitBubblesUpWaitForEventsErrors) { auto result = network_waiter.WaitTesting(timeout); EXPECT_EQ(result.code(), response_code); - response_code = Error::Code::kAlreadyListening; + response_code = Error::Code::kOperationInvalid; EXPECT_CALL(*mock_waiter_ptr, AwaitSocketsReadable(_, timeout)) .WillOnce(Return(ByMove(std::move(response_code)))); result = network_waiter.WaitTesting(timeout); @@ -180,8 +214,10 @@ TEST(NetworkReaderTest, WaitSuccessfullyCalledOnAllWatchedSockets) { std::unique_ptr<NetworkWaiter>(mock_waiter_ptr); std::unique_ptr<TaskRunner> task_runner = std::unique_ptr<TaskRunner>(new MockTaskRunner()); - std::unique_ptr<MockUdpSocket> socket = - std::make_unique<MockUdpSocket>(UdpSocket::Version::kV4); + MockUdpSocket::MockClient client; + std::unique_ptr<MockUdpSocketPosix> socket = + std::make_unique<MockUdpSocketPosix>(task_runner.get(), &client, + UdpSocket::Version::kV4); TestingNetworkWaiter network_waiter(std::move(mock_waiter), task_runner.get()); auto timeout = Clock::duration(0); @@ -208,7 +244,9 @@ TEST(NetworkReaderTest, WaitSuccessfulReadAndCallCallback) { std::unique_ptr<NetworkWaiter>(mock_waiter_ptr); std::unique_ptr<TaskRunner> task_runner = std::unique_ptr<TaskRunner>(task_runner_ptr); - MockUdpSocket socket(UdpSocket::Version::kV4); + MockUdpSocket::MockClient client; + MockUdpSocketPosix socket(task_runner.get(), &client, + UdpSocket::Version::kV4); TestingNetworkWaiter network_waiter(std::move(mock_waiter), task_runner.get()); auto timeout = Clock::duration(0); @@ -236,7 +274,9 @@ TEST(NetworkReaderTest, WaitFailsIfReadingSocketFails) { std::unique_ptr<NetworkWaiter>(mock_waiter_ptr); std::unique_ptr<TaskRunner> task_runner = std::unique_ptr<TaskRunner>(new MockTaskRunner()); - MockUdpSocket socket(UdpSocket::Version::kV4); + MockUdpSocket::MockClient client; + MockUdpSocketPosix socket(task_runner.get(), &client, + UdpSocket::Version::kV4); TestingNetworkWaiter network_waiter(std::move(mock_waiter), task_runner.get()); auto timeout = Clock::duration(0); @@ -248,9 +288,8 @@ TEST(NetworkReaderTest, WaitFailsIfReadingSocketFails) { .WillOnce(Return(ByMove(std::vector<UdpSocket*>{&socket}))); EXPECT_CALL(callbacks, ReadCallbackInternal()).Times(0); EXPECT_CALL(socket, ReceiveMessage()) - .WillOnce(Return(ByMove(Error::Code::kGenericPlatformError))); - EXPECT_EQ(network_waiter.WaitTesting(timeout), - Error::Code::kGenericPlatformError); + .WillOnce(Return(ByMove(Error::Code::kUnknownError))); + EXPECT_EQ(network_waiter.WaitTesting(timeout), Error::Code::kUnknownError); // Set deletion callback because otherwise the destructor tries to call a // callback on the deleted object when it goes out of scope. diff --git a/chromium/third_party/openscreen/src/platform/impl/network_runner.cc b/chromium/third_party/openscreen/src/platform/impl/network_runner.cc index c090c6b1380..398c31a5f50 100644 --- a/chromium/third_party/openscreen/src/platform/impl/network_runner.cc +++ b/chromium/third_party/openscreen/src/platform/impl/network_runner.cc @@ -13,14 +13,8 @@ namespace openscreen { namespace platform { NetworkRunnerImpl::NetworkRunnerImpl(std::unique_ptr<TaskRunner> task_runner) - : NetworkRunnerImpl(std::move(task_runner), - std::make_unique<NetworkReader>(task_runner.get())) {} - -NetworkRunnerImpl::NetworkRunnerImpl( - std::unique_ptr<TaskRunner> task_runner, - std::unique_ptr<NetworkReader> network_loop) - : network_loop_(std::move(network_loop)), - task_runner_(std::move(task_runner)){}; + : network_loop_(std::make_unique<NetworkReader>(task_runner.get())), + task_runner_(std::move(task_runner)) {} Error NetworkRunnerImpl::ReadRepeatedly(UdpSocket* socket, UdpReadCallback* callback) { @@ -30,7 +24,7 @@ Error NetworkRunnerImpl::ReadRepeatedly(UdpSocket* socket, return network_loop_->ReadRepeatedly(socket, func); } -bool NetworkRunnerImpl::CancelRead(UdpSocket* socket) { +Error NetworkRunnerImpl::CancelRead(UdpSocket* socket) { return network_loop_->CancelRead(socket); } @@ -44,15 +38,10 @@ void NetworkRunnerImpl::PostPackagedTaskWithDelay(Task task, } void NetworkRunnerImpl::RunUntilStopped() { - const bool was_running = is_running_.exchange(true); - OSP_CHECK(!was_running); - network_loop_->RunUntilStopped(); } void NetworkRunnerImpl::RequestStopSoon() { - is_running_.exchange(false); - network_loop_->RequestStopSoon(); } diff --git a/chromium/third_party/openscreen/src/platform/impl/network_runner.h b/chromium/third_party/openscreen/src/platform/impl/network_runner.h index 230de6bed31..66572c3e439 100644 --- a/chromium/third_party/openscreen/src/platform/impl/network_runner.h +++ b/chromium/third_party/openscreen/src/platform/impl/network_runner.h @@ -33,7 +33,7 @@ class NetworkRunnerImpl final : public NetworkRunner { Error ReadRepeatedly(UdpSocket* socket, UdpReadCallback* callback); - bool CancelRead(UdpSocket* socket); + Error CancelRead(UdpSocket* socket); void PostPackagedTask(Task task); @@ -48,20 +48,11 @@ class NetworkRunnerImpl final : public NetworkRunner { void RequestStopSoon(); protected: - // Creates a new NetworkRunnerImpl with the provided NetworkLoop and - // TaskRunner. Note that the Task Runner is expected to be running at the time - // it is provided. - NetworkRunnerImpl(std::unique_ptr<TaskRunner> task_runner, - std::unique_ptr<NetworkReader> network_loop); - // Objects handling actual processing of this instance's calls. std::unique_ptr<NetworkReader> network_loop_; std::unique_ptr<TaskRunner> task_runner_; private: - // Atomic so that we can perform atomic exchanges. - std::atomic_bool is_running_; - OSP_DISALLOW_COPY_AND_ASSIGN(NetworkRunnerImpl); }; diff --git a/chromium/third_party/openscreen/src/platform/impl/network_runner_lifetime_manager.cc b/chromium/third_party/openscreen/src/platform/impl/network_runner_lifetime_manager.cc new file mode 100644 index 00000000000..4f4c092bb22 --- /dev/null +++ b/chromium/third_party/openscreen/src/platform/impl/network_runner_lifetime_manager.cc @@ -0,0 +1,63 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file + +#include "platform/impl/network_runner_lifetime_manager.h" + +#include <thread> + +#include "platform/api/time.h" +#include "platform/impl/network_runner.h" +#include "platform/impl/task_runner.h" + +namespace openscreen { +namespace platform { + +NetworkRunnerLifetimeManagerImpl::NetworkRunnerLifetimeManagerImpl() = default; + +NetworkRunnerLifetimeManagerImpl::NetworkRunnerLifetimeManagerImpl( + std::unique_ptr<TaskRunner> task_runner) + : task_runner_(std::move(task_runner)) {} + +void NetworkRunnerLifetimeManagerImpl::CreateNetworkRunner() { + OSP_CHECK(!network_runner_.get()) << "NetworkRunner already created"; + + if (!task_runner_.get()) { + auto task_runner = std::make_unique<TaskRunnerImpl>(Clock::now); + task_runner_thread_ = std::make_unique<std::thread>( + [ptr = task_runner.get()]() { ptr->RunUntilStopped(); }); + cleanup_tasks_.emplace( + [ptr = task_runner.get()]() { ptr->RequestStopSoon(); }); + task_runner_ = std::move(task_runner); + } + + network_runner_ = + std::make_unique<NetworkRunnerImpl>(std::move(task_runner_)); + network_runner_thread_ = std::make_unique<std::thread>( + [ptr = network_runner_.get()]() { ptr->RunUntilStopped(); }); + cleanup_tasks_.emplace( + [ptr = network_runner_.get()]() { ptr->RequestStopSoon(); }); +} + +NetworkRunnerLifetimeManagerImpl::~NetworkRunnerLifetimeManagerImpl() { + while (!cleanup_tasks_.empty()) { + cleanup_tasks_.front()(); + cleanup_tasks_.pop(); + } + + if (task_runner_thread_.get()) { + task_runner_thread_->join(); + } + if (network_runner_thread_.get()) { + network_runner_thread_->join(); + } +} + +// static +std::unique_ptr<NetworkRunnerLifetimeManager> +NetworkRunnerLifetimeManager::Create() { + return std::make_unique<NetworkRunnerLifetimeManagerImpl>(); +} + +} // namespace platform +} // namespace openscreen diff --git a/chromium/third_party/openscreen/src/platform/impl/network_runner_lifetime_manager.h b/chromium/third_party/openscreen/src/platform/impl/network_runner_lifetime_manager.h new file mode 100644 index 00000000000..d980516a893 --- /dev/null +++ b/chromium/third_party/openscreen/src/platform/impl/network_runner_lifetime_manager.h @@ -0,0 +1,45 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file + +#ifndef PLATFORM_IMPL_NETWORK_RUNNER_LIFETIME_MANAGER_H_ +#define PLATFORM_IMPL_NETWORK_RUNNER_LIFETIME_MANAGER_H_ + +#include <queue> +#include <thread> + +#include "platform/api/network_runner_lifetime_manager.h" +#include "platform/api/task_runner.h" +#include "platform/impl/network_runner.h" + +namespace openscreen { +namespace platform { + +class NetworkRunnerLifetimeManagerImpl final + : public NetworkRunnerLifetimeManager { + public: + NetworkRunnerLifetimeManagerImpl(); + explicit NetworkRunnerLifetimeManagerImpl( + std::unique_ptr<TaskRunner> task_runner); + ~NetworkRunnerLifetimeManagerImpl() override; + + void CreateNetworkRunner() override; + NetworkRunner* Get() override { + OSP_CHECK(network_runner_.get()) << "NetworkRunner not yet created"; + return network_runner_.get(); + } + + private: + std::unique_ptr<NetworkRunnerImpl> network_runner_; + std::unique_ptr<std::thread> network_runner_thread_; + std::unique_ptr<TaskRunner> task_runner_; + std::unique_ptr<std::thread> task_runner_thread_; + std::queue<TaskRunner::Task> cleanup_tasks_; + + OSP_DISALLOW_COPY_AND_ASSIGN(NetworkRunnerLifetimeManagerImpl); +}; + +} // namespace platform +} // namespace openscreen + +#endif // PLATFORM_IMPL_NETWORK_RUNNER_LIFETIME_MANAGER_H_ diff --git a/chromium/third_party/openscreen/src/platform/impl/task_runner.cc b/chromium/third_party/openscreen/src/platform/impl/task_runner.cc index 651e1fe01ef..ebc8877d53f 100644 --- a/chromium/third_party/openscreen/src/platform/impl/task_runner.cc +++ b/chromium/third_party/openscreen/src/platform/impl/task_runner.cc @@ -4,19 +4,13 @@ #include "platform/impl/task_runner.h" +#include <thread> + #include "platform/api/logging.h" namespace openscreen { namespace platform { -TaskRunnerImpl::TaskWithMetadata::TaskWithMetadata(Task task) - : task_(std::move(task)), trace_ids_(TRACE_HIERARCHY) {} - -void TaskRunnerImpl::TaskWithMetadata::operator()() { - TRACE_SET_HIERARCHY(trace_ids_); - std::move(task_)(); -} - TaskRunnerImpl::TaskRunnerImpl(platform::ClockNowFunctionPtr now_function, TaskWaiter* event_waiter, Clock::duration waiter_timeout) diff --git a/chromium/third_party/openscreen/src/platform/impl/task_runner.h b/chromium/third_party/openscreen/src/platform/impl/task_runner.h index 084d628599c..88e0bed9ce4 100644 --- a/chromium/third_party/openscreen/src/platform/impl/task_runner.h +++ b/chromium/third_party/openscreen/src/platform/impl/task_runner.h @@ -9,8 +9,7 @@ #include <condition_variable> // NOLINT #include <map> #include <memory> -#include <mutex> // NOLINT -#include <thread> // NOLINT +#include <mutex> // NOLINT #include <utility> #include <vector> @@ -75,20 +74,30 @@ class TaskRunnerImpl final : public TaskRunner { void RunUntilIdleForTesting(); private: +#ifndef TRACE_FORCE_DISABLE // Wrapper around a Task used to store the TraceId Metadata along with the // task itself, and to set the current TraceIdHierarchy before executing the // task. class TaskWithMetadata { public: - // NOTE: Conversion constructor required due to condition_variable library. - TaskWithMetadata(Task task); + // NOTE: 'explicit' keyword omitted so that conversion construtor can be + // used. This simplifies switching between 'Task' and 'TaskWithMetadata' + // based on the compilation flag. + TaskWithMetadata(Task task) + : task_(std::move(task)), trace_ids_(TRACE_HIERARCHY){}; - void operator()(); + void operator()() { + TRACE_SET_HIERARCHY(trace_ids_); + std::move(task_)(); + } private: Task task_; TraceIdHierarchy trace_ids_; }; +#else // TRACE_FORCE_DISABLE defined + using TaskWithMetadata = Task; +#endif // TRACE_FORCE_DISABLE // Run all tasks already in the task queue. If the queue is empty, wait for // either (1) a delayed task to become available, or (2) a task to be added diff --git a/chromium/third_party/openscreen/src/platform/impl/task_runner_unittest.cc b/chromium/third_party/openscreen/src/platform/impl/task_runner_unittest.cc index 71403a0d591..b8d582cbdb3 100644 --- a/chromium/third_party/openscreen/src/platform/impl/task_runner_unittest.cc +++ b/chromium/third_party/openscreen/src/platform/impl/task_runner_unittest.cc @@ -235,7 +235,7 @@ TEST(TaskRunnerImplTest, TaskRunnerUsesEventWaiter) { std::unique_ptr<TaskRunnerImpl> runner = TaskRunnerWithWaiterFactory::Create(Clock::now); - int x = 0; + std::atomic<int> x{0}; std::thread t([&runner, &x] { runner.get()->RunUntilStopped(); x = 1; @@ -262,7 +262,7 @@ TEST(TaskRunnerImplTest, WakesEventWaiterOnPostTask) { std::unique_ptr<TaskRunnerImpl> runner = TaskRunnerWithWaiterFactory::Create(Clock::now); - int x = 0; + std::atomic<int> x{0}; std::thread t([&runner] { runner.get()->RunUntilStopped(); }); const Clock::time_point start1 = Clock::now(); @@ -285,8 +285,6 @@ TEST(TaskRunnerImplTest, WakesEventWaiterOnPostTask) { class RepeatedClass { public: - RepeatedClass() { execution_count = 0; } - MOCK_METHOD0(Repeat, absl::optional<Clock::duration>()); absl::optional<Clock::duration> DoCall() { @@ -295,7 +293,7 @@ class RepeatedClass { return result; } - int execution_count; + std::atomic<int> execution_count{0}; }; TEST(TaskRunnerImplTest, RepeatingFunctionCalledRepeatedly) { diff --git a/chromium/third_party/openscreen/src/platform/impl/text_trace_logging_platform.cc b/chromium/third_party/openscreen/src/platform/impl/text_trace_logging_platform.cc index b214b0ca852..09a71ff7142 100644 --- a/chromium/third_party/openscreen/src/platform/impl/text_trace_logging_platform.cc +++ b/chromium/third_party/openscreen/src/platform/impl/text_trace_logging_platform.cc @@ -4,6 +4,8 @@ #include "platform/impl/text_trace_logging_platform.h" +#include <sstream> + #include "platform/api/logging.h" namespace openscreen { @@ -27,30 +29,31 @@ void TextTraceLoggingPlatform::LogTrace(const char* name, const char* file, Clock::time_point start_time, Clock::time_point end_time, - TraceId trace_id, - TraceId parent_id, - TraceId root_id, + TraceIdHierarchy ids, Error::Code error) { auto total_runtime = std::chrono::duration_cast<std::chrono::microseconds>( end_time - start_time) .count(); constexpr auto microseconds_symbol = "\u03BCs"; // Greek Mu + 's' - OSP_LOG << "TRACE [" << std::hex << root_id << ":" << parent_id << ":" - << trace_id << "] (" << std::dec << total_runtime - << microseconds_symbol << ") " << name << "<" << file << ":" << line - << "> " << error; + std::stringstream ss; + ss << "TRACE [" << std::hex << ids.root << ":" << ids.parent << ":" + << ids.current << "] (" << std::dec << total_runtime << microseconds_symbol + << ") " << name << "<" << file << ":" << line << "> " << error; + + OSP_LOG << ss.str(); } void TextTraceLoggingPlatform::LogAsyncStart(const char* name, const uint32_t line, const char* file, Clock::time_point timestamp, - TraceId trace_id, - TraceId parent_id, - TraceId root_id) { - OSP_LOG << "ASYNC TRACE START [" << std::hex << root_id << ":" << parent_id - << ":" << trace_id << std::dec << "] (" << timestamp << ") " << name - << "<" << file << ":" << line << ">"; + TraceIdHierarchy ids) { + std::stringstream ss; + ss << "ASYNC TRACE START [" << std::hex << ids.root << ":" << ids.parent + << ":" << ids.current << std::dec << "] (" << timestamp << ") " << name + << "<" << file << ":" << line << ">"; + + OSP_LOG << ss.str(); } void TextTraceLoggingPlatform::LogAsyncEnd(const uint32_t line, diff --git a/chromium/third_party/openscreen/src/platform/impl/text_trace_logging_platform.h b/chromium/third_party/openscreen/src/platform/impl/text_trace_logging_platform.h index b7f98204f73..b564a47a8d8 100644 --- a/chromium/third_party/openscreen/src/platform/impl/text_trace_logging_platform.h +++ b/chromium/third_party/openscreen/src/platform/impl/text_trace_logging_platform.h @@ -21,18 +21,14 @@ class TextTraceLoggingPlatform : public TraceLoggingPlatform { const char* file, Clock::time_point start_time, Clock::time_point end_time, - TraceId trace_id, - TraceId parent_id, - TraceId root_id, + TraceIdHierarchy ids, Error::Code error) override; void LogAsyncStart(const char* name, const uint32_t line, const char* file, Clock::time_point timestamp, - TraceId trace_id, - TraceId parent_id, - TraceId root_id) override; + TraceIdHierarchy ids) override; void LogAsyncEnd(const uint32_t line, const char* file, diff --git a/chromium/third_party/openscreen/src/platform/impl/udp_socket_posix.cc b/chromium/third_party/openscreen/src/platform/impl/udp_socket_posix.cc index db08e20c438..d361c530e1c 100644 --- a/chromium/third_party/openscreen/src/platform/impl/udp_socket_posix.cc +++ b/chromium/third_party/openscreen/src/platform/impl/udp_socket_posix.cc @@ -50,17 +50,24 @@ ErrorOr<int> CreateNonBlockingUdpSocket(int domain) { } // namespace -UdpSocketPosix::UdpSocketPosix(int fd, UdpSocket::Version version) - : fd_(fd), version_(version) {} +UdpSocketPosix::UdpSocketPosix(TaskRunner* task_runner, + Client* client, + int fd, + const IPEndpoint& local_endpoint) + : UdpSocket(task_runner, client), fd_(fd), local_endpoint_(local_endpoint) { + OSP_DCHECK(local_endpoint_.address.IsV4() || local_endpoint_.address.IsV6()); +} UdpSocketPosix::~UdpSocketPosix() { close(fd_); } // static -ErrorOr<UdpSocketUniquePtr> UdpSocket::Create(UdpSocket::Version version) { +ErrorOr<UdpSocketUniquePtr> UdpSocket::Create(TaskRunner* task_runner, + Client* client, + const IPEndpoint& endpoint) { int domain; - switch (version) { + switch (endpoint.address.version()) { case Version::kV4: domain = AF_INET; break; @@ -72,19 +79,57 @@ ErrorOr<UdpSocketUniquePtr> UdpSocket::Create(UdpSocket::Version version) { if (!fd) { return fd.error(); } - return UdpSocketUniquePtr( - static_cast<UdpSocket*>(new UdpSocketPosix(fd.value(), version))); + return UdpSocketUniquePtr(static_cast<UdpSocket*>( + new UdpSocketPosix(task_runner, client, fd.value(), endpoint))); } bool UdpSocketPosix::IsIPv4() const { - return version_ == UdpSocket::Version::kV4; + return local_endpoint_.address.IsV4(); } bool UdpSocketPosix::IsIPv6() const { - return version_ == UdpSocket::Version::kV6; + return local_endpoint_.address.IsV6(); +} + +IPEndpoint UdpSocketPosix::GetLocalEndpoint() const { + if (local_endpoint_.port == 0) { + // Note: If the getsockname() call fails, just assume that's because the + // socket isn't bound yet. In this case, leave the original value in-place. + switch (local_endpoint_.address.version()) { + case UdpSocket::Version::kV4: { + struct sockaddr_in address; + socklen_t address_len = sizeof(address); + if (getsockname(fd_, reinterpret_cast<struct sockaddr*>(&address), + &address_len) == 0) { + OSP_DCHECK_EQ(address.sin_family, AF_INET); + local_endpoint_.address = + IPAddress(IPAddress::Version::kV4, + reinterpret_cast<uint8_t*>(&address.sin_addr.s_addr)); + local_endpoint_.port = ntohs(address.sin_port); + } + break; + } + + case UdpSocket::Version::kV6: { + struct sockaddr_in6 address; + socklen_t address_len = sizeof(address); + if (getsockname(fd_, reinterpret_cast<struct sockaddr*>(&address), + &address_len) == 0) { + OSP_DCHECK_EQ(address.sin6_family, AF_INET6); + local_endpoint_.address = + IPAddress(IPAddress::Version::kV6, + reinterpret_cast<uint8_t*>(&address.sin6_addr)); + local_endpoint_.port = ntohs(address.sin6_port); + } + break; + } + } + } + + return local_endpoint_; } -Error UdpSocketPosix::Bind(const IPEndpoint& endpoint) { +Error UdpSocketPosix::Bind() { // This is effectively a boolean passed to setsockopt() to allow a future // bind() on the same socket to succeed, even if the address is already in // use. This is pretty much universally the desired behavior. @@ -94,12 +139,12 @@ Error UdpSocketPosix::Bind(const IPEndpoint& endpoint) { return Error(Error::Code::kSocketOptionSettingFailure, strerror(errno)); } - switch (version_) { + switch (local_endpoint_.address.version()) { case UdpSocket::Version::kV4: { struct sockaddr_in address; address.sin_family = AF_INET; - address.sin_port = htons(endpoint.port); - endpoint.address.CopyToV4( + address.sin_port = htons(local_endpoint_.port); + local_endpoint_.address.CopyToV4( reinterpret_cast<uint8_t*>(&address.sin_addr.s_addr)); if (bind(fd_, reinterpret_cast<struct sockaddr*>(&address), sizeof(address)) == -1) { @@ -112,8 +157,9 @@ Error UdpSocketPosix::Bind(const IPEndpoint& endpoint) { struct sockaddr_in6 address; address.sin6_family = AF_INET6; address.sin6_flowinfo = 0; - address.sin6_port = htons(endpoint.port); - endpoint.address.CopyToV6(reinterpret_cast<uint8_t*>(&address.sin6_addr)); + address.sin6_port = htons(local_endpoint_.port); + local_endpoint_.address.CopyToV6( + reinterpret_cast<uint8_t*>(&address.sin6_addr)); address.sin6_scope_id = 0; if (bind(fd_, reinterpret_cast<struct sockaddr*>(&address), sizeof(address)) == -1) { @@ -124,12 +170,12 @@ Error UdpSocketPosix::Bind(const IPEndpoint& endpoint) { } OSP_NOTREACHED(); - return Error::Code::kGenericPlatformError; + return Error::Code::kUnknownError; } Error UdpSocketPosix::SetMulticastOutboundInterface( NetworkInterfaceIndex ifindex) { - switch (version_) { + switch (local_endpoint_.address.version()) { case UdpSocket::Version::kV4: { struct ip_mreqn multicast_properties; // Appropriate address is set based on |imr_ifindex| when set. @@ -155,12 +201,12 @@ Error UdpSocketPosix::SetMulticastOutboundInterface( } OSP_NOTREACHED(); - return Error::Code::kGenericPlatformError; + return Error::Code::kUnknownError; } Error UdpSocketPosix::JoinMulticastGroup(const IPAddress& address, NetworkInterfaceIndex ifindex) { - switch (version_) { + switch (local_endpoint_.address.version()) { case UdpSocket::Version::kV4: { // Passed as data to setsockopt(). 1 means return IP_PKTINFO control data // in recvmsg() calls. @@ -212,7 +258,7 @@ Error UdpSocketPosix::JoinMulticastGroup(const IPAddress& address, } OSP_NOTREACHED(); - return Error::Code::kGenericPlatformError; + return Error::Code::kUnknownError; } namespace { @@ -327,8 +373,8 @@ ErrorOr<UdpPacket> UdpSocketPosix::ReceiveMessage() { } UdpPacket packet(bytes_available); packet.set_socket(this); - Error result = Error::Code::kGenericPlatformError; - switch (version_) { + Error result = Error::Code::kUnknownError; + switch (local_endpoint_.address.version()) { case UdpSocket::Version::kV4: { result = ReceiveMessageInternal<sockaddr_in, in_pktinfo>(fd_, &packet); break; @@ -359,7 +405,7 @@ Error UdpSocketPosix::SendMessage(const void* data, msg.msg_flags = 0; ssize_t num_bytes_sent = -2; - switch (version_) { + switch (local_endpoint_.address.version()) { case UdpSocket::Version::kV4: { struct sockaddr_in sa = { .sin_family = AF_INET, diff --git a/chromium/third_party/openscreen/src/platform/impl/udp_socket_posix.h b/chromium/third_party/openscreen/src/platform/impl/udp_socket_posix.h index 17983b1f8fd..1a0555995dc 100644 --- a/chromium/third_party/openscreen/src/platform/impl/udp_socket_posix.h +++ b/chromium/third_party/openscreen/src/platform/impl/udp_socket_posix.h @@ -12,27 +12,45 @@ namespace platform { struct UdpSocketPosix : public UdpSocket { public: - UdpSocketPosix(int fd, Version version); - ~UdpSocketPosix() final; + // Creates a new UdpSocketPosix. The provided client and task_runner must + // exist for the duration of this socket's lifetime. + UdpSocketPosix(TaskRunner* task_runner, + Client* client, + int fd, + const IPEndpoint& local_endpoint); + + ~UdpSocketPosix() override; + + // Performs a non-blocking read on the socket, returning the number of bytes + // received. Note that a non-Error return value of 0 is a valid result, + // indicating an empty message has been received. Also note that + // Error::Code::kAgain might be returned if there is no message currently + // ready for receive, which can be expected during normal operation. + virtual ErrorOr<UdpPacket> ReceiveMessage(); // Implementations of UdpSocket methods. - bool IsIPv4() const final; - bool IsIPv6() const final; - Error Bind(const IPEndpoint& local_endpoint) final; - Error SetMulticastOutboundInterface(NetworkInterfaceIndex ifindex) final; + bool IsIPv4() const override; + bool IsIPv6() const override; + IPEndpoint GetLocalEndpoint() const override; + Error Bind() override; + Error SetMulticastOutboundInterface(NetworkInterfaceIndex ifindex) override; Error JoinMulticastGroup(const IPAddress& address, - NetworkInterfaceIndex ifindex) final; - ErrorOr<UdpPacket> ReceiveMessage() final; + NetworkInterfaceIndex ifindex) override; Error SendMessage(const void* data, size_t length, - const IPEndpoint& dest) final; - Error SetDscp(DscpMode state) final; + const IPEndpoint& dest) override; + Error SetDscp(DscpMode state) override; int GetFd() const { return fd_; } private: const int fd_; - const UdpSocket::Version version_; + + // Cached value of current local endpoint. This can change (e.g., when the + // operating system auto-assigns a free local port when Bind() is called). If + // the port is zero, getsockname() is called to try to resolve it. Once the + // port is non-zero, it is assumed never to change again. + mutable IPEndpoint local_endpoint_; }; } // namespace platform diff --git a/chromium/third_party/openscreen/src/streaming/cast/DEPS b/chromium/third_party/openscreen/src/streaming/cast/DEPS index 77ab80d771b..7ad157857ef 100644 --- a/chromium/third_party/openscreen/src/streaming/cast/DEPS +++ b/chromium/third_party/openscreen/src/streaming/cast/DEPS @@ -4,6 +4,5 @@ include_rules = [ # BoringSSL includes - '-third_party/boringssl', '+openssl' ] diff --git a/chromium/third_party/openscreen/src/streaming/cast/rtcp_common.h b/chromium/third_party/openscreen/src/streaming/cast/rtcp_common.h index 0b8e712b965..9e3e6266084 100644 --- a/chromium/third_party/openscreen/src/streaming/cast/rtcp_common.h +++ b/chromium/third_party/openscreen/src/streaming/cast/rtcp_common.h @@ -92,7 +92,7 @@ struct RtcpReportBlock { RtpTimeDelta jitter; // The last Status Report received. - StatusReportId last_status_report_id; + StatusReportId last_status_report_id{}; // The delay between when the peer received the most-recent Status Report and // when this report was sent. The timebase is 65536 ticks per second and, diff --git a/chromium/third_party/openscreen/src/streaming/cast/rtcp_session.cc b/chromium/third_party/openscreen/src/streaming/cast/rtcp_session.cc index 127c7c0fd24..b04cbf9f9ef 100644 --- a/chromium/third_party/openscreen/src/streaming/cast/rtcp_session.cc +++ b/chromium/third_party/openscreen/src/streaming/cast/rtcp_session.cc @@ -11,6 +11,8 @@ namespace cast_streaming { RtcpSession::RtcpSession(Ssrc sender_ssrc, Ssrc receiver_ssrc) : sender_ssrc_(sender_ssrc), receiver_ssrc_(receiver_ssrc) { + OSP_DCHECK_NE(sender_ssrc_, kNullSsrc); + OSP_DCHECK_NE(receiver_ssrc_, kNullSsrc); OSP_DCHECK_NE(sender_ssrc_, receiver_ssrc_); } diff --git a/chromium/third_party/openscreen/src/streaming/cast/rtp_packet_parser.cc b/chromium/third_party/openscreen/src/streaming/cast/rtp_packet_parser.cc index fcc1f4d3bd0..c06066d3ee7 100644 --- a/chromium/third_party/openscreen/src/streaming/cast/rtp_packet_parser.cc +++ b/chromium/third_party/openscreen/src/streaming/cast/rtp_packet_parser.cc @@ -55,6 +55,9 @@ absl::optional<RtpPacketParser::ParseResult> RtpPacketParser::Parse( highest_rtp_frame_id_.Expand(ConsumeField<uint8_t>(&buffer)); result.packet_id = ConsumeField<uint16_t>(&buffer); result.max_packet_id = ConsumeField<uint16_t>(&buffer); + if (result.max_packet_id == kAllPacketsLost) { + return absl::nullopt; // Packet ID cannot be the special value. + } if (result.packet_id > result.max_packet_id) { return absl::nullopt; } diff --git a/chromium/third_party/openscreen/src/streaming/cast/rtp_packet_parser_unittest.cc b/chromium/third_party/openscreen/src/streaming/cast/rtp_packet_parser_unittest.cc index 03ccec91efb..e35c47a6aae 100644 --- a/chromium/third_party/openscreen/src/streaming/cast/rtp_packet_parser_unittest.cc +++ b/chromium/third_party/openscreen/src/streaming/cast/rtp_packet_parser_unittest.cc @@ -6,6 +6,7 @@ #include "gtest/gtest.h" #include "streaming/cast/rtp_defines.h" +#include "util/big_endian.h" namespace openscreen { namespace cast_streaming { @@ -269,7 +270,7 @@ TEST(RtpPacketParserTest, RejectsTruncatedPackets) { } // Tests that the parser rejects invalid packet ID values. -TEST(RtpPacketParserTest, RejectsPacketWithBadFramePacketId) { +TEST(RtpPacketParserTest, RejectsPacketWithBadFramePacketIds) { // clang-format off const uint8_t kInput[] = { 0b10000000, // Version/Padding byte. @@ -286,8 +287,22 @@ TEST(RtpPacketParserTest, RejectsPacketWithBadFramePacketId) { // clang-format on const Ssrc kSenderSsrc = 0x01020304; + // The parser should reject the packet because its packet ID field is greater + // than the max packet ID. RtpPacketParser parser(kSenderSsrc); ASSERT_FALSE(parser.Parse(kInput)); + + // Now, modify the packet such that its "max packet ID" field is set to the + // special "all packets lost" value. This makes the "packet ID" field valid, + // because it is less than the "max packet ID", but the "max packet ID" value + // itself is invalid. + uint8_t input_with_bad_max_packet_id[sizeof(kInput)]; + memcpy(input_with_bad_max_packet_id, kInput, sizeof(kInput)); + WriteBigEndian<uint16_t>(kAllPacketsLost, &input_with_bad_max_packet_id[16]); + const uint16_t packet_id = + ReadBigEndian<uint16_t>(&input_with_bad_max_packet_id[14]); + ASSERT_LE(packet_id, kAllPacketsLost); + ASSERT_FALSE(parser.Parse(input_with_bad_max_packet_id)); } } // namespace diff --git a/chromium/third_party/openscreen/src/streaming/cast/rtp_packetizer.cc b/chromium/third_party/openscreen/src/streaming/cast/rtp_packetizer.cc index 0dd607ec428..c5a75e73691 100644 --- a/chromium/third_party/openscreen/src/streaming/cast/rtp_packetizer.cc +++ b/chromium/third_party/openscreen/src/streaming/cast/rtp_packetizer.cc @@ -55,6 +55,7 @@ absl::Span<uint8_t> RtpPacketizer::GeneratePacket(const EncryptedFrame& frame, OSP_CHECK_GE(static_cast<int>(buffer.size()), max_packet_size_); const int num_packets = ComputeNumberOfPackets(frame); + OSP_DCHECK_GT(num_packets, 0); OSP_DCHECK_LT(int{packet_id}, num_packets); const bool is_last_packet = int{packet_id} == (num_packets - 1); @@ -128,8 +129,7 @@ int RtpPacketizer::ComputeNumberOfPackets(const EncryptedFrame& frame) const { num_packets = std::max(1, num_packets); // Ensure that the entire range of FramePacketIds can be represented. - OSP_DCHECK_LE(num_packets, int{kMaxAllowedFramePacketId}); - return num_packets; + return num_packets <= int{kMaxAllowedFramePacketId} ? num_packets : -1; } } // namespace cast_streaming diff --git a/chromium/third_party/openscreen/src/streaming/cast/rtp_packetizer.h b/chromium/third_party/openscreen/src/streaming/cast/rtp_packetizer.h index 272702e60d9..f3bc45712c0 100644 --- a/chromium/third_party/openscreen/src/streaming/cast/rtp_packetizer.h +++ b/chromium/third_party/openscreen/src/streaming/cast/rtp_packetizer.h @@ -43,7 +43,8 @@ class RtpPacketizer { absl::Span<uint8_t> buffer); // Given |frame|, compute the total number of packets over which the whole - // frame will be split-up. + // frame will be split-up. Returns -1 if the frame is too large and cannot be + // packetized. int ComputeNumberOfPackets(const EncryptedFrame& frame) const; private: diff --git a/chromium/third_party/openscreen/src/streaming/cast/ssrc.h b/chromium/third_party/openscreen/src/streaming/cast/ssrc.h index 849b4c10398..e1d7803cd00 100644 --- a/chromium/third_party/openscreen/src/streaming/cast/ssrc.h +++ b/chromium/third_party/openscreen/src/streaming/cast/ssrc.h @@ -16,6 +16,9 @@ namespace cast_streaming { // and a video stream will have a different sender SSRC. using Ssrc = uint32_t; +// The "not set" or "null" value for the Ssrc type. +constexpr Ssrc kNullSsrc{0}; + // Computes a new SSRC that will be used to uniquely identify an RTP stream. The // |higher_priority| argument, if true, will generate an SSRC that causes the // system to use a higher priority when scheduling data transmission. Generally, diff --git a/chromium/third_party/openscreen/src/third_party/abseil/BUILD.gn b/chromium/third_party/openscreen/src/third_party/abseil/BUILD.gn index facb89a5806..f229f78b88d 100644 --- a/chromium/third_party/openscreen/src/third_party/abseil/BUILD.gn +++ b/chromium/third_party/openscreen/src/third_party/abseil/BUILD.gn @@ -14,6 +14,9 @@ if (build_with_chromium) { } else { config("abseil_config") { include_dirs = [ "//third_party/abseil/src" ] + } + + config("private_abseil_config") { cflags = [ "-Wno-sign-compare", "-Wno-extra-semi", @@ -84,6 +87,8 @@ if (build_with_chromium) { configs -= [ "//build/config:symbol_visibility_hidden" ] configs += [ "//build/config:symbol_visibility_default" ] configs -= [ "//build:default_include_dirs" ] + + configs += [ ":private_abseil_config" ] public_configs = [ ":abseil_config" ] } } diff --git a/chromium/third_party/openscreen/src/third_party/boringssl/BUILD.gn b/chromium/third_party/openscreen/src/third_party/boringssl/BUILD.gn index 5e9cca25f90..ff5c2aba268 100644 --- a/chromium/third_party/openscreen/src/third_party/boringssl/BUILD.gn +++ b/chromium/third_party/openscreen/src/third_party/boringssl/BUILD.gn @@ -17,6 +17,7 @@ if (build_with_chromium) { # Config for us and everybody else depending on BoringSSL. config("external_config") { include_dirs = [ "src/include" ] + cflags = [ "-Wno-extra-semi" ] } # Config internal to this build file, shared by boringssl and boringssl_fuzzer. diff --git a/chromium/third_party/openscreen/src/third_party/protobuf/BUILD.gn b/chromium/third_party/openscreen/src/third_party/protobuf/BUILD.gn index 77112f74f07..7a3c96d22fa 100644 --- a/chromium/third_party/openscreen/src/third_party/protobuf/BUILD.gn +++ b/chromium/third_party/openscreen/src/third_party/protobuf/BUILD.gn @@ -160,111 +160,113 @@ static_library("protobuf_full") { ] } -# protoc compiler is separated into protoc library and executable targets to -# support protoc plugins that need to link libprotoc, but not the main() -# itself. See src/google/protobuf/compiler/plugin.h -static_library("protoc_lib") { - sources = [ - "src/src/google/protobuf/compiler/code_generator.cc", - "src/src/google/protobuf/compiler/command_line_interface.cc", - "src/src/google/protobuf/compiler/cpp/cpp_enum.cc", - "src/src/google/protobuf/compiler/cpp/cpp_enum_field.cc", - "src/src/google/protobuf/compiler/cpp/cpp_extension.cc", - "src/src/google/protobuf/compiler/cpp/cpp_field.cc", - "src/src/google/protobuf/compiler/cpp/cpp_file.cc", - "src/src/google/protobuf/compiler/cpp/cpp_generator.cc", - "src/src/google/protobuf/compiler/cpp/cpp_helpers.cc", - "src/src/google/protobuf/compiler/cpp/cpp_map_field.cc", - "src/src/google/protobuf/compiler/cpp/cpp_message.cc", - "src/src/google/protobuf/compiler/cpp/cpp_message_field.cc", - "src/src/google/protobuf/compiler/cpp/cpp_padding_optimizer.cc", - "src/src/google/protobuf/compiler/cpp/cpp_primitive_field.cc", - "src/src/google/protobuf/compiler/cpp/cpp_service.cc", - "src/src/google/protobuf/compiler/cpp/cpp_string_field.cc", - "src/src/google/protobuf/compiler/csharp/csharp_doc_comment.cc", - "src/src/google/protobuf/compiler/csharp/csharp_enum.cc", - "src/src/google/protobuf/compiler/csharp/csharp_enum_field.cc", - "src/src/google/protobuf/compiler/csharp/csharp_field_base.cc", - "src/src/google/protobuf/compiler/csharp/csharp_generator.cc", - "src/src/google/protobuf/compiler/csharp/csharp_helpers.cc", - "src/src/google/protobuf/compiler/csharp/csharp_map_field.cc", - "src/src/google/protobuf/compiler/csharp/csharp_message.cc", - "src/src/google/protobuf/compiler/csharp/csharp_message_field.cc", - "src/src/google/protobuf/compiler/csharp/csharp_primitive_field.cc", - "src/src/google/protobuf/compiler/csharp/csharp_reflection_class.cc", - "src/src/google/protobuf/compiler/csharp/csharp_repeated_enum_field.cc", - "src/src/google/protobuf/compiler/csharp/csharp_repeated_message_field.cc", - "src/src/google/protobuf/compiler/csharp/csharp_repeated_primitive_field.cc", - "src/src/google/protobuf/compiler/csharp/csharp_source_generator_base.cc", - "src/src/google/protobuf/compiler/csharp/csharp_wrapper_field.cc", - "src/src/google/protobuf/compiler/java/java_context.cc", - "src/src/google/protobuf/compiler/java/java_doc_comment.cc", - "src/src/google/protobuf/compiler/java/java_enum.cc", - "src/src/google/protobuf/compiler/java/java_enum_field.cc", - "src/src/google/protobuf/compiler/java/java_enum_field_lite.cc", - "src/src/google/protobuf/compiler/java/java_enum_lite.cc", - "src/src/google/protobuf/compiler/java/java_extension.cc", - "src/src/google/protobuf/compiler/java/java_extension_lite.cc", - "src/src/google/protobuf/compiler/java/java_field.cc", - "src/src/google/protobuf/compiler/java/java_file.cc", - "src/src/google/protobuf/compiler/java/java_generator.cc", - "src/src/google/protobuf/compiler/java/java_generator_factory.cc", - "src/src/google/protobuf/compiler/java/java_helpers.cc", - "src/src/google/protobuf/compiler/java/java_map_field.cc", - "src/src/google/protobuf/compiler/java/java_map_field_lite.cc", - "src/src/google/protobuf/compiler/java/java_message.cc", - "src/src/google/protobuf/compiler/java/java_message_builder.cc", - "src/src/google/protobuf/compiler/java/java_message_builder_lite.cc", - "src/src/google/protobuf/compiler/java/java_message_field.cc", - "src/src/google/protobuf/compiler/java/java_message_field_lite.cc", - "src/src/google/protobuf/compiler/java/java_message_lite.cc", - "src/src/google/protobuf/compiler/java/java_name_resolver.cc", - "src/src/google/protobuf/compiler/java/java_primitive_field.cc", - "src/src/google/protobuf/compiler/java/java_primitive_field_lite.cc", - "src/src/google/protobuf/compiler/java/java_service.cc", - "src/src/google/protobuf/compiler/java/java_shared_code_generator.cc", - "src/src/google/protobuf/compiler/java/java_string_field.cc", - "src/src/google/protobuf/compiler/java/java_string_field_lite.cc", - "src/src/google/protobuf/compiler/js/js_generator.cc", - "src/src/google/protobuf/compiler/js/well_known_types_embed.cc", - "src/src/google/protobuf/compiler/objectivec/objectivec_enum.cc", - "src/src/google/protobuf/compiler/objectivec/objectivec_enum_field.cc", - "src/src/google/protobuf/compiler/objectivec/objectivec_extension.cc", - "src/src/google/protobuf/compiler/objectivec/objectivec_field.cc", - "src/src/google/protobuf/compiler/objectivec/objectivec_file.cc", - "src/src/google/protobuf/compiler/objectivec/objectivec_generator.cc", - "src/src/google/protobuf/compiler/objectivec/objectivec_helpers.cc", - "src/src/google/protobuf/compiler/objectivec/objectivec_map_field.cc", - "src/src/google/protobuf/compiler/objectivec/objectivec_message.cc", - "src/src/google/protobuf/compiler/objectivec/objectivec_message_field.cc", - "src/src/google/protobuf/compiler/objectivec/objectivec_oneof.cc", - "src/src/google/protobuf/compiler/objectivec/objectivec_primitive_field.cc", - "src/src/google/protobuf/compiler/php/php_generator.cc", - "src/src/google/protobuf/compiler/plugin.cc", - "src/src/google/protobuf/compiler/plugin.pb.cc", - "src/src/google/protobuf/compiler/python/python_generator.cc", - "src/src/google/protobuf/compiler/ruby/ruby_generator.cc", - "src/src/google/protobuf/compiler/subprocess.cc", - "src/src/google/protobuf/compiler/zip_writer.cc", - ] +if (current_toolchain == host_toolchain) { + # protoc compiler is separated into protoc library and executable targets to + # support protoc plugins that need to link libprotoc, but not the main() + # itself. See src/google/protobuf/compiler/plugin.h + static_library("protoc_lib") { + sources = [ + "src/src/google/protobuf/compiler/code_generator.cc", + "src/src/google/protobuf/compiler/command_line_interface.cc", + "src/src/google/protobuf/compiler/cpp/cpp_enum.cc", + "src/src/google/protobuf/compiler/cpp/cpp_enum_field.cc", + "src/src/google/protobuf/compiler/cpp/cpp_extension.cc", + "src/src/google/protobuf/compiler/cpp/cpp_field.cc", + "src/src/google/protobuf/compiler/cpp/cpp_file.cc", + "src/src/google/protobuf/compiler/cpp/cpp_generator.cc", + "src/src/google/protobuf/compiler/cpp/cpp_helpers.cc", + "src/src/google/protobuf/compiler/cpp/cpp_map_field.cc", + "src/src/google/protobuf/compiler/cpp/cpp_message.cc", + "src/src/google/protobuf/compiler/cpp/cpp_message_field.cc", + "src/src/google/protobuf/compiler/cpp/cpp_padding_optimizer.cc", + "src/src/google/protobuf/compiler/cpp/cpp_primitive_field.cc", + "src/src/google/protobuf/compiler/cpp/cpp_service.cc", + "src/src/google/protobuf/compiler/cpp/cpp_string_field.cc", + "src/src/google/protobuf/compiler/csharp/csharp_doc_comment.cc", + "src/src/google/protobuf/compiler/csharp/csharp_enum.cc", + "src/src/google/protobuf/compiler/csharp/csharp_enum_field.cc", + "src/src/google/protobuf/compiler/csharp/csharp_field_base.cc", + "src/src/google/protobuf/compiler/csharp/csharp_generator.cc", + "src/src/google/protobuf/compiler/csharp/csharp_helpers.cc", + "src/src/google/protobuf/compiler/csharp/csharp_map_field.cc", + "src/src/google/protobuf/compiler/csharp/csharp_message.cc", + "src/src/google/protobuf/compiler/csharp/csharp_message_field.cc", + "src/src/google/protobuf/compiler/csharp/csharp_primitive_field.cc", + "src/src/google/protobuf/compiler/csharp/csharp_reflection_class.cc", + "src/src/google/protobuf/compiler/csharp/csharp_repeated_enum_field.cc", + "src/src/google/protobuf/compiler/csharp/csharp_repeated_message_field.cc", + "src/src/google/protobuf/compiler/csharp/csharp_repeated_primitive_field.cc", + "src/src/google/protobuf/compiler/csharp/csharp_source_generator_base.cc", + "src/src/google/protobuf/compiler/csharp/csharp_wrapper_field.cc", + "src/src/google/protobuf/compiler/java/java_context.cc", + "src/src/google/protobuf/compiler/java/java_doc_comment.cc", + "src/src/google/protobuf/compiler/java/java_enum.cc", + "src/src/google/protobuf/compiler/java/java_enum_field.cc", + "src/src/google/protobuf/compiler/java/java_enum_field_lite.cc", + "src/src/google/protobuf/compiler/java/java_enum_lite.cc", + "src/src/google/protobuf/compiler/java/java_extension.cc", + "src/src/google/protobuf/compiler/java/java_extension_lite.cc", + "src/src/google/protobuf/compiler/java/java_field.cc", + "src/src/google/protobuf/compiler/java/java_file.cc", + "src/src/google/protobuf/compiler/java/java_generator.cc", + "src/src/google/protobuf/compiler/java/java_generator_factory.cc", + "src/src/google/protobuf/compiler/java/java_helpers.cc", + "src/src/google/protobuf/compiler/java/java_map_field.cc", + "src/src/google/protobuf/compiler/java/java_map_field_lite.cc", + "src/src/google/protobuf/compiler/java/java_message.cc", + "src/src/google/protobuf/compiler/java/java_message_builder.cc", + "src/src/google/protobuf/compiler/java/java_message_builder_lite.cc", + "src/src/google/protobuf/compiler/java/java_message_field.cc", + "src/src/google/protobuf/compiler/java/java_message_field_lite.cc", + "src/src/google/protobuf/compiler/java/java_message_lite.cc", + "src/src/google/protobuf/compiler/java/java_name_resolver.cc", + "src/src/google/protobuf/compiler/java/java_primitive_field.cc", + "src/src/google/protobuf/compiler/java/java_primitive_field_lite.cc", + "src/src/google/protobuf/compiler/java/java_service.cc", + "src/src/google/protobuf/compiler/java/java_shared_code_generator.cc", + "src/src/google/protobuf/compiler/java/java_string_field.cc", + "src/src/google/protobuf/compiler/java/java_string_field_lite.cc", + "src/src/google/protobuf/compiler/js/js_generator.cc", + "src/src/google/protobuf/compiler/js/well_known_types_embed.cc", + "src/src/google/protobuf/compiler/objectivec/objectivec_enum.cc", + "src/src/google/protobuf/compiler/objectivec/objectivec_enum_field.cc", + "src/src/google/protobuf/compiler/objectivec/objectivec_extension.cc", + "src/src/google/protobuf/compiler/objectivec/objectivec_field.cc", + "src/src/google/protobuf/compiler/objectivec/objectivec_file.cc", + "src/src/google/protobuf/compiler/objectivec/objectivec_generator.cc", + "src/src/google/protobuf/compiler/objectivec/objectivec_helpers.cc", + "src/src/google/protobuf/compiler/objectivec/objectivec_map_field.cc", + "src/src/google/protobuf/compiler/objectivec/objectivec_message.cc", + "src/src/google/protobuf/compiler/objectivec/objectivec_message_field.cc", + "src/src/google/protobuf/compiler/objectivec/objectivec_oneof.cc", + "src/src/google/protobuf/compiler/objectivec/objectivec_primitive_field.cc", + "src/src/google/protobuf/compiler/php/php_generator.cc", + "src/src/google/protobuf/compiler/plugin.cc", + "src/src/google/protobuf/compiler/plugin.pb.cc", + "src/src/google/protobuf/compiler/python/python_generator.cc", + "src/src/google/protobuf/compiler/ruby/ruby_generator.cc", + "src/src/google/protobuf/compiler/subprocess.cc", + "src/src/google/protobuf/compiler/zip_writer.cc", + ] - configs += [ ":protobuf_warnings" ] - public_configs = [ ":protobuf_config" ] + configs += [ ":protobuf_warnings" ] + public_configs = [ ":protobuf_config" ] - public_deps = [ - ":protobuf_full", - ] -} + public_deps = [ + ":protobuf_full", + ] + } -executable("protoc") { - sources = [ - "src/src/google/protobuf/compiler/main.cc", - ] + executable("protoc") { + sources = [ + "src/src/google/protobuf/compiler/main.cc", + ] - deps = [ - ":protoc_lib", - ] + deps = [ + ":protoc_lib", + ] - configs += [ ":protobuf_warnings" ] - public_configs = [ ":protobuf_config" ] + configs += [ ":protobuf_warnings" ] + public_configs = [ ":protobuf_config" ] + } } diff --git a/chromium/third_party/openscreen/src/third_party/protobuf/proto_library.gni b/chromium/third_party/openscreen/src/third_party/protobuf/proto_library.gni index 89867ca592a..222d6503b01 100644 --- a/chromium/third_party/openscreen/src/third_party/protobuf/proto_library.gni +++ b/chromium/third_party/openscreen/src/third_party/protobuf/proto_library.gni @@ -50,7 +50,7 @@ template("proto_library") { outputs = get_path_info(protogens_cc, "abspath") args = protos - protoc_label = "//third_party/protobuf:protoc" + protoc_label = "//third_party/protobuf:protoc($host_toolchain)" protoc_path = get_label_info(protoc_label, "root_out_dir") + "/protoc" args += [ # Wrapper should never pick a system protoc. diff --git a/chromium/third_party/openscreen/src/tools/cddl/BUILD.gn b/chromium/third_party/openscreen/src/tools/cddl/BUILD.gn index 403cab3b553..a50dfa07851 100644 --- a/chromium/third_party/openscreen/src/tools/cddl/BUILD.gn +++ b/chromium/third_party/openscreen/src/tools/cddl/BUILD.gn @@ -4,28 +4,30 @@ import("//build_overrides/build.gni") -executable("cddl") { - sources = [ - "codegen.cc", - "codegen.h", - "logging.cc", - "logging.h", - "main.cc", - "parse.cc", - "parse.h", - "sema.cc", - "sema.h", - ] +if (current_toolchain == host_toolchain) { + executable("cddl") { + sources = [ + "codegen.cc", + "codegen.h", + "logging.cc", + "logging.h", + "main.cc", + "parse.cc", + "parse.h", + "sema.cc", + "sema.h", + ] - if (build_with_chromium) { - sources += [ "stubs.cc" ] - } + if (build_with_chromium) { + sources += [ "stubs.cc" ] + } - deps = [ - # CDDL always uses the default logger, even when embedded. - "../../platform", - "../../platform:default_logger", - "../../third_party/abseil", - ] - configs += [ "../../build:allow_build_from_embedder" ] + deps = [ + # CDDL always uses the default logger, even when embedded. + "../../platform", + "../../platform:default_logger", + "../../third_party/abseil", + ] + configs += [ "../../build:allow_build_from_embedder" ] + } } diff --git a/chromium/third_party/openscreen/src/tools/cddl/cddl.py b/chromium/third_party/openscreen/src/tools/cddl/cddl.py index 5ad9c90e578..28436f27056 100644 --- a/chromium/third_party/openscreen/src/tools/cddl/cddl.py +++ b/chromium/third_party/openscreen/src/tools/cddl/cddl.py @@ -32,7 +32,7 @@ def main(): if (args.verbose): print('Creating C++ files from provided CDDL file...') - echoAndRunCommand(['./cddl', "--header", args.header, "--cc", args.cc, + echoAndRunCommand([args.cddl, "--header", args.header, "--cc", args.cc, "--gen-dir", args.gen_dir, args.file], False, log, args.verbose) @@ -49,6 +49,7 @@ def main(): def parseInput(): parser = argparse.ArgumentParser() + parser.add_argument("--cddl", help="path to the cddl executable to use") parser.add_argument("--header", help="Specify the filename of the output \ header file. This is also the name that will be used for the include \ guard and as the include path in the source file.") diff --git a/chromium/third_party/openscreen/src/tools/cddl/logging.cc b/chromium/third_party/openscreen/src/tools/cddl/logging.cc index a984c409bb7..477eac5d520 100644 --- a/chromium/third_party/openscreen/src/tools/cddl/logging.cc +++ b/chromium/third_party/openscreen/src/tools/cddl/logging.cc @@ -15,7 +15,7 @@ #include <string> #include <utility> -const char* Logger::MakePrintable(const std::string data) { +const char* Logger::MakePrintable(const std::string& data) { return data.c_str(); } diff --git a/chromium/third_party/openscreen/src/tools/cddl/logging.h b/chromium/third_party/openscreen/src/tools/cddl/logging.h index a4fce443305..a076f097687 100644 --- a/chromium/third_party/openscreen/src/tools/cddl/logging.h +++ b/chromium/third_party/openscreen/src/tools/cddl/logging.h @@ -59,7 +59,7 @@ class Logger { return data; } - const char* MakePrintable(const std::string data); + const char* MakePrintable(const std::string& data); // Writes a log message to this instance of Logger's text file. template <typename... Args> diff --git a/chromium/third_party/openscreen/src/util/crypto/DEPS b/chromium/third_party/openscreen/src/util/crypto/DEPS index 77ab80d771b..7ad157857ef 100644 --- a/chromium/third_party/openscreen/src/util/crypto/DEPS +++ b/chromium/third_party/openscreen/src/util/crypto/DEPS @@ -4,6 +4,5 @@ include_rules = [ # BoringSSL includes - '-third_party/boringssl', '+openssl' ] diff --git a/chromium/third_party/openscreen/src/util/saturate_cast.h b/chromium/third_party/openscreen/src/util/saturate_cast.h index 904e50958ee..db40660eeda 100644 --- a/chromium/third_party/openscreen/src/util/saturate_cast.h +++ b/chromium/third_party/openscreen/src/util/saturate_cast.h @@ -10,55 +10,59 @@ namespace openscreen { -namespace { -template <typename T> -constexpr auto unsigned_cast(T from) { - return static_cast<typename std::make_unsigned<T>::type>(from); -} -} // namespace +// Because of the way C++ signed versus unsigned comparison works (i.e., the +// type promotion strategy employed), extra care must be taken to range-check +// the input value. For example, if the current architecture is 32-bits, then +// any int32_t compared with a uint32_t will NOT promote to a int64_t↔int64_t +// comparison. Instead, it will become a uint32_t↔uint32_t comparison (!), +// which will sometimes produce invalid results. -// Convert from one value type to another, clamping to the min/max of the new -// value type's range if necessary. +// Case 1: "From" and "To" are either both signed, or are both unsigned. In +// this case, the smaller of the two types will be promoted to match the +// larger's size, and a valid comparison will be made. template <typename To, typename From> -constexpr To saturate_cast(From from) { - static_assert(std::numeric_limits<From>::is_integer && - std::numeric_limits<To>::is_integer, - "Non-integral saturate_cast is not implemented."); +constexpr typename std::enable_if_t< + std::is_integral<From>::value && std::is_integral<To>::value && + (std::is_signed<From>::value == std::is_signed<To>::value), + To> +saturate_cast(From from) { + if (from <= std::numeric_limits<To>::min()) { + return std::numeric_limits<To>::min(); + } + if (from >= std::numeric_limits<To>::max()) { + return std::numeric_limits<To>::max(); + } + return static_cast<To>(from); +} - // Because of the way C++ signed versus unsigned comparison works (i.e., the - // type promotion strategy employed), extra care must be taken to range-check - // the input value. For example, if the current architecture is 32-bits, then - // any int32_t compared with a uint32_t will NOT promote to a int64_t↔int64_t - // comparison. Instead, it will become a uint32_t↔uint32_t comparison (!), - // which will sometimes produce invalid results. - if (std::numeric_limits<From>::is_signed == - std::numeric_limits<To>::is_signed) { - // Case 1: "From" and "To" are either both signed, or are both unsigned. In - // this case, the smaller of the two types will be promoted to match the - // larger's size, and a valid comparison will be made. - if (from <= std::numeric_limits<To>::min()) { - return std::numeric_limits<To>::min(); - } - if (from >= std::numeric_limits<To>::max()) { - return std::numeric_limits<To>::max(); - } - } else if (std::numeric_limits<From>::is_signed) { - // Case 2: "From" is signed, but "To" is unsigned. - if (from <= From{0}) { - return To{0}; - } - if (unsigned_cast(from) >= std::numeric_limits<To>::max()) { - return std::numeric_limits<To>::max(); - } - } else { - // Case 3: "From" is unsigned, but "To" is signed. - if (from >= unsigned_cast(std::numeric_limits<To>::max())) { - return std::numeric_limits<To>::max(); - } - // Note: "From" can never be less than "To's" minimum value. +// Case 2: "From" is signed, but "To" is unsigned. +template <typename To, typename From> +constexpr typename std::enable_if_t< + std::is_integral<From>::value && std::is_integral<To>::value && + std::is_signed<From>::value && !std::is_signed<To>::value, + To> +saturate_cast(From from) { + if (from <= From{0}) { + return To{0}; } + if (static_cast<typename std::make_unsigned_t<From>>(from) >= + std::numeric_limits<To>::max()) { + return std::numeric_limits<To>::max(); + } + return static_cast<To>(from); +} - // No clamping is needed: |from| is within the representable value range. +// Case 3: "From" is unsigned, but "To" is signed. +template <typename To, typename From> +constexpr typename std::enable_if_t< + std::is_integral<From>::value && std::is_integral<To>::value && + !std::is_signed<From>::value && std::is_signed<To>::value, + To> +saturate_cast(From from) { + if (from >= static_cast<typename std::make_unsigned_t<To>>( + std::numeric_limits<To>::max())) { + return std::numeric_limits<To>::max(); + } return static_cast<To>(from); } diff --git a/chromium/third_party/openscreen/src/util/yet_another_bit_vector.cc b/chromium/third_party/openscreen/src/util/yet_another_bit_vector.cc index 4af66118034..ff4b1316243 100644 --- a/chromium/third_party/openscreen/src/util/yet_another_bit_vector.cc +++ b/chromium/third_party/openscreen/src/util/yet_another_bit_vector.cc @@ -14,11 +14,16 @@ namespace openscreen { namespace { // Returns a bitmask where all the bits whose positions are in the range -// [begin,end) are set, and all other bits are cleared. -constexpr uint64_t MakeBitmask(int begin, int end) { - const int num_consecutive_bits_to_set = end - begin; - const uint64_t some_power_of_two = uint64_t{1} << num_consecutive_bits_to_set; - const uint64_t bits_in_wrong_position = some_power_of_two - 1; +// [begin,begin+count) are set, and all other bits are cleared. +constexpr uint64_t MakeBitmask(int begin, int count) { + // Form a contiguous sequence of bits by subtracting one from the appropriate + // power of 2. Set all the bits if count >= 64. + const uint64_t bits_in_wrong_position = + (count >= std::numeric_limits<uint64_t>::digits) + ? std::numeric_limits<uint64_t>::max() + : ((uint64_t{1} << count) - 1); + + // Now shift the contiguous sequence of bits into the correct position. return bits_in_wrong_position << begin; } @@ -90,14 +95,10 @@ void YetAnotherBitVector::SetAll() { // valid range are not set. if (using_array_storage()) { - if (int end_bit_offset = (size_ % kBitsPerInteger)) { - uint64_t* last = &bits_.as_array[array_size() - 1]; - std::fill(&bits_.as_array[0], last, kAllBitsSet); - *last = MakeBitmask(0, end_bit_offset); - } else { - // The size is an exact multiple of 64. So, just set all the bits. - std::fill(&bits_.as_array[0], &bits_.as_array[array_size()], kAllBitsSet); - } + const int last_index = array_size() - 1; + uint64_t* const last = &bits_.as_array[last_index]; + std::fill(&bits_.as_array[0], last, kAllBitsSet); + *last = MakeBitmask(0, size_ - (last_index * kBitsPerInteger)); } else { bits_.as_integer = MakeBitmask(0, size_); } @@ -144,7 +145,11 @@ void YetAnotherBitVector::ShiftRight(int steps) { incoming_carry_bits = outgoing_carry_bits; } } else { - bits_.as_integer >>= steps; + if (steps < kBitsPerInteger) { + bits_.as_integer >>= steps; + } else { + bits_.as_integer = 0; + } } } @@ -217,10 +222,9 @@ int YetAnotherBitVector::CountBitsSet(int begin, int end) const { const int first = begin / kBitsPerInteger; const int last = (end - 1) / kBitsPerInteger; if (first == last) { - count = - PopCount(bits_.as_array[first] & - MakeBitmask(begin % kBitsPerInteger, end % kBitsPerInteger)); - } else { + count = PopCount(bits_.as_array[first] & + MakeBitmask(begin % kBitsPerInteger, end - begin)); + } else if (first < last) { // Count a subset of the bits in the first and last integers (according to // |begin| and |end|), and all of the bits in the integers in-between. const uint64_t* p = &bits_.as_array[first]; @@ -229,10 +233,12 @@ int YetAnotherBitVector::CountBitsSet(int begin, int end) const { for (++p; p != &bits_.as_array[last]; ++p) { count += PopCount(*p); } - count += PopCount((*p) & MakeBitmask(0, end % kBitsPerInteger)); + count += PopCount((*p) & MakeBitmask(0, end - (last * kBitsPerInteger))); + } else { + count = 0; } } else { - count = PopCount(bits_.as_integer & MakeBitmask(begin, end)); + count = PopCount(bits_.as_integer & MakeBitmask(begin, end - begin)); } return count; } diff --git a/chromium/third_party/openscreen/src/util/yet_another_bit_vector_unittest.cc b/chromium/third_party/openscreen/src/util/yet_another_bit_vector_unittest.cc index c1a20be5237..98951fe0d38 100644 --- a/chromium/third_party/openscreen/src/util/yet_another_bit_vector_unittest.cc +++ b/chromium/third_party/openscreen/src/util/yet_another_bit_vector_unittest.cc @@ -19,18 +19,19 @@ constexpr uint8_t kBitPatterns[] = {0b00000000, 0b11111111, 0b01010101, // These are used for testing various vector sizes, begins/ends of ranges, etc. // They will exercise both the "inlined storage" (size <= 64 case) and -// "heap-allocated storage" cases. -const int kPrimeNumbers[] = {1, 2, 3, 5, 7, 11, 13, 17, 19, 23, - 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, - 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, - 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, - 173, 179, 181, 191, 193, 197, 199}; - -// Returns a subspan of |kPrimeNumbers| that contains all values in the range +// "heap-allocated storage" cases. These are all of the prime numbers less than +// 200, and also any non-negative multiples of 64 less than 200. +const int kTestSizes[] = {0, 1, 2, 3, 5, 7, 11, 13, 17, 19, 23, + 29, 31, 37, 41, 43, 47, 53, 59, 61, 64, 67, + 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, + 127, 128, 131, 137, 139, 149, 151, 157, 163, 167, 173, + 179, 181, 191, 192, 193, 197, 199}; + +// Returns a subspan of |kTestSizes| that contains all values in the range // [first,last]. -absl::Span<const int> GetPrimeNumbersInRange(int first, int last) { - const auto begin = absl::c_lower_bound(kPrimeNumbers, first); - const auto end = absl::c_upper_bound(kPrimeNumbers, last); +absl::Span<const int> GetTestSizesInRange(int first, int last) { + const auto begin = absl::c_lower_bound(kTestSizes, first); + const auto end = absl::c_upper_bound(kTestSizes, last); return absl::Span<const int>(&*begin, std::distance(begin, end)); } @@ -60,7 +61,7 @@ TEST(YetAnotherBitVectorTest, ConstructsAndResizes) { ASSERT_EQ(v.size(), 0); for (int fill_set = 0; fill_set <= 1; ++fill_set) { - for (int size : kPrimeNumbers) { + for (int size : kTestSizes) { const bool all_bits_should_be_set = !!fill_set; v.Resize(size, all_bits_should_be_set ? YetAnotherBitVector::SET : YetAnotherBitVector::CLEARED); @@ -77,7 +78,7 @@ TEST(YetAnotherBitVectorTest, ConstructsAndResizes) { TEST(YetAnotherBitVectorTest, SetsAndClearsIndividualBits) { YetAnotherBitVector v; for (int fill_set = 0; fill_set <= 1; ++fill_set) { - for (int size : kPrimeNumbers) { + for (int size : kTestSizes) { v.Resize(size, fill_set ? YetAnotherBitVector::SET : YetAnotherBitVector::CLEARED); @@ -95,23 +96,31 @@ TEST(YetAnotherBitVectorTest, SetsAndClearsIndividualBits) { // vector sizes and bit patterns. TEST(YetAnotherBitVectorTest, ShiftsRight) { YetAnotherBitVector v; - for (int size : kPrimeNumbers) { + for (int size : kTestSizes) { v.Resize(size, YetAnotherBitVector::CLEARED); - for (int steps_per_shift : GetPrimeNumbersInRange(0, size)) { + for (int steps_per_shift : GetTestSizesInRange(0, size)) { for (uint8_t pattern : kBitPatterns) { FillWithPattern(pattern, 0, &v); - const int num_shifts = 2 * size / steps_per_shift; - for (int iteration = 1; iteration <= num_shifts; ++iteration) { - v.ShiftRight(steps_per_shift); - const int total_shift_amount = iteration * steps_per_shift; + if (size == 0 || steps_per_shift == 0) { + v.ShiftRight(0); for (int i = 0; i < size; ++i) { - const int original_position = i + total_shift_amount; - if (original_position >= size) { - ASSERT_FALSE(v.IsSet(i)); - } else { - ASSERT_EQ(IsSetInPattern(pattern, original_position), v.IsSet(i)); + ASSERT_EQ(IsSetInPattern(pattern, i), v.IsSet(i)); + } + } else { + const int num_shifts = 2 * size / steps_per_shift; + for (int iteration = 1; iteration <= num_shifts; ++iteration) { + v.ShiftRight(steps_per_shift); + const int total_shift_amount = iteration * steps_per_shift; + for (int i = 0; i < size; ++i) { + const int original_position = i + total_shift_amount; + if (original_position >= size) { + ASSERT_FALSE(v.IsSet(i)); + } else { + ASSERT_EQ(IsSetInPattern(pattern, original_position), + v.IsSet(i)); + } } } } @@ -125,22 +134,19 @@ TEST(YetAnotherBitVectorTest, ShiftsRight) { TEST(YetAnotherBitVectorTest, FindsTheFirstBitSet) { YetAnotherBitVector v; - // In an empty vector, the first bit set is the size(), which is zero. - ASSERT_EQ(0, v.FindFirstSet()); - // For various sizes of vector where no bits are set, the FFS operation should // always return size(). - for (int size : kPrimeNumbers) { + for (int size : kTestSizes) { v.Resize(size, YetAnotherBitVector::CLEARED); ASSERT_EQ(size, v.FindFirstSet()); } // For various sizes of vector where only one bit is set, the FFS operation // should always return the position of that bit. - for (int size : kPrimeNumbers) { + for (int size : kTestSizes) { v.Resize(size, YetAnotherBitVector::CLEARED); - for (int position_plus_one : GetPrimeNumbersInRange(0, size)) { + for (int position_plus_one : GetTestSizesInRange(1, size)) { const int position = position_plus_one - 1; v.Set(position); ASSERT_EQ(position, v.FindFirstSet()); @@ -150,10 +156,10 @@ TEST(YetAnotherBitVectorTest, FindsTheFirstBitSet) { // For various sizes of vector where a pattern of bits are set, the FFS // operation should always return the first one set. - for (int size : kPrimeNumbers) { + for (int size : kTestSizes) { v.Resize(size, YetAnotherBitVector::CLEARED); - for (int position_plus_one : GetPrimeNumbersInRange(0, size)) { + for (int position_plus_one : GetTestSizesInRange(1, size)) { const int position = position_plus_one - 1; v.ClearAll(); v.Set(position); @@ -171,15 +177,12 @@ TEST(YetAnotherBitVectorTest, FindsTheFirstBitSet) { TEST(YetAnotherBitVector, CountsTheNumberOfBitsSet) { YetAnotherBitVector v; - // There are zero bits set in an empty vector. - ASSERT_EQ(0, v.CountBitsSet(0, v.size())); - // For various sizes of vector where no bits are set, the operation should // always return zero for any range. - for (int size : kPrimeNumbers) { + for (int size : kTestSizes) { v.Resize(size, YetAnotherBitVector::CLEARED); - for (int begin : GetPrimeNumbersInRange(0, size)) { - for (int end : GetPrimeNumbersInRange(begin, size)) { + for (int begin : GetTestSizesInRange(0, size)) { + for (int end : GetTestSizesInRange(begin, size)) { ASSERT_EQ(0, v.CountBitsSet(begin, end)); } } @@ -187,23 +190,23 @@ TEST(YetAnotherBitVector, CountsTheNumberOfBitsSet) { // For various sizes of vector where all bits are set, the operation should // always return the length of the range (or zero for invalid ranges). - for (int size : kPrimeNumbers) { + for (int size : kTestSizes) { v.Resize(size, YetAnotherBitVector::SET); - for (int begin : GetPrimeNumbersInRange(0, size)) { - for (int end : GetPrimeNumbersInRange(begin, size)) { + for (int begin : GetTestSizesInRange(0, size)) { + for (int end : GetTestSizesInRange(begin, size)) { ASSERT_EQ(end - begin, v.CountBitsSet(begin, end)); } } } // Test various sizes of vector where various patterns of bits are set. - for (int size : kPrimeNumbers) { + for (int size : kTestSizes) { v.Resize(size, YetAnotherBitVector::CLEARED); for (uint8_t pattern : kBitPatterns) { FillWithPattern(pattern, 0, &v); - for (int begin : GetPrimeNumbersInRange(0, size)) { - for (int end : GetPrimeNumbersInRange(begin, size)) { + for (int begin : GetTestSizesInRange(0, size)) { + for (int end : GetTestSizesInRange(begin, size)) { // Note: The expected value being manually computed by examining each // bit individually by calling IsSet(). Thus, this value is only good // if IsSet() is working (which is tested by a different unit test). |